Commit 4d9c113a authored by Cresson Remi's avatar Cresson Remi
Browse files

Merge branch 'refactoring' into develop

parents 20bb0192 93390d3b
This diff is collapsed.
...@@ -92,9 +92,11 @@ private: ...@@ -92,9 +92,11 @@ private:
ShareParameter("deepmodel", "tfmodel.model", ShareParameter("deepmodel", "tfmodel.model",
"Deep net model parameters", "Deep net model parameters"); "Deep net model parameters", "Deep net model parameters");
ShareParameter("output", "tfmodel.output", ShareParameter("output", "tfmodel.output",
"Deep net outputs parameters", "Deep net outputs parameters"); "Deep net outputs parameters",
ShareParameter("finetuning", "tfmodel.finetuning", "Deep net outputs parameters");
"Deep net fine tuning parameters","Deep net fine tuning parameters"); ShareParameter("optim", "tfmodel.optim",
"This group of parameters allows optimization of processing time",
"This group of parameters allows optimization of processing time");
// Classify shared parameters // Classify shared parameters
ShareParameter("model" , "classif.model" , "Model file" , "Model file" ); ShareParameter("model" , "classif.model" , "Model file" , "Model file" );
......
...@@ -106,14 +106,14 @@ public: ...@@ -106,14 +106,14 @@ public:
// Parameter group keys // Parameter group keys
ss_key_in << ss_key_group.str() << ".il"; ss_key_in << ss_key_group.str() << ".il";
ss_key_dims_x << ss_key_group.str() << ".fovx"; ss_key_dims_x << ss_key_group.str() << ".rfieldx";
ss_key_dims_y << ss_key_group.str() << ".fovy"; ss_key_dims_y << ss_key_group.str() << ".rfieldy";
ss_key_ph << ss_key_group.str() << ".placeholder"; ss_key_ph << ss_key_group.str() << ".placeholder";
// Parameter group descriptions // Parameter group descriptions
ss_desc_in << "Input image (or list to stack) for source #" << inputNumber; ss_desc_in << "Input image (or list to stack) for source #" << inputNumber;
ss_desc_dims_x << "Field of view width for source #" << inputNumber; ss_desc_dims_x << "Input receptive field (width) for source #" << inputNumber;
ss_desc_dims_y << "Field of view height for source #" << inputNumber; ss_desc_dims_y << "Input receptive field (height) for source #" << inputNumber;
ss_desc_ph << "Name of the input placeholder for source #" << inputNumber; ss_desc_ph << "Name of the input placeholder for source #" << inputNumber;
// Populate group // Populate group
...@@ -182,22 +182,22 @@ public: ...@@ -182,22 +182,22 @@ public:
MandatoryOn ("output.names"); MandatoryOn ("output.names");
// Output Field of Expression // Output Field of Expression
AddParameter(ParameterType_Int, "output.foex", "The output field of expression (x)"); AddParameter(ParameterType_Int, "output.efieldx", "The output expression field (width)");
SetMinimumParameterIntValue ("output.foex", 1); SetMinimumParameterIntValue ("output.efieldx", 1);
SetDefaultParameterInt ("output.foex", 1); SetDefaultParameterInt ("output.efieldx", 1);
MandatoryOn ("output.foex"); MandatoryOn ("output.efieldx");
AddParameter(ParameterType_Int, "output.foey", "The output field of expression (y)"); AddParameter(ParameterType_Int, "output.efieldy", "The output expression field (height)");
SetMinimumParameterIntValue ("output.foey", 1); SetMinimumParameterIntValue ("output.efieldy", 1);
SetDefaultParameterInt ("output.foey", 1); SetDefaultParameterInt ("output.efieldy", 1);
MandatoryOn ("output.foey"); MandatoryOn ("output.efieldy");
// Fine tuning // Fine tuning
AddParameter(ParameterType_Group, "finetuning" , "Fine tuning performance or consistency parameters"); AddParameter(ParameterType_Group, "optim" , "This group of parameters allows optimization of processing time");
AddParameter(ParameterType_Bool, "finetuning.disabletiling", "Disable tiling"); AddParameter(ParameterType_Bool, "optim.disabletiling", "Disable tiling");
MandatoryOff ("finetuning.disabletiling"); MandatoryOff ("optim.disabletiling");
AddParameter(ParameterType_Int, "finetuning.tilesize", "Tile width used to stream the filter output"); AddParameter(ParameterType_Int, "optim.tilesize", "Tile width used to stream the filter output");
SetMinimumParameterIntValue ("finetuning.tilesize", 1); SetMinimumParameterIntValue ("optim.tilesize", 1);
SetDefaultParameterInt ("finetuning.tilesize", 16); SetDefaultParameterInt ("optim.tilesize", 16);
// Output image // Output image
AddParameter(ParameterType_OutputImage, "out", "output image"); AddParameter(ParameterType_OutputImage, "out", "output image");
...@@ -205,8 +205,8 @@ public: ...@@ -205,8 +205,8 @@ public:
// Example // Example
SetDocExampleParameterValue("source1.il", "spot6pms.tif"); SetDocExampleParameterValue("source1.il", "spot6pms.tif");
SetDocExampleParameterValue("source1.placeholder", "x1"); SetDocExampleParameterValue("source1.placeholder", "x1");
SetDocExampleParameterValue("source1.fovx", "16"); SetDocExampleParameterValue("source1.rfieldx", "16");
SetDocExampleParameterValue("source1.fovy", "16"); SetDocExampleParameterValue("source1.rfieldy", "16");
SetDocExampleParameterValue("model.dir", "/tmp/my_saved_model/"); SetDocExampleParameterValue("model.dir", "/tmp/my_saved_model/");
SetDocExampleParameterValue("model.userplaceholders", "is_training=false dropout=0.0"); SetDocExampleParameterValue("model.userplaceholders", "is_training=false dropout=0.0");
SetDocExampleParameterValue("output.names", "out_predict1 out_proba1"); SetDocExampleParameterValue("output.names", "out_predict1 out_proba1");
...@@ -248,16 +248,16 @@ public: ...@@ -248,16 +248,16 @@ public:
m_TFFilter = TFModelFilterType::New(); m_TFFilter = TFModelFilterType::New();
m_TFFilter->SetGraph(m_SavedModel.meta_graph_def.graph_def()); m_TFFilter->SetGraph(m_SavedModel.meta_graph_def.graph_def());
m_TFFilter->SetSession(m_SavedModel.session.get()); m_TFFilter->SetSession(m_SavedModel.session.get());
m_TFFilter->SetOutputTensorsNames(GetParameterStringList("output.names")); m_TFFilter->SetOutputTensors(GetParameterStringList("output.names"));
m_TFFilter->SetOutputSpacingScale(GetParameterFloat("output.spcscale")); m_TFFilter->SetOutputSpacingScale(GetParameterFloat("output.spcscale"));
otbAppLogINFO("Output spacing ratio: " << m_TFFilter->GetOutputSpacingScale()); otbAppLogINFO("Output spacing ratio: " << m_TFFilter->GetOutputSpacingScale());
// Get user placeholders // Get user placeholders
TFModelFilterType::DictListType dict;
TFModelFilterType::StringList expressions = GetParameterStringList("model.userplaceholders"); TFModelFilterType::StringList expressions = GetParameterStringList("model.userplaceholders");
TFModelFilterType::DictType dict;
for (auto& exp: expressions) for (auto& exp: expressions)
{ {
TFModelFilterType::DictType entry = tf::ExpressionToTensor(exp); TFModelFilterType::DictElementType entry = tf::ExpressionToTensor(exp);
dict.push_back(entry); dict.push_back(entry);
otbAppLogINFO("Using placeholder " << entry.first << " with " << tf::PrintTensorInfos(entry.second)); otbAppLogINFO("Using placeholder " << entry.first << " with " << tf::PrintTensorInfos(entry.second));
...@@ -267,7 +267,7 @@ public: ...@@ -267,7 +267,7 @@ public:
// Input sources // Input sources
for (auto& bundle: m_Bundles) for (auto& bundle: m_Bundles)
{ {
m_TFFilter->PushBackInputBundle(bundle.m_Placeholder, bundle.m_PatchSize, bundle.m_ImageSource.Get()); m_TFFilter->PushBackInputTensorBundle(bundle.m_Placeholder, bundle.m_PatchSize, bundle.m_ImageSource.Get());
} }
// Fully convolutional mode on/off // Fully convolutional mode on/off
...@@ -281,15 +281,15 @@ public: ...@@ -281,15 +281,15 @@ public:
FloatVectorImageType::SizeType foe; FloatVectorImageType::SizeType foe;
foe[0] = GetParameterInt("output.foex"); foe[0] = GetParameterInt("output.foex");
foe[1] = GetParameterInt("output.foey"); foe[1] = GetParameterInt("output.foey");
m_TFFilter->SetOutputFOESize(foe); m_TFFilter->SetOutputExpressionFields({foe});
otbAppLogINFO("Output field of expression: " << m_TFFilter->GetOutputFOESize()); otbAppLogINFO("Output field of expression: " << m_TFFilter->GetOutputExpressionFields()[0]);
// Streaming // Streaming
if (GetParameterInt("finetuning.disabletiling")!=1) if (GetParameterInt("optim.disabletiling")!=1)
{ {
// Get the tile size // Get the tile size
const unsigned int tileSize = GetParameterInt("finetuning.tilesize"); const unsigned int tileSize = GetParameterInt("optim.tilesize");
otbAppLogINFO("Force tiling with squared tiles of " << tileSize) otbAppLogINFO("Force tiling with squared tiles of " << tileSize)
// Update the TF filter to get the output image size // Update the TF filter to get the output image size
......
...@@ -53,7 +53,7 @@ public: ...@@ -53,7 +53,7 @@ public:
itkNewMacro(Self); itkNewMacro(Self);
itkTypeMacro(TensorflowModelTrain, Application); itkTypeMacro(TensorflowModelTrain, Application);
/** Typedefs for tensorflow */ /** Typedefs for TensorFlow */
typedef otb::TensorflowMultisourceModelTrain<FloatVectorImageType> TrainModelFilterType; typedef otb::TensorflowMultisourceModelTrain<FloatVectorImageType> TrainModelFilterType;
typedef otb::TensorflowMultisourceModelValidate<FloatVectorImageType> ValidateModelFilterType; typedef otb::TensorflowMultisourceModelValidate<FloatVectorImageType> ValidateModelFilterType;
typedef otb::TensorflowSource<FloatVectorImageType> TFSource; typedef otb::TensorflowSource<FloatVectorImageType> TFSource;
...@@ -75,8 +75,8 @@ public: ...@@ -75,8 +75,8 @@ public:
// Parameters keys // Parameters keys
std::string m_KeyInForTrain; // Key of input image list (training) std::string m_KeyInForTrain; // Key of input image list (training)
std::string m_KeyInForValid; // Key of input image list (validation) std::string m_KeyInForValid; // Key of input image list (validation)
std::string m_KeyPHNameForTrain; // Key for placeholder name in the tensorflow model (training) std::string m_KeyPHNameForTrain; // Key for placeholder name in the TensorFlow model (training)
std::string m_KeyPHNameForValid; // Key for placeholder name in the tensorflow model (validation) std::string m_KeyPHNameForValid; // Key for placeholder name in the TensorFlow model (validation)
std::string m_KeyPszX; // Key for samples sizes X std::string m_KeyPszX; // Key for samples sizes X
std::string m_KeyPszY; // Key for samples sizes Y std::string m_KeyPszY; // Key for samples sizes Y
}; };
...@@ -122,18 +122,19 @@ public: ...@@ -122,18 +122,19 @@ public:
// Parameter group keys // Parameter group keys
ss_key_tr_in << ss_key_tr_group.str() << ".il"; ss_key_tr_in << ss_key_tr_group.str() << ".il";
ss_key_val_in << ss_key_val_group.str() << ".il"; ss_key_val_in << ss_key_val_group.str() << ".il";
ss_key_dims_x << ss_key_tr_group.str() << ".fovx"; ss_key_dims_x << ss_key_tr_group.str() << ".patchsizex";
ss_key_dims_y << ss_key_tr_group.str() << ".fovy"; ss_key_dims_y << ss_key_tr_group.str() << ".patchsizey";
ss_key_tr_ph << ss_key_tr_group.str() << ".placeholder"; ss_key_tr_ph << ss_key_tr_group.str() << ".placeholder";
ss_key_val_ph << ss_key_val_group.str() << ".placeholder"; ss_key_val_ph << ss_key_val_group.str() << ".name";
// Parameter group descriptions // Parameter group descriptions
ss_desc_tr_in << "Input image (or list to stack) for source #" << inputNumber << " (training)"; ss_desc_tr_in << "Input image (or list to stack) for source #" << inputNumber << " (training)";
ss_desc_val_in << "Input image (or list to stack) for source #" << inputNumber << " (validation)"; ss_desc_val_in << "Input image (or list to stack) for source #" << inputNumber << " (validation)";
ss_desc_dims_x << "Field of view width for source #" << inputNumber; ss_desc_dims_x << "Patch size (x) for source #" << inputNumber;
ss_desc_dims_y << "Field of view height for source #" << inputNumber; ss_desc_dims_y << "Patch size (y) for source #" << inputNumber;
ss_desc_tr_ph << "Name of the input placeholder for source #" << inputNumber << " (training)"; ss_desc_tr_ph << "Name of the input placeholder for source #" << inputNumber << " (training)";
ss_desc_val_ph << "Name of the input placeholder for source #" << inputNumber << " (validation)"; ss_desc_val_ph << "Name of the input placeholder "
"or output tensor for source #" << inputNumber << " (validation)";
// Populate group // Populate group
AddParameter(ParameterType_Group, ss_key_tr_group.str(), ss_desc_tr_group.str()); AddParameter(ParameterType_Group, ss_key_tr_group.str(), ss_desc_tr_group.str());
...@@ -194,21 +195,25 @@ public: ...@@ -194,21 +195,25 @@ public:
AddParameter(ParameterType_StringList, "training.userplaceholders", AddParameter(ParameterType_StringList, "training.userplaceholders",
"Additional single-valued placeholders for training. Supported types: int, float, bool."); "Additional single-valued placeholders for training. Supported types: int, float, bool.");
MandatoryOff ("training.userplaceholders"); MandatoryOff ("training.userplaceholders");
AddParameter(ParameterType_StringList, "training.targetnodesnames", "Names of the target nodes"); AddParameter(ParameterType_StringList, "training.targetnodes", "Names of the target nodes");
MandatoryOn ("training.targetnodesnames"); MandatoryOn ("training.targetnodes");
AddParameter(ParameterType_StringList, "training.outputtensorsnames", "Names of the output tensors to display"); AddParameter(ParameterType_StringList, "training.outputtensors", "Names of the output tensors to display");
MandatoryOff ("training.outputtensorsnames"); MandatoryOff ("training.outputtensors");
AddParameter(ParameterType_Bool, "training.usestreaming", "Use the streaming through patches (slower but can process big dataset)");
MandatoryOff ("training.usestreaming");
// Metrics // Metrics
AddParameter(ParameterType_Group, "validation", "Validation parameters"); AddParameter(ParameterType_Group, "validation", "Validation parameters");
MandatoryOff ("validation"); MandatoryOff ("validation");
AddParameter(ParameterType_Choice, "validation.mode", "Metrics to compute"); AddParameter(ParameterType_Choice, "validation.mode", "Metrics to compute");
AddChoice ("validation.mode.none", "No validation step"); AddChoice ("validation.mode.none", "No validation step");
AddChoice ("validation.mode.class", "Classification metrics"); AddChoice ("validation.mode.class", "Classification metrics");
AddChoice ("validation.mode.rmse", "Root mean square error"); AddChoice ("validation.mode.rmse", "Root mean square error");
AddParameter(ParameterType_StringList, "validation.userplaceholders", AddParameter(ParameterType_StringList, "validation.userplaceholders",
"Additional single-valued placeholders for validation. Supported types: int, float, bool."); "Additional single-valued placeholders for validation. Supported types: int, float, bool.");
MandatoryOff ("validation.userplaceholders"); MandatoryOff ("validation.userplaceholders");
AddParameter(ParameterType_Bool, "validation.usestreaming", "Use the streaming through patches (slower but can process big dataset)");
MandatoryOff ("validation.usestreaming");
// Input/output images // Input/output images
AddAnInputImage(); AddAnInputImage();
...@@ -220,15 +225,15 @@ public: ...@@ -220,15 +225,15 @@ public:
// Example // Example
SetDocExampleParameterValue("source1.il", "spot6pms.tif"); SetDocExampleParameterValue("source1.il", "spot6pms.tif");
SetDocExampleParameterValue("source1.placeholder", "x1"); SetDocExampleParameterValue("source1.placeholder", "x1");
SetDocExampleParameterValue("source1.fovx", "16"); SetDocExampleParameterValue("source1.patchsizex", "16");
SetDocExampleParameterValue("source1.fovy", "16"); SetDocExampleParameterValue("source1.patchsizey", "16");
SetDocExampleParameterValue("source2.il", "labels.tif"); SetDocExampleParameterValue("source2.il", "labels.tif");
SetDocExampleParameterValue("source2.placeholder", "y1"); SetDocExampleParameterValue("source2.placeholder", "y1");
SetDocExampleParameterValue("source2.fovx", "1"); SetDocExampleParameterValue("source2.patchsizex", "1");
SetDocExampleParameterValue("source2.fovy", "1"); SetDocExampleParameterValue("source2.patchsizex", "1");
SetDocExampleParameterValue("model.dir", "/tmp/my_saved_model/"); SetDocExampleParameterValue("model.dir", "/tmp/my_saved_model/");
SetDocExampleParameterValue("training.userplaceholders", "is_training=true dropout=0.2"); SetDocExampleParameterValue("training.userplaceholders", "is_training=true dropout=0.2");
SetDocExampleParameterValue("training.targetnodenames", "optimizer"); SetDocExampleParameterValue("training.targetnodes", "optimizer");
SetDocExampleParameterValue("model.saveto", "/tmp/my_saved_model_vars1"); SetDocExampleParameterValue("model.saveto", "/tmp/my_saved_model_vars1");
} }
...@@ -289,7 +294,7 @@ public: ...@@ -289,7 +294,7 @@ public:
m_InputPatchesSizeForTraining.push_back(patchSize); m_InputPatchesSizeForTraining.push_back(patchSize);
otbAppLogINFO("New source:"); otbAppLogINFO("New source:");
otbAppLogINFO("Field of view : "<< patchSize); otbAppLogINFO("Patch size : "<< patchSize);
otbAppLogINFO("Placeholder (training) : "<< placeholderForTraining); otbAppLogINFO("Placeholder (training) : "<< placeholderForTraining);
// Prepare validation sources // Prepare validation sources
...@@ -350,13 +355,13 @@ public: ...@@ -350,13 +355,13 @@ public:
// //
// Get user placeholders // Get user placeholders
// //
TrainModelFilterType::DictListType GetUserPlaceholders(const std::string key) TrainModelFilterType::DictType GetUserPlaceholders(const std::string key)
{ {
TrainModelFilterType::DictListType dict; TrainModelFilterType::DictType dict;
TrainModelFilterType::StringList expressions = GetParameterStringList(key); TrainModelFilterType::StringList expressions = GetParameterStringList(key);
for (auto& exp: expressions) for (auto& exp: expressions)
{ {
TrainModelFilterType::DictType entry = tf::ExpressionToTensor(exp); TrainModelFilterType::DictElementType entry = tf::ExpressionToTensor(exp);
dict.push_back(entry); dict.push_back(entry);
otbAppLogINFO("Using placeholder " << entry.first << " with " << tf::PrintTensorInfos(entry.second)); otbAppLogINFO("Using placeholder " << entry.first << " with " << tf::PrintTensorInfos(entry.second));
...@@ -414,15 +419,16 @@ public: ...@@ -414,15 +419,16 @@ public:
m_TrainModelFilter = TrainModelFilterType::New(); m_TrainModelFilter = TrainModelFilterType::New();
m_TrainModelFilter->SetGraph(m_SavedModel.meta_graph_def.graph_def()); m_TrainModelFilter->SetGraph(m_SavedModel.meta_graph_def.graph_def());
m_TrainModelFilter->SetSession(m_SavedModel.session.get()); m_TrainModelFilter->SetSession(m_SavedModel.session.get());
m_TrainModelFilter->SetOutputTensorsNames(GetParameterStringList("training.outputtensorsnames")); m_TrainModelFilter->SetOutputTensors(GetParameterStringList("training.outputtensors"));
m_TrainModelFilter->SetTargetNodesNames(GetParameterStringList("training.targetnodesnames")); m_TrainModelFilter->SetTargetNodesNames(GetParameterStringList("training.targetnodes"));
m_TrainModelFilter->SetBatchSize(GetParameterInt("training.batchsize")); m_TrainModelFilter->SetBatchSize(GetParameterInt("training.batchsize"));
m_TrainModelFilter->SetUserPlaceholders(GetUserPlaceholders("training.userplaceholders")); m_TrainModelFilter->SetUserPlaceholders(GetUserPlaceholders("training.userplaceholders"));
m_TrainModelFilter->SetUseStreaming(GetParameterInt("training.usestreaming"));
// Set inputs // Set inputs
for (unsigned int i = 0 ; i < m_InputSourcesForTraining.size() ; i++) for (unsigned int i = 0 ; i < m_InputSourcesForTraining.size() ; i++)
{ {
m_TrainModelFilter->PushBackInputBundle( m_TrainModelFilter->PushBackInputTensorBundle(
m_InputPlaceholdersForTraining[i], m_InputPlaceholdersForTraining[i],
m_InputPatchesSizeForTraining[i], m_InputPatchesSizeForTraining[i],
m_InputSourcesForTraining[i]); m_InputSourcesForTraining[i]);
...@@ -454,18 +460,21 @@ public: ...@@ -454,18 +460,21 @@ public:
m_ValidateModelFilter->SetBatchSize(GetParameterInt("training.batchsize")); m_ValidateModelFilter->SetBatchSize(GetParameterInt("training.batchsize"));
m_ValidateModelFilter->SetUserPlaceholders(GetUserPlaceholders("validation.userplaceholders")); m_ValidateModelFilter->SetUserPlaceholders(GetUserPlaceholders("validation.userplaceholders"));
// AS we use the learning data here, it's rational to use the same option as streaming during training
m_ValidateModelFilter->SetUseStreaming(GetParameterInt("training.usestreaming"));
// 1. Evaluate the metrics against the learning data // 1. Evaluate the metrics against the learning data
for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstLearningData.size() ; i++) for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstLearningData.size() ; i++)
{ {
m_ValidateModelFilter->PushBackInputBundle( m_ValidateModelFilter->PushBackInputTensorBundle(
m_InputPlaceholdersForValidation[i], m_InputPlaceholdersForValidation[i],
m_InputPatchesSizeForValidation[i], m_InputPatchesSizeForValidation[i],
m_InputSourcesForEvaluationAgainstLearningData[i]); m_InputSourcesForEvaluationAgainstLearningData[i]);
} }
m_ValidateModelFilter->SetOutputTensorsNames(m_TargetTensorsNames); m_ValidateModelFilter->SetOutputTensors(m_TargetTensorsNames);
m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstLearningData); m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstLearningData);
m_ValidateModelFilter->SetOutputFOESizes(m_TargetPatchesSize); m_ValidateModelFilter->SetOutputExpressionFields(m_TargetPatchesSize);
// Update // Update
AddProcess(m_ValidateModelFilter, "Evaluate model (Learning data)"); AddProcess(m_ValidateModelFilter, "Evaluate model (Learning data)");
...@@ -485,6 +494,7 @@ public: ...@@ -485,6 +494,7 @@ public:
m_ValidateModelFilter->SetInput(i, m_InputSourcesForEvaluationAgainstValidationData[i]); m_ValidateModelFilter->SetInput(i, m_InputSourcesForEvaluationAgainstValidationData[i]);
} }
m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstValidationData); m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstValidationData);
m_ValidateModelFilter->SetUseStreaming(GetParameterInt("validation.usestreaming"));
// Update // Update
AddProcess(m_ValidateModelFilter, "Evaluate model (Validation data)"); AddProcess(m_ValidateModelFilter, "Evaluate model (Validation data)");
...@@ -509,10 +519,13 @@ public: ...@@ -509,10 +519,13 @@ public:
private: private:
tensorflow::SavedModelBundle m_SavedModel; // must be alive during all the execution of the application !
// Filters
TrainModelFilterType::Pointer m_TrainModelFilter; TrainModelFilterType::Pointer m_TrainModelFilter;
ValidateModelFilterType::Pointer m_ValidateModelFilter; ValidateModelFilterType::Pointer m_ValidateModelFilter;
tensorflow::SavedModelBundle m_SavedModel; // must be alive during all the execution of the application !
// Inputs
BundleList m_Bundles; BundleList m_Bundles;
// Patches size // Patches size
......
...@@ -63,7 +63,6 @@ private: ...@@ -63,7 +63,6 @@ private:
} }
void DoInit() void DoInit()
{ {
...@@ -91,7 +90,7 @@ private: ...@@ -91,7 +90,7 @@ private:
} }
ShareParameter("model", "tfmodel.model", "Deep net model parameters", "Deep net model parameters"); ShareParameter("model", "tfmodel.model", "Deep net model parameters", "Deep net model parameters");
ShareParameter("output", "tfmodel.output", "Deep net outputs parameters", "Deep net outputs parameters"); ShareParameter("output", "tfmodel.output", "Deep net outputs parameters", "Deep net outputs parameters");
ShareParameter("finetuning", "tfmodel.finetuning", "Deep net fine tuning parameters", "Deep net fine tuning parameters"); ShareParameter("optim", "tfmodel.optim", "This group of parameters allows optimization of processing time", "This group of parameters allows optimization of processing time");
// Train shared parameters // Train shared parameters
ShareParameter("vd" , "train.io.vd" , "Input vector data list" , "Input vector data list" ); ShareParameter("vd" , "train.io.vd" , "Input vector data list" , "Input vector data list" );
......
...@@ -105,7 +105,9 @@ void GetTensorAttributes(const tensorflow::GraphDef & graph, std::vector<std::st ...@@ -105,7 +105,9 @@ void GetTensorAttributes(const tensorflow::GraphDef & graph, std::vector<std::st
if (node.name().compare((*nameIt)) == 0) if (node.name().compare((*nameIt)) == 0)
{ {
found = true; found = true;
tensorflow::DataType ts_dt;
// Set default to DT_FLOAT
tensorflow::DataType ts_dt = tensorflow::DT_FLOAT;
// Default (input?) tensor type // Default (input?) tensor type
auto test_is_output = node.attr().find("T"); auto test_is_output = node.attr().find("T");
......
...@@ -30,15 +30,36 @@ namespace otb ...@@ -30,15 +30,36 @@ namespace otb
/** /**
* \class TensorflowMultisourceModelBase * \class TensorflowMultisourceModelBase
* \brief This filter is base for TensorFlow model over multiple input images. * \brief This filter is the base class for all TensorFlow model filters.
* *
* The filter takes N input images and feed the TensorFlow model. * This abstract class implements a number of generic methods that are used in
* Names of input placeholders must be specified using the * filters that use the TensorFlow engine.
* SetInputPlaceholdersNames method
* *
* TODO: * The filter has N input images (Input), each one corresponding to a placeholder
* Replace FOV (Field Of View) --> RF (Receptive Field) * that will fed the TensorFlow model. For each input, the name of the
* Replace FEO (Field Of Expr) --> EF (Expression Field) * placeholder (InputPlaceholders, a std::vector of std::string) and the
* receptive field (InputReceptiveFields, a std::vector of SizeType) i.e. the
* input space that the model will "see", must be provided. Hence the number of
* input images, and the size of InputPlaceholders and InputReceptiveFields must
* be the same. If not, an exception will be thrown during the method
* GenerateOutputInformation().
*
* The TensorFlow graph and session must be set using the SetGraph() and
* SetSession() methods.
*
* Target nodes names of the TensorFlow graph that must be triggered can be set
* with the SetTargetNodesNames.
*
* The OutputTensorNames consists in a strd::vector of std::string, and
* corresponds to the names of tensors that will be computed during the session.
* As for input placeholders, output tensors field of expression
* (OutputExpressionFields, a std::vector of SizeType), i.e. the output
* space that the TensorFlow model will "generate", must be provided.
*
* Finally, a list of scalar placeholders can be fed in the form of std::vector
* of std::string, each one expressing the assigment of a signle valued
* placeholder, e.g. "drop_rate=0.5 learning_rate=0.002 toto=true".
* See otb::tf::ExpressionToTensor() to know more about syntax.
* *
* \ingroup OTBTensorflow * \ingroup OTBTensorflow
*/ */
...@@ -72,10 +93,10 @@ public: ...@@ -72,10 +93,10 @@ public:
typedef typename TInputImage::RegionType RegionType; typedef typename TInputImage::RegionType RegionType;
/** Typedefs for parameters */ /** Typedefs for parameters */
typedef std::pair<std::string, tensorflow::Tensor> DictType; typedef std::pair<std::string, tensorflow::Tensor> DictElementType;
typedef std::vector<std::string> StringList; typedef std::vector<std::string> StringList;
typedef std::vector<SizeType> SizeListType; typedef std::vector<SizeType> SizeListType;
typedef std::vector<DictType> DictListType; typedef std::vector<DictElementType> DictType;
typedef std::vector<tensorflow::DataType> DataTypeListType; typedef std::vector<tensorflow::DataType> DataTypeListType;
typedef std::vector<tensorflow::TensorShapeProto> TensorShapeProtoList; typedef std::vector<tensorflow::TensorShapeProto> TensorShapeProtoList;
typedef std::vector<tensorflow::Tensor> TensorListType; typedef std::vector<tensorflow::Tensor> TensorListType;
...@@ -87,27 +108,28 @@ public: ...@@ -87,27 +108,28 @@ public:
tensorflow::Session * GetSession() { return m_Session; } tensorflow::Session * GetSession() { return m_Session; }
/** Model parameters */ /** Model parameters */
void PushBackInputBundle(std::string placeholder, SizeType receptiveField, ImagePointerType image); void PushBackInputTensorBundle(std::string name, SizeType receptiveField, ImagePointerType image);
void PushBackOuputTensorBundle(std::string name, SizeType expressionField);
/** Input placeholders names */
itkSetMacro(InputPlaceholders, StringList);
itkGetMacro(InputPlaceholders, StringList);
// /** Input placeholders names */ /** Receptive field */
// itkSetMacro(InputPlaceholdersNames, StringList); itkSetMacro(InputReceptiveFields, SizeListType);
itkGetMacro(InputPlaceholdersNames, StringList); itkGetMacro(InputReceptiveFields, SizeListType);
//
// /** Receptive field */
// itkSetMacro(InputFOVSizes, SizeListType);
itkGetMacro(InputFOVSizes, SizeListType);
/** Output tensors names */ /** Output tensors names */
itkSetMacro(OutputTensorsNames, StringList); itkSetMacro(OutputTensors, StringList);
itkGetMacro(OutputTensorsNames, StringList); itkGetMacro(OutputTensors, StringList);
/** Expression field */ /** Expression field */
itkSetMacro(OutputFOESizes, SizeListType); itkSetMacro(OutputExpressionFields, SizeListType);
itkGetMacro(OutputFOESizes, SizeListType); itkGetMacro(OutputExpressionFields, SizeListType);