otbTensorflowSampler.h 5.13 KB
Newer Older
remi cresson's avatar
remi cresson committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/*=========================================================================

  Copyright (c) Remi Cresson (IRSTEA). All rights reserved.


     This software is distributed WITHOUT ANY WARRANTY; without even
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
     PURPOSE.  See the above copyright notices for more information.

=========================================================================*/
#ifndef otbTensorflowSampler_h
#define otbTensorflowSampler_h

#include "itkProcessObject.h"
#include "itkNumericTraits.h"
#include "itkSimpleDataObjectDecorator.h"

// Extract ROI
#include "otbMultiChannelExtractROI.h"

// TF common
#include "otbTensorflowCommon.h"

Cresson Remi's avatar
Cresson Remi committed
24
25
26
// Tree iterator
#include "itkPreOrderTreeIterator.h"

remi cresson's avatar
remi cresson committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
namespace otb
{

/**
 * \class TensorflowSampler
 * \brief This process objects performs samples extraction from an input image.
 *
 * The filter takes one input image and extract samples of fixed size.
 * Samples are concatenated in y dimension to form a single big image of
 * extracted patches.
 * Label image is also created from the value of the m_Field field of the
 * input vector data
 *
 * TODO:
 * -must inherit from itk::imageToImageFilter
 * -implement streaming mechanism : the input requested region of
 *  image should be computed from the output requested region of the patches
 *  This would allow to compute huge patches images and speed up the whole
 *  process. This might be achieved using indexation structure like RTree
 *  on the samples pos (in image coordinates)
 *
 * \ingroup OTBTensorflow
 */
template <class TInputImage, class TVectorData>
class ITK_EXPORT TensorflowSampler :
public itk::ProcessObject
{
public:

  /** Standard class typedefs. */
  typedef TensorflowSampler                       Self;
  typedef itk::ProcessObject                      Superclass;
  typedef itk::SmartPointer<Self>                 Pointer;
  typedef itk::SmartPointer<const Self>           ConstPointer;

  /** Method for creation through the object factory. */
  itkNewMacro(Self);

  /** Run-time type information (and related methods). */
  itkTypeMacro(TensorflowSampler, itk::ProcessObject);

  /** Images typedefs */
  typedef TInputImage                             ImageType;
  typedef typename TInputImage::Pointer           ImagePointerType;
  typedef typename TInputImage::InternalPixelType InternalPixelType;
  typedef typename TInputImage::PixelType         PixelType;
  typedef typename TInputImage::RegionType        RegionType;
  typedef typename TInputImage::PointType         PointType;
  typedef typename TInputImage::SizeType          SizeType;
  typedef typename TInputImage::IndexType         IndexType;
  typedef typename otb::MultiChannelExtractROI<InternalPixelType,
      InternalPixelType>                          ExtractROIMultiFilterType;
  typedef typename ExtractROIMultiFilterType::Pointer
                                                  ExtractROIMultiFilterPointerType;
  typedef typename std::vector<ImagePointerType>  ImagePointerListType;
  typedef typename std::vector<SizeType>          SizeListType;

  /** Vector data typedefs */
  typedef TVectorData                             VectorDataType;
  typedef typename VectorDataType::Pointer        VectorDataPointer;
  typedef typename VectorDataType::DataTreeType   DataTreeType;
  typedef typename itk::PreOrderTreeIterator<DataTreeType>
                                                  TreeIteratorType;
  typedef typename VectorDataType::DataNodeType   DataNodeType;
  typedef typename DataNodeType::Pointer          DataNodePointer;
  typedef typename DataNodeType::PolygonListPointerType
                                                  PolygonListPointerType;

  /** Set / get parameters */
  itkSetMacro(Field, std::string);
  itkGetMacro(Field, std::string);

  /** Set / get vector data */
  itkSetMacro(InputVectorData, VectorDataPointer);
  itkGetConstMacro(InputVectorData, VectorDataPointer);

  /** Set / get image */
  virtual void PushBackInputWithPatchSize(const ImageType *input, SizeType & patchSize);
  const ImageType* GetInput(unsigned int index);

  /** Do the real work */
  virtual void Update();

  /** Get outputs */
  itkGetMacro(OutputPatchImages, ImagePointerListType);
  itkGetMacro(OutputLabelImage, ImagePointerType);
  itkGetMacro(NumberOfAcceptedSamples, unsigned long);
  itkGetMacro(NumberOfRejectedSamples, unsigned long);

protected:
  TensorflowSampler();
  virtual ~TensorflowSampler() {};

  virtual void ResizeImage(ImagePointerType & image, SizeType & patchSize, unsigned int nbSamples);
  virtual void AllocateImage(ImagePointerType & image, SizeType & patchSize, unsigned int nbSamples, unsigned int nbComponents);

private:
  TensorflowSampler(const Self&); //purposely not implemented
  void operator=(const Self&); //purposely not implemented

  std::string          m_Field;
  SizeListType         m_PatchSizes;
  VectorDataPointer    m_InputVectorData;

  // Read only
  ImagePointerListType m_OutputPatchImages;
  ImagePointerType     m_OutputLabelImage;
  unsigned long        m_NumberOfAcceptedSamples;
  unsigned long        m_NumberOfRejectedSamples;

}; // end class

} // end namespace otb

#include "otbTensorflowSampler.hxx"

#endif