From 78ce0cff3b157047c1369c65c708a6b699f37c49 Mon Sep 17 00:00:00 2001
From: remi cresson <remi.cresson@teledetection.fr>
Date: Fri, 21 Sep 2018 20:01:40 +0200
Subject: [PATCH] ADD: streamer filter

---
 app/otbTensorflowModelServe.cxx         |  33 +++----
 include/otbTensorflowStreamerFilter.h   |  79 ++++++++++++++++
 include/otbTensorflowStreamerFilter.hxx | 116 ++++++++++++++++++++++++
 3 files changed, 209 insertions(+), 19 deletions(-)
 create mode 100644 include/otbTensorflowStreamerFilter.h
 create mode 100644 include/otbTensorflowStreamerFilter.hxx

diff --git a/app/otbTensorflowModelServe.cxx b/app/otbTensorflowModelServe.cxx
index cfed551..bbe204f 100644
--- a/app/otbTensorflowModelServe.cxx
+++ b/app/otbTensorflowModelServe.cxx
@@ -30,8 +30,7 @@
 #include "otbTensorflowSource.h"
 
 // Streaming
-#include "otbImageRegionSquareTileSplitter.h"
-#include "itkStreamingImageFilter.h"
+#include "otbTensorflowStreamerFilter.h"
 
 namespace otb
 {
@@ -58,7 +57,7 @@ public:
 
   /** Typedef for streaming */
   typedef otb::ImageRegionSquareTileSplitter<FloatVectorImageType::ImageDimension> TileSplitterType;
-  typedef itk::StreamingImageFilter<FloatVectorImageType, FloatVectorImageType> StreamingFilterType;
+  typedef otb::TensorflowStreamerFilter<FloatVectorImageType, FloatVectorImageType> StreamingFilterType;
 
   /** Typedefs for images */
   typedef FloatVectorImageType::SizeType SizeType;
@@ -198,9 +197,12 @@ public:
     AddParameter(ParameterType_Bool,          "optim.disabletiling", "Disable tiling");
     MandatoryOff                             ("optim.disabletiling");
     SetParameterDescription                  ("optim.disabletiling", "Tiling avoids to process a too large subset of image, but sometimes it can be useful to disable it");
-    AddParameter(ParameterType_Int,           "optim.tilesize", "Tile width used to stream the filter output");
-    SetMinimumParameterIntValue              ("optim.tilesize", 1);
-    SetDefaultParameterInt                   ("optim.tilesize", 16);
+    AddParameter(ParameterType_Int,           "optim.tilesizex", "Tile width used to stream the filter output");
+    SetMinimumParameterIntValue              ("optim.tilesizex", 1);
+    SetDefaultParameterInt                   ("optim.tilesizex", 16);
+    AddParameter(ParameterType_Int,           "optim.tilesizey", "Tile height used to stream the filter output");
+    SetMinimumParameterIntValue              ("optim.tilesizey", 1);
+    SetDefaultParameterInt                   ("optim.tilesizey", 16);
 
     // Output image
     AddParameter(ParameterType_OutputImage, "out", "output image");
@@ -292,22 +294,15 @@ public:
     if (GetParameterInt("optim.disabletiling") != 1)
     {
       // Get the tile size
-      const unsigned int tileSize = GetParameterInt("optim.tilesize");
-      otbAppLogINFO("Force tiling with squared tiles of " << tileSize)
+      SizeType gridSize;
+      gridSize[0] = GetParameterInt("optim.tilesizex");
+      gridSize[1] = GetParameterInt("optim.tilesizey");
 
-      // Update the TensorFlow filter output information to get the output image size
-      m_TFFilter->UpdateOutputInformation();
+      otbAppLogINFO("Force tiling with squared tiles of " << gridSize)
 
-      // Splitting using square tiles
-      TileSplitterType::Pointer splitter = TileSplitterType::New();
-      splitter->SetTileSizeAlignment(tileSize);
-      unsigned int nbDesiredTiles = itk::Math::Ceil<unsigned int>(
-          double(m_TFFilter->GetOutput()->GetLargestPossibleRegion().GetNumberOfPixels() ) / (tileSize * tileSize) );
-
-      // Use an itk::StreamingImageFilter to force the computation tile by tile
+      // Force the computation tile by tile
       m_StreamFilter = StreamingFilterType::New();
-      m_StreamFilter->SetRegionSplitter(splitter);
-      m_StreamFilter->SetNumberOfStreamDivisions(nbDesiredTiles);
+      m_StreamFilter->SetOutputGridSize(gridSize);
       m_StreamFilter->SetInput(m_TFFilter->GetOutput());
 
       SetParameterOutputImage("out", m_StreamFilter->GetOutput());
diff --git a/include/otbTensorflowStreamerFilter.h b/include/otbTensorflowStreamerFilter.h
new file mode 100644
index 0000000..9ac462d
--- /dev/null
+++ b/include/otbTensorflowStreamerFilter.h
@@ -0,0 +1,79 @@
+/*=========================================================================
+
+  Copyright (c) Remi Cresson (IRSTEA). All rights reserved.
+
+
+     This software is distributed WITHOUT ANY WARRANTY; without even
+     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
+     PURPOSE.  See the above copyright notices for more information.
+
+=========================================================================*/
+#ifndef otbTensorflowStreamerFilter_h
+#define otbTensorflowStreamerFilter_h
+
+// Image2image
+#include "itkImageToImageFilter.h"
+
+namespace otb
+{
+
+/**
+ * \class TensorflowStreamerFilter
+ * \brief This filter generates an output image with an internal
+ * explicit streaming mechanism.
+ *
+ * \ingroup OTBTensorflow
+ */
+template <class TInputImage, class TOutputImage>
+class ITK_EXPORT TensorflowStreamerFilter :
+public itk::ImageToImageFilter<TInputImage, TOutputImage>
+{
+
+public:
+
+  /** Standard class typedefs. */
+  typedef TensorflowStreamerFilter                           Self;
+  typedef itk::ImageToImageFilter<TInputImage, TOutputImage> Superclass;
+  typedef itk::SmartPointer<Self>                            Pointer;
+  typedef itk::SmartPointer<const Self>                      ConstPointer;
+
+  /** Method for creation through the object factory. */
+  itkNewMacro(Self);
+
+  /** Run-time type information (and related methods). */
+  itkTypeMacro(TensorflowStreamerFilter, itk::ImageToImageFilter);
+
+  /** Images typedefs */
+  typedef typename Superclass::InputImageType       ImageType;
+  typedef typename ImageType::IndexType             IndexType;
+  typedef typename ImageType::IndexValueType        IndexValueType;
+  typedef typename ImageType::SizeType              SizeType;
+  typedef typename Superclass::InputImageRegionType RegionType;
+
+  typedef TOutputImage                             OutputImageType;
+
+  itkSetMacro(OutputGridSize, SizeType);
+  itkGetMacro(OutputGridSize, SizeType);
+
+protected:
+  TensorflowStreamerFilter();
+  virtual ~TensorflowStreamerFilter() {};
+
+  virtual void GenerateInputRequestedRegion(void);
+
+  virtual void GenerateData();
+
+private:
+  TensorflowStreamerFilter(const Self&); //purposely not implemented
+  void operator=(const Self&); //purposely not implemented
+
+  SizeType                   m_OutputGridSize;       // Output grid size
+
+}; // end class
+
+
+} // end namespace otb
+
+#include "otbTensorflowStreamerFilter.hxx"
+
+#endif
diff --git a/include/otbTensorflowStreamerFilter.hxx b/include/otbTensorflowStreamerFilter.hxx
new file mode 100644
index 0000000..a2afd3f
--- /dev/null
+++ b/include/otbTensorflowStreamerFilter.hxx
@@ -0,0 +1,116 @@
+/*=========================================================================
+
+  Copyright (c) Remi Cresson (IRSTEA). All rights reserved.
+
+
+     This software is distributed WITHOUT ANY WARRANTY; without even
+     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
+     PURPOSE.  See the above copyright notices for more information.
+
+=========================================================================*/
+#ifndef otbTensorflowStreamerFilter_txx
+#define otbTensorflowStreamerFilter_txx
+
+#include "otbTensorflowStreamerFilter.h"
+#include "itkImageAlgorithm.h"
+
+namespace otb
+{
+
+template <class TInputImage, class TOutputImage>
+TensorflowStreamerFilter<TInputImage, TOutputImage>
+::TensorflowStreamerFilter()
+ {
+  m_OutputGridSize.Fill(1);
+ }
+
+
+template <class TInputImage, class TOutputImage>
+void
+TensorflowStreamerFilter<TInputImage, TOutputImage>
+::GenerateInputRequestedRegion()
+ {
+  // We intentionally break the pipeline
+  ImageType * inputImage = static_cast<ImageType * >(  Superclass::ProcessObject::GetInput(0) );
+  RegionType nullRegion;
+  inputImage->SetRequestedRegion(nullRegion);
+ }
+
+/**
+ * Compute the output image
+ */
+template <class TInputImage, class TOutputImage>
+void
+TensorflowStreamerFilter<TInputImage, TOutputImage>
+::GenerateData()
+ {
+  // Output pointer and requested region
+  OutputImageType * outputPtr = this->GetOutput();
+  const RegionType outputReqRegion = outputPtr->GetRequestedRegion();
+  outputPtr->SetRegions(outputReqRegion);
+  outputPtr->Allocate();
+
+  // Compute the aligned region
+  RegionType region;
+  for(unsigned int dim = 0; dim<OutputImageType::ImageDimension; ++dim)
+    {
+    // Get corners
+    IndexValueType lower = outputReqRegion.GetIndex(dim);
+    IndexValueType upper = lower + outputReqRegion.GetSize(dim);
+
+    // Compute deltas between corners and the grid
+    const IndexValueType deltaLo = lower % m_OutputGridSize[dim];
+    const IndexValueType deltaUp = upper % m_OutputGridSize[dim];
+
+    // Move corners to aligned positions
+    lower -= deltaLo;
+    if (deltaUp > 0)
+      {
+      upper += m_OutputGridSize[dim] - deltaUp;
+      }
+
+    // Update region
+    region.SetIndex(dim, lower);
+    region.SetSize(dim, upper - lower);
+
+    }
+
+  // Compute the number of subregions to process
+  const unsigned int nbTilesX = region.GetSize(0) / m_OutputGridSize[0];
+  const unsigned int nbTilesY = region.GetSize(1) / m_OutputGridSize[1];
+
+  // Progress
+  itk::ProgressReporter progress(this, 0, nbTilesX*nbTilesY);
+
+  // For each tile, propagate the input region and recopy the output
+  ImageType * inputImage = static_cast<ImageType * >(  Superclass::ProcessObject::GetInput(0) );
+  unsigned int tx, ty;
+  RegionType subRegion;
+  subRegion.SetSize(m_OutputGridSize);
+  for (ty = 0; ty < nbTilesY; ty++)
+  {
+    subRegion.SetIndex(1, ty*m_OutputGridSize[1] + region.GetIndex(1));
+    for (tx = 0; tx < nbTilesX; tx++)
+    {
+      // Update the input subregion
+      subRegion.SetIndex(0, tx*m_OutputGridSize[0] + region.GetIndex(0));
+      inputImage->SetRequestedRegion(subRegion);
+      inputImage->PropagateRequestedRegion();
+      inputImage->UpdateOutputData();
+
+      // Copy the subregion to output
+      RegionType cpyRegion(subRegion);
+      cpyRegion.Crop(outputReqRegion);
+      itk::ImageAlgorithm::Copy( inputImage, outputPtr, cpyRegion, cpyRegion );
+
+      progress.CompletedPixel();
+    }
+  }
+
+ }
+
+
+} // end namespace otb
+
+
+#endif
-- 
GitLab