diff --git a/Modules/Applications/AppClassification/app/otbTrainRegression.cxx b/Modules/Applications/AppClassification/app/otbTrainRegression.cxx index 9012b49ab223f25ae083c8ae0b512074d43db98c..a93c8b5439d65e670fa92e52ee889933d45f3d61 100644 --- a/Modules/Applications/AppClassification/app/otbTrainRegression.cxx +++ b/Modules/Applications/AppClassification/app/otbTrainRegression.cxx @@ -271,7 +271,7 @@ void ParseCSVPredictors(std::string path, ListSampleType* outputList) elem.Fill(0.0); for (unsigned int i=0 ; i<nbCols ; ++i) { - elem[i] = std::stod(words[i]); + elem[i] = std::stod(words[i]); } outputList->PushBack(elem); } diff --git a/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx b/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx index c014b21b03d4d232e59b7d0be3f5a0bdebb878d0..fbd04d4a4b5c0fa213d5266d3588b1f8de7202cb 100644 --- a/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx +++ b/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx @@ -73,15 +73,13 @@ protected: SetOfficialDocLink(); Superclass::DoInit(); - - // Add a new parameter to compute confusion matrix / contingency table - this->AddParameter( ParameterType_OutputFilename, "io.confmatout", - "Output confusion matrix or contingency table" ); - this->SetParameterDescription( "io.confmatout", - "Output file containing the confusion matrix or contingency table (.csv format)." - "The contingency table is output when we unsupervised algorithms is used otherwise the confusion matrix is output." ); - this->MandatoryOff( "io.confmatout" ); + // Add a new parameter to compute confusion matrix / contingency table + this->AddParameter(ParameterType_OutputFilename, "io.confmatout", "Output confusion matrix or contingency table"); + this->SetParameterDescription("io.confmatout", + "Output file containing the confusion matrix or contingency table (.csv format)." + "The contingency table is output when we unsupervised algorithms is used otherwise the confusion matrix is output."); + this->MandatoryOff("io.confmatout"); } void DoUpdateParameters() override diff --git a/Modules/Applications/AppClassification/app/otbTrainVectorRegression.cxx b/Modules/Applications/AppClassification/app/otbTrainVectorRegression.cxx index c31ebd3ad3468195c4e792e4858fc4d64753c6fb..b5f599b629077d692724afe0debe2c4a24ba11f6 100644 --- a/Modules/Applications/AppClassification/app/otbTrainVectorRegression.cxx +++ b/Modules/Applications/AppClassification/app/otbTrainVectorRegression.cxx @@ -30,14 +30,13 @@ class TrainVectorRegression : public TrainVectorBase<float, float> public: typedef TrainVectorRegression Self; typedef TrainVectorBase<float, float> Superclass; - typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<Self> Pointer; typedef itk::SmartPointer<const Self> ConstPointer; - - itkNewMacro( Self ) - itkTypeMacro( Self, Superclass ) - typedef Superclass::SampleType SampleType; - typedef Superclass::ListSampleType ListSampleType; + itkNewMacro(Self) itkTypeMacro(Self, Superclass) + + typedef Superclass::SampleType SampleType; + typedef Superclass::ListSampleType ListSampleType; typedef Superclass::TargetListSampleType TargetListSampleType; protected: @@ -45,79 +44,78 @@ protected: { this->m_RegressionFlag = true; } - + void DoInit() override { - SetName( "TrainVectorRegression" ); - SetDescription( "Train a regression algorithm based on geometries with " - "list of features to consider and a predictor." ); - - SetDocLongDescription( "This application trains a regression algorithm based on " - "a predictor geometries and a list of features to consider for " - "regression.\nThis application is based on LibSVM, OpenCV Machine " - "Learning (2.3.1 and later), and Shark ML The output of this application " - "is a text model file, whose format corresponds to the ML model type " - "chosen. There is no image nor vector data output."); + SetName("TrainVectorRegression"); + SetDescription( + "Train a regression algorithm based on geometries with " + "list of features to consider and a predictor."); + + SetDocLongDescription( + "This application trains a regression algorithm based on " + "a predictor geometries and a list of features to consider for " + "regression.\nThis application is based on LibSVM, OpenCV Machine " + "Learning (2.3.1 and later), and Shark ML The output of this application " + "is a text model file, whose format corresponds to the ML model type " + "chosen. There is no image nor vector data output."); SetDocLimitations("None"); - SetDocAuthors( "OTB Team" ); - SetDocSeeAlso( "TrainVectorClassifier" ); + SetDocAuthors("OTB Team"); + SetDocSeeAlso("TrainVectorClassifier"); SetOfficialDocLink(); Superclass::DoInit(); - - AddParameter( ParameterType_Float , "io.mse" , "Mean Square Error" ); - SetParameterDescription( "io.mse" , - "Mean square error computed with the validation predictors" ); - SetParameterRole( "io.mse" , Role_Output ); - this->MandatoryOff( "io.mse" ); + AddParameter(ParameterType_Float, "io.mse", "Mean Square Error"); + SetParameterDescription("io.mse", "Mean square error computed with the validation predictors"); + SetParameterRole("io.mse", Role_Output); + this->MandatoryOff("io.mse"); } void DoUpdateParameters() override { Superclass::DoUpdateParameters(); } - + double ComputeMSE(const TargetListSampleType& list1, const TargetListSampleType& list2) { assert(list1.Size() == list2.Size()); double mse = 0.; - for (TargetListSampleType::InstanceIdentifier i=0; i<list1.Size() ; ++i) + for (TargetListSampleType::InstanceIdentifier i = 0; i < list1.Size(); ++i) { auto elem1 = list1.GetMeasurementVector(i); auto elem2 = list2.GetMeasurementVector(i); - + mse += (elem1[0] - elem2[0]) * (elem1[0] - elem2[0]); } mse /= static_cast<double>(list1.Size()); return mse; } - - + + void DoExecute() override { - m_FeaturesInfo.SetClassFieldNames( GetChoiceNames( "cfield" ), GetSelectedItems( "cfield" ) ); + m_FeaturesInfo.SetClassFieldNames(GetChoiceNames("cfield"), GetSelectedItems("cfield")); - if( m_FeaturesInfo.m_SelectedCFieldIdx.empty() && GetClassifierCategory() == Supervised ) - { - otbAppLogFATAL( << "No field has been selected for data labelling!" ); - } + if (m_FeaturesInfo.m_SelectedCFieldIdx.empty() && GetClassifierCategory() == Supervised) + { + otbAppLogFATAL(<< "No field has been selected for data labelling!"); + } Superclass::DoExecute(); - + otbAppLogINFO("Computing training performances"); - - auto mse = ComputeMSE(*m_ClassificationSamplesWithLabel.labeledListSample, *m_PredictedList ); - otbAppLogINFO("Mean Square Error = "<<mse); - this->SetParameterFloat("io.mse",mse); + auto mse = ComputeMSE(*m_ClassificationSamplesWithLabel.labeledListSample, *m_PredictedList); + + otbAppLogINFO("Mean Square Error = " << mse); + this->SetParameterFloat("io.mse", mse); } - -private: +private: }; } } -OTB_APPLICATION_EXPORT( otb::Wrapper::TrainVectorRegression ) +OTB_APPLICATION_EXPORT(otb::Wrapper::TrainVectorRegression) diff --git a/Modules/Applications/AppClassification/include/otbTrainVectorBase.h b/Modules/Applications/AppClassification/include/otbTrainVectorBase.h index 52675e935fd425deb13c91caad4e40a3d7412e3d..bc5c716aef98324bcd14882b451fc2636dbfe9cc 100644 --- a/Modules/Applications/AppClassification/include/otbTrainVectorBase.h +++ b/Modules/Applications/AppClassification/include/otbTrainVectorBase.h @@ -62,8 +62,8 @@ public: /** Standard macro */ itkTypeMacro(Self, Superclass); - typedef typename Superclass::SampleType SampleType; - typedef typename Superclass::ListSampleType ListSampleType; + typedef typename Superclass::SampleType SampleType; + typedef typename Superclass::ListSampleType ListSampleType; typedef typename Superclass::TargetListSampleType TargetListSampleType; typedef double ValueType; @@ -87,7 +87,7 @@ protected: class SamplesWithLabel { public: - typename ListSampleType::Pointer listSample; + typename ListSampleType::Pointer listSample; typename TargetListSampleType::Pointer labeledListSample; SamplesWithLabel() { @@ -190,8 +190,7 @@ private: /** * Get the field of the input feature corresponding to the input field */ - inline TOutputValue GetFeatureField(const ogr::Feature & feature, int field); - + inline TOutputValue GetFeatureField(const ogr::Feature& feature, int field); }; }