otbTrainVectorBase.hxx 13.66 KiB
/*
 * Copyright (C) 2005-2019 Centre National d'Etudes Spatiales (CNES)
 * This file is part of Orfeo Toolbox
 *     https://www.orfeo-toolbox.org/
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *     http://www.apache.org/licenses/LICENSE-2.0
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
#ifndef otbTrainVectorBase_hxx
#define otbTrainVectorBase_hxx
#include "otbTrainVectorBase.h"
namespace otb
namespace Wrapper
template <class TInputValue, class TOutputValue>
void
TrainVectorBase<TInputValue, TOutputValue>
::DoInit()
  // Common Parameters for all Learning Application
  this->AddParameter( ParameterType_Group, "io", "Input and output data" );
  this->SetParameterDescription( "io", 
    "This group of parameters allows setting input and output data." );
  this->AddParameter( ParameterType_InputVectorDataList, "io.vd", "Input Vector Data" );
  this->SetParameterDescription( "io.vd",
    "Input geometries used for training (note: all geometries from the layer will be used)" );
  this->AddParameter( ParameterType_InputFilename, "io.stats", "Input XML image statistics file" );
  this->MandatoryOff( "io.stats" );
  this->SetParameterDescription( "io.stats", 
    "XML file containing mean and variance of each feature." );
  this->AddParameter( ParameterType_OutputFilename, "io.out", "Output model" );
  this->SetParameterDescription( "io.out", 
    "Output file containing the model estimated (.txt format)." );
  this->AddParameter( ParameterType_Int, "layer", "Layer Index" );
  this->SetParameterDescription( "layer", 
    "Index of the layer to use in the input vector file." );
  this->MandatoryOff( "layer" );
  this->SetDefaultParameterInt( "layer", 0 );
  this->AddParameter(ParameterType_ListView,  "feat", "Field names for training features");
  this->SetParameterDescription("feat",
    "List of field names in the input vector data to be used as features for training.");
  // Add validation data used to compute confusion matrix or contingency table
  this->AddParameter( ParameterType_Group, "valid", "Validation data" );
  this->SetParameterDescription( "valid", 
    "This group of parameters defines validation data." );
  this->AddParameter( ParameterType_InputVectorDataList, "valid.vd", 
    "Validation Vector Data" );
  this->SetParameterDescription( "valid.vd", "Geometries used for validation "
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
"(must contain the same fields used for training, all geometries from the layer will be used)" ); this->MandatoryOff( "valid.vd" ); this->AddParameter( ParameterType_Int, "valid.layer", "Layer Index" ); this->SetParameterDescription( "valid.layer", "Index of the layer to use in the validation vector file." ); this->MandatoryOff( "valid.layer" ); this->SetDefaultParameterInt( "valid.layer", 0 ); // Add class field if we used validation this->AddParameter( ParameterType_ListView, "cfield", "Field containing the class integer label for supervision" ); this->SetParameterDescription( "cfield", "Field containing the class id for supervision. " "The values in this field shall be cast into integers. " "Only geometries with this field available will be taken into account." ); this->SetListViewSingleSelectionMode( "cfield", true ); this->AddParameter(ParameterType_Bool, "v", "Verbose mode"); this->SetParameterDescription("v", "Verbose mode, display the contingency table result."); this->SetParameterInt("v", 1); // Doc example parameter settings this->SetDocExampleParameterValue( "io.vd", "vectorData.shp" ); this->SetDocExampleParameterValue( "io.stats", "meanVar.xml" ); this->SetDocExampleParameterValue( "io.out", "svmModel.svm" ); this->SetDocExampleParameterValue( "feat", "perimeter area width" ); this->SetDocExampleParameterValue( "cfield", "predicted" ); // Add parameters for the classifier choice Superclass::DoInit(); this->AddRANDParameter(); } template <class TInputValue, class TOutputValue> void TrainVectorBase<TInputValue, TOutputValue> ::DoUpdateParameters() { // if vector data is present and updated then reload fields if( this->HasValue( "io.vd" ) ) { std::vector<std::string> vectorFileList = this->GetParameterStringList( "io.vd" ); ogr::DataSource::Pointer ogrDS = ogr::DataSource::New( vectorFileList[0], ogr::DataSource::Modes::Read ); ogr::Layer layer = ogrDS->GetLayer( static_cast<size_t>( this->GetParameterInt( "layer" ) ) ); ogr::Feature feature = layer.ogr().GetNextFeature(); this->ClearChoices( "feat" ); this->ClearChoices( "cfield" ); for( int iField = 0; iField < feature.ogr().GetFieldCount(); iField++ ) { std::string key, item = feature.ogr().GetFieldDefnRef( iField )->GetNameRef(); key = item; std::string::iterator end = std::remove_if( key.begin(), key.end(), IsNotAlphaNum ); std::transform( key.begin(), end, key.begin(), tolower ); OGRFieldType fieldType = feature.ogr().GetFieldDefnRef( iField )->GetType(); if( fieldType == OFTInteger || fieldType == OFTInteger64 || fieldType == OFTReal ) { std::string tmpKey = "feat." + key.substr( 0, static_cast<unsigned long>( end - key.begin() ) ); this->AddChoice( tmpKey, item ); } if( fieldType == OFTString || fieldType == OFTInteger || fieldType == OFTInteger64 || fieldType == OFTReal ) { std::string tmpKey = "cfield." + key.substr( 0, static_cast<unsigned long>( end - key.begin() ) ); this->AddChoice( tmpKey, item );
141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
} } } } template <class TInputValue, class TOutputValue> void TrainVectorBase<TInputValue, TOutputValue> ::DoExecute() { m_FeaturesInfo.SetFieldNames( this->GetChoiceNames( "feat" ), this->GetSelectedItems( "feat" )); // Check input parameters if( m_FeaturesInfo.m_SelectedIdx.empty() ) { otbAppLogFATAL( << "No features have been selected to train the classifier on!" ); } ShiftScaleParameters measurement = GetStatistics( m_FeaturesInfo.m_NbFeatures ); ExtractAllSamples( measurement ); this->Train( m_TrainingSamplesWithLabel.listSample, m_TrainingSamplesWithLabel.labeledListSample, this->GetParameterString( "io.out" ) ); m_PredictedList = this->Classify( m_ClassificationSamplesWithLabel.listSample, this->GetParameterString( "io.out" ) ); } template <class TInputValue, class TOutputValue> void TrainVectorBase<TInputValue, TOutputValue> ::ExtractAllSamples(const ShiftScaleParameters &measurement) { m_TrainingSamplesWithLabel = ExtractTrainingSamplesWithLabel(measurement); m_ClassificationSamplesWithLabel = ExtractClassificationSamplesWithLabel(measurement); } template <class TInputValue, class TOutputValue> typename TrainVectorBase<TInputValue, TOutputValue>::SamplesWithLabel TrainVectorBase<TInputValue, TOutputValue> ::ExtractTrainingSamplesWithLabel(const ShiftScaleParameters &measurement) { return ExtractSamplesWithLabel( "io.vd", "layer", measurement); } template <class TInputValue, class TOutputValue> typename TrainVectorBase<TInputValue, TOutputValue>::SamplesWithLabel TrainVectorBase<TInputValue, TOutputValue> ::ExtractClassificationSamplesWithLabel(const ShiftScaleParameters &measurement) { if(this->GetClassifierCategory() == Superclass::Supervised) { SamplesWithLabel tmpSamplesWithLabel; SamplesWithLabel validationSamplesWithLabel = ExtractSamplesWithLabel( "valid.vd", "valid.layer", measurement ); //Test the input validation set size if( validationSamplesWithLabel.labeledListSample->Size() != 0 ) { tmpSamplesWithLabel.listSample = validationSamplesWithLabel.listSample; tmpSamplesWithLabel.labeledListSample = validationSamplesWithLabel.labeledListSample; } else { otbAppLogWARNING( "The validation set is empty. The performance estimation is done using the input training set in this case." ); tmpSamplesWithLabel.listSample = m_TrainingSamplesWithLabel.listSample; tmpSamplesWithLabel.labeledListSample = m_TrainingSamplesWithLabel.labeledListSample; } return tmpSamplesWithLabel; } else
211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
{ return m_TrainingSamplesWithLabel; } } template <class TInputValue, class TOutputValue> typename TrainVectorBase<TInputValue, TOutputValue>::ShiftScaleParameters TrainVectorBase<TInputValue, TOutputValue> ::GetStatistics(unsigned int nbFeatures) { ShiftScaleParameters measurement = ShiftScaleParameters(); if( this->HasValue( "io.stats" ) && this->IsParameterEnabled( "io.stats" ) ) { typename StatisticsReader::Pointer statisticsReader = StatisticsReader::New(); std::string XMLfile = this->GetParameterString( "io.stats" ); statisticsReader->SetFileName( XMLfile ); measurement.meanMeasurementVector = statisticsReader->GetStatisticVectorByName( "mean" ); measurement.stddevMeasurementVector = statisticsReader->GetStatisticVectorByName( "stddev" ); } else { measurement.meanMeasurementVector.SetSize( nbFeatures ); measurement.meanMeasurementVector.Fill( 0. ); measurement.stddevMeasurementVector.SetSize( nbFeatures ); measurement.stddevMeasurementVector.Fill( 1. ); } return measurement; } // Template specialization for the integer case (i.e.classification), to avoid a cast from double to integer template <> inline int TrainVectorBase<float, int> ::GetFeatureField(const ogr::Feature & feature, int fieldIndex) { return(feature[fieldIndex].GetValue<int>()); } template <class TInputValue, class TOutputValue> inline TOutputValue TrainVectorBase<TInputValue, TOutputValue> ::GetFeatureField(const ogr::Feature & feature, int fieldIndex) { return(feature[fieldIndex].GetValue<double>()); } template <class TInputValue, class TOutputValue> typename TrainVectorBase<TInputValue, TOutputValue>::SamplesWithLabel TrainVectorBase<TInputValue, TOutputValue> ::ExtractSamplesWithLabel(std::string parameterName, std::string parameterLayer, const ShiftScaleParameters &measurement) { SamplesWithLabel samplesWithLabel; if( this->HasValue( parameterName ) && this->IsParameterEnabled( parameterName ) ) { typename ListSampleType::Pointer input = ListSampleType::New(); typename TargetListSampleType::Pointer target = TargetListSampleType::New(); input->SetMeasurementVectorSize( m_FeaturesInfo.m_NbFeatures ); std::vector<std::string> fileList = this->GetParameterStringList( parameterName ); for( unsigned int k = 0; k < fileList.size(); k++ ) { otbAppLogINFO( "Reading vector file " << k + 1 << "/" << fileList.size() ); ogr::DataSource::Pointer source = ogr::DataSource::New( fileList[k], ogr::DataSource::Modes::Read ); ogr::Layer layer = source->GetLayer( static_cast<size_t>(this->GetParameterInt( parameterLayer )) ); ogr::Feature feature = layer.ogr().GetNextFeature(); bool goesOn = feature.addr() != 0; if( !goesOn ) { otbAppLogWARNING( "The layer " << this->GetParameterInt( parameterLayer ) << " of " << fileList[k]
281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
<< " is empty, input is skipped." ); continue; } // Check all needed fields are present : // - check class field if we use supervised classification or if class field name is not empty int cFieldIndex = feature.ogr().GetFieldIndex( m_FeaturesInfo.m_SelectedCFieldName.c_str() ); if( cFieldIndex < 0 && !m_FeaturesInfo.m_SelectedCFieldName.empty()) { otbAppLogFATAL( "The field name for class label (" << m_FeaturesInfo.m_SelectedCFieldName << ") has not been found in the vector file " << fileList[k] ); } // - check feature fields std::vector<int> featureFieldIndex( m_FeaturesInfo.m_NbFeatures, -1 ); for( unsigned int i = 0; i < m_FeaturesInfo.m_NbFeatures; i++ ) { featureFieldIndex[i] = feature.ogr().GetFieldIndex( m_FeaturesInfo.m_SelectedNames[i].c_str() ); if( featureFieldIndex[i] < 0 ) otbAppLogFATAL( "The field name for feature " << m_FeaturesInfo.m_SelectedNames[i] << " has not been found in the vector file " << fileList[k] ); } while( goesOn ) { // Retrieve all the features for each field in the ogr layer. MeasurementType mv; mv.SetSize( m_FeaturesInfo.m_NbFeatures ); for( unsigned int idx = 0; idx < m_FeaturesInfo.m_NbFeatures; ++idx ) mv[idx] = feature[featureFieldIndex[idx]].GetValue<double>(); input->PushBack( mv ); if(cFieldIndex>=0 && ogr::Field(feature,cFieldIndex).HasBeenSet()) target->PushBack(GetFeatureField(feature,cFieldIndex)); else target->PushBack( 0. ); feature = layer.ogr().GetNextFeature(); goesOn = feature.addr() != 0; } } typename ShiftScaleFilterType::Pointer shiftScaleFilter = ShiftScaleFilterType::New(); shiftScaleFilter->SetInput( input ); shiftScaleFilter->SetShifts( measurement.meanMeasurementVector ); shiftScaleFilter->SetScales( measurement.stddevMeasurementVector ); shiftScaleFilter->Update(); samplesWithLabel.listSample = shiftScaleFilter->GetOutput(); samplesWithLabel.labeledListSample = target; samplesWithLabel.listSample->DisconnectPipeline(); } return samplesWithLabel; } } } #endif