diff --git a/Data/Baseline/OTB-Applications/Files/apTvClTrainVectorRegressionModel.1.txt b/Data/Baseline/OTB-Applications/Files/apTvClTrainVectorRegressionModel.1.txt new file mode 100644 index 0000000000000000000000000000000000000000..20e589ba2a9407e406e8511baf12a44bc63c2a43 --- /dev/null +++ b/Data/Baseline/OTB-Applications/Files/apTvClTrainVectorRegressionModel.1.txt @@ -0,0 +1 @@ +io.mse: 0.003289417131 diff --git a/Data/Baseline/OTB-Applications/Files/apTvClTrainVectorRegressionModel.txt b/Data/Baseline/OTB-Applications/Files/apTvClTrainVectorRegressionModel.txt new file mode 100644 index 0000000000000000000000000000000000000000..aec10cefd80236745402a3c392cbc9590a666de6 --- /dev/null +++ b/Data/Baseline/OTB-Applications/Files/apTvClTrainVectorRegressionModel.txt @@ -0,0 +1 @@ +io.mse: 0.001359587419 diff --git a/Modules/Applications/AppClassification/app/CMakeLists.txt b/Modules/Applications/AppClassification/app/CMakeLists.txt index 50ac00ea1619c068da99774d0010ee532a5b7fd1..6fc42a4ee2c73a284a379f9763b170fac9b50093 100644 --- a/Modules/Applications/AppClassification/app/CMakeLists.txt +++ b/Modules/Applications/AppClassification/app/CMakeLists.txt @@ -65,6 +65,11 @@ otb_create_application( SOURCES otbTrainRegression.cxx LINK_LIBRARIES ${${otb-module}_LIBRARIES}) +otb_create_application( + NAME TrainVectorRegression + SOURCES otbTrainVectorRegression.cxx + LINK_LIBRARIES ${${otb-module}_LIBRARIES}) + otb_create_application( NAME PredictRegression SOURCES otbPredictRegression.cxx diff --git a/Modules/Applications/AppClassification/app/otbTrainRegression.cxx b/Modules/Applications/AppClassification/app/otbTrainRegression.cxx index c4ac059fe4bd9fcd3ab758f94c4122999b42caf2..a93c8b5439d65e670fa92e52ee889933d45f3d61 100644 --- a/Modules/Applications/AppClassification/app/otbTrainRegression.cxx +++ b/Modules/Applications/AppClassification/app/otbTrainRegression.cxx @@ -271,8 +271,7 @@ void ParseCSVPredictors(std::string path, ListSampleType* outputList) elem.Fill(0.0); for (unsigned int i=0 ; i<nbCols ; ++i) { - iss.str(words[i]); - iss >> elem[i]; + elem[i] = std::stod(words[i]); } outputList->PushBack(elem); } diff --git a/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx b/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx index 5d10533448e7a08e6141963f98ccc73d15eea063..fbd04d4a4b5c0fa213d5266d3588b1f8de7202cb 100644 --- a/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx +++ b/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx @@ -29,11 +29,11 @@ namespace otb namespace Wrapper { -class TrainVectorClassifier : public TrainVectorBase +class TrainVectorClassifier : public TrainVectorBase<float, int> { public: typedef TrainVectorClassifier Self; - typedef TrainVectorBase Superclass; + typedef TrainVectorBase<float, int> Superclass; typedef itk::SmartPointer<Self> Pointer; typedef itk::SmartPointer<const Self> ConstPointer; itkNewMacro( Self ) @@ -66,13 +66,20 @@ protected: "Learning (2.3.1 and later), and Shark ML The output of this application " "is a text model file, whose format corresponds to the ML model type " "chosen. There is no image nor vector data output."); - SetDocLimitations(""); + SetDocLimitations("None"); SetDocAuthors( "OTB Team" ); SetDocSeeAlso( " " ); SetOfficialDocLink(); Superclass::DoInit(); + + // Add a new parameter to compute confusion matrix / contingency table + this->AddParameter(ParameterType_OutputFilename, "io.confmatout", "Output confusion matrix or contingency table"); + this->SetParameterDescription("io.confmatout", + "Output file containing the confusion matrix or contingency table (.csv format)." + "The contingency table is output when we unsupervised algorithms is used otherwise the confusion matrix is output."); + this->MandatoryOff("io.confmatout"); } void DoUpdateParameters() override diff --git a/Modules/Applications/AppClassification/app/otbTrainVectorRegression.cxx b/Modules/Applications/AppClassification/app/otbTrainVectorRegression.cxx new file mode 100644 index 0000000000000000000000000000000000000000..c98e69ee0c94b0b1882a4bd0ffab112e6020e53d --- /dev/null +++ b/Modules/Applications/AppClassification/app/otbTrainVectorRegression.cxx @@ -0,0 +1,122 @@ +/* + * Copyright (C) 2005-2019 Centre National d'Etudes Spatiales (CNES) + * + * This file is part of Orfeo Toolbox + * + * https://www.orfeo-toolbox.org/ + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "otbTrainVectorBase.h" + +namespace otb +{ +namespace Wrapper +{ + +class TrainVectorRegression : public TrainVectorBase<float, float> +{ +public: + typedef TrainVectorRegression Self; + typedef TrainVectorBase<float, float> Superclass; + typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; + + itkNewMacro(Self) itkTypeMacro(Self, Superclass) + + typedef Superclass::SampleType SampleType; + typedef Superclass::ListSampleType ListSampleType; + typedef Superclass::TargetListSampleType TargetListSampleType; + +protected: + TrainVectorRegression() + { + this->m_RegressionFlag = true; + } + + void DoInit() override + { + SetName("TrainVectorRegression"); + SetDescription( + "Train a regression algorithm based on geometries with " + "list of features to consider and a predictor."); + + SetDocLongDescription( + "This application trains a regression algorithm based on " + "a predictor geometries and a list of features to consider for " + "regression.\nThis application is based on LibSVM, OpenCV Machine " + "Learning (2.3.1 and later), and Shark ML The output of this application " + "is a text model file, whose format corresponds to the ML model type " + "chosen. There is no image or vector data output."); + + SetDocLimitations("None"); + SetDocAuthors("OTB Team"); + SetDocSeeAlso("TrainVectorClassifier"); + + SetOfficialDocLink(); + + Superclass::DoInit(); + + AddParameter(ParameterType_Float, "io.mse", "Mean Square Error"); + SetParameterDescription("io.mse", "Mean square error computed with the validation predictors"); + SetParameterRole("io.mse", Role_Output); + this->MandatoryOff("io.mse"); + } + + void DoUpdateParameters() override + { + Superclass::DoUpdateParameters(); + } + + double ComputeMSE(const TargetListSampleType& list1, const TargetListSampleType& list2) + { + assert(list1.Size() == list2.Size()); + double mse = 0.; + for (TargetListSampleType::InstanceIdentifier i = 0; i < list1.Size(); ++i) + { + auto elem1 = list1.GetMeasurementVector(i); + auto elem2 = list2.GetMeasurementVector(i); + + mse += (elem1[0] - elem2[0]) * (elem1[0] - elem2[0]); + } + mse /= static_cast<double>(list1.Size()); + return mse; + } + + + void DoExecute() override + { + m_FeaturesInfo.SetClassFieldNames(GetChoiceNames("cfield"), GetSelectedItems("cfield")); + + if (m_FeaturesInfo.m_SelectedCFieldIdx.empty() && GetClassifierCategory() == Supervised) + { + otbAppLogFATAL(<< "No field has been selected for data labelling!"); + } + + Superclass::DoExecute(); + + otbAppLogINFO("Computing training performances"); + + auto mse = ComputeMSE(*m_ClassificationSamplesWithLabel.labeledListSample, *m_PredictedList); + + otbAppLogINFO("Mean Square Error = " << mse); + this->SetParameterFloat("io.mse", mse); + } + +private: +}; +} +} + +OTB_APPLICATION_EXPORT(otb::Wrapper::TrainVectorRegression) diff --git a/Modules/Applications/AppClassification/include/otbTrainVectorBase.h b/Modules/Applications/AppClassification/include/otbTrainVectorBase.h index 37869068157660e37b5063fd69303be97023df2f..bc5c716aef98324bcd14882b451fc2636dbfe9cc 100644 --- a/Modules/Applications/AppClassification/include/otbTrainVectorBase.h +++ b/Modules/Applications/AppClassification/include/otbTrainVectorBase.h @@ -49,21 +49,22 @@ bool IsNotAlphaNum(char c) return !std::isalnum( c ); } -class TrainVectorBase : public LearningApplicationBase<float, int> +template <class TInputValue, class TOutputValue> +class TrainVectorBase : public LearningApplicationBase<TInputValue, TOutputValue> { public: /** Standard class typedefs. */ typedef TrainVectorBase Self; - typedef LearningApplicationBase<float, int> Superclass; + typedef LearningApplicationBase<TInputValue, TOutputValue> Superclass; typedef itk::SmartPointer <Self> Pointer; typedef itk::SmartPointer<const Self> ConstPointer; /** Standard macro */ itkTypeMacro(Self, Superclass); - typedef Superclass::SampleType SampleType; - typedef Superclass::ListSampleType ListSampleType; - typedef Superclass::TargetListSampleType TargetListSampleType; + typedef typename Superclass::SampleType SampleType; + typedef typename Superclass::ListSampleType ListSampleType; + typedef typename Superclass::TargetListSampleType TargetListSampleType; typedef double ValueType; typedef itk::VariableLengthVector <ValueType> MeasurementType; @@ -86,8 +87,8 @@ protected: class SamplesWithLabel { public: - ListSampleType::Pointer listSample; - TargetListSampleType::Pointer labeledListSample; + typename ListSampleType::Pointer listSample; + typename TargetListSampleType::Pointer labeledListSample; SamplesWithLabel() { listSample = ListSampleType::New(); @@ -178,13 +179,18 @@ protected: SamplesWithLabel m_TrainingSamplesWithLabel; SamplesWithLabel m_ClassificationSamplesWithLabel; - TargetListSampleType::Pointer m_PredictedList; + typename TargetListSampleType::Pointer m_PredictedList; FeaturesInfo m_FeaturesInfo; void DoInit() override; void DoUpdateParameters() override; void DoExecute() override; +private: + /** + * Get the field of the input feature corresponding to the input field + */ + inline TOutputValue GetFeatureField(const ogr::Feature& feature, int field); }; } diff --git a/Modules/Applications/AppClassification/include/otbTrainVectorBase.hxx b/Modules/Applications/AppClassification/include/otbTrainVectorBase.hxx index fb5e582dae4dc6073b77a5a2f1f2659831e1b3e9..c1e4f88f1cf9f428a038da63d99161e26daa9f66 100644 --- a/Modules/Applications/AppClassification/include/otbTrainVectorBase.hxx +++ b/Modules/Applications/AppClassification/include/otbTrainVectorBase.hxx @@ -27,100 +27,98 @@ namespace otb namespace Wrapper { -void TrainVectorBase::DoInit() +template <class TInputValue, class TOutputValue> +void +TrainVectorBase<TInputValue, TOutputValue> +::DoInit() { // Common Parameters for all Learning Application - AddParameter( ParameterType_Group, "io", "Input and output data" ); - SetParameterDescription( "io", + this->AddParameter( ParameterType_Group, "io", "Input and output data" ); + this->SetParameterDescription( "io", "This group of parameters allows setting input and output data." ); - AddParameter( ParameterType_InputVectorDataList, "io.vd", "Input Vector Data" ); - SetParameterDescription( "io.vd", + this->AddParameter( ParameterType_InputVectorDataList, "io.vd", "Input Vector Data" ); + this->SetParameterDescription( "io.vd", "Input geometries used for training (note: all geometries from the layer will be used)" ); - AddParameter( ParameterType_InputFilename, "io.stats", "Input XML image statistics file" ); - MandatoryOff( "io.stats" ); - SetParameterDescription( "io.stats", + this->AddParameter( ParameterType_InputFilename, "io.stats", "Input XML image statistics file" ); + this->MandatoryOff( "io.stats" ); + this->SetParameterDescription( "io.stats", "XML file containing mean and variance of each feature." ); - AddParameter( ParameterType_OutputFilename, "io.out", "Output model" ); - SetParameterDescription( "io.out", + this->AddParameter( ParameterType_OutputFilename, "io.out", "Output model" ); + this->SetParameterDescription( "io.out", "Output file containing the model estimated (.txt format)." ); - AddParameter( ParameterType_Int, "layer", "Layer Index" ); - SetParameterDescription( "layer", + this->AddParameter( ParameterType_Int, "layer", "Layer Index" ); + this->SetParameterDescription( "layer", "Index of the layer to use in the input vector file." ); - MandatoryOff( "layer" ); - SetDefaultParameterInt( "layer", 0 ); + this->MandatoryOff( "layer" ); + this->SetDefaultParameterInt( "layer", 0 ); - AddParameter(ParameterType_ListView, "feat", "Field names for training features"); - SetParameterDescription("feat", + this->AddParameter(ParameterType_ListView, "feat", "Field names for training features"); + this->SetParameterDescription("feat", "List of field names in the input vector data to be used as features for training."); // Add validation data used to compute confusion matrix or contingency table - AddParameter( ParameterType_Group, "valid", "Validation data" ); - SetParameterDescription( "valid", + this->AddParameter( ParameterType_Group, "valid", "Validation data" ); + this->SetParameterDescription( "valid", "This group of parameters defines validation data." ); - AddParameter( ParameterType_InputVectorDataList, "valid.vd", + this->AddParameter( ParameterType_InputVectorDataList, "valid.vd", "Validation Vector Data" ); - SetParameterDescription( "valid.vd", "Geometries used for validation " + this->SetParameterDescription( "valid.vd", "Geometries used for validation " "(must contain the same fields used for training, all geometries from the layer will be used)" ); - MandatoryOff( "valid.vd" ); + this->MandatoryOff( "valid.vd" ); - AddParameter( ParameterType_Int, "valid.layer", "Layer Index" ); - SetParameterDescription( "valid.layer", + this->AddParameter( ParameterType_Int, "valid.layer", "Layer Index" ); + this->SetParameterDescription( "valid.layer", "Index of the layer to use in the validation vector file." ); - MandatoryOff( "valid.layer" ); - SetDefaultParameterInt( "valid.layer", 0 ); + this->MandatoryOff( "valid.layer" ); + this->SetDefaultParameterInt( "valid.layer", 0 ); // Add class field if we used validation - AddParameter( ParameterType_ListView, "cfield", + this->AddParameter( ParameterType_ListView, "cfield", "Field containing the class integer label for supervision" ); - SetParameterDescription( "cfield", + this->SetParameterDescription( "cfield", "Field containing the class id for supervision. " "The values in this field shall be cast into integers. " "Only geometries with this field available will be taken into account." ); - SetListViewSingleSelectionMode( "cfield", true ); + this->SetListViewSingleSelectionMode( "cfield", true ); - // Add a new parameter to compute confusion matrix / contingency table - AddParameter( ParameterType_OutputFilename, "io.confmatout", - "Output confusion matrix or contingency table" ); - SetParameterDescription( "io.confmatout", - "Output file containing the confusion matrix or contingency table (.csv format)." - "The contingency table is output when we unsupervised algorithms is used otherwise the confusion matrix is output." ); - MandatoryOff( "io.confmatout" ); - - AddParameter(ParameterType_Bool, "v", "Verbose mode"); - SetParameterDescription("v", "Verbose mode, display the contingency table result."); - SetParameterInt("v", 1); + this->AddParameter(ParameterType_Bool, "v", "Verbose mode"); + this->SetParameterDescription("v", "Verbose mode, display the contingency table result."); + this->SetParameterInt("v", 1); // Doc example parameter settings - SetDocExampleParameterValue( "io.vd", "vectorData.shp" ); - SetDocExampleParameterValue( "io.stats", "meanVar.xml" ); - SetDocExampleParameterValue( "io.out", "svmModel.svm" ); - SetDocExampleParameterValue( "feat", "perimeter area width" ); - SetDocExampleParameterValue( "cfield", "predicted" ); + this->SetDocExampleParameterValue( "io.vd", "vectorData.shp" ); + this->SetDocExampleParameterValue( "io.stats", "meanVar.xml" ); + this->SetDocExampleParameterValue( "io.out", "svmModel.svm" ); + this->SetDocExampleParameterValue( "feat", "perimeter area width" ); + this->SetDocExampleParameterValue( "cfield", "predicted" ); // Add parameters for the classifier choice Superclass::DoInit(); - AddRANDParameter(); + this->AddRANDParameter(); } -void TrainVectorBase::DoUpdateParameters() +template <class TInputValue, class TOutputValue> +void +TrainVectorBase<TInputValue, TOutputValue> +::DoUpdateParameters() { // if vector data is present and updated then reload fields - if( HasValue( "io.vd" ) ) + if( this->HasValue( "io.vd" ) ) { - std::vector<std::string> vectorFileList = GetParameterStringList( "io.vd" ); + std::vector<std::string> vectorFileList = this->GetParameterStringList( "io.vd" ); ogr::DataSource::Pointer ogrDS = ogr::DataSource::New( vectorFileList[0], ogr::DataSource::Modes::Read ); ogr::Layer layer = ogrDS->GetLayer( static_cast<size_t>( this->GetParameterInt( "layer" ) ) ); ogr::Feature feature = layer.ogr().GetNextFeature(); - ClearChoices( "feat" ); - ClearChoices( "cfield" ); + this->ClearChoices( "feat" ); + this->ClearChoices( "cfield" ); for( int iField = 0; iField < feature.ogr().GetFieldCount(); iField++ ) { @@ -134,20 +132,23 @@ void TrainVectorBase::DoUpdateParameters() if( fieldType == OFTInteger || fieldType == OFTInteger64 || fieldType == OFTReal ) { std::string tmpKey = "feat." + key.substr( 0, static_cast<unsigned long>( end - key.begin() ) ); - AddChoice( tmpKey, item ); + this->AddChoice( tmpKey, item ); } - if( fieldType == OFTString || fieldType == OFTInteger || fieldType == OFTInteger64 ) + if( fieldType == OFTString || fieldType == OFTInteger || fieldType == OFTInteger64 || fieldType == OFTReal ) { std::string tmpKey = "cfield." + key.substr( 0, static_cast<unsigned long>( end - key.begin() ) ); - AddChoice( tmpKey, item ); + this->AddChoice( tmpKey, item ); } } } } -void TrainVectorBase::DoExecute() +template <class TInputValue, class TOutputValue> +void +TrainVectorBase<TInputValue, TOutputValue> +::DoExecute() { - m_FeaturesInfo.SetFieldNames( GetChoiceNames( "feat" ), GetSelectedItems( "feat" )); + m_FeaturesInfo.SetFieldNames( this->GetChoiceNames( "feat" ), this->GetSelectedItems( "feat" )); // Check input parameters if( m_FeaturesInfo.m_SelectedIdx.empty() ) @@ -158,29 +159,35 @@ void TrainVectorBase::DoExecute() ShiftScaleParameters measurement = GetStatistics( m_FeaturesInfo.m_NbFeatures ); ExtractAllSamples( measurement ); - this->Train( m_TrainingSamplesWithLabel.listSample, m_TrainingSamplesWithLabel.labeledListSample, GetParameterString( "io.out" ) ); + this->Train( m_TrainingSamplesWithLabel.listSample, m_TrainingSamplesWithLabel.labeledListSample, this->GetParameterString( "io.out" ) ); m_PredictedList = - this->Classify( m_ClassificationSamplesWithLabel.listSample, GetParameterString( "io.out" ) ); + this->Classify( m_ClassificationSamplesWithLabel.listSample, this->GetParameterString( "io.out" ) ); } - -void TrainVectorBase::ExtractAllSamples(const ShiftScaleParameters &measurement) +template <class TInputValue, class TOutputValue> +void +TrainVectorBase<TInputValue, TOutputValue> +::ExtractAllSamples(const ShiftScaleParameters &measurement) { m_TrainingSamplesWithLabel = ExtractTrainingSamplesWithLabel(measurement); m_ClassificationSamplesWithLabel = ExtractClassificationSamplesWithLabel(measurement); } -TrainVectorBase::SamplesWithLabel -TrainVectorBase::ExtractTrainingSamplesWithLabel(const ShiftScaleParameters &measurement) +template <class TInputValue, class TOutputValue> +typename TrainVectorBase<TInputValue, TOutputValue>::SamplesWithLabel +TrainVectorBase<TInputValue, TOutputValue> +::ExtractTrainingSamplesWithLabel(const ShiftScaleParameters &measurement) { return ExtractSamplesWithLabel( "io.vd", "layer", measurement); } -TrainVectorBase::SamplesWithLabel -TrainVectorBase::ExtractClassificationSamplesWithLabel(const ShiftScaleParameters &measurement) +template <class TInputValue, class TOutputValue> +typename TrainVectorBase<TInputValue, TOutputValue>::SamplesWithLabel +TrainVectorBase<TInputValue, TOutputValue> +::ExtractClassificationSamplesWithLabel(const ShiftScaleParameters &measurement) { - if(GetClassifierCategory() == Supervised) + if(this->GetClassifierCategory() == Superclass::Supervised) { SamplesWithLabel tmpSamplesWithLabel; SamplesWithLabel validationSamplesWithLabel = ExtractSamplesWithLabel( "valid.vd", "valid.layer", measurement ); @@ -206,15 +213,16 @@ TrainVectorBase::ExtractClassificationSamplesWithLabel(const ShiftScaleParameter } } - -TrainVectorBase::ShiftScaleParameters -TrainVectorBase::GetStatistics(unsigned int nbFeatures) +template <class TInputValue, class TOutputValue> +typename TrainVectorBase<TInputValue, TOutputValue>::ShiftScaleParameters +TrainVectorBase<TInputValue, TOutputValue> +::GetStatistics(unsigned int nbFeatures) { ShiftScaleParameters measurement = ShiftScaleParameters(); - if( HasValue( "io.stats" ) && IsParameterEnabled( "io.stats" ) ) + if( this->HasValue( "io.stats" ) && this->IsParameterEnabled( "io.stats" ) ) { - StatisticsReader::Pointer statisticsReader = StatisticsReader::New(); - std::string XMLfile = GetParameterString( "io.stats" ); + typename StatisticsReader::Pointer statisticsReader = StatisticsReader::New(); + std::string XMLfile = this->GetParameterString( "io.stats" ); statisticsReader->SetFileName( XMLfile ); measurement.meanMeasurementVector = statisticsReader->GetStatisticVectorByName( "mean" ); measurement.stddevMeasurementVector = statisticsReader->GetStatisticVectorByName( "stddev" ); @@ -229,16 +237,34 @@ TrainVectorBase::GetStatistics(unsigned int nbFeatures) return measurement; } +// Template specialization for the integer case (i.e.classification), to avoid a cast from double to integer +template <> +inline int +TrainVectorBase<float, int> +::GetFeatureField(const ogr::Feature & feature, int fieldIndex) +{ + return(feature[fieldIndex].GetValue<int>()); +} + +template <class TInputValue, class TOutputValue> +inline TOutputValue +TrainVectorBase<TInputValue, TOutputValue> +::GetFeatureField(const ogr::Feature & feature, int fieldIndex) +{ + return(feature[fieldIndex].GetValue<double>()); +} -TrainVectorBase::SamplesWithLabel -TrainVectorBase::ExtractSamplesWithLabel(std::string parameterName, std::string parameterLayer, +template <class TInputValue, class TOutputValue> +typename TrainVectorBase<TInputValue, TOutputValue>::SamplesWithLabel +TrainVectorBase<TInputValue, TOutputValue> +::ExtractSamplesWithLabel(std::string parameterName, std::string parameterLayer, const ShiftScaleParameters &measurement) { SamplesWithLabel samplesWithLabel; - if( HasValue( parameterName ) && IsParameterEnabled( parameterName ) ) + if( this->HasValue( parameterName ) && this->IsParameterEnabled( parameterName ) ) { - ListSampleType::Pointer input = ListSampleType::New(); - TargetListSampleType::Pointer target = TargetListSampleType::New(); + typename ListSampleType::Pointer input = ListSampleType::New(); + typename TargetListSampleType::Pointer target = TargetListSampleType::New(); input->SetMeasurementVectorSize( m_FeaturesInfo.m_NbFeatures ); std::vector<std::string> fileList = this->GetParameterStringList( parameterName ); @@ -251,7 +277,7 @@ TrainVectorBase::ExtractSamplesWithLabel(std::string parameterName, std::string bool goesOn = feature.addr() != 0; if( !goesOn ) { - otbAppLogWARNING( "The layer " << GetParameterInt( parameterLayer ) << " of " << fileList[k] + otbAppLogWARNING( "The layer " << this->GetParameterInt( parameterLayer ) << " of " << fileList[k] << " is empty, input is skipped." ); continue; } @@ -284,14 +310,14 @@ TrainVectorBase::ExtractSamplesWithLabel(std::string parameterName, std::string MeasurementType mv; mv.SetSize( m_FeaturesInfo.m_NbFeatures ); for( unsigned int idx = 0; idx < m_FeaturesInfo.m_NbFeatures; ++idx ) - mv[idx] = feature.ogr().GetFieldAsDouble( featureFieldIndex[idx] ); + mv[idx] = feature[featureFieldIndex[idx]].GetValue<double>(); input->PushBack( mv ); if(cFieldIndex>=0 && ogr::Field(feature,cFieldIndex).HasBeenSet()) - target->PushBack( feature.ogr().GetFieldAsInteger( cFieldIndex ) ); + target->PushBack(GetFeatureField(feature,cFieldIndex)); else - target->PushBack( 0 ); + target->PushBack( 0. ); feature = layer.ogr().GetNextFeature(); goesOn = feature.addr() != 0; @@ -300,7 +326,7 @@ TrainVectorBase::ExtractSamplesWithLabel(std::string parameterName, std::string - ShiftScaleFilterType::Pointer shiftScaleFilter = ShiftScaleFilterType::New(); + typename ShiftScaleFilterType::Pointer shiftScaleFilter = ShiftScaleFilterType::New(); shiftScaleFilter->SetInput( input ); shiftScaleFilter->SetShifts( measurement.meanMeasurementVector ); shiftScaleFilter->SetScales( measurement.stddevMeasurementVector ); diff --git a/Modules/Applications/AppClassification/test/CMakeLists.txt b/Modules/Applications/AppClassification/test/CMakeLists.txt index 2379deb501c3c6233fb1a1a6b34e8fe99a78c6e0..15286ce99c373666820979ebc1f73d7c848169f6 100644 --- a/Modules/Applications/AppClassification/test/CMakeLists.txt +++ b/Modules/Applications/AppClassification/test/CMakeLists.txt @@ -837,6 +837,22 @@ if(OTB_USE_OPENCV) ${TEMP}/apTvClTrainVectorClassifierModel.rf) endif() +#----------- TrainVectorRegression TESTS ---------------- +if(OTB_USE_OPENCV) + otb_test_application(NAME apTvClTrainVectorRegression + APP TrainVectorRegression + OPTIONS -io.vd ${INPUTDATA}/Classification/apTvClSampleExtractionOut.sqlite + -feat value_0 value_1 value_2 value_3 + -cfield class + -classifier rf + -io.out ${TEMP}/apTvClTrainVectorRegressionModel.rf + -io.mse ${TEMP}/apTvClTrainVectorRegressionModel.txt + TESTENVOPTIONS ${TEMP}/apTvClTrainVectorRegressionModel.txt + VALID ${ascii_comparison} + ${OTBAPP_BASELINE_FILES}/apTvClTrainVectorRegressionModel.txt + ${TEMP}/apTvClTrainVectorRegressionModel.txt) +endif() + #----------- TrainVectorClassifier unsupervised TESTS ---------------- if(OTB_USE_SHARK) otb_test_application(NAME apTvClTrainVectorUnsupervised