Commit e2968219 authored by remi cresson's avatar remi cresson
Browse files

FIX: validation step not performed if not user parameters are set

parent 831c400b
...@@ -468,32 +468,32 @@ public: ...@@ -468,32 +468,32 @@ public:
AddProcess(m_TrainModelFilter, "Training epoch #" + std::to_string(epoch)); AddProcess(m_TrainModelFilter, "Training epoch #" + std::to_string(epoch));
m_TrainModelFilter->Update(); m_TrainModelFilter->Update();
// Validate the model if (do_validation)
if (epoch % GetParameterInt("validation.step") == 0)
{ {
// 1. Evaluate the metrics against the learning data // Validate the model
if (epoch % GetParameterInt("validation.step") == 0)
for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstLearningData.size() ; i++)
{ {
m_ValidateModelFilter->SetInput(i, m_InputSourcesForEvaluationAgainstLearningData[i]); // 1. Evaluate the metrics against the learning data
}
m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstLearningData);
// As we use the learning data here, it's rational to use the same option as streaming during training for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstLearningData.size() ; i++)
m_ValidateModelFilter->SetUseStreaming(GetParameterInt("training.usestreaming")); {
m_ValidateModelFilter->SetInput(i, m_InputSourcesForEvaluationAgainstLearningData[i]);
}
m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstLearningData);
// Update // As we use the learning data here, it's rational to use the same option as streaming during training
AddProcess(m_ValidateModelFilter, "Evaluate model (Learning data)"); m_ValidateModelFilter->SetUseStreaming(GetParameterInt("training.usestreaming"));
m_ValidateModelFilter->Update();
for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++) // Update
{ AddProcess(m_ValidateModelFilter, "Evaluate model (Learning data)");
otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":"); m_ValidateModelFilter->Update();
PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i));
} for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++)
{
otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":");
PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i));
}
if (do_validation)
{
// 2. Evaluate the metrics against the validation data // 2. Evaluate the metrics against the validation data
// Here we just change the input sources and references // Here we just change the input sources and references
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment