otbTensorflowMultisourceModelLearningBase.h 4.63 KB
Newer Older
remi cresson's avatar
remi cresson committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
/*=========================================================================

  Copyright (c) Remi Cresson (IRSTEA). All rights reserved.


     This software is distributed WITHOUT ANY WARRANTY; without even
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
     PURPOSE.  See the above copyright notices for more information.

=========================================================================*/
#ifndef otbTensorflowMultisourceModelLearningBase_h
#define otbTensorflowMultisourceModelLearningBase_h

#include "itkProcessObject.h"
#include "itkNumericTraits.h"
#include "itkSimpleDataObjectDecorator.h"

// Base
#include "otbTensorflowMultisourceModelBase.h"

namespace otb
{

/**
 * \class TensorflowMultisourceModelLearningBase
26
27
28
29
30
31
32
33
34
 * \brief This filter is the base class for all filters that input patches images.
 *
 * One input patches image consist in an image of size (pszx, pszy*n, nbands) where:
 * -pszx   : is the width of one patch
 * -pszy   : is the height of one patch
 * -n      : is the number of patches in the patches image
 * -nbands : is the number of channels in the patches image
 *
 * This filter verify that every patches images are consistent.
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
 *
 * The batch size can be set using the SetBatchSize() method.
 * The streaming can be activated to allow the processing of huge datasets.
 * However, it should be noted that the process is significantly slower due to
 * multiple read of input patches. When streaming is deactivated, the whole
 * patches images are read and kept in memory, guaranteeing fast patches access.
 *
 * The GenerateData() implements a loop over batches, that call the ProcessBatch()
 * method for each one.
 * The ProcessBatch() function is a pure virtual method that must be implemented in
 * child classes.
 *
 * The PopulateInputTensors() method converts input patches images into placeholders
 * that will be fed to the model. It is a common method to learning filters, and
 * is intended to be used in child classes, as a kind of helper.
remi cresson's avatar
remi cresson committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
 *
 * \ingroup OTBTensorflow
 */
template <class TInputImage>
class ITK_EXPORT TensorflowMultisourceModelLearningBase :
public TensorflowMultisourceModelBase<TInputImage>
{
public:

  /** Standard class typedefs. */
  typedef TensorflowMultisourceModelLearningBase       Self;
  typedef TensorflowMultisourceModelBase<TInputImage>  Superclass;
  typedef itk::SmartPointer<Self>                      Pointer;
  typedef itk::SmartPointer<const Self>                ConstPointer;

  /** Run-time type information (and related methods). */
  itkTypeMacro(TensorflowMultisourceModelLearningBase, TensorflowMultisourceModelBase);

  /** Images typedefs */
  typedef typename Superclass::ImageType         ImageType;
  typedef typename Superclass::ImagePointerType  ImagePointerType;
  typedef typename Superclass::RegionType        RegionType;
  typedef typename Superclass::SizeType          SizeType;
  typedef typename Superclass::IndexType         IndexType;

  /* Typedefs for parameters */
  typedef typename Superclass::DictType          DictType;
  typedef typename Superclass::DictElementType   DictElementType;
  typedef typename Superclass::StringList        StringList;
  typedef typename Superclass::SizeListType      SizeListType;
  typedef typename Superclass::TensorListType    TensorListType;

  /* Typedefs for index */
  typedef typename ImageType::IndexValueType     IndexValueType;
  typedef std::vector<IndexValueType>            IndexListType;

  // Batch size
  itkSetMacro(BatchSize, IndexValueType);
  itkGetMacro(BatchSize, IndexValueType);

  // Use streaming
  itkSetMacro(UseStreaming, bool);
  itkGetMacro(UseStreaming, bool);

  // Get number of samples
  itkGetMacro(NumberOfSamples, IndexValueType);

protected:
  TensorflowMultisourceModelLearningBase();
  virtual ~TensorflowMultisourceModelLearningBase() {};

  virtual void GenerateOutputInformation(void);

  virtual void GenerateInputRequestedRegion();

  virtual void GenerateData();

Cresson Remi's avatar
Cresson Remi committed
107
108
  virtual void PopulateInputTensors(DictType & inputs, const IndexValueType & sampleStart,
      const IndexValueType & batchSize, const IndexListType & order);
remi cresson's avatar
remi cresson committed
109

Cresson Remi's avatar
Cresson Remi committed
110
  virtual void ProcessBatch(DictType & inputs, const IndexValueType & sampleStart,
remi cresson's avatar
remi cresson committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
      const IndexValueType & batchSize) = 0;

private:
  TensorflowMultisourceModelLearningBase(const Self&); //purposely not implemented
  void operator=(const Self&); //purposely not implemented

  unsigned int          m_BatchSize;       // Batch size
  bool                  m_UseStreaming;    // Use streaming on/off

  // Read only
  IndexValueType        m_NumberOfSamples; // Number of samples

}; // end class


} // end namespace otb

#include "otbTensorflowMultisourceModelLearningBase.hxx"

#endif