otbTensorflowMultisourceModelLearningBase.h 4.27 KB
Newer Older
remi cresson's avatar
remi cresson committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
/*=========================================================================

  Copyright (c) Remi Cresson (IRSTEA). All rights reserved.


     This software is distributed WITHOUT ANY WARRANTY; without even
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
     PURPOSE.  See the above copyright notices for more information.

=========================================================================*/
#ifndef otbTensorflowMultisourceModelLearningBase_h
#define otbTensorflowMultisourceModelLearningBase_h

#include "itkProcessObject.h"
#include "itkNumericTraits.h"
#include "itkSimpleDataObjectDecorator.h"

// Base
#include "otbTensorflowMultisourceModelBase.h"

namespace otb
{

/**
 * \class TensorflowMultisourceModelLearningBase
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
 * \brief This filter is the base class for all learning filters.
 *
 * The batch size can be set using the SetBatchSize() method.
 * The streaming can be activated to allow the processing of huge datasets.
 * However, it should be noted that the process is significantly slower due to
 * multiple read of input patches. When streaming is deactivated, the whole
 * patches images are read and kept in memory, guaranteeing fast patches access.
 *
 * The GenerateData() implements a loop over batches, that call the ProcessBatch()
 * method for each one.
 * The ProcessBatch() function is a pure virtual method that must be implemented in
 * child classes.
 *
 * The PopulateInputTensors() method converts input patches images into placeholders
 * that will be fed to the model. It is a common method to learning filters, and
 * is intended to be used in child classes, as a kind of helper.
remi cresson's avatar
remi cresson committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
 *
 * \ingroup OTBTensorflow
 */
template <class TInputImage>
class ITK_EXPORT TensorflowMultisourceModelLearningBase :
public TensorflowMultisourceModelBase<TInputImage>
{
public:

  /** Standard class typedefs. */
  typedef TensorflowMultisourceModelLearningBase       Self;
  typedef TensorflowMultisourceModelBase<TInputImage>  Superclass;
  typedef itk::SmartPointer<Self>                      Pointer;
  typedef itk::SmartPointer<const Self>                ConstPointer;

  /** Run-time type information (and related methods). */
  itkTypeMacro(TensorflowMultisourceModelLearningBase, TensorflowMultisourceModelBase);

  /** Images typedefs */
  typedef typename Superclass::ImageType         ImageType;
  typedef typename Superclass::ImagePointerType  ImagePointerType;
  typedef typename Superclass::RegionType        RegionType;
  typedef typename Superclass::SizeType          SizeType;
  typedef typename Superclass::IndexType         IndexType;

  /* Typedefs for parameters */
  typedef typename Superclass::DictType          DictType;
  typedef typename Superclass::DictElementType   DictElementType;
  typedef typename Superclass::StringList        StringList;
  typedef typename Superclass::SizeListType      SizeListType;
  typedef typename Superclass::TensorListType    TensorListType;

  /* Typedefs for index */
  typedef typename ImageType::IndexValueType     IndexValueType;
  typedef std::vector<IndexValueType>            IndexListType;

  // Batch size
  itkSetMacro(BatchSize, IndexValueType);
  itkGetMacro(BatchSize, IndexValueType);

  // Use streaming
  itkSetMacro(UseStreaming, bool);
  itkGetMacro(UseStreaming, bool);

  // Get number of samples
  itkGetMacro(NumberOfSamples, IndexValueType);

protected:
  TensorflowMultisourceModelLearningBase();
  virtual ~TensorflowMultisourceModelLearningBase() {};

  virtual void GenerateOutputInformation(void);

  virtual void GenerateInputRequestedRegion();

  virtual void GenerateData();

Cresson Remi's avatar
Cresson Remi committed
99
100
  virtual void PopulateInputTensors(DictType & inputs, const IndexValueType & sampleStart,
      const IndexValueType & batchSize, const IndexListType & order);
remi cresson's avatar
remi cresson committed
101

Cresson Remi's avatar
Cresson Remi committed
102
  virtual void ProcessBatch(DictType & inputs, const IndexValueType & sampleStart,
remi cresson's avatar
remi cresson committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
      const IndexValueType & batchSize) = 0;

private:
  TensorflowMultisourceModelLearningBase(const Self&); //purposely not implemented
  void operator=(const Self&); //purposely not implemented

  unsigned int          m_BatchSize;       // Batch size
  bool                  m_UseStreaming;    // Use streaming on/off

  // Read only
  IndexValueType        m_NumberOfSamples; // Number of samples

}; // end class


} // end namespace otb

#include "otbTensorflowMultisourceModelLearningBase.hxx"

#endif