Commit db9711e4 authored by remi cresson's avatar remi cresson

FIX: validation step not performed if not user parameters are set

parent d9f71201
......@@ -468,32 +468,32 @@ public:
AddProcess(m_TrainModelFilter, "Training epoch #" + std::to_string(epoch));
m_TrainModelFilter->Update();
if (do_validation)
// Validate the model
if (epoch % GetParameterInt("validation.step") == 0)
{
// Validate the model
if (epoch % GetParameterInt("validation.step") == 0)
{
// 1. Evaluate the metrics against the learning data
// 1. Evaluate the metrics against the learning data
for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstLearningData.size() ; i++)
{
m_ValidateModelFilter->SetInput(i, m_InputSourcesForEvaluationAgainstLearningData[i]);
}
m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstLearningData);
for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstLearningData.size() ; i++)
{
m_ValidateModelFilter->SetInput(i, m_InputSourcesForEvaluationAgainstLearningData[i]);
}
m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstLearningData);
// As we use the learning data here, it's rational to use the same option as streaming during training
m_ValidateModelFilter->SetUseStreaming(GetParameterInt("training.usestreaming"));
// As we use the learning data here, it's rational to use the same option as streaming during training
m_ValidateModelFilter->SetUseStreaming(GetParameterInt("training.usestreaming"));
// Update
AddProcess(m_ValidateModelFilter, "Evaluate model (Learning data)");
m_ValidateModelFilter->Update();
// Update
AddProcess(m_ValidateModelFilter, "Evaluate model (Learning data)");
m_ValidateModelFilter->Update();
for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++)
{
otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":");
PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i));
}
for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++)
{
otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":");
PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i));
}
if (do_validation)
{
// 2. Evaluate the metrics against the validation data
// Here we just change the input sources and references
......@@ -514,7 +514,7 @@ public:
PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i));
}
} // Step is OK to perform validation
}
} // Do the validation against the validation data
} // Next epoch
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment