otbTensorflowMultisourceModelBase.h 7.51 KB
Newer Older
remi cresson's avatar
remi cresson committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
/*=========================================================================

  Copyright (c) Remi Cresson (IRSTEA). All rights reserved.


     This software is distributed WITHOUT ANY WARRANTY; without even
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
     PURPOSE.  See the above copyright notices for more information.

=========================================================================*/
#ifndef otbTensorflowMultisourceModelBase_h
#define otbTensorflowMultisourceModelBase_h

#include "itkProcessObject.h"
#include "itkNumericTraits.h"
#include "itkSimpleDataObjectDecorator.h"

// Tensorflow
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/platform/env.h"

// Tensorflow helpers
#include "otbTensorflowGraphOperations.h"
#include "otbTensorflowDataTypeBridge.h"
#include "otbTensorflowCopyUtils.h"
#include "otbTensorflowCommon.h"

namespace otb
{

/**
 * \class TensorflowMultisourceModelBase
33
 * \brief This filter is the base class for all TensorFlow model filters.
remi cresson's avatar
remi cresson committed
34
 *
35
36
 * This abstract class implements a number of generic methods that are used in
 * filters that use the TensorFlow engine.
remi cresson's avatar
remi cresson committed
37
 *
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
 * The filter has N input images (Input), each one corresponding to a placeholder
 * that will fed the TensorFlow model. For each input, the name of the
 * placeholder (InputPlaceholders, a std::vector of std::string) and the
 * receptive field (InputReceptiveFields, a std::vector of SizeType) i.e. the
 * input space that the model will "see", must be provided. Hence the number of
 * input images, and the size of InputPlaceholders and InputReceptiveFields must
 * be the same. If not, an exception will be thrown during the method
 * GenerateOutputInformation().
 *
 * The TensorFlow graph and session must be set using the SetGraph() and
 * SetSession() methods.
 *
 * Target nodes names of the TensorFlow graph that must be triggered can be set
 * with the SetTargetNodesNames.
 *
53
 * The OutputTensorNames consists in a std::vector of std::string, and
54
55
56
57
58
59
 * corresponds to the names of tensors that will be computed during the session.
 * As for input placeholders, output tensors field of expression
 * (OutputExpressionFields, a std::vector of SizeType), i.e. the output
 * space that the TensorFlow model will "generate", must be provided.
 *
 * Finally, a list of scalar placeholders can be fed in the form of std::vector
60
 * of std::string, each one expressing the assignment of a single valued
61
62
 * placeholder, e.g. "drop_rate=0.5 learning_rate=0.002 toto=true".
 * See otb::tf::ExpressionToTensor() to know more about syntax.
remi cresson's avatar
remi cresson committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
 *
 * \ingroup OTBTensorflow
 */
template <class TInputImage, class TOutputImage=TInputImage>
class ITK_EXPORT TensorflowMultisourceModelBase :
public itk::ImageToImageFilter<TInputImage, TOutputImage>
{

public:

  /** Standard class typedefs. */
  typedef TensorflowMultisourceModelBase             Self;
  typedef itk::ProcessObject                         Superclass;
  typedef itk::SmartPointer<Self>                    Pointer;
  typedef itk::SmartPointer<const Self>              ConstPointer;

  /** Run-time type information (and related methods). */
  itkTypeMacro(TensorflowMultisourceModelBase, itk::ImageToImageFilter);

  /** Images typedefs */
  typedef TInputImage                                ImageType;
  typedef typename TInputImage::Pointer              ImagePointerType;
  typedef typename TInputImage::PixelType            PixelType;
  typedef typename TInputImage::InternalPixelType    InternalPixelType;
  typedef typename TInputImage::IndexType            IndexType;
  typedef typename TInputImage::IndexValueType       IndexValueType;
  typedef typename TInputImage::PointType            PointType;
  typedef typename TInputImage::SizeType             SizeType;
  typedef typename TInputImage::SizeValueType        SizeValueType;
  typedef typename TInputImage::SpacingType          SpacingType;
  typedef typename TInputImage::RegionType           RegionType;

  /** Typedefs for parameters */
remi cresson's avatar
remi cresson committed
96
  typedef std::pair<std::string, tensorflow::Tensor> DictElementType;
remi cresson's avatar
remi cresson committed
97
98
  typedef std::vector<std::string>                   StringList;
  typedef std::vector<SizeType>                      SizeListType;
remi cresson's avatar
remi cresson committed
99
  typedef std::vector<DictElementType>               DictType;
remi cresson's avatar
remi cresson committed
100
101
102
103
104
105
106
107
108
109
110
  typedef std::vector<tensorflow::DataType>          DataTypeListType;
  typedef std::vector<tensorflow::TensorShapeProto>  TensorShapeProtoList;
  typedef std::vector<tensorflow::Tensor>            TensorListType;

  /** Set and Get the Tensorflow session and graph */
  void SetGraph(tensorflow::GraphDef graph)      { m_Graph = graph;     }
  tensorflow::GraphDef GetGraph()                { return m_Graph ;     }
  void SetSession(tensorflow::Session * session) { m_Session = session; }
  tensorflow::Session * GetSession()             { return m_Session;    }

  /** Model parameters */
remi cresson's avatar
remi cresson committed
111
112
  void PushBackInputTensorBundle(std::string name, SizeType receptiveField, ImagePointerType image);
  void PushBackOuputTensorBundle(std::string name, SizeType expressionField);
113

remi cresson's avatar
remi cresson committed
114
115
116
117
118
119
120
  /** Input placeholders names */
  itkSetMacro(InputPlaceholders, StringList);
  itkGetMacro(InputPlaceholders, StringList);

  /** Receptive field */
  itkSetMacro(InputReceptiveFields, SizeListType);
  itkGetMacro(InputReceptiveFields, SizeListType);
121
122

  /** Output tensors names */
remi cresson's avatar
remi cresson committed
123
124
  itkSetMacro(OutputTensors, StringList);
  itkGetMacro(OutputTensors, StringList);
125
126

  /** Expression field */
remi cresson's avatar
remi cresson committed
127
128
  itkSetMacro(OutputExpressionFields, SizeListType);
  itkGetMacro(OutputExpressionFields, SizeListType);
129
130

  /** User placeholders */
remi cresson's avatar
remi cresson committed
131
132
  void SetUserPlaceholders(DictType dict) { m_UserPlaceholders = dict; }
  DictType GetUserPlaceholders()          { return m_UserPlaceholders; }
133
134

  /** Target nodes names */
remi cresson's avatar
remi cresson committed
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
  itkSetMacro(TargetNodesNames, StringList);
  itkGetMacro(TargetNodesNames, StringList);

  /** Read only methods */
  itkGetMacro(InputTensorsDataTypes, DataTypeListType);
  itkGetMacro(OutputTensorsDataTypes, DataTypeListType);
  itkGetMacro(InputTensorsShapes, TensorShapeProtoList);
  itkGetMacro(OutputTensorsShapes, TensorShapeProtoList);

  virtual void GenerateOutputInformation();

protected:
  TensorflowMultisourceModelBase();
  virtual ~TensorflowMultisourceModelBase() {};

remi cresson's avatar
remi cresson committed
150
  virtual std::stringstream GenerateDebugReport(DictType & inputs);
remi cresson's avatar
remi cresson committed
151

remi cresson's avatar
remi cresson committed
152
  virtual void RunSession(DictType & inputs, TensorListType & outputs);
remi cresson's avatar
remi cresson committed
153
154
155
156
157
158

private:
  TensorflowMultisourceModelBase(const Self&); //purposely not implemented
  void operator=(const Self&); //purposely not implemented

  // Tensorflow graph and session
159
160
  tensorflow::GraphDef       m_Graph;                   // The TensorFlow graph
  tensorflow::Session *      m_Session;                 // The TensorFlow session
remi cresson's avatar
remi cresson committed
161
162

  // Model parameters
remi cresson's avatar
remi cresson committed
163
164
165
166
167
  StringList                 m_InputPlaceholders;       // Input placeholders names
  SizeListType               m_InputReceptiveFields;    // Input receptive fields
  StringList                 m_OutputTensors;           // Output tensors names
  SizeListType               m_OutputExpressionFields;  // Output expression fields
  DictType                   m_UserPlaceholders;        // User placeholders
168
  StringList                 m_TargetNodesNames;        // User nodes target
remi cresson's avatar
remi cresson committed
169

170
  // Internal, read-only
remi cresson's avatar
remi cresson committed
171
172
173
174
175
176
177
178
179
180
181
182
183
  DataTypeListType           m_InputTensorsDataTypes;   // Input tensors datatype
  DataTypeListType           m_OutputTensorsDataTypes;  // Output tensors datatype
  TensorShapeProtoList       m_InputTensorsShapes;      // Input tensors shapes
  TensorShapeProtoList       m_OutputTensorsShapes;     // Output tensors shapes

}; // end class


} // end namespace otb

#include "otbTensorflowMultisourceModelBase.hxx"

#endif