Commit 4d9c113a authored by Cresson Remi's avatar Cresson Remi

Merge branch 'refactoring' into develop

parents 20bb0192 93390d3b
This diff is collapsed.
......@@ -92,9 +92,11 @@ private:
ShareParameter("deepmodel", "tfmodel.model",
"Deep net model parameters", "Deep net model parameters");
ShareParameter("output", "tfmodel.output",
"Deep net outputs parameters", "Deep net outputs parameters");
ShareParameter("finetuning", "tfmodel.finetuning",
"Deep net fine tuning parameters","Deep net fine tuning parameters");
"Deep net outputs parameters",
"Deep net outputs parameters");
ShareParameter("optim", "tfmodel.optim",
"This group of parameters allows optimization of processing time",
"This group of parameters allows optimization of processing time");
// Classify shared parameters
ShareParameter("model" , "classif.model" , "Model file" , "Model file" );
......
......@@ -106,14 +106,14 @@ public:
// Parameter group keys
ss_key_in << ss_key_group.str() << ".il";
ss_key_dims_x << ss_key_group.str() << ".fovx";
ss_key_dims_y << ss_key_group.str() << ".fovy";
ss_key_dims_x << ss_key_group.str() << ".rfieldx";
ss_key_dims_y << ss_key_group.str() << ".rfieldy";
ss_key_ph << ss_key_group.str() << ".placeholder";
// Parameter group descriptions
ss_desc_in << "Input image (or list to stack) for source #" << inputNumber;
ss_desc_dims_x << "Field of view width for source #" << inputNumber;
ss_desc_dims_y << "Field of view height for source #" << inputNumber;
ss_desc_dims_x << "Input receptive field (width) for source #" << inputNumber;
ss_desc_dims_y << "Input receptive field (height) for source #" << inputNumber;
ss_desc_ph << "Name of the input placeholder for source #" << inputNumber;
// Populate group
......@@ -182,22 +182,22 @@ public:
MandatoryOn ("output.names");
// Output Field of Expression
AddParameter(ParameterType_Int, "output.foex", "The output field of expression (x)");
SetMinimumParameterIntValue ("output.foex", 1);
SetDefaultParameterInt ("output.foex", 1);
MandatoryOn ("output.foex");
AddParameter(ParameterType_Int, "output.foey", "The output field of expression (y)");
SetMinimumParameterIntValue ("output.foey", 1);
SetDefaultParameterInt ("output.foey", 1);
MandatoryOn ("output.foey");
AddParameter(ParameterType_Int, "output.efieldx", "The output expression field (width)");
SetMinimumParameterIntValue ("output.efieldx", 1);
SetDefaultParameterInt ("output.efieldx", 1);
MandatoryOn ("output.efieldx");
AddParameter(ParameterType_Int, "output.efieldy", "The output expression field (height)");
SetMinimumParameterIntValue ("output.efieldy", 1);
SetDefaultParameterInt ("output.efieldy", 1);
MandatoryOn ("output.efieldy");
// Fine tuning
AddParameter(ParameterType_Group, "finetuning" , "Fine tuning performance or consistency parameters");
AddParameter(ParameterType_Bool, "finetuning.disabletiling", "Disable tiling");
MandatoryOff ("finetuning.disabletiling");
AddParameter(ParameterType_Int, "finetuning.tilesize", "Tile width used to stream the filter output");
SetMinimumParameterIntValue ("finetuning.tilesize", 1);
SetDefaultParameterInt ("finetuning.tilesize", 16);
AddParameter(ParameterType_Group, "optim" , "This group of parameters allows optimization of processing time");
AddParameter(ParameterType_Bool, "optim.disabletiling", "Disable tiling");
MandatoryOff ("optim.disabletiling");
AddParameter(ParameterType_Int, "optim.tilesize", "Tile width used to stream the filter output");
SetMinimumParameterIntValue ("optim.tilesize", 1);
SetDefaultParameterInt ("optim.tilesize", 16);
// Output image
AddParameter(ParameterType_OutputImage, "out", "output image");
......@@ -205,8 +205,8 @@ public:
// Example
SetDocExampleParameterValue("source1.il", "spot6pms.tif");
SetDocExampleParameterValue("source1.placeholder", "x1");
SetDocExampleParameterValue("source1.fovx", "16");
SetDocExampleParameterValue("source1.fovy", "16");
SetDocExampleParameterValue("source1.rfieldx", "16");
SetDocExampleParameterValue("source1.rfieldy", "16");
SetDocExampleParameterValue("model.dir", "/tmp/my_saved_model/");
SetDocExampleParameterValue("model.userplaceholders", "is_training=false dropout=0.0");
SetDocExampleParameterValue("output.names", "out_predict1 out_proba1");
......@@ -248,16 +248,16 @@ public:
m_TFFilter = TFModelFilterType::New();
m_TFFilter->SetGraph(m_SavedModel.meta_graph_def.graph_def());
m_TFFilter->SetSession(m_SavedModel.session.get());
m_TFFilter->SetOutputTensorsNames(GetParameterStringList("output.names"));
m_TFFilter->SetOutputTensors(GetParameterStringList("output.names"));
m_TFFilter->SetOutputSpacingScale(GetParameterFloat("output.spcscale"));
otbAppLogINFO("Output spacing ratio: " << m_TFFilter->GetOutputSpacingScale());
// Get user placeholders
TFModelFilterType::DictListType dict;
TFModelFilterType::StringList expressions = GetParameterStringList("model.userplaceholders");
TFModelFilterType::DictType dict;
for (auto& exp: expressions)
{
TFModelFilterType::DictType entry = tf::ExpressionToTensor(exp);
TFModelFilterType::DictElementType entry = tf::ExpressionToTensor(exp);
dict.push_back(entry);
otbAppLogINFO("Using placeholder " << entry.first << " with " << tf::PrintTensorInfos(entry.second));
......@@ -267,7 +267,7 @@ public:
// Input sources
for (auto& bundle: m_Bundles)
{
m_TFFilter->PushBackInputBundle(bundle.m_Placeholder, bundle.m_PatchSize, bundle.m_ImageSource.Get());
m_TFFilter->PushBackInputTensorBundle(bundle.m_Placeholder, bundle.m_PatchSize, bundle.m_ImageSource.Get());
}
// Fully convolutional mode on/off
......@@ -281,15 +281,15 @@ public:
FloatVectorImageType::SizeType foe;
foe[0] = GetParameterInt("output.foex");
foe[1] = GetParameterInt("output.foey");
m_TFFilter->SetOutputFOESize(foe);
m_TFFilter->SetOutputExpressionFields({foe});
otbAppLogINFO("Output field of expression: " << m_TFFilter->GetOutputFOESize());
otbAppLogINFO("Output field of expression: " << m_TFFilter->GetOutputExpressionFields()[0]);
// Streaming
if (GetParameterInt("finetuning.disabletiling")!=1)
if (GetParameterInt("optim.disabletiling")!=1)
{
// Get the tile size
const unsigned int tileSize = GetParameterInt("finetuning.tilesize");
const unsigned int tileSize = GetParameterInt("optim.tilesize");
otbAppLogINFO("Force tiling with squared tiles of " << tileSize)
// Update the TF filter to get the output image size
......
......@@ -53,7 +53,7 @@ public:
itkNewMacro(Self);
itkTypeMacro(TensorflowModelTrain, Application);
/** Typedefs for tensorflow */
/** Typedefs for TensorFlow */
typedef otb::TensorflowMultisourceModelTrain<FloatVectorImageType> TrainModelFilterType;
typedef otb::TensorflowMultisourceModelValidate<FloatVectorImageType> ValidateModelFilterType;
typedef otb::TensorflowSource<FloatVectorImageType> TFSource;
......@@ -75,8 +75,8 @@ public:
// Parameters keys
std::string m_KeyInForTrain; // Key of input image list (training)
std::string m_KeyInForValid; // Key of input image list (validation)
std::string m_KeyPHNameForTrain; // Key for placeholder name in the tensorflow model (training)
std::string m_KeyPHNameForValid; // Key for placeholder name in the tensorflow model (validation)
std::string m_KeyPHNameForTrain; // Key for placeholder name in the TensorFlow model (training)
std::string m_KeyPHNameForValid; // Key for placeholder name in the TensorFlow model (validation)
std::string m_KeyPszX; // Key for samples sizes X
std::string m_KeyPszY; // Key for samples sizes Y
};
......@@ -122,18 +122,19 @@ public:
// Parameter group keys
ss_key_tr_in << ss_key_tr_group.str() << ".il";
ss_key_val_in << ss_key_val_group.str() << ".il";
ss_key_dims_x << ss_key_tr_group.str() << ".fovx";
ss_key_dims_y << ss_key_tr_group.str() << ".fovy";
ss_key_dims_x << ss_key_tr_group.str() << ".patchsizex";
ss_key_dims_y << ss_key_tr_group.str() << ".patchsizey";
ss_key_tr_ph << ss_key_tr_group.str() << ".placeholder";
ss_key_val_ph << ss_key_val_group.str() << ".placeholder";
ss_key_val_ph << ss_key_val_group.str() << ".name";
// Parameter group descriptions
ss_desc_tr_in << "Input image (or list to stack) for source #" << inputNumber << " (training)";
ss_desc_val_in << "Input image (or list to stack) for source #" << inputNumber << " (validation)";
ss_desc_dims_x << "Field of view width for source #" << inputNumber;
ss_desc_dims_y << "Field of view height for source #" << inputNumber;
ss_desc_dims_x << "Patch size (x) for source #" << inputNumber;
ss_desc_dims_y << "Patch size (y) for source #" << inputNumber;
ss_desc_tr_ph << "Name of the input placeholder for source #" << inputNumber << " (training)";
ss_desc_val_ph << "Name of the input placeholder for source #" << inputNumber << " (validation)";
ss_desc_val_ph << "Name of the input placeholder "
"or output tensor for source #" << inputNumber << " (validation)";
// Populate group
AddParameter(ParameterType_Group, ss_key_tr_group.str(), ss_desc_tr_group.str());
......@@ -194,21 +195,25 @@ public:
AddParameter(ParameterType_StringList, "training.userplaceholders",
"Additional single-valued placeholders for training. Supported types: int, float, bool.");
MandatoryOff ("training.userplaceholders");
AddParameter(ParameterType_StringList, "training.targetnodesnames", "Names of the target nodes");
MandatoryOn ("training.targetnodesnames");
AddParameter(ParameterType_StringList, "training.outputtensorsnames", "Names of the output tensors to display");
MandatoryOff ("training.outputtensorsnames");
AddParameter(ParameterType_StringList, "training.targetnodes", "Names of the target nodes");
MandatoryOn ("training.targetnodes");
AddParameter(ParameterType_StringList, "training.outputtensors", "Names of the output tensors to display");
MandatoryOff ("training.outputtensors");
AddParameter(ParameterType_Bool, "training.usestreaming", "Use the streaming through patches (slower but can process big dataset)");
MandatoryOff ("training.usestreaming");
// Metrics
AddParameter(ParameterType_Group, "validation", "Validation parameters");
AddParameter(ParameterType_Group, "validation", "Validation parameters");
MandatoryOff ("validation");
AddParameter(ParameterType_Choice, "validation.mode", "Metrics to compute");
AddChoice ("validation.mode.none", "No validation step");
AddChoice ("validation.mode.class", "Classification metrics");
AddChoice ("validation.mode.rmse", "Root mean square error");
AddParameter(ParameterType_Choice, "validation.mode", "Metrics to compute");
AddChoice ("validation.mode.none", "No validation step");
AddChoice ("validation.mode.class", "Classification metrics");
AddChoice ("validation.mode.rmse", "Root mean square error");
AddParameter(ParameterType_StringList, "validation.userplaceholders",
"Additional single-valued placeholders for validation. Supported types: int, float, bool.");
MandatoryOff ("validation.userplaceholders");
AddParameter(ParameterType_Bool, "validation.usestreaming", "Use the streaming through patches (slower but can process big dataset)");
MandatoryOff ("validation.usestreaming");
// Input/output images
AddAnInputImage();
......@@ -220,15 +225,15 @@ public:
// Example
SetDocExampleParameterValue("source1.il", "spot6pms.tif");
SetDocExampleParameterValue("source1.placeholder", "x1");
SetDocExampleParameterValue("source1.fovx", "16");
SetDocExampleParameterValue("source1.fovy", "16");
SetDocExampleParameterValue("source1.patchsizex", "16");
SetDocExampleParameterValue("source1.patchsizey", "16");
SetDocExampleParameterValue("source2.il", "labels.tif");
SetDocExampleParameterValue("source2.placeholder", "y1");
SetDocExampleParameterValue("source2.fovx", "1");
SetDocExampleParameterValue("source2.fovy", "1");
SetDocExampleParameterValue("source2.patchsizex", "1");
SetDocExampleParameterValue("source2.patchsizex", "1");
SetDocExampleParameterValue("model.dir", "/tmp/my_saved_model/");
SetDocExampleParameterValue("training.userplaceholders", "is_training=true dropout=0.2");
SetDocExampleParameterValue("training.targetnodenames", "optimizer");
SetDocExampleParameterValue("training.targetnodes", "optimizer");
SetDocExampleParameterValue("model.saveto", "/tmp/my_saved_model_vars1");
}
......@@ -289,7 +294,7 @@ public:
m_InputPatchesSizeForTraining.push_back(patchSize);
otbAppLogINFO("New source:");
otbAppLogINFO("Field of view : "<< patchSize);
otbAppLogINFO("Patch size : "<< patchSize);
otbAppLogINFO("Placeholder (training) : "<< placeholderForTraining);
// Prepare validation sources
......@@ -350,13 +355,13 @@ public:
//
// Get user placeholders
//
TrainModelFilterType::DictListType GetUserPlaceholders(const std::string key)
TrainModelFilterType::DictType GetUserPlaceholders(const std::string key)
{
TrainModelFilterType::DictListType dict;
TrainModelFilterType::DictType dict;
TrainModelFilterType::StringList expressions = GetParameterStringList(key);
for (auto& exp: expressions)
{
TrainModelFilterType::DictType entry = tf::ExpressionToTensor(exp);
TrainModelFilterType::DictElementType entry = tf::ExpressionToTensor(exp);
dict.push_back(entry);
otbAppLogINFO("Using placeholder " << entry.first << " with " << tf::PrintTensorInfos(entry.second));
......@@ -414,15 +419,16 @@ public:
m_TrainModelFilter = TrainModelFilterType::New();
m_TrainModelFilter->SetGraph(m_SavedModel.meta_graph_def.graph_def());
m_TrainModelFilter->SetSession(m_SavedModel.session.get());
m_TrainModelFilter->SetOutputTensorsNames(GetParameterStringList("training.outputtensorsnames"));
m_TrainModelFilter->SetTargetNodesNames(GetParameterStringList("training.targetnodesnames"));
m_TrainModelFilter->SetOutputTensors(GetParameterStringList("training.outputtensors"));
m_TrainModelFilter->SetTargetNodesNames(GetParameterStringList("training.targetnodes"));
m_TrainModelFilter->SetBatchSize(GetParameterInt("training.batchsize"));
m_TrainModelFilter->SetUserPlaceholders(GetUserPlaceholders("training.userplaceholders"));
m_TrainModelFilter->SetUseStreaming(GetParameterInt("training.usestreaming"));
// Set inputs
for (unsigned int i = 0 ; i < m_InputSourcesForTraining.size() ; i++)
{
m_TrainModelFilter->PushBackInputBundle(
m_TrainModelFilter->PushBackInputTensorBundle(
m_InputPlaceholdersForTraining[i],
m_InputPatchesSizeForTraining[i],
m_InputSourcesForTraining[i]);
......@@ -454,18 +460,21 @@ public:
m_ValidateModelFilter->SetBatchSize(GetParameterInt("training.batchsize"));
m_ValidateModelFilter->SetUserPlaceholders(GetUserPlaceholders("validation.userplaceholders"));
// AS we use the learning data here, it's rational to use the same option as streaming during training
m_ValidateModelFilter->SetUseStreaming(GetParameterInt("training.usestreaming"));
// 1. Evaluate the metrics against the learning data
for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstLearningData.size() ; i++)
{
m_ValidateModelFilter->PushBackInputBundle(
m_ValidateModelFilter->PushBackInputTensorBundle(
m_InputPlaceholdersForValidation[i],
m_InputPatchesSizeForValidation[i],
m_InputSourcesForEvaluationAgainstLearningData[i]);
}
m_ValidateModelFilter->SetOutputTensorsNames(m_TargetTensorsNames);
m_ValidateModelFilter->SetOutputTensors(m_TargetTensorsNames);
m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstLearningData);
m_ValidateModelFilter->SetOutputFOESizes(m_TargetPatchesSize);
m_ValidateModelFilter->SetOutputExpressionFields(m_TargetPatchesSize);
// Update
AddProcess(m_ValidateModelFilter, "Evaluate model (Learning data)");
......@@ -485,6 +494,7 @@ public:
m_ValidateModelFilter->SetInput(i, m_InputSourcesForEvaluationAgainstValidationData[i]);
}
m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstValidationData);
m_ValidateModelFilter->SetUseStreaming(GetParameterInt("validation.usestreaming"));
// Update
AddProcess(m_ValidateModelFilter, "Evaluate model (Validation data)");
......@@ -509,10 +519,13 @@ public:
private:
tensorflow::SavedModelBundle m_SavedModel; // must be alive during all the execution of the application !
// Filters
TrainModelFilterType::Pointer m_TrainModelFilter;
ValidateModelFilterType::Pointer m_ValidateModelFilter;
tensorflow::SavedModelBundle m_SavedModel; // must be alive during all the execution of the application !
// Inputs
BundleList m_Bundles;
// Patches size
......
......@@ -63,7 +63,6 @@ private:
}
void DoInit()
{
......@@ -91,7 +90,7 @@ private:
}
ShareParameter("model", "tfmodel.model", "Deep net model parameters", "Deep net model parameters");
ShareParameter("output", "tfmodel.output", "Deep net outputs parameters", "Deep net outputs parameters");
ShareParameter("finetuning", "tfmodel.finetuning", "Deep net fine tuning parameters", "Deep net fine tuning parameters");
ShareParameter("optim", "tfmodel.optim", "This group of parameters allows optimization of processing time", "This group of parameters allows optimization of processing time");
// Train shared parameters
ShareParameter("vd" , "train.io.vd" , "Input vector data list" , "Input vector data list" );
......
......@@ -105,7 +105,9 @@ void GetTensorAttributes(const tensorflow::GraphDef & graph, std::vector<std::st
if (node.name().compare((*nameIt)) == 0)
{
found = true;
tensorflow::DataType ts_dt;
// Set default to DT_FLOAT
tensorflow::DataType ts_dt = tensorflow::DT_FLOAT;
// Default (input?) tensor type
auto test_is_output = node.attr().find("T");
......
......@@ -30,15 +30,36 @@ namespace otb
/**
* \class TensorflowMultisourceModelBase
* \brief This filter is base for TensorFlow model over multiple input images.
* \brief This filter is the base class for all TensorFlow model filters.
*
* The filter takes N input images and feed the TensorFlow model.
* Names of input placeholders must be specified using the
* SetInputPlaceholdersNames method
* This abstract class implements a number of generic methods that are used in
* filters that use the TensorFlow engine.
*
* TODO:
* Replace FOV (Field Of View) --> RF (Receptive Field)
* Replace FEO (Field Of Expr) --> EF (Expression Field)
* The filter has N input images (Input), each one corresponding to a placeholder
* that will fed the TensorFlow model. For each input, the name of the
* placeholder (InputPlaceholders, a std::vector of std::string) and the
* receptive field (InputReceptiveFields, a std::vector of SizeType) i.e. the
* input space that the model will "see", must be provided. Hence the number of
* input images, and the size of InputPlaceholders and InputReceptiveFields must
* be the same. If not, an exception will be thrown during the method
* GenerateOutputInformation().
*
* The TensorFlow graph and session must be set using the SetGraph() and
* SetSession() methods.
*
* Target nodes names of the TensorFlow graph that must be triggered can be set
* with the SetTargetNodesNames.
*
* The OutputTensorNames consists in a strd::vector of std::string, and
* corresponds to the names of tensors that will be computed during the session.
* As for input placeholders, output tensors field of expression
* (OutputExpressionFields, a std::vector of SizeType), i.e. the output
* space that the TensorFlow model will "generate", must be provided.
*
* Finally, a list of scalar placeholders can be fed in the form of std::vector
* of std::string, each one expressing the assigment of a signle valued
* placeholder, e.g. "drop_rate=0.5 learning_rate=0.002 toto=true".
* See otb::tf::ExpressionToTensor() to know more about syntax.
*
* \ingroup OTBTensorflow
*/
......@@ -72,10 +93,10 @@ public:
typedef typename TInputImage::RegionType RegionType;
/** Typedefs for parameters */
typedef std::pair<std::string, tensorflow::Tensor> DictType;
typedef std::pair<std::string, tensorflow::Tensor> DictElementType;
typedef std::vector<std::string> StringList;
typedef std::vector<SizeType> SizeListType;
typedef std::vector<DictType> DictListType;
typedef std::vector<DictElementType> DictType;
typedef std::vector<tensorflow::DataType> DataTypeListType;
typedef std::vector<tensorflow::TensorShapeProto> TensorShapeProtoList;
typedef std::vector<tensorflow::Tensor> TensorListType;
......@@ -87,27 +108,28 @@ public:
tensorflow::Session * GetSession() { return m_Session; }
/** Model parameters */
void PushBackInputBundle(std::string placeholder, SizeType receptiveField, ImagePointerType image);
void PushBackInputTensorBundle(std::string name, SizeType receptiveField, ImagePointerType image);
void PushBackOuputTensorBundle(std::string name, SizeType expressionField);
/** Input placeholders names */
itkSetMacro(InputPlaceholders, StringList);
itkGetMacro(InputPlaceholders, StringList);
// /** Input placeholders names */
// itkSetMacro(InputPlaceholdersNames, StringList);
itkGetMacro(InputPlaceholdersNames, StringList);
//
// /** Receptive field */
// itkSetMacro(InputFOVSizes, SizeListType);
itkGetMacro(InputFOVSizes, SizeListType);
/** Receptive field */
itkSetMacro(InputReceptiveFields, SizeListType);
itkGetMacro(InputReceptiveFields, SizeListType);
/** Output tensors names */
itkSetMacro(OutputTensorsNames, StringList);
itkGetMacro(OutputTensorsNames, StringList);
itkSetMacro(OutputTensors, StringList);
itkGetMacro(OutputTensors, StringList);
/** Expression field */
itkSetMacro(OutputFOESizes, SizeListType);
itkGetMacro(OutputFOESizes, SizeListType);
itkSetMacro(OutputExpressionFields, SizeListType);
itkGetMacro(OutputExpressionFields, SizeListType);
/** User placeholders */
void SetUserPlaceholders(DictListType dict) { m_UserPlaceholders = dict; }
DictListType GetUserPlaceholders() { return m_UserPlaceholders; }
void SetUserPlaceholders(DictType dict) { m_UserPlaceholders = dict; }
DictType GetUserPlaceholders() { return m_UserPlaceholders; }
/** Target nodes names */
itkSetMacro(TargetNodesNames, StringList);
......@@ -125,27 +147,27 @@ protected:
TensorflowMultisourceModelBase();
virtual ~TensorflowMultisourceModelBase() {};
virtual std::stringstream GenerateDebugReport(DictListType & inputs, TensorListType & outputs);
virtual std::stringstream GenerateDebugReport(DictType & inputs);
virtual void RunSession(DictListType & inputs, TensorListType & outputs);
virtual void RunSession(DictType & inputs, TensorListType & outputs);
private:
TensorflowMultisourceModelBase(const Self&); //purposely not implemented
void operator=(const Self&); //purposely not implemented
// Tensorflow graph and session
tensorflow::GraphDef m_Graph; // The tensorflow graph
tensorflow::Session * m_Session; // The tensorflow session
tensorflow::GraphDef m_Graph; // The TensorFlow graph
tensorflow::Session * m_Session; // The TensorFlow session
// Model parameters
StringList m_InputPlaceholdersNames; // Input placeholders names
SizeListType m_InputFOVSizes; // Input tensors field of view (FOV) sizes
SizeListType m_OutputFOESizes; // Output tensors field of expression (FOE) sizes
DictListType m_UserPlaceholders; // User placeholders
StringList m_OutputTensorsNames; // User tensors
StringList m_TargetNodesNames; // User target tensors
// Read-only
StringList m_InputPlaceholders; // Input placeholders names
SizeListType m_InputReceptiveFields; // Input receptive fields
StringList m_OutputTensors; // Output tensors names
SizeListType m_OutputExpressionFields; // Output expression fields
DictType m_UserPlaceholders; // User placeholders
StringList m_TargetNodesNames; // User nodes target
// Internal, read-only
DataTypeListType m_InputTensorsDataTypes; // Input tensors datatype
DataTypeListType m_OutputTensorsDataTypes; // Output tensors datatype
TensorShapeProtoList m_InputTensorsShapes; // Input tensors shapes
......
......@@ -20,22 +20,23 @@ template <class TInputImage, class TOutputImage>
TensorflowMultisourceModelBase<TInputImage, TOutputImage>
::TensorflowMultisourceModelBase()
{
m_Session = nullptr;
}
template <class TInputImage, class TOutputImage>
void
TensorflowMultisourceModelBase<TInputImage, TOutputImage>
::PushBackInputBundle(std::string placeholder, SizeType receptiveField, ImagePointerType image)
::PushBackInputTensorBundle(std::string placeholder, SizeType receptiveField, ImagePointerType image)
{
Superclass::PushBackInput(image);
m_InputFOVSizes.push_back(receptiveField);
m_InputPlaceholdersNames.push_back(placeholder);
m_InputReceptiveFields.push_back(receptiveField);
m_InputPlaceholders.push_back(placeholder);
}
template <class TInputImage, class TOutputImage>
std::stringstream
TensorflowMultisourceModelBase<TInputImage, TOutputImage>
::GenerateDebugReport(DictListType & inputs, TensorListType & outputs)
::GenerateDebugReport(DictType & inputs)
{
// Create a debug report
std::stringstream debugReport;
......@@ -69,7 +70,7 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage>
template <class TInputImage, class TOutputImage>
void
TensorflowMultisourceModelBase<TInputImage, TOutputImage>
::RunSession(DictListType & inputs, TensorListType & outputs)
::RunSession(DictType & inputs, TensorListType & outputs)
{
// Add the user's placeholders
......@@ -82,11 +83,11 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage>
// The session will initialize the outputs
// Run the session, evaluating our output tensors from the graph
auto status = this->GetSession()->Run(inputs, m_OutputTensorsNames, m_TargetNodesNames, &outputs);
auto status = this->GetSession()->Run(inputs, m_OutputTensors, m_TargetNodesNames, &outputs);
if (!status.ok()) {
// Create a debug report
std::stringstream debugReport = GenerateDebugReport(inputs, outputs);
std::stringstream debugReport = GenerateDebugReport(inputs);
// Throw an exception with the report
itkExceptionMacro("Can't run the tensorflow session !\n" <<
......@@ -104,15 +105,24 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage>
{
// Check that the number of the following is the same
// - placeholders names
// - patches sizes
// - input image
// - input placeholders names
// - input receptive fields
// - input images
const unsigned int nbInputs = this->GetNumberOfInputs();
if (nbInputs != m_InputFOVSizes.size() || nbInputs != m_InputPlaceholdersNames.size())
if (nbInputs != m_InputReceptiveFields.size() || nbInputs != m_InputPlaceholders.size())
{
itkExceptionMacro("Number of input images is " << nbInputs <<
" but the number of input patches size is " << m_InputFOVSizes.size() <<
" and the number of input tensors names is " << m_InputPlaceholdersNames.size());
" but the number of input patches size is " << m_InputReceptiveFields.size() <<
" and the number of input tensors names is " << m_InputPlaceholders.size());
}
// Check that the number of the following is the same
// - output tensors names
// - output expression fields
if (m_OutputExpressionFields.size() != m_OutputTensors.size())
{
itkExceptionMacro("Number of output tensors names is " << m_OutputTensors.size() <<
" but the number of output fields of expression is " << m_OutputExpressionFields.size());
}
//////////////////////////////////////////////////////////////////////////////////////////
......@@ -120,8 +130,8 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage>
//////////////////////////////////////////////////////////////////////////////////////////
// Get input and output tensors datatypes and shapes
tf::GetTensorAttributes(m_Graph, m_InputPlaceholdersNames, m_InputTensorsShapes, m_InputTensorsDataTypes);
tf::GetTensorAttributes(m_Graph, m_OutputTensorsNames, m_OutputTensorsShapes, m_OutputTensorsDataTypes);
tf::GetTensorAttributes(m_Graph, m_InputPlaceholders, m_InputTensorsShapes, m_InputTensorsDataTypes);
tf::GetTensorAttributes(m_Graph, m_OutputTensors, m_OutputTensorsShapes, m_OutputTensorsDataTypes);
}
......
......@@ -26,31 +26,53 @@ namespace otb
/**
* \class TensorflowMultisourceModelFilter
* \brief This filter apply a TensorFlow model over multiple input images.
* \brief This filter apply a TensorFlow model over multiple input images and
* generates one output image corresponding to outputs of the model.
*
* The filter takes N input images and feed the TensorFlow model to produce
* one output image of desired TF op results.
* Names of input/output placeholders/tensors must be specified using the
* SetInputPlaceholdersNames/SetOutputTensorNames.
* one output image corresponding to the desired results of the TensorFlow model.