Commit e42a2d60 authored by remi cresson's avatar remi cresson
Browse files

REFAC: wip

parent 20bb0192
......@@ -72,10 +72,10 @@ public:
typedef typename TInputImage::RegionType RegionType;
/** Typedefs for parameters */
typedef std::pair<std::string, tensorflow::Tensor> DictType;
typedef std::pair<std::string, tensorflow::Tensor> DictElementType;
typedef std::vector<std::string> StringList;
typedef std::vector<SizeType> SizeListType;
typedef std::vector<DictType> DictListType;
typedef std::vector<DictElementType> DictType;
typedef std::vector<tensorflow::DataType> DataTypeListType;
typedef std::vector<tensorflow::TensorShapeProto> TensorShapeProtoList;
typedef std::vector<tensorflow::Tensor> TensorListType;
......@@ -87,27 +87,28 @@ public:
tensorflow::Session * GetSession() { return m_Session; }
/** Model parameters */
void PushBackInputBundle(std::string placeholder, SizeType receptiveField, ImagePointerType image);
void PushBackInputTensorBundle(std::string name, SizeType receptiveField, ImagePointerType image);
void PushBackOuputTensorBundle(std::string name, SizeType expressionField);
// /** Input placeholders names */
// itkSetMacro(InputPlaceholdersNames, StringList);
itkGetMacro(InputPlaceholdersNames, StringList);
//
// /** Receptive field */
// itkSetMacro(InputFOVSizes, SizeListType);
itkGetMacro(InputFOVSizes, SizeListType);
/** Input placeholders names */
itkSetMacro(InputPlaceholders, StringList);
itkGetMacro(InputPlaceholders, StringList);
/** Receptive field */
itkSetMacro(InputReceptiveFields, SizeListType);
itkGetMacro(InputReceptiveFields, SizeListType);
/** Output tensors names */
itkSetMacro(OutputTensorsNames, StringList);
itkGetMacro(OutputTensorsNames, StringList);
itkSetMacro(OutputTensors, StringList);
itkGetMacro(OutputTensors, StringList);
/** Expression field */
itkSetMacro(OutputFOESizes, SizeListType);
itkGetMacro(OutputFOESizes, SizeListType);
itkSetMacro(OutputExpressionFields, SizeListType);
itkGetMacro(OutputExpressionFields, SizeListType);
/** User placeholders */
void SetUserPlaceholders(DictListType dict) { m_UserPlaceholders = dict; }
DictListType GetUserPlaceholders() { return m_UserPlaceholders; }
void SetUserPlaceholders(DictType dict) { m_UserPlaceholders = dict; }
DictType GetUserPlaceholders() { return m_UserPlaceholders; }
/** Target nodes names */
itkSetMacro(TargetNodesNames, StringList);
......@@ -125,9 +126,9 @@ protected:
TensorflowMultisourceModelBase();
virtual ~TensorflowMultisourceModelBase() {};
virtual std::stringstream GenerateDebugReport(DictListType & inputs, TensorListType & outputs);
virtual std::stringstream GenerateDebugReport(DictType & inputs);
virtual void RunSession(DictListType & inputs, TensorListType & outputs);
virtual void RunSession(DictType & inputs, TensorListType & outputs);
private:
TensorflowMultisourceModelBase(const Self&); //purposely not implemented
......@@ -138,11 +139,11 @@ private:
tensorflow::Session * m_Session; // The tensorflow session
// Model parameters
StringList m_InputPlaceholdersNames; // Input placeholders names
SizeListType m_InputFOVSizes; // Input tensors field of view (FOV) sizes
SizeListType m_OutputFOESizes; // Output tensors field of expression (FOE) sizes
DictListType m_UserPlaceholders; // User placeholders
StringList m_OutputTensorsNames; // User tensors
StringList m_InputPlaceholders; // Input placeholders names
SizeListType m_InputReceptiveFields; // Input receptive fields
StringList m_OutputTensors; // Output tensors names
SizeListType m_OutputExpressionFields; // Output expression fields
DictType m_UserPlaceholders; // User placeholders
StringList m_TargetNodesNames; // User target tensors
// Read-only
......
......@@ -20,22 +20,23 @@ template <class TInputImage, class TOutputImage>
TensorflowMultisourceModelBase<TInputImage, TOutputImage>
::TensorflowMultisourceModelBase()
{
m_Session = nullptr;
}
template <class TInputImage, class TOutputImage>
void
TensorflowMultisourceModelBase<TInputImage, TOutputImage>
::PushBackInputBundle(std::string placeholder, SizeType receptiveField, ImagePointerType image)
::PushBackInputTensorBundle(std::string placeholder, SizeType receptiveField, ImagePointerType image)
{
Superclass::PushBackInput(image);
m_InputFOVSizes.push_back(receptiveField);
m_InputPlaceholdersNames.push_back(placeholder);
m_InputReceptiveFields.push_back(receptiveField);
m_InputPlaceholders.push_back(placeholder);
}
template <class TInputImage, class TOutputImage>
std::stringstream
TensorflowMultisourceModelBase<TInputImage, TOutputImage>
::GenerateDebugReport(DictListType & inputs, TensorListType & outputs)
::GenerateDebugReport(DictType & inputs)
{
// Create a debug report
std::stringstream debugReport;
......@@ -69,7 +70,7 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage>
template <class TInputImage, class TOutputImage>
void
TensorflowMultisourceModelBase<TInputImage, TOutputImage>
::RunSession(DictListType & inputs, TensorListType & outputs)
::RunSession(DictType & inputs, TensorListType & outputs)
{
// Add the user's placeholders
......@@ -82,11 +83,11 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage>
// The session will initialize the outputs
// Run the session, evaluating our output tensors from the graph
auto status = this->GetSession()->Run(inputs, m_OutputTensorsNames, m_TargetNodesNames, &outputs);
auto status = this->GetSession()->Run(inputs, m_OutputTensors, m_TargetNodesNames, &outputs);
if (!status.ok()) {
// Create a debug report
std::stringstream debugReport = GenerateDebugReport(inputs, outputs);
std::stringstream debugReport = GenerateDebugReport(inputs);
// Throw an exception with the report
itkExceptionMacro("Can't run the tensorflow session !\n" <<
......@@ -108,11 +109,11 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage>
// - patches sizes
// - input image
const unsigned int nbInputs = this->GetNumberOfInputs();
if (nbInputs != m_InputFOVSizes.size() || nbInputs != m_InputPlaceholdersNames.size())
if (nbInputs != m_InputReceptiveFields.size() || nbInputs != m_InputPlaceholders.size())
{
itkExceptionMacro("Number of input images is " << nbInputs <<
" but the number of input patches size is " << m_InputFOVSizes.size() <<
" and the number of input tensors names is " << m_InputPlaceholdersNames.size());
" but the number of input patches size is " << m_InputReceptiveFields.size() <<
" and the number of input tensors names is " << m_InputPlaceholders.size());
}
//////////////////////////////////////////////////////////////////////////////////////////
......@@ -120,8 +121,8 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage>
//////////////////////////////////////////////////////////////////////////////////////////
// Get input and output tensors datatypes and shapes
tf::GetTensorAttributes(m_Graph, m_InputPlaceholdersNames, m_InputTensorsShapes, m_InputTensorsDataTypes);
tf::GetTensorAttributes(m_Graph, m_OutputTensorsNames, m_OutputTensorsShapes, m_OutputTensorsDataTypes);
tf::GetTensorAttributes(m_Graph, m_InputPlaceholders, m_InputTensorsShapes, m_InputTensorsDataTypes);
tf::GetTensorAttributes(m_Graph, m_OutputTensors, m_OutputTensorsShapes, m_OutputTensorsDataTypes);
}
......
......@@ -30,27 +30,26 @@ namespace otb
*
* The filter takes N input images and feed the TensorFlow model to produce
* one output image of desired TF op results.
* Names of input/output placeholders/tensors must be specified using the
* SetInputPlaceholdersNames/SetOutputTensorNames.
* Names of input placeholders and output tensors must be specified using the
* SetPlaceholders() and SetTensors() methods.
*
* Example: we have a tensorflow model which runs the input images "x1" and "x2"
* Example: we have a TensorFlow model which runs the input images "x1" and "x2"
* and produces the output image "y".
* "x1" and "x2" are two TF placeholders, we set InputTensorNames={"x1","x2"}
* "y1" corresponds to one TF op output, we set OutputTensorNames={"y1"}
* "x1" and "x2" are two TF placeholders, we set InputPlaceholder={"x1","x2"}
* "y1" corresponds to one TF op output, we set OutputTensors={"y1"}
*
* The reference grid for the output image is the same as the first input image.
* This grid can be scaled by setting the OutputSpacingScale value.
* This can be used to run models which downsize the output image spacing
* (typically fully convolutional model with strides) or to produce the result
* (e.g. fully convolutional model with strides) or to produce the result
* of a patch-based network at regular intervals.
*
* For each input, input field of view (FOV) must be set.
* For each input (resp. output), receptive field (resp. expression field) must be set.
* If the number of values in the output tensors (produced by the model) don't
* fit with the output image region, exception will be thrown.
* fit with the output image region, an exception will be thrown.
*
*
* The tensorflow Graph is passed using the SetGraph() method
* The tensorflow Session is passed using the SetSession() method
* The TensorFlow Graph is passed using the SetGraph() method
* The TensorFlow Session is passed using the SetSession() method
*
* \ingroup OTBTensorflow
*/
......@@ -94,6 +93,7 @@ public:
typedef typename itk::ImageRegionConstIterator<TInputImage> InputConstIteratorType;
/* Typedefs for parameters */
typedef typename Superclass::DictElementType DictElementType;
typedef typename Superclass::DictType DictType;
typedef typename Superclass::StringList StringList;
typedef typename Superclass::SizeListType SizeListType;
......@@ -101,8 +101,6 @@ public:
typedef typename Superclass::TensorListType TensorListType;
typedef std::vector<float> ScaleListType;
itkSetMacro(OutputFOESize, SizeType);
itkGetMacro(OutputFOESize, SizeType);
itkSetMacro(OutputGridSize, SizeType);
itkGetMacro(OutputGridSize, SizeType);
itkSetMacro(ForceOutputGridSize, bool);
......@@ -132,7 +130,6 @@ private:
TensorflowMultisourceModelFilter(const Self&); //purposely not implemented
void operator=(const Self&); //purposely not implemented
SizeType m_OutputFOESize; // Output tensors field of expression (FOE) sizes
SizeType m_OutputGridSize; // Output grid size
bool m_ForceOutputGridSize; // Force output grid size
bool m_FullyConvolutional; // Convolution mode
......
......@@ -216,7 +216,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
// Update output image extent
PointType currentInputImageExtentInf, currentInputImageExtentSup;
ImageToExtent(currentImage, currentInputImageExtentInf, currentInputImageExtentSup, this->GetInputFOVSizes()[imageIndex]);
ImageToExtent(currentImage, currentInputImageExtentInf, currentInputImageExtentSup, this->GetInputReceptiveFields()[imageIndex]);
for(unsigned int dim = 0; dim<ImageType::ImageDimension; ++dim)
{
extentInf[dim] = vnl_math_max(currentInputImageExtentInf[dim], extentInf[dim]);
......@@ -236,7 +236,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
if (!m_ForceOutputGridSize)
{
// Default is the output field of expression
m_OutputGridSize = m_OutputFOESize;
m_OutputGridSize = this->GetOutputExpressionFields().at(0);
}
// Resize the largestPossibleRegion to be a multiple of the grid size
......@@ -315,9 +315,9 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
}
// Compute the FOV-scale*FOE radius to pad
SizeType toPad(this->GetInputFOVSizes().at(i));
toPad[0] -= 1 + (m_OutputFOESize[0] - 1) * m_OutputSpacingScale;
toPad[1] -= 1 + (m_OutputFOESize[1] - 1) * m_OutputSpacingScale;
SizeType toPad(this->GetInputReceptiveFields().at(i));
toPad[0] -= 1 + (this->GetOutputExpressionFields().at(0)[0] - 1) * m_OutputSpacingScale;
toPad[1] -= 1 + (this->GetOutputExpressionFields().at(0)[1] - 1) * m_OutputSpacingScale;
// Pad with radius
SmartPad(inRegion, toPad);
......@@ -365,7 +365,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
const unsigned int nInputs = this->GetNumberOfInputs();
// Create input tensors list
DictListType inputs;
DictType inputs;
// Populate input tensors
for (unsigned int i = 0 ; i < nInputs ; i++)
......@@ -374,7 +374,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
const ImagePointerType inputPtr = const_cast<TInputImage*>(this->GetInput(i));
// Patch size of tensor #i
const SizeType inputPatchSize = this->GetInputFOVSizes().at(i);
const SizeType inputPatchSize = this->GetInputReceptiveFields().at(i);
// Input image requested region
const RegionType reqRegion = inputPtr->GetRequestedRegion();
......@@ -395,7 +395,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
tf::RecopyImageRegionToTensorWithCast<TInputImage>(inputPtr, reqRegion, inputTensor, 0);
// Input #1 : the tensor of patches (aka the batch)
DictType input1 = { this->GetInputPlaceholdersNames()[i], inputTensor };
DictElementType input1 = { this->GetInputPlaceholdersNames()[i], inputTensor };
inputs.push_back(input1);
}
else
......@@ -429,7 +429,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
}
// Input #1 : the tensor of patches (aka the batch)
DictType input1 = { this->GetInputPlaceholdersNames()[i], inputTensor };
DictElementType input1 = { this->GetInputPlaceholdersNames()[i], inputTensor };
inputs.push_back(input1);
} // mode is not full convolutional
......
/*=========================================================================
Copyright (c) Remi Cresson (IRSTEA). All rights reserved.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef otbTensorflowMultisourceModelLearningBase_h
#define otbTensorflowMultisourceModelLearningBase_h
#include "itkProcessObject.h"
#include "itkNumericTraits.h"
#include "itkSimpleDataObjectDecorator.h"
// Base
#include "otbTensorflowMultisourceModelBase.h"
// Shuffle
#include <random>
#include <algorithm>
#include <iterator>
namespace otb
{
/**
* \class TensorflowMultisourceModelLearningBase
* \brief This filter is the base class for learning filters.
*
* \ingroup OTBTensorflow
*/
template <class TInputImage>
class ITK_EXPORT TensorflowMultisourceModelLearningBase :
public TensorflowMultisourceModelBase<TInputImage>
{
public:
/** Standard class typedefs. */
typedef TensorflowMultisourceModelLearningBase Self;
typedef TensorflowMultisourceModelBase<TInputImage> Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Run-time type information (and related methods). */
itkTypeMacro(TensorflowMultisourceModelLearningBase, TensorflowMultisourceModelBase);
/** Images typedefs */
typedef typename Superclass::ImageType ImageType;
typedef typename Superclass::ImagePointerType ImagePointerType;
typedef typename Superclass::RegionType RegionType;
typedef typename Superclass::SizeType SizeType;
typedef typename Superclass::IndexType IndexType;
/* Typedefs for parameters */
typedef typename Superclass::DictType DictType;
typedef typename Superclass::DictElementType DictElementType;
typedef typename Superclass::StringList StringList;
typedef typename Superclass::SizeListType SizeListType;
typedef typename Superclass::TensorListType TensorListType;
/* Typedefs for index */
typedef typename ImageType::IndexValueType IndexValueType;
typedef std::vector<IndexValueType> IndexListType;
// Batch size
itkSetMacro(BatchSize, IndexValueType);
itkGetMacro(BatchSize, IndexValueType);
// Use streaming
itkSetMacro(UseStreaming, bool);
itkGetMacro(UseStreaming, bool);
// Get number of samples
itkGetMacro(NumberOfSamples, IndexValueType);
protected:
TensorflowMultisourceModelLearningBase();
virtual ~TensorflowMultisourceModelLearningBase() {};
virtual void GenerateOutputInformation(void);
virtual void GenerateInputRequestedRegion();
virtual void GenerateData();
virtual void PopulateInputTensors(TensorListType & inputs, const IndexValueType & sampleStart,
const IndexValueType & batchSize, const IndexListType & order = IndexListType());
virtual void ProcessBatch(TensorListType & inputs, const IndexValueType & sampleStart,
const IndexValueType & batchSize) = 0;
private:
TensorflowMultisourceModelLearningBase(const Self&); //purposely not implemented
void operator=(const Self&); //purposely not implemented
unsigned int m_BatchSize; // Batch size
bool m_UseStreaming; // Use streaming on/off
// Read only
IndexValueType m_NumberOfSamples; // Number of samples
}; // end class
} // end namespace otb
#include "otbTensorflowMultisourceModelLearningBase.hxx"
#endif
/*=========================================================================
Copyright (c) Remi Cresson (IRSTEA). All rights reserved.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef otbTensorflowMultisourceModelLearningBase_txx
#define otbTensorflowMultisourceModelLearningBase_txx
#include "otbTensorflowMultisourceModelLearningBase.h"
namespace otb
{
template <class TInputImage>
TensorflowMultisourceModelLearningBase<TInputImage>
::TensorflowMultisourceModelLearningBase(): m_BatchSize(100),
m_NumberOfSamples(0), m_UseStreaming(false)
{
}
template <class TInputImage>
void
TensorflowMultisourceModelLearningBase<TInputImage>
::GenerateOutputInformation()
{
Superclass::GenerateOutputInformation();
ImageType * outputPtr = this->GetOutput();
RegionType nullRegion;
nullRegion.GetModifiableSize().Fill(1);
outputPtr->SetNumberOfComponentsPerPixel(1);
outputPtr->SetLargestPossibleRegion( nullRegion );
// Count the number of samples
m_NumberOfSamples = 0;
for (unsigned int i = 0 ; i < this->GetNumberOfInputs() ; i++)
{
// Input image pointer
ImagePointerType inputPtr = const_cast<ImageType*>(this->GetInput(i));
// Make sure input is available
if ( inputPtr.IsNull() )
{
itkExceptionMacro(<< "Input " << i << " is null!");
}
// Update input information
inputPtr->UpdateOutputInformation();
// Patch size of tensor #i
const SizeType inputPatchSize = this->GetInputReceptiveFields().at(i);
// Input image requested region
const RegionType reqRegion = inputPtr->GetLargestPossibleRegion();
// Check size X
if (inputPatchSize[0] != reqRegion.GetSize(0))
itkExceptionMacro("Patch size for input " << i
<< " is " << inputPatchSize
<< " but input patches image size is " << reqRegion.GetSize());
// Check size Y
if (reqRegion.GetSize(1) % inputPatchSize[1] != 0)
itkExceptionMacro("Input patches image must have a number of rows which is "
<< "a multiple of the patch size Y! Patches image has " << reqRegion.GetSize(1)
<< " rows but patch size Y is " << inputPatchSize[1] << " for input " << i);
// Get the batch size
const tensorflow::uint64 currNumberOfSamples = reqRegion.GetSize(1) / inputPatchSize[1];
// Check the consistency with other inputs
if (m_NumberOfSamples == 0)
{
m_NumberOfSamples = currNumberOfSamples;
}
else if (m_NumberOfSamples != currNumberOfSamples)
{
itkGenericExceptionMacro("Previous batch size is " << m_NumberOfSamples
<< " but input " << i
<< " has a batch size of " << currNumberOfSamples );
}
} // next input
}
template <class TInputImage>
void
TensorflowMultisourceModelLearningBase<TInputImage>
::GenerateInputRequestedRegion()
{
Superclass::GenerateInputRequestedRegion();
// For each image, set no image region
for(unsigned int i = 0; i < this->GetNumberOfInputs(); ++i)
{
RegionType nullRegion;
ImageType * inputImage = static_cast<ImageType * >( Superclass::ProcessObject::GetInput(i) );
// If the streaming is enabled, we don't read the full image
if (m_UseStreaming)
{
inputImage->SetRequestedRegion(nullRegion);
}
else
{
inputImage->SetRequestedRegion(inputImage->GetLargestPossibleRegion());
}
} // next image
}
/**
*
*/
template <class TInputImage>
void
TensorflowMultisourceModelLearningBase<TInputImage>
::GenerateData()
{
// Batches loop
const IndexValueType nBatches = vcl_ceil(m_NumberOfSamples / m_BatchSize);
const IndexValueType rest = m_NumberOfSamples % m_BatchSize;
itk::ProgressReporter progress(this, 0, nBatches);
for (IndexValueType batch = 0 ; batch < nBatches ; batch++)
{
// Create input tensors list
TensorListType inputs;
// Batch start and size
const IndexValueType sampleStart = batch * m_BatchSize;
IndexValueType batchSize = m_BatchSize;
if (rest != 0)
{
batchSize = rest;
}
// Process the batch
ProcessBatch(inputs, sampleStart, batchSize);
progress.CompletedPixel();
} // Next batch
}
template <class TInputImage>
void
TensorflowMultisourceModelLearningBase<TInputImage>
::PopulateInputTensors(TensorListType & inputs, const IndexValueType & sampleStart,
const IndexValueType & batchSize, const IndexListType & order)
{
const bool reorder = order.size();
// Populate input tensors
for (unsigned int i = 0 ; i < this->GetNumberOfInputs() ; i++)
{
// Input image pointer
ImagePointerType inputPtr = const_cast<ImageType*>(this->GetInput(i));
// Patch size of tensor #i
const SizeType inputPatchSize = this->GetInputReceptiveFields().at(i);
// Create the tensor for the batch
const tensorflow::int64 sz_n = batchSize;
const tensorflow::int64 sz_y = inputPatchSize[1];
const tensorflow::int64 sz_x = inputPatchSize[0];
const tensorflow::int64 sz_c = inputPtr->GetNumberOfComponentsPerPixel();
const tensorflow::TensorShape inputTensorShape({sz_n, sz_y, sz_x, sz_c});
tensorflow::Tensor inputTensor(this->GetInputTensorsDataTypes()[i], inputTensorShape);
// Populate the tensor
for (tensorflow::uint64 elem = 0 ; elem < batchSize ; elem++)
{
const tensorflow::uint64 samplePos = sampleStart + elem;
IndexType start;
start[0] = 0;
if (reorder)
{
start[1] = order[samplePos] * sz_y;
}
else
{
start[1] = samplePos * sz_y;;
}
RegionType patchRegion(start, inputPatchSize);
if (m_UseStreaming)
{
// If streaming is enabled, we need to explicitly propagate requested region
tf::PropagateRequestedRegion<TInputImage>(inputPtr, patchRegion);
}
tf::RecopyImageRegionToTensorWithCast<TInputImage>(inputPtr, patchRegion, inputTensor, elem );
}