From 78ce0cff3b157047c1369c65c708a6b699f37c49 Mon Sep 17 00:00:00 2001 From: remi cresson <remi.cresson@teledetection.fr> Date: Fri, 21 Sep 2018 20:01:40 +0200 Subject: [PATCH] ADD: streamer filter --- app/otbTensorflowModelServe.cxx | 33 +++---- include/otbTensorflowStreamerFilter.h | 79 ++++++++++++++++ include/otbTensorflowStreamerFilter.hxx | 116 ++++++++++++++++++++++++ 3 files changed, 209 insertions(+), 19 deletions(-) create mode 100644 include/otbTensorflowStreamerFilter.h create mode 100644 include/otbTensorflowStreamerFilter.hxx diff --git a/app/otbTensorflowModelServe.cxx b/app/otbTensorflowModelServe.cxx index cfed551..bbe204f 100644 --- a/app/otbTensorflowModelServe.cxx +++ b/app/otbTensorflowModelServe.cxx @@ -30,8 +30,7 @@ #include "otbTensorflowSource.h" // Streaming -#include "otbImageRegionSquareTileSplitter.h" -#include "itkStreamingImageFilter.h" +#include "otbTensorflowStreamerFilter.h" namespace otb { @@ -58,7 +57,7 @@ public: /** Typedef for streaming */ typedef otb::ImageRegionSquareTileSplitter<FloatVectorImageType::ImageDimension> TileSplitterType; - typedef itk::StreamingImageFilter<FloatVectorImageType, FloatVectorImageType> StreamingFilterType; + typedef otb::TensorflowStreamerFilter<FloatVectorImageType, FloatVectorImageType> StreamingFilterType; /** Typedefs for images */ typedef FloatVectorImageType::SizeType SizeType; @@ -198,9 +197,12 @@ public: AddParameter(ParameterType_Bool, "optim.disabletiling", "Disable tiling"); MandatoryOff ("optim.disabletiling"); SetParameterDescription ("optim.disabletiling", "Tiling avoids to process a too large subset of image, but sometimes it can be useful to disable it"); - AddParameter(ParameterType_Int, "optim.tilesize", "Tile width used to stream the filter output"); - SetMinimumParameterIntValue ("optim.tilesize", 1); - SetDefaultParameterInt ("optim.tilesize", 16); + AddParameter(ParameterType_Int, "optim.tilesizex", "Tile width used to stream the filter output"); + SetMinimumParameterIntValue ("optim.tilesizex", 1); + SetDefaultParameterInt ("optim.tilesizex", 16); + AddParameter(ParameterType_Int, "optim.tilesizey", "Tile height used to stream the filter output"); + SetMinimumParameterIntValue ("optim.tilesizey", 1); + SetDefaultParameterInt ("optim.tilesizey", 16); // Output image AddParameter(ParameterType_OutputImage, "out", "output image"); @@ -292,22 +294,15 @@ public: if (GetParameterInt("optim.disabletiling") != 1) { // Get the tile size - const unsigned int tileSize = GetParameterInt("optim.tilesize"); - otbAppLogINFO("Force tiling with squared tiles of " << tileSize) + SizeType gridSize; + gridSize[0] = GetParameterInt("optim.tilesizex"); + gridSize[1] = GetParameterInt("optim.tilesizey"); - // Update the TensorFlow filter output information to get the output image size - m_TFFilter->UpdateOutputInformation(); + otbAppLogINFO("Force tiling with squared tiles of " << gridSize) - // Splitting using square tiles - TileSplitterType::Pointer splitter = TileSplitterType::New(); - splitter->SetTileSizeAlignment(tileSize); - unsigned int nbDesiredTiles = itk::Math::Ceil<unsigned int>( - double(m_TFFilter->GetOutput()->GetLargestPossibleRegion().GetNumberOfPixels() ) / (tileSize * tileSize) ); - - // Use an itk::StreamingImageFilter to force the computation tile by tile + // Force the computation tile by tile m_StreamFilter = StreamingFilterType::New(); - m_StreamFilter->SetRegionSplitter(splitter); - m_StreamFilter->SetNumberOfStreamDivisions(nbDesiredTiles); + m_StreamFilter->SetOutputGridSize(gridSize); m_StreamFilter->SetInput(m_TFFilter->GetOutput()); SetParameterOutputImage("out", m_StreamFilter->GetOutput()); diff --git a/include/otbTensorflowStreamerFilter.h b/include/otbTensorflowStreamerFilter.h new file mode 100644 index 0000000..9ac462d --- /dev/null +++ b/include/otbTensorflowStreamerFilter.h @@ -0,0 +1,79 @@ +/*========================================================================= + + Copyright (c) Remi Cresson (IRSTEA). All rights reserved. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#ifndef otbTensorflowStreamerFilter_h +#define otbTensorflowStreamerFilter_h + +// Image2image +#include "itkImageToImageFilter.h" + +namespace otb +{ + +/** + * \class TensorflowStreamerFilter + * \brief This filter generates an output image with an internal + * explicit streaming mechanism. + * + * \ingroup OTBTensorflow + */ +template <class TInputImage, class TOutputImage> +class ITK_EXPORT TensorflowStreamerFilter : +public itk::ImageToImageFilter<TInputImage, TOutputImage> +{ + +public: + + /** Standard class typedefs. */ + typedef TensorflowStreamerFilter Self; + typedef itk::ImageToImageFilter<TInputImage, TOutputImage> Superclass; + typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; + + /** Method for creation through the object factory. */ + itkNewMacro(Self); + + /** Run-time type information (and related methods). */ + itkTypeMacro(TensorflowStreamerFilter, itk::ImageToImageFilter); + + /** Images typedefs */ + typedef typename Superclass::InputImageType ImageType; + typedef typename ImageType::IndexType IndexType; + typedef typename ImageType::IndexValueType IndexValueType; + typedef typename ImageType::SizeType SizeType; + typedef typename Superclass::InputImageRegionType RegionType; + + typedef TOutputImage OutputImageType; + + itkSetMacro(OutputGridSize, SizeType); + itkGetMacro(OutputGridSize, SizeType); + +protected: + TensorflowStreamerFilter(); + virtual ~TensorflowStreamerFilter() {}; + + virtual void GenerateInputRequestedRegion(void); + + virtual void GenerateData(); + +private: + TensorflowStreamerFilter(const Self&); //purposely not implemented + void operator=(const Self&); //purposely not implemented + + SizeType m_OutputGridSize; // Output grid size + +}; // end class + + +} // end namespace otb + +#include "otbTensorflowStreamerFilter.hxx" + +#endif diff --git a/include/otbTensorflowStreamerFilter.hxx b/include/otbTensorflowStreamerFilter.hxx new file mode 100644 index 0000000..a2afd3f --- /dev/null +++ b/include/otbTensorflowStreamerFilter.hxx @@ -0,0 +1,116 @@ +/*========================================================================= + + Copyright (c) Remi Cresson (IRSTEA). All rights reserved. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#ifndef otbTensorflowStreamerFilter_txx +#define otbTensorflowStreamerFilter_txx + +#include "otbTensorflowStreamerFilter.h" +#include "itkImageAlgorithm.h" + +namespace otb +{ + +template <class TInputImage, class TOutputImage> +TensorflowStreamerFilter<TInputImage, TOutputImage> +::TensorflowStreamerFilter() + { + m_OutputGridSize.Fill(1); + } + + +template <class TInputImage, class TOutputImage> +void +TensorflowStreamerFilter<TInputImage, TOutputImage> +::GenerateInputRequestedRegion() + { + // We intentionally break the pipeline + ImageType * inputImage = static_cast<ImageType * >( Superclass::ProcessObject::GetInput(0) ); + RegionType nullRegion; + inputImage->SetRequestedRegion(nullRegion); + } + +/** + * Compute the output image + */ +template <class TInputImage, class TOutputImage> +void +TensorflowStreamerFilter<TInputImage, TOutputImage> +::GenerateData() + { + // Output pointer and requested region + OutputImageType * outputPtr = this->GetOutput(); + const RegionType outputReqRegion = outputPtr->GetRequestedRegion(); + outputPtr->SetRegions(outputReqRegion); + outputPtr->Allocate(); + + // Compute the aligned region + RegionType region; + for(unsigned int dim = 0; dim<OutputImageType::ImageDimension; ++dim) + { + // Get corners + IndexValueType lower = outputReqRegion.GetIndex(dim); + IndexValueType upper = lower + outputReqRegion.GetSize(dim); + + // Compute deltas between corners and the grid + const IndexValueType deltaLo = lower % m_OutputGridSize[dim]; + const IndexValueType deltaUp = upper % m_OutputGridSize[dim]; + + // Move corners to aligned positions + lower -= deltaLo; + if (deltaUp > 0) + { + upper += m_OutputGridSize[dim] - deltaUp; + } + + // Update region + region.SetIndex(dim, lower); + region.SetSize(dim, upper - lower); + + } + + // Compute the number of subregions to process + const unsigned int nbTilesX = region.GetSize(0) / m_OutputGridSize[0]; + const unsigned int nbTilesY = region.GetSize(1) / m_OutputGridSize[1]; + + // Progress + itk::ProgressReporter progress(this, 0, nbTilesX*nbTilesY); + + // For each tile, propagate the input region and recopy the output + ImageType * inputImage = static_cast<ImageType * >( Superclass::ProcessObject::GetInput(0) ); + unsigned int tx, ty; + RegionType subRegion; + subRegion.SetSize(m_OutputGridSize); + for (ty = 0; ty < nbTilesY; ty++) + { + subRegion.SetIndex(1, ty*m_OutputGridSize[1] + region.GetIndex(1)); + for (tx = 0; tx < nbTilesX; tx++) + { + // Update the input subregion + subRegion.SetIndex(0, tx*m_OutputGridSize[0] + region.GetIndex(0)); + inputImage->SetRequestedRegion(subRegion); + inputImage->PropagateRequestedRegion(); + inputImage->UpdateOutputData(); + + // Copy the subregion to output + RegionType cpyRegion(subRegion); + cpyRegion.Crop(outputReqRegion); + itk::ImageAlgorithm::Copy( inputImage, outputPtr, cpyRegion, cpyRegion ); + + progress.CompletedPixel(); + } + } + + } + + +} // end namespace otb + + +#endif -- GitLab