diff --git a/app/otbTensorflowModelServe.cxx b/app/otbTensorflowModelServe.cxx
index cfed5516f534f8240ed86e0bbf3a309ce39ab3e7..cc5e073494d10931bc04a039619f4411adc968bd 100644
--- a/app/otbTensorflowModelServe.cxx
+++ b/app/otbTensorflowModelServe.cxx
@@ -30,8 +30,7 @@
#include "otbTensorflowSource.h"
// Streaming
-#include "otbImageRegionSquareTileSplitter.h"
-#include "itkStreamingImageFilter.h"
+#include "otbTensorflowStreamerFilter.h"
namespace otb
{
@@ -58,7 +57,7 @@ public:
/** Typedef for streaming */
typedef otb::ImageRegionSquareTileSplitter<FloatVectorImageType::ImageDimension> TileSplitterType;
- typedef itk::StreamingImageFilter<FloatVectorImageType, FloatVectorImageType> StreamingFilterType;
+ typedef otb::TensorflowStreamerFilter<FloatVectorImageType, FloatVectorImageType> StreamingFilterType;
/** Typedefs for images */
typedef FloatVectorImageType::SizeType SizeType;
@@ -198,9 +197,12 @@ public:
AddParameter(ParameterType_Bool, "optim.disabletiling", "Disable tiling");
MandatoryOff ("optim.disabletiling");
SetParameterDescription ("optim.disabletiling", "Tiling avoids to process a too large subset of image, but sometimes it can be useful to disable it");
- AddParameter(ParameterType_Int, "optim.tilesize", "Tile width used to stream the filter output");
- SetMinimumParameterIntValue ("optim.tilesize", 1);
- SetDefaultParameterInt ("optim.tilesize", 16);
+ AddParameter(ParameterType_Int, "optim.tilesizex", "Tile width used to stream the filter output");
+ SetMinimumParameterIntValue ("optim.tilesizex", 1);
+ SetDefaultParameterInt ("optim.tilesizex", 16);
+ AddParameter(ParameterType_Int, "optim.tilesizey", "Tile height used to stream the filter output");
+ SetMinimumParameterIntValue ("optim.tilesizey", 1);
+ SetDefaultParameterInt ("optim.tilesizey", 16);
// Output image
AddParameter(ParameterType_OutputImage, "out", "output image");
@@ -292,22 +294,28 @@ public:
if (GetParameterInt("optim.disabletiling") != 1)
{
// Get the tile size
- const unsigned int tileSize = GetParameterInt("optim.tilesize");
- otbAppLogINFO("Force tiling with squared tiles of " << tileSize)
+ SizeType tileSize;
+ tileSize[0] = GetParameterInt("optim.tilesizex");
+ tileSize[1] = GetParameterInt("optim.tilesizey");
+
+ // Check that the tile size is aligned to the field of expression
+ for (unsigned int i = 0 ; i < FloatVectorImageType::ImageDimension ; i++)
+ if (tileSize[i] % foe[i] != 0)
+ {
+ SizeType::SizeValueType newSize = 1 + std::floor(tileSize[i] / foe[i]);
+ newSize *= foe[i];
- // Update the TensorFlow filter output information to get the output image size
- m_TFFilter->UpdateOutputInformation();
+ otbAppLogWARNING("Aligning the tiling to the output expression field "
+ << "for better performances (dim " << i << "). New value set to " << newSize)
- // Splitting using square tiles
- TileSplitterType::Pointer splitter = TileSplitterType::New();
- splitter->SetTileSizeAlignment(tileSize);
- unsigned int nbDesiredTiles = itk::Math::Ceil<unsigned int>(
- double(m_TFFilter->GetOutput()->GetLargestPossibleRegion().GetNumberOfPixels() ) / (tileSize * tileSize) );
+ tileSize[i] = newSize;
+ }
+
+ otbAppLogINFO("Force tiling with squared tiles of " << tileSize)
- // Use an itk::StreamingImageFilter to force the computation tile by tile
+ // Force the computation tile by tile
m_StreamFilter = StreamingFilterType::New();
- m_StreamFilter->SetRegionSplitter(splitter);
- m_StreamFilter->SetNumberOfStreamDivisions(nbDesiredTiles);
+ m_StreamFilter->SetOutputGridSize(tileSize);
m_StreamFilter->SetInput(m_TFFilter->GetOutput());
SetParameterOutputImage("out", m_StreamFilter->GetOutput());
diff --git a/include/otbTensorflowMultisourceModelFilter.hxx b/include/otbTensorflowMultisourceModelFilter.hxx
index 9aaeeda91856eb83e2532cf6eefc3fbc8a18843b..77902bbdf6c045fe8fa2a9f2136919dd781ff54c 100644
--- a/include/otbTensorflowMultisourceModelFilter.hxx
+++ b/include/otbTensorflowMultisourceModelFilter.hxx
@@ -368,9 +368,6 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
RegionType outputAlignedReqRegion(outputReqRegion);
EnlargeToAlignedRegion(outputAlignedReqRegion);
- // Add a progress reporter
- itk::ProgressReporter progress(this, 0, outputReqRegion.GetNumberOfPixels());
-
const unsigned int nInputs = this->GetNumberOfInputs();
// Create input tensors list
diff --git a/include/otbTensorflowStreamerFilter.h b/include/otbTensorflowStreamerFilter.h
new file mode 100644
index 0000000000000000000000000000000000000000..b58e92923b5b5b4ad78d5cc3d48c487c97594bfe
--- /dev/null
+++ b/include/otbTensorflowStreamerFilter.h
@@ -0,0 +1,79 @@
+/*=========================================================================
+
+ Copyright (c) Remi Cresson (IRSTEA). All rights reserved.
+
+
+ This software is distributed WITHOUT ANY WARRANTY; without even
+ the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
+ PURPOSE. See the above copyright notices for more information.
+
+=========================================================================*/
+#ifndef otbTensorflowStreamerFilter_h
+#define otbTensorflowStreamerFilter_h
+
+// Image2image
+#include "itkImageToImageFilter.h"
+
+namespace otb
+{
+
+/**
+ * \class TensorflowStreamerFilter
+ * \brief This filter generates an output image with an internal
+ * explicit streaming mechanism.
+ *
+ * \ingroup OTBTensorflow
+ */
+template <class TInputImage, class TOutputImage>
+class ITK_EXPORT TensorflowStreamerFilter :
+public itk::ImageToImageFilter<TInputImage, TOutputImage>
+{
+
+public:
+
+ /** Standard class typedefs. */
+ typedef TensorflowStreamerFilter Self;
+ typedef itk::ImageToImageFilter<TInputImage, TOutputImage> Superclass;
+ typedef itk::SmartPointer<Self> Pointer;
+ typedef itk::SmartPointer<const Self> ConstPointer;
+
+ /** Method for creation through the object factory. */
+ itkNewMacro(Self);
+
+ /** Run-time type information (and related methods). */
+ itkTypeMacro(TensorflowStreamerFilter, itk::ImageToImageFilter);
+
+ /** Images typedefs */
+ typedef typename Superclass::InputImageType ImageType;
+ typedef typename ImageType::IndexType IndexType;
+ typedef typename ImageType::IndexValueType IndexValueType;
+ typedef typename ImageType::SizeType SizeType;
+ typedef typename Superclass::InputImageRegionType RegionType;
+
+ typedef TOutputImage OutputImageType;
+
+ itkSetMacro(OutputGridSize, SizeType);
+ itkGetMacro(OutputGridSize, SizeType);
+
+protected:
+ TensorflowStreamerFilter();
+ virtual ~TensorflowStreamerFilter() {};
+
+ virtual void UpdateOutputData(itk::DataObject *output){(void) output; this->GenerateData();}
+
+ virtual void GenerateData();
+
+private:
+ TensorflowStreamerFilter(const Self&); //purposely not implemented
+ void operator=(const Self&); //purposely not implemented
+
+ SizeType m_OutputGridSize; // Output grid size
+
+}; // end class
+
+
+} // end namespace otb
+
+#include "otbTensorflowStreamerFilter.hxx"
+
+#endif
diff --git a/include/otbTensorflowStreamerFilter.hxx b/include/otbTensorflowStreamerFilter.hxx
new file mode 100644
index 0000000000000000000000000000000000000000..c5aa7231421f04baebe7a16cfbd68acb851157e7
--- /dev/null
+++ b/include/otbTensorflowStreamerFilter.hxx
@@ -0,0 +1,103 @@
+/*=========================================================================
+
+ Copyright (c) Remi Cresson (IRSTEA). All rights reserved.
+
+
+ This software is distributed WITHOUT ANY WARRANTY; without even
+ the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
+ PURPOSE. See the above copyright notices for more information.
+
+=========================================================================*/
+#ifndef otbTensorflowStreamerFilter_txx
+#define otbTensorflowStreamerFilter_txx
+
+#include "otbTensorflowStreamerFilter.h"
+#include "itkImageAlgorithm.h"
+
+namespace otb
+{
+
+template <class TInputImage, class TOutputImage>
+TensorflowStreamerFilter<TInputImage, TOutputImage>
+::TensorflowStreamerFilter()
+ {
+ m_OutputGridSize.Fill(1);
+ }
+
+/**
+ * Compute the output image
+ */
+template <class TInputImage, class TOutputImage>
+void
+TensorflowStreamerFilter<TInputImage, TOutputImage>
+::GenerateData()
+ {
+ // Output pointer and requested region
+ OutputImageType * outputPtr = this->GetOutput();
+ const RegionType outputReqRegion = outputPtr->GetRequestedRegion();
+ outputPtr->SetBufferedRegion(outputReqRegion);
+ outputPtr->Allocate();
+
+ // Compute the aligned region
+ RegionType region;
+ for(unsigned int dim = 0; dim<OutputImageType::ImageDimension; ++dim)
+ {
+ // Get corners
+ IndexValueType lower = outputReqRegion.GetIndex(dim);
+ IndexValueType upper = lower + outputReqRegion.GetSize(dim);
+
+ // Compute deltas between corners and the grid
+ const IndexValueType deltaLo = lower % m_OutputGridSize[dim];
+ const IndexValueType deltaUp = upper % m_OutputGridSize[dim];
+
+ // Move corners to aligned positions
+ lower -= deltaLo;
+ if (deltaUp > 0)
+ {
+ upper += m_OutputGridSize[dim] - deltaUp;
+ }
+
+ // Update region
+ region.SetIndex(dim, lower);
+ region.SetSize(dim, upper - lower);
+
+ }
+
+ // Compute the number of subregions to process
+ const unsigned int nbTilesX = region.GetSize(0) / m_OutputGridSize[0];
+ const unsigned int nbTilesY = region.GetSize(1) / m_OutputGridSize[1];
+
+ // Progress
+ itk::ProgressReporter progress(this, 0, nbTilesX*nbTilesY);
+
+ // For each tile, propagate the input region and recopy the output
+ ImageType * inputImage = static_cast<ImageType * >( Superclass::ProcessObject::GetInput(0) );
+ unsigned int tx, ty;
+ RegionType subRegion;
+ subRegion.SetSize(m_OutputGridSize);
+ for (ty = 0; ty < nbTilesY; ty++)
+ {
+ subRegion.SetIndex(1, ty*m_OutputGridSize[1] + region.GetIndex(1));
+ for (tx = 0; tx < nbTilesX; tx++)
+ {
+ // Update the input subregion
+ subRegion.SetIndex(0, tx*m_OutputGridSize[0] + region.GetIndex(0));
+ inputImage->SetRequestedRegion(subRegion);
+ inputImage->PropagateRequestedRegion();
+ inputImage->UpdateOutputData();
+
+ // Copy the subregion to output
+ RegionType cpyRegion(subRegion);
+ cpyRegion.Crop(outputReqRegion);
+ itk::ImageAlgorithm::Copy( inputImage, outputPtr, cpyRegion, cpyRegion );
+
+ progress.CompletedPixel();
+ }
+ }
+ }
+
+
+} // end namespace otb
+
+
+#endif