diff --git a/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx b/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx index c71885ded4440670f0dc511e104629f14581c7f5..3a7accaf5f0e1890a522c1a0c97de5a2d33e714c 100644 --- a/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx +++ b/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx @@ -29,11 +29,11 @@ namespace otb namespace Wrapper { -class TrainVectorClassifier : public TrainVectorBase +class TrainVectorClassifier : public TrainVectorBase<float, int> { public: typedef TrainVectorClassifier Self; - typedef TrainVectorBase Superclass; + typedef TrainVectorBase<float, int> Superclass; typedef itk::SmartPointer<Self> Pointer; typedef itk::SmartPointer<const Self> ConstPointer; itkNewMacro( Self ) diff --git a/Modules/Applications/AppClassification/app/otbTrainVectorRegression.cxx b/Modules/Applications/AppClassification/app/otbTrainVectorRegression.cxx index c94cd63262905f2d5e8cb9da59ceb234a6292e5f..91ec818be1321ec783381b3d0dad31e461ca099c 100644 --- a/Modules/Applications/AppClassification/app/otbTrainVectorRegression.cxx +++ b/Modules/Applications/AppClassification/app/otbTrainVectorRegression.cxx @@ -25,11 +25,11 @@ namespace otb namespace Wrapper { -class TrainVectorRegression : public TrainVectorBase +class TrainVectorRegression : public TrainVectorBase<float, int> { public: typedef TrainVectorRegression Self; - typedef TrainVectorBase Superclass; + typedef TrainVectorBase<float, int> Superclass; typedef itk::SmartPointer<Self> Pointer; typedef itk::SmartPointer<const Self> ConstPointer; @@ -83,6 +83,13 @@ protected: } Superclass::DoExecute(); + + /* + std::cout << m_PredictedList << std::endl; + std::cout << m_ClassificationSamplesWithLabel.labeledListSample << std::endl; + */ + + } private: diff --git a/Modules/Applications/AppClassification/include/otbTrainVectorBase.h b/Modules/Applications/AppClassification/include/otbTrainVectorBase.h index 37869068157660e37b5063fd69303be97023df2f..477b798d58b550cba0d80347042c57a68cd01a2d 100644 --- a/Modules/Applications/AppClassification/include/otbTrainVectorBase.h +++ b/Modules/Applications/AppClassification/include/otbTrainVectorBase.h @@ -49,21 +49,22 @@ bool IsNotAlphaNum(char c) return !std::isalnum( c ); } -class TrainVectorBase : public LearningApplicationBase<float, int> +template <class TInputValue, class TOutputValue> +class TrainVectorBase : public LearningApplicationBase<TInputValue, TOutputValue> { public: /** Standard class typedefs. */ typedef TrainVectorBase Self; - typedef LearningApplicationBase<float, int> Superclass; + typedef LearningApplicationBase<TInputValue, TOutputValue> Superclass; typedef itk::SmartPointer <Self> Pointer; typedef itk::SmartPointer<const Self> ConstPointer; /** Standard macro */ itkTypeMacro(Self, Superclass); - typedef Superclass::SampleType SampleType; - typedef Superclass::ListSampleType ListSampleType; - typedef Superclass::TargetListSampleType TargetListSampleType; + typedef typename Superclass::SampleType SampleType; + typedef typename Superclass::ListSampleType ListSampleType; + typedef typename Superclass::TargetListSampleType TargetListSampleType; typedef double ValueType; typedef itk::VariableLengthVector <ValueType> MeasurementType; @@ -86,8 +87,8 @@ protected: class SamplesWithLabel { public: - ListSampleType::Pointer listSample; - TargetListSampleType::Pointer labeledListSample; + typename ListSampleType::Pointer listSample; + typename TargetListSampleType::Pointer labeledListSample; SamplesWithLabel() { listSample = ListSampleType::New(); @@ -178,7 +179,7 @@ protected: SamplesWithLabel m_TrainingSamplesWithLabel; SamplesWithLabel m_ClassificationSamplesWithLabel; - TargetListSampleType::Pointer m_PredictedList; + typename TargetListSampleType::Pointer m_PredictedList; FeaturesInfo m_FeaturesInfo; void DoInit() override; diff --git a/Modules/Applications/AppClassification/include/otbTrainVectorBase.hxx b/Modules/Applications/AppClassification/include/otbTrainVectorBase.hxx index fb5e582dae4dc6073b77a5a2f1f2659831e1b3e9..447608d0c58057cf73592c8023462ae78e4c6cc0 100644 --- a/Modules/Applications/AppClassification/include/otbTrainVectorBase.hxx +++ b/Modules/Applications/AppClassification/include/otbTrainVectorBase.hxx @@ -27,100 +27,106 @@ namespace otb namespace Wrapper { -void TrainVectorBase::DoInit() +template <class TInputValue, class TOutputValue> +void +TrainVectorBase<TInputValue, TOutputValue> +::DoInit() { // Common Parameters for all Learning Application - AddParameter( ParameterType_Group, "io", "Input and output data" ); - SetParameterDescription( "io", + this->AddParameter( ParameterType_Group, "io", "Input and output data" ); + this->SetParameterDescription( "io", "This group of parameters allows setting input and output data." ); - AddParameter( ParameterType_InputVectorDataList, "io.vd", "Input Vector Data" ); - SetParameterDescription( "io.vd", + this->AddParameter( ParameterType_InputVectorDataList, "io.vd", "Input Vector Data" ); + this->SetParameterDescription( "io.vd", "Input geometries used for training (note: all geometries from the layer will be used)" ); - AddParameter( ParameterType_InputFilename, "io.stats", "Input XML image statistics file" ); - MandatoryOff( "io.stats" ); - SetParameterDescription( "io.stats", + this->AddParameter( ParameterType_InputFilename, "io.stats", "Input XML image statistics file" ); + this->MandatoryOff( "io.stats" ); + this->SetParameterDescription( "io.stats", "XML file containing mean and variance of each feature." ); - AddParameter( ParameterType_OutputFilename, "io.out", "Output model" ); - SetParameterDescription( "io.out", + this->AddParameter( ParameterType_OutputFilename, "io.out", "Output model" ); + this->SetParameterDescription( "io.out", "Output file containing the model estimated (.txt format)." ); - AddParameter( ParameterType_Int, "layer", "Layer Index" ); - SetParameterDescription( "layer", + this->AddParameter( ParameterType_Int, "layer", "Layer Index" ); + this->SetParameterDescription( "layer", "Index of the layer to use in the input vector file." ); - MandatoryOff( "layer" ); - SetDefaultParameterInt( "layer", 0 ); + this->MandatoryOff( "layer" ); + this->SetDefaultParameterInt( "layer", 0 ); - AddParameter(ParameterType_ListView, "feat", "Field names for training features"); - SetParameterDescription("feat", + this->AddParameter(ParameterType_ListView, "feat", "Field names for training features"); + this->SetParameterDescription("feat", "List of field names in the input vector data to be used as features for training."); // Add validation data used to compute confusion matrix or contingency table - AddParameter( ParameterType_Group, "valid", "Validation data" ); - SetParameterDescription( "valid", + this->AddParameter( ParameterType_Group, "valid", "Validation data" ); + this->SetParameterDescription( "valid", "This group of parameters defines validation data." ); - AddParameter( ParameterType_InputVectorDataList, "valid.vd", + this->AddParameter( ParameterType_InputVectorDataList, "valid.vd", "Validation Vector Data" ); - SetParameterDescription( "valid.vd", "Geometries used for validation " + this->SetParameterDescription( "valid.vd", "Geometries used for validation " "(must contain the same fields used for training, all geometries from the layer will be used)" ); - MandatoryOff( "valid.vd" ); + this->MandatoryOff( "valid.vd" ); - AddParameter( ParameterType_Int, "valid.layer", "Layer Index" ); - SetParameterDescription( "valid.layer", + this->AddParameter( ParameterType_Int, "valid.layer", "Layer Index" ); + this->SetParameterDescription( "valid.layer", "Index of the layer to use in the validation vector file." ); - MandatoryOff( "valid.layer" ); - SetDefaultParameterInt( "valid.layer", 0 ); + this->MandatoryOff( "valid.layer" ); + this->SetDefaultParameterInt( "valid.layer", 0 ); // Add class field if we used validation - AddParameter( ParameterType_ListView, "cfield", + this->AddParameter( ParameterType_ListView, "cfield", "Field containing the class integer label for supervision" ); - SetParameterDescription( "cfield", + this->SetParameterDescription( "cfield", "Field containing the class id for supervision. " "The values in this field shall be cast into integers. " "Only geometries with this field available will be taken into account." ); - SetListViewSingleSelectionMode( "cfield", true ); + this->SetListViewSingleSelectionMode( "cfield", true ); // Add a new parameter to compute confusion matrix / contingency table - AddParameter( ParameterType_OutputFilename, "io.confmatout", + this->AddParameter( ParameterType_OutputFilename, "io.confmatout", "Output confusion matrix or contingency table" ); - SetParameterDescription( "io.confmatout", + this->SetParameterDescription( "io.confmatout", "Output file containing the confusion matrix or contingency table (.csv format)." "The contingency table is output when we unsupervised algorithms is used otherwise the confusion matrix is output." ); - MandatoryOff( "io.confmatout" ); + this->MandatoryOff( "io.confmatout" ); - AddParameter(ParameterType_Bool, "v", "Verbose mode"); - SetParameterDescription("v", "Verbose mode, display the contingency table result."); - SetParameterInt("v", 1); + this->AddParameter(ParameterType_Bool, "v", "Verbose mode"); + this->SetParameterDescription("v", "Verbose mode, display the contingency table result."); + this->SetParameterInt("v", 1); // Doc example parameter settings - SetDocExampleParameterValue( "io.vd", "vectorData.shp" ); - SetDocExampleParameterValue( "io.stats", "meanVar.xml" ); - SetDocExampleParameterValue( "io.out", "svmModel.svm" ); - SetDocExampleParameterValue( "feat", "perimeter area width" ); - SetDocExampleParameterValue( "cfield", "predicted" ); + this->SetDocExampleParameterValue( "io.vd", "vectorData.shp" ); + this->SetDocExampleParameterValue( "io.stats", "meanVar.xml" ); + this->SetDocExampleParameterValue( "io.out", "svmModel.svm" ); + this->SetDocExampleParameterValue( "feat", "perimeter area width" ); + this->SetDocExampleParameterValue( "cfield", "predicted" ); // Add parameters for the classifier choice Superclass::DoInit(); - AddRANDParameter(); + this->AddRANDParameter(); } -void TrainVectorBase::DoUpdateParameters() +template <class TInputValue, class TOutputValue> +void +TrainVectorBase<TInputValue, TOutputValue> +::DoUpdateParameters() { // if vector data is present and updated then reload fields - if( HasValue( "io.vd" ) ) + if( this->HasValue( "io.vd" ) ) { - std::vector<std::string> vectorFileList = GetParameterStringList( "io.vd" ); + std::vector<std::string> vectorFileList = this->GetParameterStringList( "io.vd" ); ogr::DataSource::Pointer ogrDS = ogr::DataSource::New( vectorFileList[0], ogr::DataSource::Modes::Read ); ogr::Layer layer = ogrDS->GetLayer( static_cast<size_t>( this->GetParameterInt( "layer" ) ) ); ogr::Feature feature = layer.ogr().GetNextFeature(); - ClearChoices( "feat" ); - ClearChoices( "cfield" ); + this->ClearChoices( "feat" ); + this->ClearChoices( "cfield" ); for( int iField = 0; iField < feature.ogr().GetFieldCount(); iField++ ) { @@ -134,20 +140,23 @@ void TrainVectorBase::DoUpdateParameters() if( fieldType == OFTInteger || fieldType == OFTInteger64 || fieldType == OFTReal ) { std::string tmpKey = "feat." + key.substr( 0, static_cast<unsigned long>( end - key.begin() ) ); - AddChoice( tmpKey, item ); + this->AddChoice( tmpKey, item ); } if( fieldType == OFTString || fieldType == OFTInteger || fieldType == OFTInteger64 ) { std::string tmpKey = "cfield." + key.substr( 0, static_cast<unsigned long>( end - key.begin() ) ); - AddChoice( tmpKey, item ); + this->AddChoice( tmpKey, item ); } } } } -void TrainVectorBase::DoExecute() +template <class TInputValue, class TOutputValue> +void +TrainVectorBase<TInputValue, TOutputValue> +::DoExecute() { - m_FeaturesInfo.SetFieldNames( GetChoiceNames( "feat" ), GetSelectedItems( "feat" )); + m_FeaturesInfo.SetFieldNames( this->GetChoiceNames( "feat" ), this->GetSelectedItems( "feat" )); // Check input parameters if( m_FeaturesInfo.m_SelectedIdx.empty() ) @@ -158,29 +167,35 @@ void TrainVectorBase::DoExecute() ShiftScaleParameters measurement = GetStatistics( m_FeaturesInfo.m_NbFeatures ); ExtractAllSamples( measurement ); - this->Train( m_TrainingSamplesWithLabel.listSample, m_TrainingSamplesWithLabel.labeledListSample, GetParameterString( "io.out" ) ); + this->Train( m_TrainingSamplesWithLabel.listSample, m_TrainingSamplesWithLabel.labeledListSample, this->GetParameterString( "io.out" ) ); m_PredictedList = - this->Classify( m_ClassificationSamplesWithLabel.listSample, GetParameterString( "io.out" ) ); + this->Classify( m_ClassificationSamplesWithLabel.listSample, this->GetParameterString( "io.out" ) ); } - -void TrainVectorBase::ExtractAllSamples(const ShiftScaleParameters &measurement) +template <class TInputValue, class TOutputValue> +void +TrainVectorBase<TInputValue, TOutputValue> +::ExtractAllSamples(const ShiftScaleParameters &measurement) { m_TrainingSamplesWithLabel = ExtractTrainingSamplesWithLabel(measurement); m_ClassificationSamplesWithLabel = ExtractClassificationSamplesWithLabel(measurement); } -TrainVectorBase::SamplesWithLabel -TrainVectorBase::ExtractTrainingSamplesWithLabel(const ShiftScaleParameters &measurement) +template <class TInputValue, class TOutputValue> +typename TrainVectorBase<TInputValue, TOutputValue>::SamplesWithLabel +TrainVectorBase<TInputValue, TOutputValue> +::ExtractTrainingSamplesWithLabel(const ShiftScaleParameters &measurement) { return ExtractSamplesWithLabel( "io.vd", "layer", measurement); } -TrainVectorBase::SamplesWithLabel -TrainVectorBase::ExtractClassificationSamplesWithLabel(const ShiftScaleParameters &measurement) +template <class TInputValue, class TOutputValue> +typename TrainVectorBase<TInputValue, TOutputValue>::SamplesWithLabel +TrainVectorBase<TInputValue, TOutputValue> +::ExtractClassificationSamplesWithLabel(const ShiftScaleParameters &measurement) { - if(GetClassifierCategory() == Supervised) + if(this->GetClassifierCategory() == Superclass::Supervised) { SamplesWithLabel tmpSamplesWithLabel; SamplesWithLabel validationSamplesWithLabel = ExtractSamplesWithLabel( "valid.vd", "valid.layer", measurement ); @@ -206,15 +221,16 @@ TrainVectorBase::ExtractClassificationSamplesWithLabel(const ShiftScaleParameter } } - -TrainVectorBase::ShiftScaleParameters -TrainVectorBase::GetStatistics(unsigned int nbFeatures) +template <class TInputValue, class TOutputValue> +typename TrainVectorBase<TInputValue, TOutputValue>::ShiftScaleParameters +TrainVectorBase<TInputValue, TOutputValue> +::GetStatistics(unsigned int nbFeatures) { ShiftScaleParameters measurement = ShiftScaleParameters(); - if( HasValue( "io.stats" ) && IsParameterEnabled( "io.stats" ) ) + if( this->HasValue( "io.stats" ) && this->IsParameterEnabled( "io.stats" ) ) { - StatisticsReader::Pointer statisticsReader = StatisticsReader::New(); - std::string XMLfile = GetParameterString( "io.stats" ); + typename StatisticsReader::Pointer statisticsReader = StatisticsReader::New(); + std::string XMLfile = this->GetParameterString( "io.stats" ); statisticsReader->SetFileName( XMLfile ); measurement.meanMeasurementVector = statisticsReader->GetStatisticVectorByName( "mean" ); measurement.stddevMeasurementVector = statisticsReader->GetStatisticVectorByName( "stddev" ); @@ -229,16 +245,17 @@ TrainVectorBase::GetStatistics(unsigned int nbFeatures) return measurement; } - -TrainVectorBase::SamplesWithLabel -TrainVectorBase::ExtractSamplesWithLabel(std::string parameterName, std::string parameterLayer, +template <class TInputValue, class TOutputValue> +typename TrainVectorBase<TInputValue, TOutputValue>::SamplesWithLabel +TrainVectorBase<TInputValue, TOutputValue> +::ExtractSamplesWithLabel(std::string parameterName, std::string parameterLayer, const ShiftScaleParameters &measurement) { SamplesWithLabel samplesWithLabel; - if( HasValue( parameterName ) && IsParameterEnabled( parameterName ) ) + if( this->HasValue( parameterName ) && this->IsParameterEnabled( parameterName ) ) { - ListSampleType::Pointer input = ListSampleType::New(); - TargetListSampleType::Pointer target = TargetListSampleType::New(); + typename ListSampleType::Pointer input = ListSampleType::New(); + typename TargetListSampleType::Pointer target = TargetListSampleType::New(); input->SetMeasurementVectorSize( m_FeaturesInfo.m_NbFeatures ); std::vector<std::string> fileList = this->GetParameterStringList( parameterName ); @@ -251,7 +268,7 @@ TrainVectorBase::ExtractSamplesWithLabel(std::string parameterName, std::string bool goesOn = feature.addr() != 0; if( !goesOn ) { - otbAppLogWARNING( "The layer " << GetParameterInt( parameterLayer ) << " of " << fileList[k] + otbAppLogWARNING( "The layer " << this->GetParameterInt( parameterLayer ) << " of " << fileList[k] << " is empty, input is skipped." ); continue; } @@ -300,7 +317,7 @@ TrainVectorBase::ExtractSamplesWithLabel(std::string parameterName, std::string - ShiftScaleFilterType::Pointer shiftScaleFilter = ShiftScaleFilterType::New(); + typename ShiftScaleFilterType::Pointer shiftScaleFilter = ShiftScaleFilterType::New(); shiftScaleFilter->SetInput( input ); shiftScaleFilter->SetShifts( measurement.meanMeasurementVector ); shiftScaleFilter->SetScales( measurement.stddevMeasurementVector );