otbTensorflowMultisourceModelBase.h 6.27 KB
Newer Older
remi cresson's avatar
remi cresson committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
/*=========================================================================

  Copyright (c) Remi Cresson (IRSTEA). All rights reserved.


     This software is distributed WITHOUT ANY WARRANTY; without even
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
     PURPOSE.  See the above copyright notices for more information.

=========================================================================*/
#ifndef otbTensorflowMultisourceModelBase_h
#define otbTensorflowMultisourceModelBase_h

#include "itkProcessObject.h"
#include "itkNumericTraits.h"
#include "itkSimpleDataObjectDecorator.h"

// Tensorflow
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/platform/env.h"

// Tensorflow helpers
#include "otbTensorflowGraphOperations.h"
#include "otbTensorflowDataTypeBridge.h"
#include "otbTensorflowCopyUtils.h"
#include "otbTensorflowCommon.h"

namespace otb
{

/**
 * \class TensorflowMultisourceModelBase
 * \brief This filter is base for TensorFlow model over multiple input images.
 *
 * The filter takes N input images and feed the TensorFlow model.
 * Names of input placeholders must be specified using the
 * SetInputPlaceholdersNames method
 *
remi cresson's avatar
remi cresson committed
39
40
41
 * TODO:
 *   Replace FOV (Field Of View) --> RF (Receptive Field)
 *   Replace FEO (Field Of Expr) --> EF (Expression Field)
remi cresson's avatar
remi cresson committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
 *
 * \ingroup OTBTensorflow
 */
template <class TInputImage, class TOutputImage=TInputImage>
class ITK_EXPORT TensorflowMultisourceModelBase :
public itk::ImageToImageFilter<TInputImage, TOutputImage>
{

public:

  /** Standard class typedefs. */
  typedef TensorflowMultisourceModelBase             Self;
  typedef itk::ProcessObject                         Superclass;
  typedef itk::SmartPointer<Self>                    Pointer;
  typedef itk::SmartPointer<const Self>              ConstPointer;

  /** Run-time type information (and related methods). */
  itkTypeMacro(TensorflowMultisourceModelBase, itk::ImageToImageFilter);

  /** Images typedefs */
  typedef TInputImage                                ImageType;
  typedef typename TInputImage::Pointer              ImagePointerType;
  typedef typename TInputImage::PixelType            PixelType;
  typedef typename TInputImage::InternalPixelType    InternalPixelType;
  typedef typename TInputImage::IndexType            IndexType;
  typedef typename TInputImage::IndexValueType       IndexValueType;
  typedef typename TInputImage::PointType            PointType;
  typedef typename TInputImage::SizeType             SizeType;
  typedef typename TInputImage::SizeValueType        SizeValueType;
  typedef typename TInputImage::SpacingType          SpacingType;
  typedef typename TInputImage::RegionType           RegionType;

  /** Typedefs for parameters */
remi cresson's avatar
remi cresson committed
75
  typedef std::pair<std::string, tensorflow::Tensor> DictElementType;
remi cresson's avatar
remi cresson committed
76
77
  typedef std::vector<std::string>                   StringList;
  typedef std::vector<SizeType>                      SizeListType;
remi cresson's avatar
remi cresson committed
78
  typedef std::vector<DictElementType>               DictType;
remi cresson's avatar
remi cresson committed
79
80
81
82
83
84
85
86
87
88
89
  typedef std::vector<tensorflow::DataType>          DataTypeListType;
  typedef std::vector<tensorflow::TensorShapeProto>  TensorShapeProtoList;
  typedef std::vector<tensorflow::Tensor>            TensorListType;

  /** Set and Get the Tensorflow session and graph */
  void SetGraph(tensorflow::GraphDef graph)      { m_Graph = graph;     }
  tensorflow::GraphDef GetGraph()                { return m_Graph ;     }
  void SetSession(tensorflow::Session * session) { m_Session = session; }
  tensorflow::Session * GetSession()             { return m_Session;    }

  /** Model parameters */
remi cresson's avatar
remi cresson committed
90
91
  void PushBackInputTensorBundle(std::string name, SizeType receptiveField, ImagePointerType image);
  void PushBackOuputTensorBundle(std::string name, SizeType expressionField);
92

remi cresson's avatar
remi cresson committed
93
94
95
96
97
98
99
  /** Input placeholders names */
  itkSetMacro(InputPlaceholders, StringList);
  itkGetMacro(InputPlaceholders, StringList);

  /** Receptive field */
  itkSetMacro(InputReceptiveFields, SizeListType);
  itkGetMacro(InputReceptiveFields, SizeListType);
100
101

  /** Output tensors names */
remi cresson's avatar
remi cresson committed
102
103
  itkSetMacro(OutputTensors, StringList);
  itkGetMacro(OutputTensors, StringList);
104
105

  /** Expression field */
remi cresson's avatar
remi cresson committed
106
107
  itkSetMacro(OutputExpressionFields, SizeListType);
  itkGetMacro(OutputExpressionFields, SizeListType);
108
109

  /** User placeholders */
remi cresson's avatar
remi cresson committed
110
111
  void SetUserPlaceholders(DictType dict) { m_UserPlaceholders = dict; }
  DictType GetUserPlaceholders()          { return m_UserPlaceholders; }
112
113

  /** Target nodes names */
remi cresson's avatar
remi cresson committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
  itkSetMacro(TargetNodesNames, StringList);
  itkGetMacro(TargetNodesNames, StringList);

  /** Read only methods */
  itkGetMacro(InputTensorsDataTypes, DataTypeListType);
  itkGetMacro(OutputTensorsDataTypes, DataTypeListType);
  itkGetMacro(InputTensorsShapes, TensorShapeProtoList);
  itkGetMacro(OutputTensorsShapes, TensorShapeProtoList);

  virtual void GenerateOutputInformation();

protected:
  TensorflowMultisourceModelBase();
  virtual ~TensorflowMultisourceModelBase() {};

remi cresson's avatar
remi cresson committed
129
  virtual std::stringstream GenerateDebugReport(DictType & inputs);
remi cresson's avatar
remi cresson committed
130

remi cresson's avatar
remi cresson committed
131
  virtual void RunSession(DictType & inputs, TensorListType & outputs);
remi cresson's avatar
remi cresson committed
132
133
134
135
136
137
138
139
140
141

private:
  TensorflowMultisourceModelBase(const Self&); //purposely not implemented
  void operator=(const Self&); //purposely not implemented

  // Tensorflow graph and session
  tensorflow::GraphDef       m_Graph;                   // The tensorflow graph
  tensorflow::Session *      m_Session;                 // The tensorflow session

  // Model parameters
remi cresson's avatar
remi cresson committed
142
143
144
145
146
  StringList                 m_InputPlaceholders;       // Input placeholders names
  SizeListType               m_InputReceptiveFields;    // Input receptive fields
  StringList                 m_OutputTensors;           // Output tensors names
  SizeListType               m_OutputExpressionFields;  // Output expression fields
  DictType                   m_UserPlaceholders;        // User placeholders
remi cresson's avatar
remi cresson committed
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
  StringList                 m_TargetNodesNames;        // User target tensors

  // Read-only
  DataTypeListType           m_InputTensorsDataTypes;   // Input tensors datatype
  DataTypeListType           m_OutputTensorsDataTypes;  // Output tensors datatype
  TensorShapeProtoList       m_InputTensorsShapes;      // Input tensors shapes
  TensorShapeProtoList       m_OutputTensorsShapes;     // Output tensors shapes

}; // end class


} // end namespace otb

#include "otbTensorflowMultisourceModelBase.hxx"

#endif