Commit d9f71201 authored by remi cresson's avatar remi cresson
Browse files

FIX: validation step not performed if not user parameters are set

parent 58e508bd
...@@ -438,6 +438,7 @@ public: ...@@ -438,6 +438,7 @@ public:
} }
// Setup the validation filter // Setup the validation filter
const bool do_validation = HasUserValue("validation.mode");
if (GetParameterInt("validation.mode")==1) // class if (GetParameterInt("validation.mode")==1) // class
{ {
otbAppLogINFO("Set validation mode to classification validation"); otbAppLogINFO("Set validation mode to classification validation");
...@@ -467,50 +468,53 @@ public: ...@@ -467,50 +468,53 @@ public:
AddProcess(m_TrainModelFilter, "Training epoch #" + std::to_string(epoch)); AddProcess(m_TrainModelFilter, "Training epoch #" + std::to_string(epoch));
m_TrainModelFilter->Update(); m_TrainModelFilter->Update();
// Validate the model if (do_validation)
if (epoch % GetParameterInt("validation.step") == 0) {
// Validate the model
if (epoch % GetParameterInt("validation.step") == 0)
{ {
// 1. Evaluate the metrics against the learning data // 1. Evaluate the metrics against the learning data
for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstLearningData.size() ; i++) for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstLearningData.size() ; i++)
{ {
m_ValidateModelFilter->SetInput(i, m_InputSourcesForEvaluationAgainstLearningData[i]); m_ValidateModelFilter->SetInput(i, m_InputSourcesForEvaluationAgainstLearningData[i]);
} }
m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstLearningData); m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstLearningData);
// As we use the learning data here, it's rational to use the same option as streaming during training // As we use the learning data here, it's rational to use the same option as streaming during training
m_ValidateModelFilter->SetUseStreaming(GetParameterInt("training.usestreaming")); m_ValidateModelFilter->SetUseStreaming(GetParameterInt("training.usestreaming"));
// Update // Update
AddProcess(m_ValidateModelFilter, "Evaluate model (Learning data)"); AddProcess(m_ValidateModelFilter, "Evaluate model (Learning data)");
m_ValidateModelFilter->Update(); m_ValidateModelFilter->Update();
for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++) for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++)
{ {
otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":"); otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":");
PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i)); PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i));
} }
// 2. Evaluate the metrics against the validation data // 2. Evaluate the metrics against the validation data
// Here we just change the input sources and references // Here we just change the input sources and references
for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstValidationData.size() ; i++) for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstValidationData.size() ; i++)
{ {
m_ValidateModelFilter->SetInput(i, m_InputSourcesForEvaluationAgainstValidationData[i]); m_ValidateModelFilter->SetInput(i, m_InputSourcesForEvaluationAgainstValidationData[i]);
} }
m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstValidationData); m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstValidationData);
m_ValidateModelFilter->SetUseStreaming(GetParameterInt("validation.usestreaming")); m_ValidateModelFilter->SetUseStreaming(GetParameterInt("validation.usestreaming"));
// Update // Update
AddProcess(m_ValidateModelFilter, "Evaluate model (Validation data)"); AddProcess(m_ValidateModelFilter, "Evaluate model (Validation data)");
m_ValidateModelFilter->Update(); m_ValidateModelFilter->Update();
for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++) for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++)
{ {
otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":"); otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":");
PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i)); PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i));
} }
} // Step is OK to perform validation } // Step is OK to perform validation
}
} // Next epoch } // Next epoch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment