otbTensorflowModelServe.cxx 14.05 KiB
/*=========================================================================
  Copyright (c) Remi Cresson (IRSTEA). All rights reserved.
     This software is distributed WITHOUT ANY WARRANTY; without even
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
     PURPOSE.  See the above copyright notices for more information.
=========================================================================*/
#include "itkFixedArray.h"
#include "itkObjectFactory.h"
#include "otbWrapperApplicationFactory.h"
// Application engine
#include "otbStandardFilterWatcher.h"
#include "itkFixedArray.h"
// Tensorflow stuff
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/platform/env.h"
// Tensorflow model filter
#include "otbTensorflowMultisourceModelFilter.h"
// Tensorflow graph load
#include "otbTensorflowGraphOperations.h"
// Layerstack
#include "otbTensorflowSource.h"
// Streaming
#include "otbTensorflowStreamerFilter.h"
namespace otb
namespace Wrapper
class TensorflowModelServe : public Application
public:
  /** Standard class typedefs. */
  typedef TensorflowModelServe                       Self;
  typedef Application                                Superclass;
  typedef itk::SmartPointer<Self>                    Pointer;
  typedef itk::SmartPointer<const Self>              ConstPointer;
  /** Standard macro */
  itkNewMacro(Self);
  itkTypeMacro(TensorflowModelServe, Application);
  /** Typedefs for tensorflow */
  typedef otb::TensorflowMultisourceModelFilter<FloatVectorImageType, FloatVectorImageType> TFModelFilterType;
  typedef otb::TensorflowSource<FloatVectorImageType> InputImageSource;
  /** Typedef for streaming */
  typedef otb::ImageRegionSquareTileSplitter<FloatVectorImageType::ImageDimension> TileSplitterType;
  typedef otb::TensorflowStreamerFilter<FloatVectorImageType, FloatVectorImageType> StreamingFilterType;
  /** Typedefs for images */
  typedef FloatVectorImageType::SizeType SizeType;
  void DoUpdateParameters()
  // Store stuff related to one source
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
// struct ProcessObjectsBundle { InputImageSource m_ImageSource; SizeType m_PatchSize; std::string m_Placeholder; // Parameters keys std::string m_KeyIn; // Key of input image list std::string m_KeyPszX; // Key for samples sizes X std::string m_KeyPszY; // Key for samples sizes Y std::string m_KeyPHName; // Key for placeholder name in the tensorflow model }; // // Add an input source, which includes: // -an input image list // -an input patchsize (dimensions of samples) // void AddAnInputImage() { // Number of source unsigned int inputNumber = m_Bundles.size() + 1; // Create keys and descriptions std::stringstream ss_key_group, ss_desc_group, ss_key_in, ss_desc_in, ss_key_dims_x, ss_desc_dims_x, ss_key_dims_y, ss_desc_dims_y, ss_key_ph, ss_desc_ph; // Parameter group key/description ss_key_group << "source" << inputNumber; ss_desc_group << "Parameters for source #" << inputNumber; // Parameter group keys ss_key_in << ss_key_group.str() << ".il"; ss_key_dims_x << ss_key_group.str() << ".rfieldx"; ss_key_dims_y << ss_key_group.str() << ".rfieldy"; ss_key_ph << ss_key_group.str() << ".placeholder"; // Parameter group descriptions ss_desc_in << "Input image (or list to stack) for source #" << inputNumber; ss_desc_dims_x << "Input receptive field (width) for source #" << inputNumber; ss_desc_dims_y << "Input receptive field (height) for source #" << inputNumber; ss_desc_ph << "Name of the input placeholder for source #" << inputNumber; // Populate group AddParameter(ParameterType_Group, ss_key_group.str(), ss_desc_group.str()); AddParameter(ParameterType_InputImageList, ss_key_in.str(), ss_desc_in.str() ); AddParameter(ParameterType_Int, ss_key_dims_x.str(), ss_desc_dims_x.str()); SetMinimumParameterIntValue (ss_key_dims_x.str(), 1); AddParameter(ParameterType_Int, ss_key_dims_y.str(), ss_desc_dims_y.str()); SetMinimumParameterIntValue (ss_key_dims_y.str(), 1); AddParameter(ParameterType_String, ss_key_ph.str(), ss_desc_ph.str()); // Add a new bundle ProcessObjectsBundle bundle; bundle.m_KeyIn = ss_key_in.str(); bundle.m_KeyPszX = ss_key_dims_x.str(); bundle.m_KeyPszY = ss_key_dims_y.str(); bundle.m_KeyPHName = ss_key_ph.str(); m_Bundles.push_back(bundle); } void DoInit() {