diff --git a/Modules/Applications/AppClassification/app/otbTrainVectorRegression.cxx b/Modules/Applications/AppClassification/app/otbTrainVectorRegression.cxx index ac784a433a76f195102f9925340be5efb0e1ff0b..ce07691cb99eda66f09e8cfaf026427a78037f22 100644 --- a/Modules/Applications/AppClassification/app/otbTrainVectorRegression.cxx +++ b/Modules/Applications/AppClassification/app/otbTrainVectorRegression.cxx @@ -73,6 +73,22 @@ protected: Superclass::DoUpdateParameters(); } + double ComputeMSE(TargetListSampleType* list1, TargetListSampleType* list2) + { + assert(list1->Size() == list2->Size()); + double mse = 0.; + for (TargetListSampleType::InstanceIdentifier i=0; i<list1->Size() ; ++i) + { + auto elem1 = list1->GetMeasurementVector(i); + auto elem2 = list2->GetMeasurementVector(i); + + mse += (elem1[0] - elem2[0]) * (elem1[0] - elem2[0]); + } + mse /= static_cast<double>(list1->Size()); + return mse; + } + + void DoExecute() override { m_FeaturesInfo.SetClassFieldNames( GetChoiceNames( "cfield" ), GetSelectedItems( "cfield" ) ); @@ -84,12 +100,11 @@ protected: Superclass::DoExecute(); - /* - std::cout << m_PredictedList << std::endl; - std::cout << m_ClassificationSamplesWithLabel.labeledListSample << std::endl; - */ - + otbAppLogINFO("Computing training performances"); + auto mse = ComputeMSE(m_ClassificationSamplesWithLabel.labeledListSample.GetPointer(), m_PredictedList.GetPointer() ); + + otbAppLogINFO("Mean Square Error = "<<mse); } private: