Commit db9711e4 authored by remi cresson's avatar remi cresson
Browse files

FIX: validation step not performed if not user parameters are set

parent d9f71201
...@@ -468,32 +468,32 @@ public: ...@@ -468,32 +468,32 @@ public:
AddProcess(m_TrainModelFilter, "Training epoch #" + std::to_string(epoch)); AddProcess(m_TrainModelFilter, "Training epoch #" + std::to_string(epoch));
m_TrainModelFilter->Update(); m_TrainModelFilter->Update();
if (do_validation) // Validate the model
if (epoch % GetParameterInt("validation.step") == 0)
{ {
// Validate the model // 1. Evaluate the metrics against the learning data
if (epoch % GetParameterInt("validation.step") == 0)
{
// 1. Evaluate the metrics against the learning data
for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstLearningData.size() ; i++) for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstLearningData.size() ; i++)
{ {
m_ValidateModelFilter->SetInput(i, m_InputSourcesForEvaluationAgainstLearningData[i]); m_ValidateModelFilter->SetInput(i, m_InputSourcesForEvaluationAgainstLearningData[i]);
} }
m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstLearningData); m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstLearningData);
// As we use the learning data here, it's rational to use the same option as streaming during training // As we use the learning data here, it's rational to use the same option as streaming during training
m_ValidateModelFilter->SetUseStreaming(GetParameterInt("training.usestreaming")); m_ValidateModelFilter->SetUseStreaming(GetParameterInt("training.usestreaming"));
// Update // Update
AddProcess(m_ValidateModelFilter, "Evaluate model (Learning data)"); AddProcess(m_ValidateModelFilter, "Evaluate model (Learning data)");
m_ValidateModelFilter->Update(); m_ValidateModelFilter->Update();
for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++) for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++)
{ {
otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":"); otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":");
PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i)); PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i));
} }
if (do_validation)
{
// 2. Evaluate the metrics against the validation data // 2. Evaluate the metrics against the validation data
// Here we just change the input sources and references // Here we just change the input sources and references
...@@ -514,7 +514,7 @@ public: ...@@ -514,7 +514,7 @@ public:
PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i)); PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i));
} }
} // Step is OK to perform validation } // Step is OK to perform validation
} } // Do the validation against the validation data
} // Next epoch } // Next epoch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment