An error occurred while loading the file. Please try again.
-
Dorchies David authored2d4548d2
/*=========================================================================
Copyright (c) Remi Cresson (IRSTEA). All rights reserved.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#include "itkFixedArray.h"
#include "itkObjectFactory.h"
#include "otbWrapperApplicationFactory.h"
// Application engine
#include "otbStandardFilterWatcher.h"
#include "itkFixedArray.h"
// Tensorflow stuff
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/platform/env.h"
// Tensorflow model filter
#include "otbTensorflowMultisourceModelFilter.h"
// Tensorflow graph load
#include "otbTensorflowGraphOperations.h"
// Layerstack
#include "otbTensorflowSource.h"
// Streaming
#include "otbTensorflowStreamerFilter.h"
namespace otb
{
namespace Wrapper
{
class TensorflowModelServe : public Application
{
public:
/** Standard class typedefs. */
typedef TensorflowModelServe Self;
typedef Application Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Standard macro */
itkNewMacro(Self);
itkTypeMacro(TensorflowModelServe, Application);
/** Typedefs for tensorflow */
typedef otb::TensorflowMultisourceModelFilter<FloatVectorImageType, FloatVectorImageType> TFModelFilterType;
typedef otb::TensorflowSource<FloatVectorImageType> InputImageSource;
/** Typedef for streaming */
typedef otb::ImageRegionSquareTileSplitter<FloatVectorImageType::ImageDimension> TileSplitterType;
typedef otb::TensorflowStreamerFilter<FloatVectorImageType, FloatVectorImageType> StreamingFilterType;
/** Typedefs for images */
typedef FloatVectorImageType::SizeType SizeType;
void DoUpdateParameters()
{
}
//
// Store stuff related to one source
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
//
struct ProcessObjectsBundle
{
InputImageSource m_ImageSource;
SizeType m_PatchSize;
std::string m_Placeholder;
// Parameters keys
std::string m_KeyIn; // Key of input image list
std::string m_KeyPszX; // Key for samples sizes X
std::string m_KeyPszY; // Key for samples sizes Y
std::string m_KeyPHName; // Key for placeholder name in the tensorflow model
};
//
// Add an input source, which includes:
// -an input image list
// -an input patchsize (dimensions of samples)
//
void AddAnInputImage()
{
// Number of source
unsigned int inputNumber = m_Bundles.size() + 1;
// Create keys and descriptions
std::stringstream ss_key_group, ss_desc_group,
ss_key_in, ss_desc_in,
ss_key_dims_x, ss_desc_dims_x,
ss_key_dims_y, ss_desc_dims_y,
ss_key_ph, ss_desc_ph;
// Parameter group key/description
ss_key_group << "source" << inputNumber;
ss_desc_group << "Parameters for source #" << inputNumber;
// Parameter group keys
ss_key_in << ss_key_group.str() << ".il";
ss_key_dims_x << ss_key_group.str() << ".rfieldx";
ss_key_dims_y << ss_key_group.str() << ".rfieldy";
ss_key_ph << ss_key_group.str() << ".placeholder";
// Parameter group descriptions
ss_desc_in << "Input image (or list to stack) for source #" << inputNumber;
ss_desc_dims_x << "Input receptive field (width) for source #" << inputNumber;
ss_desc_dims_y << "Input receptive field (height) for source #" << inputNumber;
ss_desc_ph << "Name of the input placeholder for source #" << inputNumber;
// Populate group
AddParameter(ParameterType_Group, ss_key_group.str(), ss_desc_group.str());
AddParameter(ParameterType_InputImageList, ss_key_in.str(), ss_desc_in.str() );
AddParameter(ParameterType_Int, ss_key_dims_x.str(), ss_desc_dims_x.str());
SetMinimumParameterIntValue (ss_key_dims_x.str(), 1);
AddParameter(ParameterType_Int, ss_key_dims_y.str(), ss_desc_dims_y.str());
SetMinimumParameterIntValue (ss_key_dims_y.str(), 1);
AddParameter(ParameterType_String, ss_key_ph.str(), ss_desc_ph.str());
// Add a new bundle
ProcessObjectsBundle bundle;
bundle.m_KeyIn = ss_key_in.str();
bundle.m_KeyPszX = ss_key_dims_x.str();
bundle.m_KeyPszY = ss_key_dims_y.str();
bundle.m_KeyPHName = ss_key_ph.str();
m_Bundles.push_back(bundle);
}
void DoInit()
{