Commit 5f6a9bdb authored by Cédric Traizet's avatar Cédric Traizet
Browse files

ENH: Compute mse in TrainRegression

No related merge requests found
Showing with 20 additions and 5 deletions
+20 -5
...@@ -73,6 +73,22 @@ protected: ...@@ -73,6 +73,22 @@ protected:
Superclass::DoUpdateParameters(); Superclass::DoUpdateParameters();
} }
double ComputeMSE(TargetListSampleType* list1, TargetListSampleType* list2)
{
assert(list1->Size() == list2->Size());
double mse = 0.;
for (TargetListSampleType::InstanceIdentifier i=0; i<list1->Size() ; ++i)
{
auto elem1 = list1->GetMeasurementVector(i);
auto elem2 = list2->GetMeasurementVector(i);
mse += (elem1[0] - elem2[0]) * (elem1[0] - elem2[0]);
}
mse /= static_cast<double>(list1->Size());
return mse;
}
void DoExecute() override void DoExecute() override
{ {
m_FeaturesInfo.SetClassFieldNames( GetChoiceNames( "cfield" ), GetSelectedItems( "cfield" ) ); m_FeaturesInfo.SetClassFieldNames( GetChoiceNames( "cfield" ), GetSelectedItems( "cfield" ) );
...@@ -84,12 +100,11 @@ protected: ...@@ -84,12 +100,11 @@ protected:
Superclass::DoExecute(); Superclass::DoExecute();
/* otbAppLogINFO("Computing training performances");
std::cout << m_PredictedList << std::endl;
std::cout << m_ClassificationSamplesWithLabel.labeledListSample << std::endl;
*/
auto mse = ComputeMSE(m_ClassificationSamplesWithLabel.labeledListSample.GetPointer(), m_PredictedList.GetPointer() );
otbAppLogINFO("Mean Square Error = "<<mse);
} }
private: private:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment