otbTensorflowModelServe.cxx 13.5 KB
Newer Older
remi cresson's avatar
remi cresson committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
/*=========================================================================

  Copyright (c) Remi Cresson (IRSTEA). All rights reserved.


     This software is distributed WITHOUT ANY WARRANTY; without even
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
     PURPOSE.  See the above copyright notices for more information.

=========================================================================*/
#include "itkFixedArray.h"
#include "itkObjectFactory.h"
#include "otbWrapperApplicationFactory.h"

// Application engine
#include "otbStandardFilterWatcher.h"
#include "itkFixedArray.h"

// Tensorflow stuff
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/platform/env.h"

// Tensorflow model filter
#include "otbTensorflowMultisourceModelFilter.h"

// Tensorflow graph load
#include "otbTensorflowGraphOperations.h"

// Layerstack
#include "otbTensorflowSource.h"

// Streaming
remi cresson's avatar
remi cresson committed
33
#include "otbTensorflowStreamerFilter.h"
remi cresson's avatar
remi cresson committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

namespace otb
{

namespace Wrapper
{

class TensorflowModelServe : public Application
{
public:
  /** Standard class typedefs. */
  typedef TensorflowModelServe                       Self;
  typedef Application                                Superclass;
  typedef itk::SmartPointer<Self>                    Pointer;
  typedef itk::SmartPointer<const Self>              ConstPointer;

  /** Standard macro */
  itkNewMacro(Self);
  itkTypeMacro(TensorflowModelServe, Application);

  /** Typedefs for tensorflow */
  typedef otb::TensorflowMultisourceModelFilter<FloatVectorImageType, FloatVectorImageType> TFModelFilterType;
  typedef otb::TensorflowSource<FloatVectorImageType> InputImageSource;

  /** Typedef for streaming */
  typedef otb::ImageRegionSquareTileSplitter<FloatVectorImageType::ImageDimension> TileSplitterType;
remi cresson's avatar
remi cresson committed
60
  typedef otb::TensorflowStreamerFilter<FloatVectorImageType, FloatVectorImageType> StreamingFilterType;
remi cresson's avatar
remi cresson committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107

  /** Typedefs for images */
  typedef FloatVectorImageType::SizeType SizeType;

  void DoUpdateParameters()
  {
  }

  //
  // Store stuff related to one source
  //
  struct ProcessObjectsBundle
  {
    InputImageSource m_ImageSource;
    SizeType         m_PatchSize;
    std::string      m_Placeholder;

    // Parameters keys
    std::string m_KeyIn;     // Key of input image list
    std::string m_KeyPszX;   // Key for samples sizes X
    std::string m_KeyPszY;   // Key for samples sizes Y
    std::string m_KeyPHName; // Key for placeholder name in the tensorflow model
  };

  //
  // Add an input source, which includes:
  // -an input image list
  // -an input patchsize (dimensions of samples)
  //
  void AddAnInputImage()
  {
    // Number of source
    unsigned int inputNumber = m_Bundles.size() + 1;

    // Create keys and descriptions
    std::stringstream ss_key_group, ss_desc_group,
    ss_key_in, ss_desc_in,
    ss_key_dims_x, ss_desc_dims_x,
    ss_key_dims_y, ss_desc_dims_y,
    ss_key_ph, ss_desc_ph;

    // Parameter group key/description
    ss_key_group  << "source"                  << inputNumber;
    ss_desc_group << "Parameters for source #" << inputNumber;

    // Parameter group keys
    ss_key_in      << ss_key_group.str() << ".il";
remi cresson's avatar
remi cresson committed
108
109
    ss_key_dims_x  << ss_key_group.str() << ".rfieldx";
    ss_key_dims_y  << ss_key_group.str() << ".rfieldy";
remi cresson's avatar
remi cresson committed
110
111
112
113
    ss_key_ph      << ss_key_group.str() << ".placeholder";

    // Parameter group descriptions
    ss_desc_in     << "Input image (or list to stack) for source #" << inputNumber;
remi cresson's avatar
remi cresson committed
114
115
    ss_desc_dims_x << "Input receptive field (width) for source #"  << inputNumber;
    ss_desc_dims_y << "Input receptive field (height) for source #" << inputNumber;
remi cresson's avatar
remi cresson committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    ss_desc_ph     << "Name of the input placeholder for source #"  << inputNumber;

    // Populate group
    AddParameter(ParameterType_Group,          ss_key_group.str(),  ss_desc_group.str());
    AddParameter(ParameterType_InputImageList, ss_key_in.str(),     ss_desc_in.str() );
    AddParameter(ParameterType_Int,            ss_key_dims_x.str(), ss_desc_dims_x.str());
    SetMinimumParameterIntValue               (ss_key_dims_x.str(), 1);
    AddParameter(ParameterType_Int,            ss_key_dims_y.str(), ss_desc_dims_y.str());
    SetMinimumParameterIntValue               (ss_key_dims_y.str(), 1);
    AddParameter(ParameterType_String,         ss_key_ph.str(),     ss_desc_ph.str());

    // Add a new bundle
    ProcessObjectsBundle bundle;
    bundle.m_KeyIn     = ss_key_in.str();
    bundle.m_KeyPszX   = ss_key_dims_x.str();
    bundle.m_KeyPszY   = ss_key_dims_y.str();
    bundle.m_KeyPHName = ss_key_ph.str();

    m_Bundles.push_back(bundle);

  }

