Commit 70004b8e authored by Cresson Remi's avatar Cresson Remi
Browse files

Merge branch 'develop'

parents 5c0ec1a9 4d9c113a
This diff is collapsed.
...@@ -92,9 +92,11 @@ private: ...@@ -92,9 +92,11 @@ private:
ShareParameter("deepmodel", "tfmodel.model", ShareParameter("deepmodel", "tfmodel.model",
"Deep net model parameters", "Deep net model parameters"); "Deep net model parameters", "Deep net model parameters");
ShareParameter("output", "tfmodel.output", ShareParameter("output", "tfmodel.output",
"Deep net outputs parameters", "Deep net outputs parameters"); "Deep net outputs parameters",
ShareParameter("finetuning", "tfmodel.finetuning", "Deep net outputs parameters");
"Deep net fine tuning parameters","Deep net fine tuning parameters"); ShareParameter("optim", "tfmodel.optim",
"This group of parameters allows optimization of processing time",
"This group of parameters allows optimization of processing time");
// Classify shared parameters // Classify shared parameters
ShareParameter("model" , "classif.model" , "Model file" , "Model file" ); ShareParameter("model" , "classif.model" , "Model file" , "Model file" );
......
...@@ -106,14 +106,14 @@ public: ...@@ -106,14 +106,14 @@ public:
// Parameter group keys // Parameter group keys
ss_key_in << ss_key_group.str() << ".il"; ss_key_in << ss_key_group.str() << ".il";
ss_key_dims_x << ss_key_group.str() << ".fovx"; ss_key_dims_x << ss_key_group.str() << ".rfieldx";
ss_key_dims_y << ss_key_group.str() << ".fovy"; ss_key_dims_y << ss_key_group.str() << ".rfieldy";
ss_key_ph << ss_key_group.str() << ".placeholder"; ss_key_ph << ss_key_group.str() << ".placeholder";
// Parameter group descriptions // Parameter group descriptions
ss_desc_in << "Input image (or list to stack) for source #" << inputNumber; ss_desc_in << "Input image (or list to stack) for source #" << inputNumber;
ss_desc_dims_x << "Field of view width for source #" << inputNumber; ss_desc_dims_x << "Input receptive field (width) for source #" << inputNumber;
ss_desc_dims_y << "Field of view height for source #" << inputNumber; ss_desc_dims_y << "Input receptive field (height) for source #" << inputNumber;
ss_desc_ph << "Name of the input placeholder for source #" << inputNumber; ss_desc_ph << "Name of the input placeholder for source #" << inputNumber;
// Populate group // Populate group
...@@ -182,22 +182,22 @@ public: ...@@ -182,22 +182,22 @@ public:
MandatoryOn ("output.names"); MandatoryOn ("output.names");
// Output Field of Expression // Output Field of Expression
AddParameter(ParameterType_Int, "output.foex", "The output field of expression (x)"); AddParameter(ParameterType_Int, "output.efieldx", "The output expression field (width)");
SetMinimumParameterIntValue ("output.foex", 1); SetMinimumParameterIntValue ("output.efieldx", 1);
SetDefaultParameterInt ("output.foex", 1); SetDefaultParameterInt ("output.efieldx", 1);
MandatoryOn ("output.foex"); MandatoryOn ("output.efieldx");
AddParameter(ParameterType_Int, "output.foey", "The output field of expression (y)"); AddParameter(ParameterType_Int, "output.efieldy", "The output expression field (height)");
SetMinimumParameterIntValue ("output.foey", 1); SetMinimumParameterIntValue ("output.efieldy", 1);
SetDefaultParameterInt ("output.foey", 1); SetDefaultParameterInt ("output.efieldy", 1);
MandatoryOn ("output.foey"); MandatoryOn ("output.efieldy");
// Fine tuning // Fine tuning
AddParameter(ParameterType_Group, "finetuning" , "Fine tuning performance or consistency parameters"); AddParameter(ParameterType_Group, "optim" , "This group of parameters allows optimization of processing time");
AddParameter(ParameterType_Bool, "finetuning.disabletiling", "Disable tiling"); AddParameter(ParameterType_Bool, "optim.disabletiling", "Disable tiling");
MandatoryOff ("finetuning.disabletiling"); MandatoryOff ("optim.disabletiling");
AddParameter(ParameterType_Int, "finetuning.tilesize", "Tile width used to stream the filter output"); AddParameter(ParameterType_Int, "optim.tilesize", "Tile width used to stream the filter output");
SetMinimumParameterIntValue ("finetuning.tilesize", 1); SetMinimumParameterIntValue ("optim.tilesize", 1);
SetDefaultParameterInt ("finetuning.tilesize", 16); SetDefaultParameterInt ("optim.tilesize", 16);
// Output image // Output image
AddParameter(ParameterType_OutputImage, "out", "output image"); AddParameter(ParameterType_OutputImage, "out", "output image");
...@@ -205,8 +205,8 @@ public: ...@@ -205,8 +205,8 @@ public:
// Example // Example
SetDocExampleParameterValue("source1.il", "spot6pms.tif"); SetDocExampleParameterValue("source1.il", "spot6pms.tif");
SetDocExampleParameterValue("source1.placeholder", "x1"); SetDocExampleParameterValue("source1.placeholder", "x1");
SetDocExampleParameterValue("source1.fovx", "16"); SetDocExampleParameterValue("source1.rfieldx", "16");
SetDocExampleParameterValue("source1.fovy", "16"); SetDocExampleParameterValue("source1.rfieldy", "16");
SetDocExampleParameterValue("model.dir", "/tmp/my_saved_model/"); SetDocExampleParameterValue("model.dir", "/tmp/my_saved_model/");
SetDocExampleParameterValue("model.userplaceholders", "is_training=false dropout=0.0"); SetDocExampleParameterValue("model.userplaceholders", "is_training=false dropout=0.0");
SetDocExampleParameterValue("output.names", "out_predict1 out_proba1"); SetDocExampleParameterValue("output.names", "out_predict1 out_proba1");
...@@ -248,16 +248,16 @@ public: ...@@ -248,16 +248,16 @@ public:
m_TFFilter = TFModelFilterType::New(); m_TFFilter = TFModelFilterType::New();
m_TFFilter->SetGraph(m_SavedModel.meta_graph_def.graph_def()); m_TFFilter->SetGraph(m_SavedModel.meta_graph_def.graph_def());
m_TFFilter->SetSession(m_SavedModel.session.get()); m_TFFilter->SetSession(m_SavedModel.session.get());
m_TFFilter->SetOutputTensorsNames(GetParameterStringList("output.names")); m_TFFilter->SetOutputTensors(GetParameterStringList("output.names"));
m_TFFilter->SetOutputSpacingScale(GetParameterFloat("output.spcscale")); m_TFFilter->SetOutputSpacingScale(GetParameterFloat("output.spcscale"));
otbAppLogINFO("Output spacing ratio: " << m_TFFilter->GetOutputSpacingScale()); otbAppLogINFO("Output spacing ratio: " << m_TFFilter->GetOutputSpacingScale());
// Get user placeholders // Get user placeholders
TFModelFilterType::DictListType dict;
TFModelFilterType::StringList expressions = GetParameterStringList("model.userplaceholders"); TFModelFilterType::StringList expressions = GetParameterStringList("model.userplaceholders");
TFModelFilterType::DictType dict;
for (auto& exp: expressions) for (auto& exp: expressions)
{ {
TFModelFilterType::DictType entry = tf::ExpressionToTensor(exp); TFModelFilterType::DictElementType entry = tf::ExpressionToTensor(exp);
dict.push_back(entry); dict.push_back(entry);
otbAppLogINFO("Using placeholder " << entry.first << " with " << tf::PrintTensorInfos(entry.second)); otbAppLogINFO("Using placeholder " << entry.first << " with " << tf::PrintTensorInfos(entry.second));
...@@ -267,7 +267,7 @@ public: ...@@ -267,7 +267,7 @@ public:
// Input sources // Input sources
for (auto& bundle: m_Bundles) for (auto& bundle: m_Bundles)
{ {
m_TFFilter->PushBackInputBundle(bundle.m_Placeholder, bundle.m_PatchSize, bundle.m_ImageSource.Get()); m_TFFilter->PushBackInputTensorBundle(bundle.m_Placeholder, bundle.m_PatchSize, bundle.m_ImageSource.Get());
} }
// Fully convolutional mode on/off // Fully convolutional mode on/off
...@@ -281,15 +281,15 @@ public: ...@@ -281,15 +281,15 @@ public:
FloatVectorImageType::SizeType foe; FloatVectorImageType::SizeType foe;
foe[0] = GetParameterInt("output.foex"); foe[0] = GetParameterInt("output.foex");
foe[1] = GetParameterInt("output.foey"); foe[1] = GetParameterInt("output.foey");
m_TFFilter->SetOutputFOESize(foe); m_TFFilter->SetOutputExpressionFields({foe});
otbAppLogINFO("Output field of expression: " << m_TFFilter->GetOutputFOESize()); otbAppLogINFO("Output field of expression: " << m_TFFilter->GetOutputExpressionFields()[0]);
// Streaming // Streaming
if (GetParameterInt("finetuning.disabletiling")!=1) if (GetParameterInt("optim.disabletiling")!=1)
{ {
// Get the tile size // Get the tile size
const unsigned int tileSize = GetParameterInt("finetuning.tilesize"); const unsigned int tileSize = GetParameterInt("optim.tilesize");
otbAppLogINFO("Force tiling with squared tiles of " << tileSize) otbAppLogINFO("Force tiling with squared tiles of " << tileSize)
// Update the TF filter to get the output image size // Update the TF filter to get the output image size
......
...@@ -53,7 +53,7 @@ public: ...@@ -53,7 +53,7 @@ public:
itkNewMacro(Self); itkNewMacro(Self);
itkTypeMacro(TensorflowModelTrain, Application); itkTypeMacro(TensorflowModelTrain, Application);
/** Typedefs for tensorflow */ /** Typedefs for TensorFlow */
typedef otb::TensorflowMultisourceModelTrain<FloatVectorImageType> TrainModelFilterType; typedef otb::TensorflowMultisourceModelTrain<FloatVectorImageType> TrainModelFilterType;
typedef otb::TensorflowMultisourceModelValidate<FloatVectorImageType> ValidateModelFilterType; typedef otb::TensorflowMultisourceModelValidate<FloatVectorImageType> ValidateModelFilterType;
typedef otb::TensorflowSource<FloatVectorImageType> TFSource; typedef otb::TensorflowSource<FloatVectorImageType> TFSource;
...@@ -75,8 +75,8 @@ public: ...@@ -75,8 +75,8 @@ public:
// Parameters keys // Parameters keys
std::string m_KeyInForTrain; // Key of input image list (training) std::string m_KeyInForTrain; // Key of input image list (training)
std::string m_KeyInForValid; // Key of input image list (validation) std::string m_KeyInForValid; // Key of input image list (validation)
std::string m_KeyPHNameForTrain; // Key for placeholder name in the tensorflow model (training) std::string m_KeyPHNameForTrain; // Key for placeholder name in the TensorFlow model (training)
std::string m_KeyPHNameForValid; // Key for placeholder name in the tensorflow model (validation) std::string m_KeyPHNameForValid; // Key for placeholder name in the TensorFlow model (validation)
std::string m_KeyPszX; // Key for samples sizes X std::string m_KeyPszX; // Key for samples sizes X
std::string m_KeyPszY; // Key for samples sizes Y std::string m_KeyPszY; // Key for samples sizes Y
}; };
...@@ -122,18 +122,19 @@ public: ...@@ -122,18 +122,19 @@ public:
// Parameter group keys // Parameter group keys
ss_key_tr_in << ss_key_tr_group.str() << ".il"; ss_key_tr_in << ss_key_tr_group.str() << ".il";
ss_key_val_in << ss_key_val_group.str() << ".il"; ss_key_val_in << ss_key_val_group.str() << ".il";
ss_key_dims_x << ss_key_tr_group.str() << ".fovx"; ss_key_dims_x << ss_key_tr_group.str() << ".patchsizex";
ss_key_dims_y << ss_key_tr_group.str() << ".fovy"; ss_key_dims_y << ss_key_tr_group.str() << ".patchsizey";
ss_key_tr_ph << ss_key_tr_group.str() << ".placeholder"; ss_key_tr_ph << ss_key_tr_group.str() << ".placeholder";
ss_key_val_ph << ss_key_val_group.str() << ".placeholder"; ss_key_val_ph << ss_key_val_group.str() << ".name";
// Parameter group descriptions // Parameter group descriptions
ss_desc_tr_in << "Input image (or list to stack) for source #" << inputNumber << " (training)"; ss_desc_tr_in << "Input image (or list to stack) for source #" << inputNumber << " (training)";
ss_desc_val_in << "Input image (or list to stack) for source #" << inputNumber << " (validation)"; ss_desc_val_in << "Input image (or list to stack) for source #" << inputNumber << " (validation)";
ss_desc_dims_x << "Field of view width for source #" << inputNumber; ss_desc_dims_x << "Patch size (x) for source #" << inputNumber;
ss_desc_dims_y << "Field of view height for source #" << inputNumber; ss_desc_dims_y << "Patch size (y) for source #" << inputNumber;
ss_desc_tr_ph << "Name of the input placeholder for source #" << inputNumber << " (training)"; ss_desc_tr_ph << "Name of the input placeholder for source #" << inputNumber << " (training)";
ss_desc_val_ph << "Name of the input placeholder for source #" << inputNumber << " (validation)"; ss_desc_val_ph << "Name of the input placeholder "
"or output tensor for source #" << inputNumber << " (validation)";
// Populate group // Populate group
AddParameter(ParameterType_Group, ss_key_tr_group.str(), ss_desc_tr_group.str()); AddParameter(ParameterType_Group, ss_key_tr_group.str(), ss_desc_tr_group.str());
...@@ -194,21 +195,25 @@ public: ...@@ -194,21 +195,25 @@ public:
AddParameter(ParameterType_StringList, "training.userplaceholders", AddParameter(ParameterType_StringList, "training.userplaceholders",
"Additional single-valued placeholders for training. Supported types: int, float, bool."); "Additional single-valued placeholders for training. Supported types: int, float, bool.");
MandatoryOff ("training.userplaceholders"); MandatoryOff ("training.userplaceholders");
AddParameter(ParameterType_StringList, "training.targetnodesnames", "Names of the target nodes"); AddParameter(ParameterType_StringList, "training.targetnodes", "Names of the target nodes");
MandatoryOn ("training.targetnodesnames"); MandatoryOn ("training.targetnodes");
AddParameter(ParameterType_StringList, "training.outputtensorsnames", "Names of the output tensors to display"); AddParameter(ParameterType_StringList, "training.outputtensors", "Names of the output tensors to display");
MandatoryOff ("training.outputtensorsnames"); MandatoryOff ("training.outputtensors");
AddParameter(ParameterType_Bool, "training.usestreaming", "Use the streaming through patches (slower but can process big dataset)");
MandatoryOff ("training.usestreaming");
// Metrics // Metrics
AddParameter(ParameterType_Group, "validation", "Validation parameters"); AddParameter(ParameterType_Group, "validation", "Validation parameters");
MandatoryOff ("validation"); MandatoryOff ("validation");
AddParameter(ParameterType_Choice, "validation.mode", "Metrics to compute"); AddParameter(ParameterType_Choice, "validation.mode", "Metrics to compute");
AddChoice ("validation.mode.none", "No validation step"); AddChoice ("validation.mode.none", "No validation step");
AddChoice ("validation.mode.class", "Classification metrics"); AddChoice ("validation.mode.class", "Classification metrics");
AddChoice ("validation.mode.rmse", "Root mean square error"); AddChoice ("validation.mode.rmse", "Root mean square error");
AddParameter(ParameterType_StringList, "validation.userplaceholders", AddParameter(ParameterType_StringList, "validation.userplaceholders",
"Additional single-valued placeholders for validation. Supported types: int, float, bool."); "Additional single-valued placeholders for validation. Supported types: int, float, bool.");
MandatoryOff ("validation.userplaceholders"); MandatoryOff ("validation.userplaceholders");
AddParameter(ParameterType_Bool, "validation.usestreaming", "Use the streaming through patches (slower but can process big dataset)");
MandatoryOff ("validation.usestreaming");
// Input/output images // Input/output images
AddAnInputImage(); AddAnInputImage();
...@@ -220,15 +225,15 @@ public: ...@@ -220,15 +225,15 @@ public:
// Example // Example
SetDocExampleParameterValue("source1.il", "spot6pms.tif"); SetDocExampleParameterValue("source1.il", "spot6pms.tif");
SetDocExampleParameterValue("source1.placeholder", "x1"); SetDocExampleParameterValue("source1.placeholder", "x1");
SetDocExampleParameterValue("source1.fovx", "16"); SetDocExampleParameterValue("source1.patchsizex", "16");
SetDocExampleParameterValue("source1.fovy", "16"); SetDocExampleParameterValue("source1.patchsizey", "16");
SetDocExampleParameterValue("source2.il", "labels.tif"); SetDocExampleParameterValue("source2.il", "labels.tif");
SetDocExampleParameterValue("source2.placeholder", "y1"); SetDocExampleParameterValue("source2.placeholder", "y1");
SetDocExampleParameterValue("source2.fovx", "1"); SetDocExampleParameterValue("source2.patchsizex", "1");
SetDocExampleParameterValue("source2.fovy", "1"); SetDocExampleParameterValue("source2.patchsizex", "1");
SetDocExampleParameterValue("model.dir", "/tmp/my_saved_model/"); SetDocExampleParameterValue("model.dir", "/tmp/my_saved_model/");
SetDocExampleParameterValue("training.userplaceholders", "is_training=true dropout=0.2"); SetDocExampleParameterValue("training.userplaceholders", "is_training=true dropout=0.2");
SetDocExampleParameterValue("training.targetnodenames", "optimizer"); SetDocExampleParameterValue("training.targetnodes", "optimizer");
SetDocExampleParameterValue("model.saveto", "/tmp/my_saved_model_vars1"); SetDocExampleParameterValue("model.saveto", "/tmp/my_saved_model_vars1");
} }
...@@ -240,10 +245,10 @@ public: ...@@ -240,10 +245,10 @@ public:
// -Placeholders // -Placeholders
// -PatchSize // -PatchSize
// -ImageSource // -ImageSource
// 2.Validation/Test // 2.Learning/Validation
// -Placeholders (if input) or Tensor name (if target) // -Placeholders (if input) or Tensor name (if target)
// -PatchSize (which is the same as for training) // -PatchSize (which is the same as for training)
// -ImageSource (depending if it's for test or validation) // -ImageSource (depending if it's for learning or validation)
// //
// TODO: a bit of refactoring. We could simply rely on m_Bundles // TODO: a bit of refactoring. We could simply rely on m_Bundles
// if we can keep trace of indices of sources for // if we can keep trace of indices of sources for
...@@ -262,12 +267,12 @@ public: ...@@ -262,12 +267,12 @@ public:
// Clear bundles // Clear bundles
m_InputSourcesForTraining.clear(); m_InputSourcesForTraining.clear();
m_InputSourcesForTest.clear(); m_InputSourcesForEvaluationAgainstLearningData.clear();
m_InputSourcesForValidation.clear(); m_InputSourcesForEvaluationAgainstValidationData.clear();
m_TargetTensorsNames.clear(); m_TargetTensorsNames.clear();
m_InputTargetsForValidation.clear(); m_InputTargetsForEvaluationAgainstValidationData.clear();
m_InputTargetsForTest.clear(); m_InputTargetsForEvaluationAgainstLearningData.clear();
// Prepare the bundles // Prepare the bundles
...@@ -289,7 +294,7 @@ public: ...@@ -289,7 +294,7 @@ public:
m_InputPatchesSizeForTraining.push_back(patchSize); m_InputPatchesSizeForTraining.push_back(patchSize);
otbAppLogINFO("New source:"); otbAppLogINFO("New source:");
otbAppLogINFO("Field of view : "<< patchSize); otbAppLogINFO("Patch size : "<< patchSize);
otbAppLogINFO("Placeholder (training) : "<< placeholderForTraining); otbAppLogINFO("Placeholder (training) : "<< placeholderForTraining);
// Prepare validation sources // Prepare validation sources
...@@ -314,8 +319,8 @@ public: ...@@ -314,8 +319,8 @@ public:
if (placeholderForValidation.compare(placeholderForTraining) == 0) if (placeholderForValidation.compare(placeholderForTraining) == 0)
{ {
// Source // Source
m_InputSourcesForValidation.push_back(bundle.tfSourceForValidation.Get()); m_InputSourcesForEvaluationAgainstValidationData.push_back(bundle.tfSourceForValidation.Get());
m_InputSourcesForTest.push_back(bundle.tfSource.Get()); m_InputSourcesForEvaluationAgainstLearningData.push_back(bundle.tfSource.Get());
// Placeholder // Placeholder
m_InputPlaceholdersForValidation.push_back(placeholderForValidation); m_InputPlaceholdersForValidation.push_back(placeholderForValidation);
...@@ -330,8 +335,8 @@ public: ...@@ -330,8 +335,8 @@ public:
else else
{ {
// Source // Source
m_InputTargetsForValidation.push_back(bundle.tfSourceForValidation.Get()); m_InputTargetsForEvaluationAgainstValidationData.push_back(bundle.tfSourceForValidation.Get());
m_InputTargetsForTest.push_back(bundle.tfSource.Get()); m_InputTargetsForEvaluationAgainstLearningData.push_back(bundle.tfSource.Get());
// Placeholder // Placeholder
m_TargetTensorsNames.push_back(placeholderForValidation); m_TargetTensorsNames.push_back(placeholderForValidation);
...@@ -350,13 +355,13 @@ public: ...@@ -350,13 +355,13 @@ public:
// //
// Get user placeholders // Get user placeholders
// //
TrainModelFilterType::DictListType GetUserPlaceholders(const std::string key) TrainModelFilterType::DictType GetUserPlaceholders(const std::string key)
{ {
TrainModelFilterType::DictListType dict; TrainModelFilterType::DictType dict;
TrainModelFilterType::StringList expressions = GetParameterStringList(key); TrainModelFilterType::StringList expressions = GetParameterStringList(key);
for (auto& exp: expressions) for (auto& exp: expressions)
{ {
TrainModelFilterType::DictType entry = tf::ExpressionToTensor(exp); TrainModelFilterType::DictElementType entry = tf::ExpressionToTensor(exp);
dict.push_back(entry); dict.push_back(entry);
otbAppLogINFO("Using placeholder " << entry.first << " with " << tf::PrintTensorInfos(entry.second)); otbAppLogINFO("Using placeholder " << entry.first << " with " << tf::PrintTensorInfos(entry.second));
...@@ -414,16 +419,19 @@ public: ...@@ -414,16 +419,19 @@ public:
m_TrainModelFilter = TrainModelFilterType::New(); m_TrainModelFilter = TrainModelFilterType::New();
m_TrainModelFilter->SetGraph(m_SavedModel.meta_graph_def.graph_def()); m_TrainModelFilter->SetGraph(m_SavedModel.meta_graph_def.graph_def());
m_TrainModelFilter->SetSession(m_SavedModel.session.get()); m_TrainModelFilter->SetSession(m_SavedModel.session.get());
m_TrainModelFilter->SetOutputTensorsNames(GetParameterStringList("training.outputtensorsnames")); m_TrainModelFilter->SetOutputTensors(GetParameterStringList("training.outputtensors"));
m_TrainModelFilter->SetTargetNodesNames(GetParameterStringList("training.targetnodesnames")); m_TrainModelFilter->SetTargetNodesNames(GetParameterStringList("training.targetnodes"));
m_TrainModelFilter->SetBatchSize(GetParameterInt("training.batchsize")); m_TrainModelFilter->SetBatchSize(GetParameterInt("training.batchsize"));
m_TrainModelFilter->SetUserPlaceholders(GetUserPlaceholders("training.userplaceholders")); m_TrainModelFilter->SetUserPlaceholders(GetUserPlaceholders("training.userplaceholders"));
m_TrainModelFilter->SetUseStreaming(GetParameterInt("training.usestreaming"));
// Set input bundles // Set inputs
for (unsigned int i = 0 ; i < m_InputSourcesForTraining.size() ; i++) for (unsigned int i = 0 ; i < m_InputSourcesForTraining.size() ; i++)
{ {
m_TrainModelFilter->PushBackInputBundle(m_InputPlaceholdersForTraining[i], m_TrainModelFilter->PushBackInputTensorBundle(
m_InputPatchesSizeForTraining[i], m_InputSourcesForTraining[i]); m_InputPlaceholdersForTraining[i],
m_InputPatchesSizeForTraining[i],
m_InputSourcesForTraining[i]);
} }
// Train the model // Train the model
...@@ -449,55 +457,55 @@ public: ...@@ -449,55 +457,55 @@ public:
m_ValidateModelFilter = ValidateModelFilterType::New(); m_ValidateModelFilter = ValidateModelFilterType::New();
m_ValidateModelFilter->SetGraph(m_SavedModel.meta_graph_def.graph_def()); m_ValidateModelFilter->SetGraph(m_SavedModel.meta_graph_def.graph_def());
m_ValidateModelFilter->SetSession(m_SavedModel.session.get()); m_ValidateModelFilter->SetSession(m_SavedModel.session.get());
m_ValidateModelFilter->SetOutputTensorsNames(m_TargetTensorsNames);
m_ValidateModelFilter->SetBatchSize(GetParameterInt("training.batchsize")); m_ValidateModelFilter->SetBatchSize(GetParameterInt("training.batchsize"));
m_ValidateModelFilter->SetUserPlaceholders(GetUserPlaceholders("validation.userplaceholders")); m_ValidateModelFilter->SetUserPlaceholders(GetUserPlaceholders("validation.userplaceholders"));
// Evaluate the metrics against the learning data (test) // AS we use the learning data here, it's rational to use the same option as streaming during training
for (unsigned int i = 0 ; i < m_InputSourcesForTest.size() ; i++) m_ValidateModelFilter->SetUseStreaming(GetParameterInt("training.usestreaming"));
{
m_ValidateModelFilter->PushBackInputBundle(m_InputPlaceholdersForValidation[i], // 1. Evaluate the metrics against the learning data
m_InputPatchesSizeForValidation[i], m_InputSourcesForTest[i]);
} for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstLearningData.size() ; i++)
for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++)
{ {
m_ValidateModelFilter->PushBackInputReference(m_InputTargetsForTest[i], m_TargetPatchesSize[i]); m_ValidateModelFilter->PushBackInputTensorBundle(
m_InputPlaceholdersForValidation[i],
m_InputPatchesSizeForValidation[i],
m_InputSourcesForEvaluationAgainstLearningData[i]);
} }
m_ValidateModelFilter->SetOutputTensors(m_TargetTensorsNames);
m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstLearningData);
m_ValidateModelFilter->SetOutputExpressionFields(m_TargetPatchesSize);
// Evaluate the model (test) // Update
AddProcess(m_ValidateModelFilter, "Evaluate model (Test)"); AddProcess(m_ValidateModelFilter, "Evaluate model (Learning data)");
m_ValidateModelFilter->Update(); m_ValidateModelFilter->Update();
// Print some metrics
for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++) for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++)
{ {
otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":"); otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":");
PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i)); PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i));
} }
// Evaluate the metrics against the validation data // 2. Evaluate the metrics against the validation data
for (unsigned int i = 0 ; i < m_InputSourcesForValidation.size() ; i++)
{ // Here we just change the input sources and references
m_ValidateModelFilter->SetInput(i, m_InputSourcesForValidation[i]); for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstValidationData.size() ; i++)
}
m_ValidateModelFilter->ClearInputReferences();
for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++)
{ {
m_ValidateModelFilter->PushBackInputReference(m_InputTargetsForValidation[i], m_TargetPatchesSize[i]); m_ValidateModelFilter->SetInput(i, m_InputSourcesForEvaluationAgainstValidationData[i]);
} }
m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstValidationData);
m_ValidateModelFilter->SetUseStreaming(GetParameterInt("validation.usestreaming"));
// Evaluate the model (validation) // Update
AddProcess(m_ValidateModelFilter, "Evaluate model (Validation)"); AddProcess(m_ValidateModelFilter, "Evaluate model (Validation data)");
m_ValidateModelFilter->Update(); m_ValidateModelFilter->Update();
// Print some metrics
for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++) for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++)
{ {
otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":"); otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":");
PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i)); PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i));
} }
} }
else if (GetParameterInt("validation.mode")==2) // rmse) else if (GetParameterInt("validation.mode")==2) // rmse)
{ {
...@@ -511,22 +519,31 @@ public: ...@@ -511,22 +519,31 @@ public:
private: private:
tensorflow::SavedModelBundle m_SavedModel; // must be alive during all the execution of the application !
// Filters
TrainModelFilterType::Pointer m_TrainModelFilter; TrainModelFilterType::Pointer m_TrainModelFilter;
ValidateModelFilterType::Pointer m_ValidateModelFilter; ValidateModelFilterType::Pointer m_ValidateModelFilter;
tensorflow::SavedModelBundle m_SavedModel; // must be alive during all the execution of the application !
// Inputs
BundleList m_Bundles; BundleList m_Bundles;
// Patches size
SizeList m_InputPatchesSizeForTraining; SizeList m_InputPatchesSizeForTraining;
SizeList m_InputPatchesSizeForValidation; SizeList m_InputPatchesSizeForValidation;
SizeList m_TargetPatchesSize; SizeList m_TargetPatchesSize;
// Placeholders and Tensors names
StringList m_InputPlaceholdersForTraining; StringList m_InputPlaceholdersForTraining;
StringList m_InputPlaceholdersForValidation; StringList m_InputPlaceholdersForValidation;
StringList m_TargetTensorsNames; StringList m_TargetTensorsNames;
// Image sources
std::vector<FloatVectorImageType::Pointer> m_InputSourcesForTraining; std::vector<FloatVectorImageType::Pointer> m_InputSourcesForTraining;
std::vector<FloatVectorImageType::Pointer> m_InputSourcesForTest; std::vector<FloatVectorImageType::Pointer> m_InputSourcesForEvaluationAgainstLearningData;
std::vector<FloatVectorImageType::Pointer> m_InputTargetsForTest; std::vector<FloatVectorImageType::Pointer> m_InputSourcesForEvaluationAgainstValidationData;
std::vector<FloatVectorImageType::Pointer> m_InputSourcesForValidation; std::vector<FloatVectorImageType::Pointer> m_InputTargetsForEvaluationAgainstLearningData;
std::vector<FloatVectorImageType::Pointer> m_InputTargetsForValidation; std::vector<FloatVectorImageType::Pointer> m_InputTargetsForEvaluationAgainstValidationData;