-
Cédric Traizet authored40e15141
/*
* Copyright (C) 2005-2019 Centre National d'Etudes Spatiales (CNES)
*
* This file is part of Orfeo Toolbox
*
* https://www.orfeo-toolbox.org/
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef otbTrainVectorBase_hxx
#define otbTrainVectorBase_hxx
#include "otbTrainVectorBase.h"
namespace otb
{
namespace Wrapper
{
template <class TInputValue, class TOutputValue>
void
TrainVectorBase<TInputValue, TOutputValue>
::DoInit()
{
// Common Parameters for all Learning Application
this->AddParameter( ParameterType_Group, "io", "Input and output data" );
this->SetParameterDescription( "io",
"This group of parameters allows setting input and output data." );
this->AddParameter( ParameterType_InputVectorDataList, "io.vd", "Input Vector Data" );
this->SetParameterDescription( "io.vd",
"Input geometries used for training (note: all geometries from the layer will be used)" );
this->AddParameter( ParameterType_InputFilename, "io.stats", "Input XML image statistics file" );
this->MandatoryOff( "io.stats" );
this->SetParameterDescription( "io.stats",
"XML file containing mean and variance of each feature." );
this->AddParameter( ParameterType_OutputFilename, "io.out", "Output model" );
this->SetParameterDescription( "io.out",
"Output file containing the model estimated (.txt format)." );
this->AddParameter( ParameterType_Int, "layer", "Layer Index" );
this->SetParameterDescription( "layer",
"Index of the layer to use in the input vector file." );
this->MandatoryOff( "layer" );
this->SetDefaultParameterInt( "layer", 0 );
this->AddParameter(ParameterType_ListView, "feat", "Field names for training features");
this->SetParameterDescription("feat",
"List of field names in the input vector data to be used as features for training.");
// Add validation data used to compute confusion matrix or contingency table
this->AddParameter( ParameterType_Group, "valid", "Validation data" );
this->SetParameterDescription( "valid",
"This group of parameters defines validation data." );
this->AddParameter( ParameterType_InputVectorDataList, "valid.vd",
"Validation Vector Data" );
this->SetParameterDescription( "valid.vd", "Geometries used for validation "
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
"(must contain the same fields used for training, all geometries from the layer will be used)" );
this->MandatoryOff( "valid.vd" );
this->AddParameter( ParameterType_Int, "valid.layer", "Layer Index" );
this->SetParameterDescription( "valid.layer",
"Index of the layer to use in the validation vector file." );
this->MandatoryOff( "valid.layer" );
this->SetDefaultParameterInt( "valid.layer", 0 );
// Add class field if we used validation
this->AddParameter( ParameterType_ListView, "cfield",
"Field containing the class integer label for supervision" );
this->SetParameterDescription( "cfield",
"Field containing the class id for supervision. "
"The values in this field shall be cast into integers. "
"Only geometries with this field available will be taken into account." );
this->SetListViewSingleSelectionMode( "cfield", true );
this->AddParameter(ParameterType_Bool, "v", "Verbose mode");
this->SetParameterDescription("v", "Verbose mode, display the contingency table result.");
this->SetParameterInt("v", 1);
// Doc example parameter settings
this->SetDocExampleParameterValue( "io.vd", "vectorData.shp" );
this->SetDocExampleParameterValue( "io.stats", "meanVar.xml" );
this->SetDocExampleParameterValue( "io.out", "svmModel.svm" );
this->SetDocExampleParameterValue( "feat", "perimeter area width" );
this->SetDocExampleParameterValue( "cfield", "predicted" );
// Add parameters for the classifier choice
Superclass::DoInit();
this->AddRANDParameter();
}
template <class TInputValue, class TOutputValue>
void
TrainVectorBase<TInputValue, TOutputValue>
::DoUpdateParameters()
{
// if vector data is present and updated then reload fields
if( this->HasValue( "io.vd" ) )
{
std::vector<std::string> vectorFileList = this->GetParameterStringList( "io.vd" );
ogr::DataSource::Pointer ogrDS = ogr::DataSource::New( vectorFileList[0], ogr::DataSource::Modes::Read );
ogr::Layer layer = ogrDS->GetLayer( static_cast<size_t>( this->GetParameterInt( "layer" ) ) );
ogr::Feature feature = layer.ogr().GetNextFeature();
this->ClearChoices( "feat" );
this->ClearChoices( "cfield" );
for( int iField = 0; iField < feature.ogr().GetFieldCount(); iField++ )
{
std::string key, item = feature.ogr().GetFieldDefnRef( iField )->GetNameRef();
key = item;
std::string::iterator end = std::remove_if( key.begin(), key.end(), IsNotAlphaNum );
std::transform( key.begin(), end, key.begin(), tolower );
OGRFieldType fieldType = feature.ogr().GetFieldDefnRef( iField )->GetType();
if( fieldType == OFTInteger || fieldType == OFTInteger64 || fieldType == OFTReal )
{
std::string tmpKey = "feat." + key.substr( 0, static_cast<unsigned long>( end - key.begin() ) );
this->AddChoice( tmpKey, item );
}
if( fieldType == OFTString || fieldType == OFTInteger || fieldType == OFTInteger64 || fieldType == OFTReal )
{
std::string tmpKey = "cfield." + key.substr( 0, static_cast<unsigned long>( end - key.begin() ) );
this->AddChoice( tmpKey, item );
141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
}
}
}
}
template <class TInputValue, class TOutputValue>
void
TrainVectorBase<TInputValue, TOutputValue>
::DoExecute()
{
m_FeaturesInfo.SetFieldNames( this->GetChoiceNames( "feat" ), this->GetSelectedItems( "feat" ));
// Check input parameters
if( m_FeaturesInfo.m_SelectedIdx.empty() )
{
otbAppLogFATAL( << "No features have been selected to train the classifier on!" );
}
ShiftScaleParameters measurement = GetStatistics( m_FeaturesInfo.m_NbFeatures );
ExtractAllSamples( measurement );
this->Train( m_TrainingSamplesWithLabel.listSample, m_TrainingSamplesWithLabel.labeledListSample, this->GetParameterString( "io.out" ) );
m_PredictedList =
this->Classify( m_ClassificationSamplesWithLabel.listSample, this->GetParameterString( "io.out" ) );
}
template <class TInputValue, class TOutputValue>
void
TrainVectorBase<TInputValue, TOutputValue>
::ExtractAllSamples(const ShiftScaleParameters &measurement)
{
m_TrainingSamplesWithLabel = ExtractTrainingSamplesWithLabel(measurement);
m_ClassificationSamplesWithLabel = ExtractClassificationSamplesWithLabel(measurement);
}
template <class TInputValue, class TOutputValue>
typename TrainVectorBase<TInputValue, TOutputValue>::SamplesWithLabel
TrainVectorBase<TInputValue, TOutputValue>
::ExtractTrainingSamplesWithLabel(const ShiftScaleParameters &measurement)
{
return ExtractSamplesWithLabel( "io.vd", "layer", measurement);
}
template <class TInputValue, class TOutputValue>
typename TrainVectorBase<TInputValue, TOutputValue>::SamplesWithLabel
TrainVectorBase<TInputValue, TOutputValue>
::ExtractClassificationSamplesWithLabel(const ShiftScaleParameters &measurement)
{
if(this->GetClassifierCategory() == Superclass::Supervised)
{
SamplesWithLabel tmpSamplesWithLabel;
SamplesWithLabel validationSamplesWithLabel = ExtractSamplesWithLabel( "valid.vd", "valid.layer", measurement );
//Test the input validation set size
if( validationSamplesWithLabel.labeledListSample->Size() != 0 )
{
tmpSamplesWithLabel.listSample = validationSamplesWithLabel.listSample;
tmpSamplesWithLabel.labeledListSample = validationSamplesWithLabel.labeledListSample;
}
else
{
otbAppLogWARNING(
"The validation set is empty. The performance estimation is done using the input training set in this case." );
tmpSamplesWithLabel.listSample = m_TrainingSamplesWithLabel.listSample;
tmpSamplesWithLabel.labeledListSample = m_TrainingSamplesWithLabel.labeledListSample;
}
return tmpSamplesWithLabel;
}
else
211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
{
return m_TrainingSamplesWithLabel;
}
}
template <class TInputValue, class TOutputValue>
typename TrainVectorBase<TInputValue, TOutputValue>::ShiftScaleParameters
TrainVectorBase<TInputValue, TOutputValue>
::GetStatistics(unsigned int nbFeatures)
{
ShiftScaleParameters measurement = ShiftScaleParameters();
if( this->HasValue( "io.stats" ) && this->IsParameterEnabled( "io.stats" ) )
{
typename StatisticsReader::Pointer statisticsReader = StatisticsReader::New();
std::string XMLfile = this->GetParameterString( "io.stats" );
statisticsReader->SetFileName( XMLfile );
measurement.meanMeasurementVector = statisticsReader->GetStatisticVectorByName( "mean" );
measurement.stddevMeasurementVector = statisticsReader->GetStatisticVectorByName( "stddev" );
}
else
{
measurement.meanMeasurementVector.SetSize( nbFeatures );
measurement.meanMeasurementVector.Fill( 0. );
measurement.stddevMeasurementVector.SetSize( nbFeatures );
measurement.stddevMeasurementVector.Fill( 1. );
}
return measurement;
}
// Template specialization for the integer case (i.e.classification), to avoid a cast from double to integer
template <>
inline int
TrainVectorBase<float, int>
::GetFeatureField(const ogr::Feature & feature, int fieldIndex)
{
return(feature[fieldIndex].GetValue<int>());
}
template <class TInputValue, class TOutputValue>
inline TOutputValue
TrainVectorBase<TInputValue, TOutputValue>
::GetFeatureField(const ogr::Feature & feature, int fieldIndex)
{
return(feature[fieldIndex].GetValue<double>());
}
template <class TInputValue, class TOutputValue>
typename TrainVectorBase<TInputValue, TOutputValue>::SamplesWithLabel
TrainVectorBase<TInputValue, TOutputValue>
::ExtractSamplesWithLabel(std::string parameterName, std::string parameterLayer,
const ShiftScaleParameters &measurement)
{
SamplesWithLabel samplesWithLabel;
if( this->HasValue( parameterName ) && this->IsParameterEnabled( parameterName ) )
{
typename ListSampleType::Pointer input = ListSampleType::New();
typename TargetListSampleType::Pointer target = TargetListSampleType::New();
input->SetMeasurementVectorSize( m_FeaturesInfo.m_NbFeatures );
std::vector<std::string> fileList = this->GetParameterStringList( parameterName );
for( unsigned int k = 0; k < fileList.size(); k++ )
{
otbAppLogINFO( "Reading vector file " << k + 1 << "/" << fileList.size() );
ogr::DataSource::Pointer source = ogr::DataSource::New( fileList[k], ogr::DataSource::Modes::Read );
ogr::Layer layer = source->GetLayer( static_cast<size_t>(this->GetParameterInt( parameterLayer )) );
ogr::Feature feature = layer.ogr().GetNextFeature();
bool goesOn = feature.addr() != 0;
if( !goesOn )
{
otbAppLogWARNING( "The layer " << this->GetParameterInt( parameterLayer ) << " of " << fileList[k]
281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
<< " is empty, input is skipped." );
continue;
}
// Check all needed fields are present :
// - check class field if we use supervised classification or if class field name is not empty
int cFieldIndex = feature.ogr().GetFieldIndex( m_FeaturesInfo.m_SelectedCFieldName.c_str() );
if( cFieldIndex < 0 && !m_FeaturesInfo.m_SelectedCFieldName.empty())
{
otbAppLogFATAL( "The field name for class label (" << m_FeaturesInfo.m_SelectedCFieldName
<< ") has not been found in the vector file "
<< fileList[k] );
}
// - check feature fields
std::vector<int> featureFieldIndex( m_FeaturesInfo.m_NbFeatures, -1 );
for( unsigned int i = 0; i < m_FeaturesInfo.m_NbFeatures; i++ )
{
featureFieldIndex[i] = feature.ogr().GetFieldIndex( m_FeaturesInfo.m_SelectedNames[i].c_str() );
if( featureFieldIndex[i] < 0 )
otbAppLogFATAL( "The field name for feature " << m_FeaturesInfo.m_SelectedNames[i]
<< " has not been found in the vector file "
<< fileList[k] );
}
while( goesOn )
{
// Retrieve all the features for each field in the ogr layer.
MeasurementType mv;
mv.SetSize( m_FeaturesInfo.m_NbFeatures );
for( unsigned int idx = 0; idx < m_FeaturesInfo.m_NbFeatures; ++idx )
mv[idx] = feature[featureFieldIndex[idx]].GetValue<double>();
input->PushBack( mv );
if(cFieldIndex>=0 && ogr::Field(feature,cFieldIndex).HasBeenSet())
target->PushBack(GetFeatureField(feature,cFieldIndex));
else
target->PushBack( 0. );
feature = layer.ogr().GetNextFeature();
goesOn = feature.addr() != 0;
}
}
typename ShiftScaleFilterType::Pointer shiftScaleFilter = ShiftScaleFilterType::New();
shiftScaleFilter->SetInput( input );
shiftScaleFilter->SetShifts( measurement.meanMeasurementVector );
shiftScaleFilter->SetScales( measurement.stddevMeasurementVector );
shiftScaleFilter->Update();
samplesWithLabel.listSample = shiftScaleFilter->GetOutput();
samplesWithLabel.labeledListSample = target;
samplesWithLabel.listSample->DisconnectPipeline();
}
return samplesWithLabel;
}
}
}
#endif