An error occurred while loading the file. Please try again.
-
Remi Cresson authored448a826a
/*=========================================================================
Copyright (c) 2018-2019 IRSTEA
Copyright (c) 2020-2021 INRAE
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#include "itkFixedArray.h"
#include "itkObjectFactory.h"
#include "otbWrapperApplicationFactory.h"
// Application engine
#include "otbStandardFilterWatcher.h"
#include "itkFixedArray.h"
// Tensorflow SavedModel
#include "tensorflow/cc/saved_model/loader.h"
// Tensorflow model filter
#include "otbTensorflowMultisourceModelFilter.h"
// Tensorflow graph load
#include "otbTensorflowGraphOperations.h"
// Layerstack
#include "otbTensorflowSource.h"
// Streaming
#include "otbTensorflowStreamerFilter.h"
namespace otb
{
namespace Wrapper
{
class TensorflowModelServe : public Application
{
public:
/** Standard class typedefs. */
typedef TensorflowModelServe Self;
typedef Application Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Standard macro */
itkNewMacro(Self);
itkTypeMacro(TensorflowModelServe, Application);
/** Typedefs for tensorflow */
typedef otb::TensorflowMultisourceModelFilter<FloatVectorImageType, FloatVectorImageType> TFModelFilterType;
typedef otb::TensorflowSource<FloatVectorImageType> InputImageSource;
/** Typedef for streaming */
typedef otb::ImageRegionSquareTileSplitter<FloatVectorImageType::ImageDimension> TileSplitterType;
typedef otb::TensorflowStreamerFilter<FloatVectorImageType, FloatVectorImageType> StreamingFilterType;
/** Typedefs for images */
typedef FloatVectorImageType::SizeType SizeType;
//
// Store stuff related to one source
//
struct ProcessObjectsBundle
{
InputImageSource m_ImageSource;
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
SizeType m_PatchSize;
std::string m_Placeholder;
float m_NodataValue;
bool m_HasNodata;
// Parameters keys
std::string m_KeyIn; // Key of input image list
std::string m_KeyPszX; // Key for receptive field size in X
std::string m_KeyPszY; // Key for receptive field size in Y
std::string m_KeyND; // Key for no-data value
std::string m_KeyPHName; // Key for placeholder name in the tensorflow model
};
//
// Add an input source, which includes:
// -an input image list
// -an input patchsize (dimensions of samples)
//
void AddAnInputImage()
{
// Number of source
unsigned int inputNumber = m_Bundles.size() + 1;
// Create keys and descriptions
std::stringstream ss_key_group, ss_desc_group,
ss_key_in, ss_desc_in,
ss_key_dims_x, ss_desc_dims_x,
ss_key_dims_y, ss_desc_dims_y,
ss_key_ph, ss_desc_ph,
ss_key_nd, ss_desc_nd;
// Parameter group key/description
ss_key_group << "source" << inputNumber;
ss_desc_group << "Parameters for source #" << inputNumber;
// Parameter group keys
ss_key_in << ss_key_group.str() << ".il";
ss_key_dims_x << ss_key_group.str() << ".rfieldx";
ss_key_dims_y << ss_key_group.str() << ".rfieldy";
ss_key_ph << ss_key_group.str() << ".placeholder";
ss_key_nd << ss_key_group.str() << ".nodata";
// Parameter group descriptions
ss_desc_in << "Input image (or list to stack) for source #" << inputNumber;
ss_desc_dims_x << "Input receptive field (width) for source #" << inputNumber;
ss_desc_dims_y << "Input receptive field (height) for source #" << inputNumber;
ss_desc_ph << "Name of the input placeholder for source #" << inputNumber;
ss_desc_nd << "No-data value for pixels of source #" << inputNumber;
// Populate group
AddParameter(ParameterType_Group, ss_key_group.str(), ss_desc_group.str());
AddParameter(ParameterType_InputImageList, ss_key_in.str(), ss_desc_in.str() );
AddParameter(ParameterType_Int, ss_key_dims_x.str(), ss_desc_dims_x.str());
SetMinimumParameterIntValue (ss_key_dims_x.str(), 1);
SetDefaultParameterInt (ss_key_dims_x.str(), 1);
AddParameter(ParameterType_Int, ss_key_dims_y.str(), ss_desc_dims_y.str());
SetMinimumParameterIntValue (ss_key_dims_y.str(), 1);
SetDefaultParameterInt (ss_key_dims_y.str(), 1);
AddParameter(ParameterType_String, ss_key_ph.str(), ss_desc_ph.str());
MandatoryOff (ss_key_ph.str());
AddParameter(ParameterType_Float, ss_key_nd.str(), ss_desc_nd.str());
MandatoryOff (ss_key_nd.str());
// Add a new bundle
ProcessObjectsBundle bundle;
bundle.m_KeyIn = ss_key_in.str();
bundle.m_KeyPszX = ss_key_dims_x.str();
bundle.m_KeyPszY = ss_key_dims_y.str();
bundle.m_KeyPHName = ss_key_ph.str();
bundle.m_KeyND = ss_key_nd.str();
141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
m_Bundles.push_back(bundle);
}
void DoInit()
{
// Documentation
SetName("TensorflowModelServe");
SetDescription("Multisource deep learning classifier using TensorFlow. Change the "
+ tf::ENV_VAR_NAME_NSOURCES + " environment variable to set the number of sources.");
SetDocLongDescription("The application run a TensorFlow model over multiple data sources. "
"The number of input sources can be changed at runtime by setting the system "
"environment variable " + tf::ENV_VAR_NAME_NSOURCES + ". For each source, you have to "
"set (1) the placeholder name, as named in the TensorFlow model, (2) the receptive "
"field and (3) the image(s) source. The output is a multiband image, stacking all "
"outputs tensors together: you have to specify (1) the names of the output tensors, as "
"named in the TensorFlow model (typically, an operator's output) and (2) the expression "
"field of each output tensor. The output tensors values will be stacked in the same "
"order as they appear in the \"model.output\" parameter (you can use a space separator "
"between names). You might consider to use extended filename to bypass the automatic "
"memory footprint calculator of the otb application engine, and set a good splitting "
"strategy (Square tiles is good for convolutional networks) or use the \"optim\" "
"parameter group to impose your squared tiles sizes");
SetDocAuthors("Remi Cresson");
AddDocTag(Tags::Learning);
// Input/output images
AddAnInputImage();
for (int i = 1; i < tf::GetNumberOfSources() ; i++)
AddAnInputImage();
// Input model
AddParameter(ParameterType_Group, "model", "model parameters");
AddParameter(ParameterType_Directory, "model.dir", "TensorFlow SavedModel directory");
MandatoryOn ("model.dir");
SetParameterDescription ("model.dir", "The model directory should contains the model Google Protobuf (.pb) and variables");
AddParameter(ParameterType_StringList, "model.userplaceholders", "Additional single-valued placeholders. Supported types: int, float, bool.");
MandatoryOff ("model.userplaceholders");
SetParameterDescription ("model.userplaceholders", "Syntax to use is \"placeholder_1=value_1 ... placeholder_N=value_N\"");
AddParameter(ParameterType_Bool, "model.fullyconv", "Fully convolutional");
MandatoryOff ("model.fullyconv");
AddParameter(ParameterType_StringList, "model.tagsets", "Which tags (i.e. v1.MetaGraphDefs) to load from the saved model. Currently, only one tag is supported. Can be retrieved by running `saved_model_cli show --dir your_model_dir --all`");
MandatoryOff ("model.tagsets");
// Output tensors parameters
AddParameter(ParameterType_Group, "output", "Output tensors parameters");
AddParameter(ParameterType_Float, "output.spcscale", "The output spacing scale, related to the first input");
SetDefaultParameterFloat ("output.spcscale", 1.0);
SetParameterDescription ("output.spcscale", "The output image size/scale and spacing*scale where size and spacing corresponds to the first input");
AddParameter(ParameterType_StringList, "output.names", "Names of the output tensors");
MandatoryOff ("output.names");
// Output background value
AddParameter(ParameterType_Float, "output.bv", "Output background value");
SetDefaultParameterFloat ("output.bv", 0.0);
SetParameterDescription ("output.bv", "The value used when one input has only no-data values in its receptive field");
// Output Field of Expression
AddParameter(ParameterType_Int, "output.efieldx", "The output expression field (width)");
SetMinimumParameterIntValue ("output.efieldx", 1);
SetDefaultParameterInt ("output.efieldx", 1);
MandatoryOn ("output.efieldx");
AddParameter(ParameterType_Int, "output.efieldy", "The output expression field (height)");
SetMinimumParameterIntValue ("output.efieldy", 1);
SetDefaultParameterInt ("output.efieldy", 1);
MandatoryOn ("output.efieldy");
211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
// Fine tuning
AddParameter(ParameterType_Group, "optim" , "This group of parameters allows optimization of processing time");
AddParameter(ParameterType_Bool, "optim.disabletiling", "Disable tiling");
MandatoryOff ("optim.disabletiling");
SetParameterDescription ("optim.disabletiling", "Tiling avoids to process a too large subset of image, but sometimes it can be useful to disable it");
AddParameter(ParameterType_Int, "optim.tilesizex", "Tile width used to stream the filter output");
SetMinimumParameterIntValue ("optim.tilesizex", 1);
SetDefaultParameterInt ("optim.tilesizex", 16);
AddParameter(ParameterType_Int, "optim.tilesizey", "Tile height used to stream the filter output");
SetMinimumParameterIntValue ("optim.tilesizey", 1);
SetDefaultParameterInt ("optim.tilesizey", 16);
// Output image
AddParameter(ParameterType_OutputImage, "out", "output image");
// Example
SetDocExampleParameterValue("source1.il", "spot6pms.tif");
SetDocExampleParameterValue("source1.placeholder", "x1");
SetDocExampleParameterValue("source1.rfieldx", "16");
SetDocExampleParameterValue("source1.rfieldy", "16");
SetDocExampleParameterValue("model.dir", "/tmp/my_saved_model/");
SetDocExampleParameterValue("model.userplaceholders", "is_training=false dropout=0.0");
SetDocExampleParameterValue("output.names", "out_predict1 out_proba1");
SetDocExampleParameterValue("out", "\"classif128tgt.tif?&streaming:type=tiled&streaming:sizemode=height&streaming:sizevalue=256\"");
}
//
// Prepare bundles from the number of points
//
void PrepareInputs()
{
for (auto& bundle: m_Bundles)
{
// Setting the image source
FloatVectorImageListType::Pointer list = GetParameterImageList(bundle.m_KeyIn);
bundle.m_ImageSource.Set(list);
bundle.m_Placeholder = GetParameterAsString(bundle.m_KeyPHName);
bundle.m_PatchSize[0] = GetParameterInt(bundle.m_KeyPszX);
bundle.m_PatchSize[1] = GetParameterInt(bundle.m_KeyPszY);
bundle.m_HasNodata = HasValue(bundle.m_KeyND);
bundle.m_NodataValue = (bundle.m_HasNodata == true) ? GetParameterFloat(bundle.m_KeyND) : 0;
otbAppLogINFO("Source info :");
otbAppLogINFO("Receptive field : " << bundle.m_PatchSize );
otbAppLogINFO("Placeholder name : " << bundle.m_Placeholder);
if (bundle.m_HasNodata == true)
otbAppLogINFO("No-data value : " << bundle.m_NodataValue);
}
}
void DoExecute()
{
// Load the Tensorflow bundle
tf::LoadModel(GetParameterAsString("model.dir"), m_SavedModel, GetParameterStringList("model.tagsets"));
// Prepare inputs
PrepareInputs();
// Setup filter
m_TFFilter = TFModelFilterType::New();
m_TFFilter->SetSavedModel(&m_SavedModel);
m_TFFilter->SetOutputTensors(GetParameterStringList("output.names"));
m_TFFilter->SetOutputSpacingScale(GetParameterFloat("output.spcscale"));
otbAppLogINFO("Output spacing ratio: " << m_TFFilter->GetOutputSpacingScale());
// Get user placeholders
281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
TFModelFilterType::StringList expressions = GetParameterStringList("model.userplaceholders");
TFModelFilterType::DictType dict;
for (auto& exp: expressions)
{
TFModelFilterType::DictElementType entry = tf::ExpressionToTensor(exp);
dict.push_back(entry);
otbAppLogINFO("Using placeholder " << entry.first << " with " << tf::PrintTensorInfos(entry.second));
}
m_TFFilter->SetUserPlaceholders(dict);
// Input sources
for (auto& bundle: m_Bundles)
{
m_TFFilter->PushBackInputTensorBundle(bundle.m_Placeholder, bundle.m_PatchSize, bundle.m_ImageSource.Get(), bundle.m_HasNodata, bundle.m_NodataValue);
}
// Fully convolutional mode on/off
if (GetParameterInt("model.fullyconv")==1)
{
otbAppLogINFO("The TensorFlow model is used in fully convolutional mode");
m_TFFilter->SetFullyConvolutional(true);
}
// Output background value
const float outBV = GetParameterFloat("output.bv");
otbAppLogINFO("Setting background value to " << outBV);
m_TFFilter->SetOutputBackgroundValue(outBV);
// Output field of expression
FloatVectorImageType::SizeType foe;
foe[0] = GetParameterInt("output.efieldx");
foe[1] = GetParameterInt("output.efieldy");
m_TFFilter->SetOutputExpressionFields({foe});
otbAppLogINFO("Output field of expression: " << m_TFFilter->GetOutputExpressionFields()[0]);
// Streaming
if (GetParameterInt("optim.disabletiling") != 1)
{
// Get the tile size
SizeType tileSize;
tileSize[0] = GetParameterInt("optim.tilesizex");
tileSize[1] = GetParameterInt("optim.tilesizey");
// Check that the tile size is aligned to the field of expression
for (unsigned int i = 0 ; i < FloatVectorImageType::ImageDimension ; i++)
if (tileSize[i] % foe[i] != 0)
{
SizeType::SizeValueType newSize = 1 + std::floor(tileSize[i] / foe[i]);
newSize *= foe[i];
otbAppLogWARNING("Aligning the tiling to the output expression field "
<< "for better performances (dim " << i << "). New value set to " << newSize)
tileSize[i] = newSize;
}
otbAppLogINFO("Force tiling with squared tiles of " << tileSize)
// Force the computation tile by tile
m_StreamFilter = StreamingFilterType::New();
m_StreamFilter->SetOutputGridSize(tileSize);
m_StreamFilter->SetInput(m_TFFilter->GetOutput());
SetParameterOutputImage("out", m_StreamFilter->GetOutput());
}
else
{
otbAppLogINFO("Tiling disabled");
351352353354355356357358359360361362363364365366367368369370371372373374
SetParameterOutputImage("out", m_TFFilter->GetOutput());
}
}
void DoUpdateParameters()
{
}
private:
TFModelFilterType::Pointer m_TFFilter;
StreamingFilterType::Pointer m_StreamFilter;
tensorflow::SavedModelBundle m_SavedModel; // must be alive during all the execution of the application !
std::vector<ProcessObjectsBundle> m_Bundles;
}; // end of class
} // namespace wrapper
} // namespace otb
OTB_APPLICATION_EXPORT( otb::Wrapper::TensorflowModelServe )