Commit 70004b8e authored by Cresson Remi's avatar Cresson Remi

Merge branch 'develop'

parents 5c0ec1a9 4d9c113a
This diff is collapsed.
......@@ -92,9 +92,11 @@ private:
ShareParameter("deepmodel", "tfmodel.model",
"Deep net model parameters", "Deep net model parameters");
ShareParameter("output", "tfmodel.output",
"Deep net outputs parameters", "Deep net outputs parameters");
ShareParameter("finetuning", "tfmodel.finetuning",
"Deep net fine tuning parameters","Deep net fine tuning parameters");
"Deep net outputs parameters",
"Deep net outputs parameters");
ShareParameter("optim", "tfmodel.optim",
"This group of parameters allows optimization of processing time",
"This group of parameters allows optimization of processing time");
// Classify shared parameters
ShareParameter("model" , "classif.model" , "Model file" , "Model file" );
......
......@@ -106,14 +106,14 @@ public:
// Parameter group keys
ss_key_in << ss_key_group.str() << ".il";
ss_key_dims_x << ss_key_group.str() << ".fovx";
ss_key_dims_y << ss_key_group.str() << ".fovy";
ss_key_dims_x << ss_key_group.str() << ".rfieldx";
ss_key_dims_y << ss_key_group.str() << ".rfieldy";
ss_key_ph << ss_key_group.str() << ".placeholder";
// Parameter group descriptions
ss_desc_in << "Input image (or list to stack) for source #" << inputNumber;
ss_desc_dims_x << "Field of view width for source #" << inputNumber;
ss_desc_dims_y << "Field of view height for source #" << inputNumber;
ss_desc_dims_x << "Input receptive field (width) for source #" << inputNumber;
ss_desc_dims_y << "Input receptive field (height) for source #" << inputNumber;
ss_desc_ph << "Name of the input placeholder for source #" << inputNumber;
// Populate group
......@@ -182,22 +182,22 @@ public:
MandatoryOn ("output.names");
// Output Field of Expression
AddParameter(ParameterType_Int, "output.foex", "The output field of expression (x)");
SetMinimumParameterIntValue ("output.foex", 1);
SetDefaultParameterInt ("output.foex", 1);
MandatoryOn ("output.foex");
AddParameter(ParameterType_Int, "output.foey", "The output field of expression (y)");
SetMinimumParameterIntValue ("output.foey", 1);
SetDefaultParameterInt ("output.foey", 1);
MandatoryOn ("output.foey");
AddParameter(ParameterType_Int, "output.efieldx", "The output expression field (width)");
SetMinimumParameterIntValue ("output.efieldx", 1);
SetDefaultParameterInt ("output.efieldx", 1);
MandatoryOn ("output.efieldx");
AddParameter(ParameterType_Int, "output.efieldy", "The output expression field (height)");
SetMinimumParameterIntValue ("output.efieldy", 1);
SetDefaultParameterInt ("output.efieldy", 1);
MandatoryOn ("output.efieldy");
// Fine tuning
AddParameter(ParameterType_Group, "finetuning" , "Fine tuning performance or consistency parameters");
AddParameter(ParameterType_Bool, "finetuning.disabletiling", "Disable tiling");
MandatoryOff ("finetuning.disabletiling");
AddParameter(ParameterType_Int, "finetuning.tilesize", "Tile width used to stream the filter output");
SetMinimumParameterIntValue ("finetuning.tilesize", 1);
SetDefaultParameterInt ("finetuning.tilesize", 16);
AddParameter(ParameterType_Group, "optim" , "This group of parameters allows optimization of processing time");
AddParameter(ParameterType_Bool, "optim.disabletiling", "Disable tiling");
MandatoryOff ("optim.disabletiling");
AddParameter(ParameterType_Int, "optim.tilesize", "Tile width used to stream the filter output");
SetMinimumParameterIntValue ("optim.tilesize", 1);
SetDefaultParameterInt ("optim.tilesize", 16);
// Output image
AddParameter(ParameterType_OutputImage, "out", "output image");
......@@ -205,8 +205,8 @@ public:
// Example
SetDocExampleParameterValue("source1.il", "spot6pms.tif");
SetDocExampleParameterValue("source1.placeholder", "x1");
SetDocExampleParameterValue("source1.fovx", "16");
SetDocExampleParameterValue("source1.fovy", "16");
SetDocExampleParameterValue("source1.rfieldx", "16");
SetDocExampleParameterValue("source1.rfieldy", "16");
SetDocExampleParameterValue("model.dir", "/tmp/my_saved_model/");
SetDocExampleParameterValue("model.userplaceholders", "is_training=false dropout=0.0");
SetDocExampleParameterValue("output.names", "out_predict1 out_proba1");
......@@ -248,16 +248,16 @@ public:
m_TFFilter = TFModelFilterType::New();
m_TFFilter->SetGraph(m_SavedModel.meta_graph_def.graph_def());
m_TFFilter->SetSession(m_SavedModel.session.get());
m_TFFilter->SetOutputTensorsNames(GetParameterStringList("output.names"));
m_TFFilter->SetOutputTensors(GetParameterStringList("output.names"));
m_TFFilter->SetOutputSpacingScale(GetParameterFloat("output.spcscale"));
otbAppLogINFO("Output spacing ratio: " << m_TFFilter->GetOutputSpacingScale());
// Get user placeholders
TFModelFilterType::DictListType dict;
TFModelFilterType::StringList expressions = GetParameterStringList("model.userplaceholders");
TFModelFilterType::DictType dict;
for (auto& exp: expressions)
{
TFModelFilterType::DictType entry = tf::ExpressionToTensor(exp);
TFModelFilterType::DictElementType entry = tf::ExpressionToTensor(exp);
dict.push_back(entry);
otbAppLogINFO("Using placeholder " << entry.first << " with " << tf::PrintTensorInfos(entry.second));
......@@ -267,7 +267,7 @@ public:
// Input sources
for (auto& bundle: m_Bundles)
{
m_TFFilter->PushBackInputBundle(bundle.m_Placeholder, bundle.m_PatchSize, bundle.m_ImageSource.Get());
m_TFFilter->PushBackInputTensorBundle(bundle.m_Placeholder, bundle.m_PatchSize, bundle.m_ImageSource.Get());
}
// Fully convolutional mode on/off
......@@ -281,15 +281,15 @@ public:
FloatVectorImageType::SizeType foe;
foe[0] = GetParameterInt("output.foex");
foe[1] = GetParameterInt("output.foey");
m_TFFilter->SetOutputFOESize(foe);
m_TFFilter->SetOutputExpressionFields({foe});
otbAppLogINFO("Output field of expression: " << m_TFFilter->GetOutputFOESize());
otbAppLogINFO("Output field of expression: " << m_TFFilter->GetOutputExpressionFields()[0]);
// Streaming
if (GetParameterInt("finetuning.disabletiling")!=1)
if (GetParameterInt("optim.disabletiling")!=1)
{
// Get the tile size
const unsigned int tileSize = GetParameterInt("finetuning.tilesize");
const unsigned int tileSize = GetParameterInt("optim.tilesize");
otbAppLogINFO("Force tiling with squared tiles of " << tileSize)
// Update the TF filter to get the output image size
......
This diff is collapsed.
......@@ -63,7 +63,6 @@ private:
}
void DoInit()
{
......@@ -91,7 +90,7 @@ private:
}
ShareParameter("model", "tfmodel.model", "Deep net model parameters", "Deep net model parameters");
ShareParameter("output", "tfmodel.output", "Deep net outputs parameters", "Deep net outputs parameters");
ShareParameter("finetuning", "tfmodel.finetuning", "Deep net fine tuning parameters", "Deep net fine tuning parameters");
ShareParameter("optim", "tfmodel.optim", "This group of parameters allows optimization of processing time", "This group of parameters allows optimization of processing time");
// Train shared parameters
ShareParameter("vd" , "train.io.vd" , "Input vector data list" , "Input vector data list" );
......
......@@ -105,7 +105,9 @@ void GetTensorAttributes(const tensorflow::GraphDef & graph, std::vector<std::st
if (node.name().compare((*nameIt)) == 0)
{
found = true;
tensorflow::DataType ts_dt;
// Set default to DT_FLOAT
tensorflow::DataType ts_dt = tensorflow::DT_FLOAT;
// Default (input?) tensor type
auto test_is_output = node.attr().find("T");
......
......@@ -30,12 +30,36 @@ namespace otb
/**
* \class TensorflowMultisourceModelBase
* \brief This filter is base for TensorFlow model over multiple input images.
* \brief This filter is the base class for all TensorFlow model filters.
*
* The filter takes N input images and feed the TensorFlow model.
* Names of input placeholders must be specified using the
* SetInputPlaceholdersNames method
* This abstract class implements a number of generic methods that are used in
* filters that use the TensorFlow engine.
*
* The filter has N input images (Input), each one corresponding to a placeholder
* that will fed the TensorFlow model. For each input, the name of the
* placeholder (InputPlaceholders, a std::vector of std::string) and the
* receptive field (InputReceptiveFields, a std::vector of SizeType) i.e. the
* input space that the model will "see", must be provided. Hence the number of
* input images, and the size of InputPlaceholders and InputReceptiveFields must
* be the same. If not, an exception will be thrown during the method
* GenerateOutputInformation().
*
* The TensorFlow graph and session must be set using the SetGraph() and
* SetSession() methods.
*
* Target nodes names of the TensorFlow graph that must be triggered can be set
* with the SetTargetNodesNames.
*
* The OutputTensorNames consists in a strd::vector of std::string, and
* corresponds to the names of tensors that will be computed during the session.
* As for input placeholders, output tensors field of expression
* (OutputExpressionFields, a std::vector of SizeType), i.e. the output
* space that the TensorFlow model will "generate", must be provided.
*
* Finally, a list of scalar placeholders can be fed in the form of std::vector
* of std::string, each one expressing the assigment of a signle valued
* placeholder, e.g. "drop_rate=0.5 learning_rate=0.002 toto=true".
* See otb::tf::ExpressionToTensor() to know more about syntax.
*
* \ingroup OTBTensorflow
*/
......@@ -69,10 +93,10 @@ public:
typedef typename TInputImage::RegionType RegionType;
/** Typedefs for parameters */
typedef std::pair<std::string, tensorflow::Tensor> DictType;
typedef std::pair<std::string, tensorflow::Tensor> DictElementType;
typedef std::vector<std::string> StringList;
typedef std::vector<SizeType> SizeListType;
typedef std::vector<DictType> DictListType;
typedef std::vector<DictElementType> DictType;
typedef std::vector<tensorflow::DataType> DataTypeListType;
typedef std::vector<tensorflow::TensorShapeProto> TensorShapeProtoList;
typedef std::vector<tensorflow::Tensor> TensorListType;
......@@ -84,15 +108,30 @@ public:
tensorflow::Session * GetSession() { return m_Session; }
/** Model parameters */
void PushBackInputBundle(std::string placeholder, SizeType fieldOfView, ImagePointerType image);
itkSetMacro(InputPlaceholdersNames, StringList);
itkGetMacro(InputPlaceholdersNames, StringList);
itkSetMacro(InputFOVSizes, SizeListType);
itkGetMacro(InputFOVSizes, SizeListType);
void SetUserPlaceholders(DictListType dict) { m_UserPlaceholders = dict; }
DictListType GetUserPlaceholders() { return m_UserPlaceholders; }
itkSetMacro(OutputTensorsNames, StringList);
itkGetMacro(OutputTensorsNames, StringList);
void PushBackInputTensorBundle(std::string name, SizeType receptiveField, ImagePointerType image);
void PushBackOuputTensorBundle(std::string name, SizeType expressionField);
/** Input placeholders names */
itkSetMacro(InputPlaceholders, StringList);
itkGetMacro(InputPlaceholders, StringList);
/** Receptive field */
itkSetMacro(InputReceptiveFields, SizeListType);
itkGetMacro(InputReceptiveFields, SizeListType);
/** Output tensors names */
itkSetMacro(OutputTensors, StringList);
itkGetMacro(OutputTensors, StringList);
/** Expression field */
itkSetMacro(OutputExpressionFields, SizeListType);
itkGetMacro(OutputExpressionFields, SizeListType);
/** User placeholders */
void SetUserPlaceholders(DictType dict) { m_UserPlaceholders = dict; }
DictType GetUserPlaceholders() { return m_UserPlaceholders; }
/** Target nodes names */
itkSetMacro(TargetNodesNames, StringList);
itkGetMacro(TargetNodesNames, StringList);
......@@ -108,24 +147,27 @@ protected:
TensorflowMultisourceModelBase();
virtual ~TensorflowMultisourceModelBase() {};
virtual void RunSession(DictListType & inputs, TensorListType & outputs);
virtual std::stringstream GenerateDebugReport(DictType & inputs);
virtual void RunSession(DictType & inputs, TensorListType & outputs);
private:
TensorflowMultisourceModelBase(const Self&); //purposely not implemented
void operator=(const Self&); //purposely not implemented
// Tensorflow graph and session
tensorflow::GraphDef m_Graph; // The tensorflow graph
tensorflow::Session * m_Session; // The tensorflow session
tensorflow::GraphDef m_Graph; // The TensorFlow graph
tensorflow::Session * m_Session; // The TensorFlow session
// Model parameters
StringList m_InputPlaceholdersNames; // Input placeholders names
SizeListType m_InputFOVSizes; // Input tensors field of view (FOV) sizes
DictListType m_UserPlaceholders; // User placeholders
StringList m_OutputTensorsNames; // User tensors
StringList m_TargetNodesNames; // User target tensors
// Read-only
StringList m_InputPlaceholders; // Input placeholders names
SizeListType m_InputReceptiveFields; // Input receptive fields
StringList m_OutputTensors; // Output tensors names
SizeListType m_OutputExpressionFields; // Output expression fields
DictType m_UserPlaceholders; // User placeholders
StringList m_TargetNodesNames; // User nodes target
// Internal, read-only
DataTypeListType m_InputTensorsDataTypes; // Input tensors datatype
DataTypeListType m_OutputTensorsDataTypes; // Output tensors datatype
TensorShapeProtoList m_InputTensorsShapes; // Input tensors shapes
......
......@@ -20,22 +20,57 @@ template <class TInputImage, class TOutputImage>
TensorflowMultisourceModelBase<TInputImage, TOutputImage>
::TensorflowMultisourceModelBase()
{
m_Session = nullptr;
}
template <class TInputImage, class TOutputImage>
void
TensorflowMultisourceModelBase<TInputImage, TOutputImage>
::PushBackInputBundle(std::string placeholder, SizeType fieldOfView, ImagePointerType image)
::PushBackInputTensorBundle(std::string placeholder, SizeType receptiveField, ImagePointerType image)
{
Superclass::PushBackInput(image);
m_InputFOVSizes.push_back(fieldOfView);
m_InputPlaceholdersNames.push_back(placeholder);
m_InputReceptiveFields.push_back(receptiveField);
m_InputPlaceholders.push_back(placeholder);
}
template <class TInputImage, class TOutputImage>
std::stringstream
TensorflowMultisourceModelBase<TInputImage, TOutputImage>
::GenerateDebugReport(DictType & inputs)
{
// Create a debug report
std::stringstream debugReport;
// Describe the output buffered region
ImagePointerType outputPtr = this->GetOutput();
const RegionType outputReqRegion = outputPtr->GetRequestedRegion();
debugReport << "Output image buffered region: " << outputReqRegion << "\n";
// Describe inputs
for (unsigned int i = 0 ; i < this->GetNumberOfInputs() ; i++)
{
const ImagePointerType inputPtr = const_cast<TInputImage*>(this->GetInput(i));
const RegionType reqRegion = inputPtr->GetRequestedRegion();
debugReport << "Input #" << i << ":\n";
debugReport << "Requested region: " << reqRegion << "\n";
debugReport << "Tensor shape (\"" << inputs[i].first << "\": " << tf::PrintTensorShape(inputs[i].second.shape()) << "\n";
}
// Show user placeholders
debugReport << "User placeholders:\n" ;
for (auto& dict: this->GetUserPlaceholders())
{
debugReport << dict.first << " " << tf::PrintTensorInfos(dict.second) << "\n" << std::endl;
}
return debugReport;
}
template <class TInputImage, class TOutputImage>
void
TensorflowMultisourceModelBase<TInputImage, TOutputImage>
::RunSession(DictListType & inputs, TensorListType & outputs)
::RunSession(DictType & inputs, TensorListType & outputs)
{
// Add the user's placeholders
......@@ -48,33 +83,11 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage>
// The session will initialize the outputs
// Run the session, evaluating our output tensors from the graph
auto status = this->GetSession()->Run(inputs, m_OutputTensorsNames, m_TargetNodesNames, &outputs);
auto status = this->GetSession()->Run(inputs, m_OutputTensors, m_TargetNodesNames, &outputs);
if (!status.ok()) {
// Create a debug report
std::stringstream debugReport;
// Describe the output buffered region
ImagePointerType outputPtr = this->GetOutput();
const RegionType outputReqRegion = outputPtr->GetRequestedRegion();
debugReport << "Output image buffered region: " << outputReqRegion << "\n";
// Describe inputs
for (unsigned int i = 0 ; i < this->GetNumberOfInputs() ; i++)
{
const ImagePointerType inputPtr = const_cast<TInputImage*>(this->GetInput(i));
const RegionType reqRegion = inputPtr->GetRequestedRegion();
debugReport << "Input #" << i << ":\n";
debugReport << "Requested region: " << reqRegion << "\n";
debugReport << "Tensor shape (\"" << inputs[i].first << "\": " << tf::PrintTensorShape(inputs[i].second.shape()) << "\n";
}
// Show user placeholders
debugReport << "User placeholders:\n" ;
for (auto& dict: this->GetUserPlaceholders())
{
debugReport << dict.first << " " << tf::PrintTensorInfos(dict.second) << "\n" << std::endl;
}
std::stringstream debugReport = GenerateDebugReport(inputs);
// Throw an exception with the report
itkExceptionMacro("Can't run the tensorflow session !\n" <<
......@@ -92,15 +105,24 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage>
{
// Check that the number of the following is the same
// - placeholders names
// - patches sizes
// - input image
// - input placeholders names
// - input receptive fields
// - input images
const unsigned int nbInputs = this->GetNumberOfInputs();
if (nbInputs != m_InputFOVSizes.size() || nbInputs != m_InputPlaceholdersNames.size())
if (nbInputs != m_InputReceptiveFields.size() || nbInputs != m_InputPlaceholders.size())
{
itkExceptionMacro("Number of input images is " << nbInputs <<
" but the number of input patches size is " << m_InputFOVSizes.size() <<
" and the number of input tensors names is " << m_InputPlaceholdersNames.size());
" but the number of input patches size is " << m_InputReceptiveFields.size() <<
" and the number of input tensors names is " << m_InputPlaceholders.size());
}
// Check that the number of the following is the same
// - output tensors names
// - output expression fields
if (m_OutputExpressionFields.size() != m_OutputTensors.size())
{
itkExceptionMacro("Number of output tensors names is " << m_OutputTensors.size() <<
" but the number of output fields of expression is " << m_OutputExpressionFields.size());
}
//////////////////////////////////////////////////////////////////////////////////////////
......@@ -108,8 +130,8 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage>
//////////////////////////////////////////////////////////////////////////////////////////
// Get input and output tensors datatypes and shapes
tf::GetTensorAttributes(m_Graph, m_InputPlaceholdersNames, m_InputTensorsShapes, m_InputTensorsDataTypes);
tf::GetTensorAttributes(m_Graph, m_OutputTensorsNames, m_OutputTensorsShapes, m_OutputTensorsDataTypes);
tf::GetTensorAttributes(m_Graph, m_InputPlaceholders, m_InputTensorsShapes, m_InputTensorsDataTypes);
tf::GetTensorAttributes(m_Graph, m_OutputTensors, m_OutputTensorsShapes, m_OutputTensorsDataTypes);
}
......
......@@ -26,31 +26,53 @@ namespace otb
/**
* \class TensorflowMultisourceModelFilter
* \brief This filter apply a TensorFlow model over multiple input images.
* \brief This filter apply a TensorFlow model over multiple input images and
* generates one output image corresponding to outputs of the model.
*
* The filter takes N input images and feed the TensorFlow model to produce
* one output image of desired TF op results.
* Names of input/output placeholders/tensors must be specified using the
* SetInputPlaceholdersNames/SetOutputTensorNames.
* one output image corresponding to the desired results of the TensorFlow model.
* Names of input placeholders and output tensors must be specified using the
* SetPlaceholders() and SetTensors() methods.
*
* Example: we have a tensorflow model which runs the input images "x1" and "x2"
* Example: we have a TensorFlow model which runs the input images "x1" and "x2"
* and produces the output image "y".
* "x1" and "x2" are two TF placeholders, we set InputTensorNames={"x1","x2"}
* "y1" corresponds to one TF op output, we set OutputTensorNames={"y1"}
* "x1" and "x2" are two placeholders, we set InputPlaceholder={"x1","x2"}
* "y1" corresponds to one output tensor, we set OutputTensors={"y1"}
*
* The filter can work in two modes:
*
* 1.Patch-based mode:
* Extract and process patches independently at regular intervals.
* Patches sizes are equal to the perceptive field sizes of inputs. For each input,
* a tensor with a number of elements equal to the number of patches is fed to the
* TensorFlow model.
*
* 2.Fully-convolutional:
* Unlike patch-based mode, it allows the processing of an entire requested region.
* For each input, a tensor composed of one single element, corresponding to the input
* requested region, is fed to the TF model. This mode requires that perceptive fields,
* expression fields and scale factors are consistent with operators implemented in the
* TensorFlow model, input images physical spacing and alignment.
* The filter produces output blocks avoiding any blocking artifact in fully-convolutional
* mode. This is done in computing input images regions that are aligned to the expression
* field sizes of the model (eventually, input requested regions are enlarged, but still
* aligned), and keeping only the subset of the output corresponding to the requested
* output region.
*
* The reference grid for the output image is the same as the first input image.
* This grid can be scaled by setting the OutputSpacingScale value.
* This can be used to run models which downsize the output image spacing
* (typically fully convolutional model with strides) or to produce the result
* (e.g. fully convolutional model with strides) or to produce the result
* of a patch-based network at regular intervals.
*
* For each input, input field of view (FOV) must be set.
* For each input (resp. output), receptive field (resp. expression field) must be set.
* If the number of values in the output tensors (produced by the model) don't
* fit with the output image region, exception will be thrown.
* fit with the output image region, an exception will be thrown.
*
*
* TODO: the filter must be able to output multiple images eventually at different
* resolutions/sizes/origins.
*
* The tensorflow Graph is passed using the SetGraph() method
* The tensorflow Session is passed using the SetSession() method
*
* \ingroup OTBTensorflow
*/
......@@ -94,15 +116,13 @@ public:
typedef typename itk::ImageRegionConstIterator<TInputImage> InputConstIteratorType;
/* Typedefs for parameters */
typedef typename Superclass::DictElementType DictElementType;
typedef typename Superclass::DictType DictType;
typedef typename Superclass::StringList StringList;
typedef typename Superclass::SizeListType SizeListType;
typedef typename Superclass::DictListType DictListType;
typedef typename Superclass::TensorListType TensorListType;
typedef std::vector<float> ScaleListType;
itkSetMacro(OutputFOESize, SizeType);
itkGetMacro(OutputFOESize, SizeType);
itkSetMacro(OutputGridSize, SizeType);
itkGetMacro(OutputGridSize, SizeType);
itkSetMacro(ForceOutputGridSize, bool);
......@@ -132,7 +152,6 @@ private:
TensorflowMultisourceModelFilter(const Self&); //purposely not implemented
void operator=(const Self&); //purposely not implemented
SizeType m_OutputFOESize; // Output tensors field of expression (FOE) sizes
SizeType m_OutputGridSize; // Output grid size
bool m_ForceOutputGridSize; // Force output grid size
bool m_FullyConvolutional; // Convolution mode
......@@ -142,6 +161,7 @@ private:
SpacingType m_OutputSpacing; // Output image spacing
PointType m_OutputOrigin; // Output image origin
SizeType m_OutputSize; // Output image size
PixelType m_NullPixel; // Pixel filled with zeros
}; // end class
......
......@@ -216,7 +216,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
// Update output image extent
PointType currentInputImageExtentInf, currentInputImageExtentSup;
ImageToExtent(currentImage, currentInputImageExtentInf, currentInputImageExtentSup, this->GetInputFOVSizes()[imageIndex]);
ImageToExtent(currentImage, currentInputImageExtentInf, currentInputImageExtentSup, this->GetInputReceptiveFields()[imageIndex]);
for(unsigned int dim = 0; dim<ImageType::ImageDimension; ++dim)
{
extentInf[dim] = vnl_math_max(currentInputImageExtentInf[dim], extentInf[dim]);
......@@ -236,7 +236,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
if (!m_ForceOutputGridSize)
{
// Default is the output field of expression
m_OutputGridSize = m_OutputFOESize;
m_OutputGridSize = this->GetOutputExpressionFields().at(0);
}
// Resize the largestPossibleRegion to be a multiple of the grid size
......@@ -279,10 +279,14 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
// Set output image origin/spacing/size/projection
ImageType * outputPtr = this->GetOutput();
outputPtr->SetNumberOfComponentsPerPixel(outputPixelSize);
outputPtr->SetProjectionRef ( projectionRef );
outputPtr->SetOrigin ( m_OutputOrigin );
outputPtr->SetSignedSpacing ( m_OutputSpacing );
outputPtr->SetLargestPossibleRegion( largestPossibleRegion);
outputPtr->SetProjectionRef ( projectionRef );
outputPtr->SetOrigin ( m_OutputOrigin );
outputPtr->SetSignedSpacing ( m_OutputSpacing );
outputPtr->SetLargestPossibleRegion( largestPossibleRegion );
// Set null pixel
m_NullPixel.SetSize(outputPtr->GetNumberOfComponentsPerPixel());
m_NullPixel.Fill(0);
}
......@@ -315,9 +319,9 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
}
// Compute the FOV-scale*FOE radius to pad
SizeType toPad(this->GetInputFOVSizes().at(i));
toPad[0] -= 1 + (m_OutputFOESize[0] - 1) * m_OutputSpacingScale;
toPad[1] -= 1 + (m_OutputFOESize[1] - 1) * m_OutputSpacingScale;
SizeType toPad(this->GetInputReceptiveFields().at(i));
toPad[0] -= 1 + (this->GetOutputExpressionFields().at(0)[0] - 1) * m_OutputSpacingScale;
toPad[1] -= 1 + (this->GetOutputExpressionFields().at(0)[1] - 1) * m_OutputSpacingScale;
// Pad with radius
SmartPad(inRegion, toPad);
......@@ -325,7 +329,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
// We need to avoid some extrapolation when mode is patch-based.
// The reason is that, when some input have a lower spacing than the
// reference image, the requested region of this lower res input image
// can be one pixel larger when the input image regions are not physicaly
// can be one pixel larger when the input image regions are not physically
// aligned.
if (!m_FullyConvolutional)
{
......@@ -337,8 +341,6 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
// Update the requested region
inputImage->SetRequestedRegion(inRegion);
// std::cout << "Input #" << i << " region starts at " << inRegion.GetIndex() << " with size " << inRegion.GetSize() << std::endl;
} // next image
}
......@@ -365,7 +367,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
const unsigned int nInputs = this->GetNumberOfInputs();
// Create input tensors list
DictListType inputs;
DictType inputs;
// Populate input tensors
for (unsigned int i = 0 ; i < nInputs ; i++)
......@@ -374,7 +376,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
const ImagePointerType inputPtr = const_cast<TInputImage*>(this->GetInput(i));
// Patch size of tensor #i
const SizeType inputPatchSize = this->GetInputFOVSizes().at(i);
const SizeType inputPatchSize = this->GetInputReceptiveFields().at(i);
// Input image requested region
const RegionType reqRegion = inputPtr->GetRequestedRegion();
......@@ -394,14 +396,13 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
// Recopy the whole input
tf::RecopyImageRegionToTensorWithCast<TInputImage>(inputPtr, reqRegion, inputTensor, 0);
// Input #1 : the tensor of patches (aka the batch)
DictType input1 = { this->GetInputPlaceholdersNames()[i], inputTensor };
inputs.push_back(input1);
// Input is the tensor representing the subset of image
DictElementType input = { this->GetInputPlaceholders()[i], inputTensor };
inputs.push_back(input);
}
else
{
// Preparing patches (not very optimized ! )
// It would be better to perform the loop inside the TF session using TF operators
// Preparing patches
// Shape of input tensor #i
tensorflow::int64 sz_n = outputReqRegion.GetNumberOfPixels();
tensorflow::int64 sz_y = inputPatchSize[1];
......@@ -428,9 +429,10 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
elemIndex++;
}
// Input #1 : the tensor of patches (aka the batch)
DictType input1 = { this->GetInputPlaceholdersNames()[i], inputTensor };
inputs.push_back(input1);
// Input is the tensor of patches (aka the batch)
DictElementType input = { this->GetInputPlaceholders()[i], inputTensor };
inputs.push_back(input);
} // mode is not full convolutional
} // next input tensor
......@@ -442,10 +444,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
// Fill the output buffer with zero value
outputPtr->SetBufferedRegion(outputReqRegion);
outputPtr->Allocate();
OutputPixelType nullpix;
nullpix.SetSize(outputPtr->GetNumberOfComponentsPerPixel());
nullpix.Fill(0);
outputPtr->FillBuffer(nullpix);
outputPtr->FillBuffer(m_NullPixel);
// Get output tensors
int bandOffset = 0;
......@@ -453,8 +452,19 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
{
// The offset (i.e. the starting index of the channel for the output tensor) is updated
// during this call
// TODO: implement a generic strategy enabling FOE copy in patch-based mode (see tf::CopyTensorToImageRegion)
tf::CopyTensorToImageRegion<TOutputImage> (outputs[i], outputAlignedReqRegion, outputPtr, outputReqRegion, bandOffset);
// TODO: implement a generic strategy enabling expression field copy in patch-based mode (see tf::CopyTensorToImageRegion)
try
{
tf::CopyTensorToImageRegion<TOutputImage> (outputs[i],
outputAlignedReqRegion, outputPtr, outputReqRegion, bandOffset);
}
catch( itk::ExceptionObject & err )
{
std::stringstream debugMsg = this->GenerateDebugReport(inputs);
itkExceptionMacro("Error occured during tensor to image conversion.\n"
<< "Context: " << debugMsg.str()
<< "Error:" << err);
}
}
}
......
/*=========================================================================
Copyright (c) Remi Cresson (IRSTEA). All rights reserved.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef otbTensorflowMultisourceModelLearningBase_h
#define otbTensorflowMultisourceModelLearningBase_h