  void DoInit()
  {

    // Documentation
    SetName("TensorflowModelServe");
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
    SetDescription("Multisource deep learning classifier using TensorFlow. Change the "
        + tf::ENV_VAR_NAME_NSOURCES + " environment variable to set the number of sources.");
    SetDocLongDescription("The application run a TensorFlow model over multiple data sources. "
        "The number of input sources can be changed at runtime by setting the system "
        "environment variable " + tf::ENV_VAR_NAME_NSOURCES + ". For each source, you have to "
        "set (1) the placeholder name, as named in the TensorFlow model, (2) the receptive "
        "field and (3) the image(s) source. The output is a multiband image, stacking all "
        "outputs tensors together: you have to specify (1) the names of the output tensors, as "
        "named in the TensorFlow model (typically, an operator's output) and (2) the expression "
        "field of each output tensor. The output tensors values will be stacked in the same "
        "order as they appear in the \"model.output\" parameter (you can use a space separator "
        "between names). You might consider to use extended filename to bypass the automatic "
        "memory footprint calculator of the otb application engine, and set a good splitting "
        "strategy (Square tiles is good for convolutional networks) or use the \"optim\" "
        "parameter group to impose your squared tiles sizes");
remi cresson's avatar
remi cresson committed
158
159
160
161
162
163
164
165
166
    SetDocAuthors("Remi Cresson");

    // Input/output images
    AddAnInputImage();
    for (int i = 1; i < tf::GetNumberOfSources() ; i++)
      AddAnInputImage();

    // Input model
    AddParameter(ParameterType_Group,         "model",           "model parameters");
167
    AddParameter(ParameterType_Directory,     "model.dir",       "TensorFlow model_save directory");
remi cresson's avatar
remi cresson committed
168
    MandatoryOn                              ("model.dir");
169
170
    SetParameterDescription                  ("model.dir", "The model directory should contains the model Google Protobuf (.pb) and variables");

remi cresson's avatar
remi cresson committed
171
172
    AddParameter(ParameterType_StringList,    "model.userplaceholders",    "Additional single-valued placeholders. Supported types: int, float, bool.");
    MandatoryOff                             ("model.userplaceholders");
173
    SetParameterDescription                  ("model.userplaceholders", "Syntax to use is \"placeholder_1=value_1 ... placeholder_N=value_N\"");
remi cresson's avatar
remi cresson committed
174
175
176
177
178
    AddParameter(ParameterType_Bool,          "model.fullyconv", "Fully convolutional");
    MandatoryOff                             ("model.fullyconv");

    // Output tensors parameters
    AddParameter(ParameterType_Group,         "output",          "Output tensors parameters");
179
    AddParameter(ParameterType_Float,         "output.spcscale", "The output spacing scale, related to the first input");
remi cresson's avatar
remi cresson committed
180
    SetDefaultParameterFloat                 ("output.spcscale", 1.0);
181
    SetParameterDescription                  ("output.spcscale", "The output image size/scale and spacing*scale where size and spacing corresponds to the first input");
remi cresson's avatar
remi cresson committed
182
183
184
185
    AddParameter(ParameterType_StringList,    "output.names",    "Names of the output tensors");
    MandatoryOn                              ("output.names");

    // Output Field of Expression
remi cresson's avatar
remi cresson committed
186
187
188
189
190
191
192
193
    AddParameter(ParameterType_Int,           "output.efieldx", "The output expression field (width)");
    SetMinimumParameterIntValue              ("output.efieldx", 1);
    SetDefaultParameterInt                   ("output.efieldx", 1);
    MandatoryOn                              ("output.efieldx");
    AddParameter(ParameterType_Int,           "output.efieldy", "The output expression field (height)");
    SetMinimumParameterIntValue              ("output.efieldy", 1);
    SetDefaultParameterInt                   ("output.efieldy", 1);
    MandatoryOn                              ("output.efieldy");
remi cresson's avatar
remi cresson committed
194
195

    // Fine tuning
remi cresson's avatar
remi cresson committed
196
197
198
    AddParameter(ParameterType_Group,         "optim" , "This group of parameters allows optimization of processing time");
    AddParameter(ParameterType_Bool,          "optim.disabletiling", "Disable tiling");
    MandatoryOff                             ("optim.disabletiling");
199
    SetParameterDescription                  ("optim.disabletiling", "Tiling avoids to process a too large subset of image, but sometimes it can be useful to disable it");
remi cresson's avatar
remi cresson committed
200
201
202
203
204
205
    AddParameter(ParameterType_Int,           "optim.tilesizex", "Tile width used to stream the filter output");
    SetMinimumParameterIntValue              ("optim.tilesizex", 1);
    SetDefaultParameterInt                   ("optim.tilesizex", 16);
    AddParameter(ParameterType_Int,           "optim.tilesizey", "Tile height used to stream the filter output");
    SetMinimumParameterIntValue              ("optim.tilesizey", 1);
    SetDefaultParameterInt                   ("optim.tilesizey", 16);
remi cresson's avatar
remi cresson committed
206
207
208
209
210
211
212

    // Output image
    AddParameter(ParameterType_OutputImage, "out", "output image");

    // Example
    SetDocExampleParameterValue("source1.il",             "spot6pms.tif");
    SetDocExampleParameterValue("source1.placeholder",    "x1");
remi cresson's avatar
remi cresson committed
213
214
    SetDocExampleParameterValue("source1.rfieldx",        "16");
    SetDocExampleParameterValue("source1.rfieldy",        "16");
remi cresson's avatar
remi cresson committed
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
    SetDocExampleParameterValue("model.dir",              "/tmp/my_saved_model/");
    SetDocExampleParameterValue("model.userplaceholders", "is_training=false dropout=0.0");
    SetDocExampleParameterValue("output.names",           "out_predict1 out_proba1");
    SetDocExampleParameterValue("out",                    "\"classif128tgt.tif?&streaming:type=tiled&streaming:sizemode=height&streaming:sizevalue=256\"");

  }

  //
  // Prepare bundles from the number of points
  //
  void PrepareInputs()
  {

    for (auto& bundle: m_Bundles)
    {
      // Setting the image source
      FloatVectorImageListType::Pointer list = GetParameterImageList(bundle.m_KeyIn);
      bundle.m_ImageSource.Set(list);
      bundle.m_Placeholder = GetParameterAsString(bundle.m_KeyPHName);
      bundle.m_PatchSize[0] = GetParameterInt(bundle.m_KeyPszX);
      bundle.m_PatchSize[1] = GetParameterInt(bundle.m_KeyPszY);

      otbAppLogINFO("Source info :");
238
239
      otbAppLogINFO("Receptive field  : " << bundle.m_PatchSize  );
      otbAppLogINFO("Placeholder name : " << bundle.m_Placeholder);
remi cresson's avatar
remi cresson committed
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
    }
  }

  void DoExecute()
  {

    // Load the Tensorflow bundle
    tf::LoadModel(GetParameterAsString("model.dir"), m_SavedModel);

    // Prepare inputs
    PrepareInputs();

    // Setup filter
    m_TFFilter = TFModelFilterType::New();
    m_TFFilter->SetGraph(m_SavedModel.meta_graph_def.graph_def());
    m_TFFilter->SetSession(m_SavedModel.session.get());
Cresson Remi's avatar
Cresson Remi committed
256
    m_TFFilter->SetOutputTensors(GetParameterStringList("output.names"));
remi cresson's avatar
remi cresson committed
257
258
259
260
261
    m_TFFilter->SetOutputSpacingScale(GetParameterFloat("output.spcscale"));
    otbAppLogINFO("Output spacing ratio: " << m_TFFilter->GetOutputSpacingScale());

    // Get user placeholders
    TFModelFilterType::StringList expressions = GetParameterStringList("model.userplaceholders");
Cresson Remi's avatar
Cresson Remi committed
262
    TFModelFilterType::DictType dict;
remi cresson's avatar
remi cresson committed
263
264
    for (auto& exp: expressions)
    {
Cresson Remi's avatar
Cresson Remi committed
265
      TFModelFilterType::DictElementType entry = tf::ExpressionToTensor(exp);
remi cresson's avatar
remi cresson committed
266
267
268
269
270
271
272
273
274
      dict.push_back(entry);

      otbAppLogINFO("Using placeholder " << entry.first << " with " << tf::PrintTensorInfos(entry.second));
    }
    m_TFFilter->SetUserPlaceholders(dict);

    // Input sources
    for (auto& bundle: m_Bundles)
    {
Cresson Remi's avatar
Cresson Remi committed
275
      m_TFFilter->PushBackInputTensorBundle(bundle.m_Placeholder, bundle.m_PatchSize, bundle.m_ImageSource.Get());
remi cresson's avatar
remi cresson committed
276
277
278
279
280
    }

    // Fully convolutional mode on/off
    if (GetParameterInt("model.fullyconv")==1)
    {
281
      otbAppLogINFO("The TensorFlow model is used in fully convolutional mode");
remi cresson's avatar
remi cresson committed
282
283
284
285
286
      m_TFFilter->SetFullyConvolutional(true);
    }

    // Output field of expression
    FloatVectorImageType::SizeType foe;
287
288
    foe[0] = GetParameterInt("output.efieldx");
    foe[1] = GetParameterInt("output.efieldy");
Cresson Remi's avatar
Cresson Remi committed
289
    m_TFFilter->SetOutputExpressionFields({foe});
remi cresson's avatar
remi cresson committed
290

Cresson Remi's avatar
Cresson Remi committed
291
    otbAppLogINFO("Output field of expression: " << m_TFFilter->GetOutputExpressionFields()[0]);
remi cresson's avatar
remi cresson committed
292
293

    // Streaming
Cresson Remi's avatar
Cresson Remi committed
294
    if (GetParameterInt("optim.disabletiling") != 1)
remi cresson's avatar
remi cresson committed
295
296
    {
      // Get the tile size
remi cresson's avatar
remi cresson committed
297
298
299
      SizeType gridSize;
      gridSize[0] = GetParameterInt("optim.tilesizex");
      gridSize[1] = GetParameterInt("optim.tilesizey");
remi cresson's avatar
remi cresson committed
300

remi cresson's avatar
remi cresson committed
301
      otbAppLogINFO("Force tiling with squared tiles of " << gridSize)
remi cresson's avatar
remi cresson committed
302

remi cresson's avatar
remi cresson committed
303
      // Force the computation tile by tile
remi cresson's avatar
remi cresson committed
304
      m_StreamFilter = StreamingFilterType::New();
remi cresson's avatar
remi cresson committed
305
      m_StreamFilter->SetOutputGridSize(gridSize);
remi cresson's avatar
remi cresson committed
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
      m_StreamFilter->SetInput(m_TFFilter->GetOutput());

      SetParameterOutputImage("out", m_StreamFilter->GetOutput());
    }
    else
    {
      otbAppLogINFO("Tiling disabled");
      SetParameterOutputImage("out", m_TFFilter->GetOutput());
    }
  }

private:

  TFModelFilterType::Pointer   m_TFFilter;
  StreamingFilterType::Pointer m_StreamFilter;
  tensorflow::SavedModelBundle m_SavedModel; // must be alive during all the execution of the application !

  std::vector<ProcessObjectsBundle>           m_Bundles;

}; // end of class

} // namespace wrapper
} // namespace otb

OTB_APPLICATION_EXPORT( otb::Wrapper::TensorflowModelServe )