Commit 82c27e1f authored by Cresson Remi's avatar Cresson Remi
Browse files

ADD: streaming option for training/validation

parent ddf221b3
......@@ -198,17 +198,21 @@ public:
MandatoryOn ("training.targetnodes");
AddParameter(ParameterType_StringList, "training.outputtensors", "Names of the output tensors to display");
MandatoryOff ("training.outputtensors");
AddParameter(ParameterType_Bool, "training.usestreaming", "Use the streaming through patches (slower but can process big dataset)");
MandatoryOff ("training.usestreaming");
// Metrics
AddParameter(ParameterType_Group, "validation", "Validation parameters");
AddParameter(ParameterType_Group, "validation", "Validation parameters");
MandatoryOff ("validation");
AddParameter(ParameterType_Choice, "validation.mode", "Metrics to compute");
AddChoice ("validation.mode.none", "No validation step");
AddChoice ("validation.mode.class", "Classification metrics");
AddChoice ("validation.mode.rmse", "Root mean square error");
AddParameter(ParameterType_Choice, "validation.mode", "Metrics to compute");
AddChoice ("validation.mode.none", "No validation step");
AddChoice ("validation.mode.class", "Classification metrics");
AddChoice ("validation.mode.rmse", "Root mean square error");
AddParameter(ParameterType_StringList, "validation.userplaceholders",
"Additional single-valued placeholders for validation. Supported types: int, float, bool.");
MandatoryOff ("validation.userplaceholders");
AddParameter(ParameterType_Bool, "validation.usestreaming", "Use the streaming through patches (slower but can process big dataset)");
MandatoryOff ("validation.usestreaming");
// Input/output images
AddAnInputImage();
......@@ -418,6 +422,7 @@ public:
m_TrainModelFilter->SetTargetNodesNames(GetParameterStringList("training.targetnodes"));
m_TrainModelFilter->SetBatchSize(GetParameterInt("training.batchsize"));
m_TrainModelFilter->SetUserPlaceholders(GetUserPlaceholders("training.userplaceholders"));
m_TrainModelFilter->SetUseStreaming(GetParameterInt("training.usestreaming"));
// Set inputs
for (unsigned int i = 0 ; i < m_InputSourcesForTraining.size() ; i++)
......@@ -454,6 +459,9 @@ public:
m_ValidateModelFilter->SetBatchSize(GetParameterInt("training.batchsize"));
m_ValidateModelFilter->SetUserPlaceholders(GetUserPlaceholders("validation.userplaceholders"));
// AS we use the learning data here, it's rational to use the same option as streaming during training
m_ValidateModelFilter->SetUseStreaming(GetParameterInt("training.usestreaming"));
// 1. Evaluate the metrics against the learning data
for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstLearningData.size() ; i++)
......@@ -485,6 +493,7 @@ public:
m_ValidateModelFilter->SetInput(i, m_InputSourcesForEvaluationAgainstValidationData[i]);
}
m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstValidationData);
m_ValidateModelFilter->SetUseStreaming(GetParameterInt("validation.usestreaming"));
// Update
AddProcess(m_ValidateModelFilter, "Evaluate model (Validation data)");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment