Commit 108aab60 authored by Cresson Remi's avatar Cresson Remi
Browse files

ENH: perform validation every Nth epochs

parent 4d9c113a
...@@ -141,23 +141,21 @@ public: ...@@ -141,23 +141,21 @@ public:
// Documentation // Documentation
SetName("TensorflowModelServe"); SetName("TensorflowModelServe");
SetDescription("Multisource deep learning classifier using Tensorflow. Change " SetDescription("Multisource deep learning classifier using TensorFlow. Change the "
"the " + tf::ENV_VAR_NAME_NSOURCES + " environment variable to set the number of " + tf::ENV_VAR_NAME_NSOURCES + " environment variable to set the number of sources.");
"sources."); SetDocLongDescription("The application run a TensorFlow model over multiple data sources. "
SetDocLongDescription("The application run a Tensorflow model over multiple data sources. " "The number of input sources can be changed at runtime by setting the system "
"The number of input sources can be changed at runtime by setting the " "environment variable " + tf::ENV_VAR_NAME_NSOURCES + ". For each source, you have to "
"system environment variable " + tf::ENV_VAR_NAME_NSOURCES + ". " "set (1) the placeholder name, as named in the TensorFlow model, (2) the receptive "
"For each source, you have to set (1) the tensor placeholder name, as named in " "field and (3) the image(s) source. The output is a multiband image, stacking all "
"the tensorflow model, (2) the patch size and (3) the image(s) source. " "outputs tensors together: you have to specify (1) the names of the output tensors, as "
"The output is a multiband image, stacking all outputs " "named in the TensorFlow model (typically, an operator's output) and (2) the expression "
"tensors together: you have to specify the names of the output tensors, as " "field of each output tensor. The output tensors values will be stacked in the same "
"named in the tensorflow model (typically, an operator's output). The output " "order as they appear in the \"model.output\" parameter (you can use a space separator "
"tensors values will be stacked in the same order as they appear in the " "between names). You might consider to use extended filename to bypass the automatic "
"\"model.output\" parameter (you can use a space separator between names). " "memory footprint calculator of the otb application engine, and set a good splitting "
"Last but not least, consider using extended filename to bypass the automatic " "strategy (Square tiles is good for convolutional networks) or use the \"optim\" "
"memory footprint calculator of the otb application engine, and set a good " "parameter group to impose your squared tiles sizes");
"splitting strategy (I would recommend using small square tiles) or use the "
"finetuning parameter group to impose your squared tiles sizes");
SetDocAuthors("Remi Cresson"); SetDocAuthors("Remi Cresson");
// Input/output images // Input/output images
...@@ -167,17 +165,21 @@ public: ...@@ -167,17 +165,21 @@ public:
// Input model // Input model
AddParameter(ParameterType_Group, "model", "model parameters"); AddParameter(ParameterType_Group, "model", "model parameters");
AddParameter(ParameterType_Directory, "model.dir", "Tensorflow model_save directory"); AddParameter(ParameterType_Directory, "model.dir", "TensorFlow model_save directory");
MandatoryOn ("model.dir"); MandatoryOn ("model.dir");
SetParameterDescription ("model.dir", "The model directory should contains the model Google Protobuf (.pb) and variables");
AddParameter(ParameterType_StringList, "model.userplaceholders", "Additional single-valued placeholders. Supported types: int, float, bool."); AddParameter(ParameterType_StringList, "model.userplaceholders", "Additional single-valued placeholders. Supported types: int, float, bool.");
MandatoryOff ("model.userplaceholders"); MandatoryOff ("model.userplaceholders");
SetParameterDescription ("model.userplaceholders", "Syntax to use is \"placeholder_1=value_1 ... placeholder_N=value_N\"");
AddParameter(ParameterType_Bool, "model.fullyconv", "Fully convolutional"); AddParameter(ParameterType_Bool, "model.fullyconv", "Fully convolutional");
MandatoryOff ("model.fullyconv"); MandatoryOff ("model.fullyconv");
// Output tensors parameters // Output tensors parameters
AddParameter(ParameterType_Group, "output", "Output tensors parameters"); AddParameter(ParameterType_Group, "output", "Output tensors parameters");
AddParameter(ParameterType_Float, "output.spcscale", "The output spacing scale"); AddParameter(ParameterType_Float, "output.spcscale", "The output spacing scale, related to the first input");
SetDefaultParameterFloat ("output.spcscale", 1.0); SetDefaultParameterFloat ("output.spcscale", 1.0);
SetParameterDescription ("output.spcscale", "The output image size/scale and spacing*scale where size and spacing corresponds to the first input");
AddParameter(ParameterType_StringList, "output.names", "Names of the output tensors"); AddParameter(ParameterType_StringList, "output.names", "Names of the output tensors");
MandatoryOn ("output.names"); MandatoryOn ("output.names");
...@@ -195,6 +197,7 @@ public: ...@@ -195,6 +197,7 @@ public:
AddParameter(ParameterType_Group, "optim" , "This group of parameters allows optimization of processing time"); AddParameter(ParameterType_Group, "optim" , "This group of parameters allows optimization of processing time");
AddParameter(ParameterType_Bool, "optim.disabletiling", "Disable tiling"); AddParameter(ParameterType_Bool, "optim.disabletiling", "Disable tiling");
MandatoryOff ("optim.disabletiling"); MandatoryOff ("optim.disabletiling");
SetParameterDescription ("optim.disabletiling", "Tiling avoids to process a too large subset of image, but sometimes it can be useful to disable it");
AddParameter(ParameterType_Int, "optim.tilesize", "Tile width used to stream the filter output"); AddParameter(ParameterType_Int, "optim.tilesize", "Tile width used to stream the filter output");
SetMinimumParameterIntValue ("optim.tilesize", 1); SetMinimumParameterIntValue ("optim.tilesize", 1);
SetDefaultParameterInt ("optim.tilesize", 16); SetDefaultParameterInt ("optim.tilesize", 16);
...@@ -230,8 +233,8 @@ public: ...@@ -230,8 +233,8 @@ public:
bundle.m_PatchSize[1] = GetParameterInt(bundle.m_KeyPszY); bundle.m_PatchSize[1] = GetParameterInt(bundle.m_KeyPszY);
otbAppLogINFO("Source info :"); otbAppLogINFO("Source info :");
otbAppLogINFO("Field of view : " << bundle.m_PatchSize ); otbAppLogINFO("Receptive field : " << bundle.m_PatchSize );
otbAppLogINFO("Placeholder : " << bundle.m_Placeholder); otbAppLogINFO("Placeholder name : " << bundle.m_Placeholder);
} }
} }
...@@ -273,7 +276,7 @@ public: ...@@ -273,7 +276,7 @@ public:
// Fully convolutional mode on/off // Fully convolutional mode on/off
if (GetParameterInt("model.fullyconv")==1) if (GetParameterInt("model.fullyconv")==1)
{ {
otbAppLogINFO("The tensorflow model is used in fully convolutional mode"); otbAppLogINFO("The TensorFlow model is used in fully convolutional mode");
m_TFFilter->SetFullyConvolutional(true); m_TFFilter->SetFullyConvolutional(true);
} }
...@@ -292,7 +295,7 @@ public: ...@@ -292,7 +295,7 @@ public:
const unsigned int tileSize = GetParameterInt("optim.tilesize"); const unsigned int tileSize = GetParameterInt("optim.tilesize");
otbAppLogINFO("Force tiling with squared tiles of " << tileSize) otbAppLogINFO("Force tiling with squared tiles of " << tileSize)
// Update the TF filter to get the output image size // Update the TensorFlow filter output information to get the output image size
m_TFFilter->UpdateOutputInformation(); m_TFFilter->UpdateOutputInformation();
// Splitting using square tiles // Splitting using square tiles
...@@ -301,7 +304,7 @@ public: ...@@ -301,7 +304,7 @@ public:
unsigned int nbDesiredTiles = itk::Math::Ceil<unsigned int>( unsigned int nbDesiredTiles = itk::Math::Ceil<unsigned int>(
double(m_TFFilter->GetOutput()->GetLargestPossibleRegion().GetNumberOfPixels() ) / (tileSize * tileSize) ); double(m_TFFilter->GetOutput()->GetLargestPossibleRegion().GetNumberOfPixels() ) / (tileSize * tileSize) );
// Use an itk::StreamingImageFilter to force the computation on tiles // Use an itk::StreamingImageFilter to force the computation tile by tile
m_StreamFilter = StreamingFilterType::New(); m_StreamFilter = StreamingFilterType::New();
m_StreamFilter->SetRegionSplitter(splitter); m_StreamFilter->SetRegionSplitter(splitter);
m_StreamFilter->SetNumberOfStreamDivisions(nbDesiredTiles); m_StreamFilter->SetNumberOfStreamDivisions(nbDesiredTiles);
...@@ -313,7 +316,6 @@ public: ...@@ -313,7 +316,6 @@ public:
{ {
otbAppLogINFO("Tiling disabled"); otbAppLogINFO("Tiling disabled");
SetParameterOutputImage("out", m_TFFilter->GetOutput()); SetParameterOutputImage("out", m_TFFilter->GetOutput());
} }
} }
......
...@@ -191,7 +191,7 @@ public: ...@@ -191,7 +191,7 @@ public:
SetDefaultParameterInt ("training.batchsize", 100); SetDefaultParameterInt ("training.batchsize", 100);
AddParameter(ParameterType_Int, "training.epochs", "Number of epochs"); AddParameter(ParameterType_Int, "training.epochs", "Number of epochs");
SetMinimumParameterIntValue ("training.epochs", 1); SetMinimumParameterIntValue ("training.epochs", 1);
SetDefaultParameterInt ("training.epochs", 10); SetDefaultParameterInt ("training.epochs", 100);
AddParameter(ParameterType_StringList, "training.userplaceholders", AddParameter(ParameterType_StringList, "training.userplaceholders",
"Additional single-valued placeholders for training. Supported types: int, float, bool."); "Additional single-valued placeholders for training. Supported types: int, float, bool.");
MandatoryOff ("training.userplaceholders"); MandatoryOff ("training.userplaceholders");
...@@ -205,6 +205,9 @@ public: ...@@ -205,6 +205,9 @@ public:
// Metrics // Metrics
AddParameter(ParameterType_Group, "validation", "Validation parameters"); AddParameter(ParameterType_Group, "validation", "Validation parameters");
MandatoryOff ("validation"); MandatoryOff ("validation");
AddParameter(ParameterType_Int, "validation.step", "Perform the validation every Nth epochs");
SetMinimumParameterIntValue ("validation.step", 1);
SetDefaultParameterInt ("validation.step", 10);
AddParameter(ParameterType_Choice, "validation.mode", "Metrics to compute"); AddParameter(ParameterType_Choice, "validation.mode", "Metrics to compute");
AddChoice ("validation.mode.none", "No validation step"); AddChoice ("validation.mode.none", "No validation step");
AddChoice ("validation.mode.class", "Classification metrics"); AddChoice ("validation.mode.class", "Classification metrics");
...@@ -415,7 +418,7 @@ public: ...@@ -415,7 +418,7 @@ public:
// Prepare inputs // Prepare inputs
PrepareInputs(); PrepareInputs();
// Setup filter // Setup training filter
m_TrainModelFilter = TrainModelFilterType::New(); m_TrainModelFilter = TrainModelFilterType::New();
m_TrainModelFilter->SetGraph(m_SavedModel.meta_graph_def.graph_def()); m_TrainModelFilter->SetGraph(m_SavedModel.meta_graph_def.graph_def());
m_TrainModelFilter->SetSession(m_SavedModel.session.get()); m_TrainModelFilter->SetSession(m_SavedModel.session.get());
...@@ -434,21 +437,6 @@ public: ...@@ -434,21 +437,6 @@ public:
m_InputSourcesForTraining[i]); m_InputSourcesForTraining[i]);
} }
// Train the model
for (int epoch = 0 ; epoch < GetParameterInt("training.epochs") ; epoch++)
{
AddProcess(m_TrainModelFilter, "Training epoch #" + std::to_string(epoch+1));
m_TrainModelFilter->Update();
}
// Check if we have to save variables to somewhere
if (HasValue("model.saveto"))
{
const std::string path = GetParameterAsString("model.saveto");
otbAppLogINFO("Saving model to " + path);
tf::SaveModel(path, m_SavedModel);
}
// Setup the validation filter // Setup the validation filter
if (GetParameterInt("validation.mode")==1) // class if (GetParameterInt("validation.mode")==1) // class
{ {
...@@ -459,60 +447,79 @@ public: ...@@ -459,60 +447,79 @@ public:
m_ValidateModelFilter->SetSession(m_SavedModel.session.get()); m_ValidateModelFilter->SetSession(m_SavedModel.session.get());
m_ValidateModelFilter->SetBatchSize(GetParameterInt("training.batchsize")); m_ValidateModelFilter->SetBatchSize(GetParameterInt("training.batchsize"));
m_ValidateModelFilter->SetUserPlaceholders(GetUserPlaceholders("validation.userplaceholders")); m_ValidateModelFilter->SetUserPlaceholders(GetUserPlaceholders("validation.userplaceholders"));
m_ValidateModelFilter->SetInputPlaceholders(m_InputPlaceholdersForValidation);
m_ValidateModelFilter->SetInputReceptiveFields(m_InputPatchesSizeForValidation);
m_ValidateModelFilter->SetOutputTensors(m_TargetTensorsNames);
m_ValidateModelFilter->SetOutputExpressionFields(m_TargetPatchesSize);
}
else if (GetParameterInt("validation.mode")==2) // rmse)
{
otbAppLogINFO("Set validation mode to classification RMSE evaluation");
otbAppLogFATAL("Not implemented yet !"); // XD
// AS we use the learning data here, it's rational to use the same option as streaming during training // TODO
m_ValidateModelFilter->SetUseStreaming(GetParameterInt("training.usestreaming")); }
// 1. Evaluate the metrics against the learning data // Epoch
for (int epoch = 1 ; epoch <= GetParameterInt("training.epochs") ; epoch++)
{
// Train the model
AddProcess(m_TrainModelFilter, "Training epoch #" + std::to_string(epoch));
m_TrainModelFilter->Update();
for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstLearningData.size() ; i++) // Validate the model
if (epoch % GetParameterInt("validation.step") == 0)
{ {
m_ValidateModelFilter->PushBackInputTensorBundle( // 1. Evaluate the metrics against the learning data
m_InputPlaceholdersForValidation[i],
m_InputPatchesSizeForValidation[i],
m_InputSourcesForEvaluationAgainstLearningData[i]);
}
m_ValidateModelFilter->SetOutputTensors(m_TargetTensorsNames);
m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstLearningData);
m_ValidateModelFilter->SetOutputExpressionFields(m_TargetPatchesSize);
// Update for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstLearningData.size() ; i++)
AddProcess(m_ValidateModelFilter, "Evaluate model (Learning data)"); {
m_ValidateModelFilter->Update(); m_ValidateModelFilter->SetInput(i, m_InputSourcesForEvaluationAgainstLearningData[i]);
}
m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstLearningData);
for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++) // As we use the learning data here, it's rational to use the same option as streaming during training
{ m_ValidateModelFilter->SetUseStreaming(GetParameterInt("training.usestreaming"));
otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":");
PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i));
}
// 2. Evaluate the metrics against the validation data // Update
AddProcess(m_ValidateModelFilter, "Evaluate model (Learning data)");
m_ValidateModelFilter->Update();
// Here we just change the input sources and references for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++)
for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstValidationData.size() ; i++) {
{ otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":");
m_ValidateModelFilter->SetInput(i, m_InputSourcesForEvaluationAgainstValidationData[i]); PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i));
} }
m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstValidationData);
m_ValidateModelFilter->SetUseStreaming(GetParameterInt("validation.usestreaming"));
// Update // 2. Evaluate the metrics against the validation data
AddProcess(m_ValidateModelFilter, "Evaluate model (Validation data)");
m_ValidateModelFilter->Update();
for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++) // Here we just change the input sources and references
{ for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstValidationData.size() ; i++)
otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":"); {
PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i)); m_ValidateModelFilter->SetInput(i, m_InputSourcesForEvaluationAgainstValidationData[i]);
} }
m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstValidationData);
m_ValidateModelFilter->SetUseStreaming(GetParameterInt("validation.usestreaming"));
} // Update
else if (GetParameterInt("validation.mode")==2) // rmse) AddProcess(m_ValidateModelFilter, "Evaluate model (Validation data)");
{ m_ValidateModelFilter->Update();
otbAppLogINFO("Set validation mode to classification RMSE evaluation");
// TODO for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++)
{
otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":");
PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i));
}
} // Step is OK to perform validation
} // Next epoch
// Check if we have to save variables to somewhere
if (HasValue("model.saveto"))
{
const std::string path = GetParameterAsString("model.saveto");
otbAppLogINFO("Saving model to " + path);
tf::SaveModel(path, m_SavedModel);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment