Failed to fetch fork details. Try again later.
-
Cresson Remi authored22981466
Forked from
Cresson Remi / otbtf
Source project has a limited visibility.
/*=========================================================================
Copyright (c) Remi Cresson (IRSTEA). All rights reserved.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef otbTensorflowMultisourceModelLearningBase_h
#define otbTensorflowMultisourceModelLearningBase_h
#include "itkProcessObject.h"
#include "itkNumericTraits.h"
#include "itkSimpleDataObjectDecorator.h"
// Base
#include "otbTensorflowMultisourceModelBase.h"
namespace otb
{
/**
* \class TensorflowMultisourceModelLearningBase
* \brief This filter is the base class for all filters that input patches images.
*
* One input patches image consist in an image of size (pszx, pszy*n, nbands) where:
* -pszx : is the width of one patch
* -pszy : is the height of one patch
* -n : is the number of patches in the patches image
* -nbands : is the number of channels in the patches image
*
* This filter verify that every patches images are consistent.
*
* The batch size can be set using the SetBatchSize() method.
* The streaming can be activated to allow the processing of huge datasets.
* However, it should be noted that the process is significantly slower due to
* multiple read of input patches. When streaming is deactivated, the whole
* patches images are read and kept in memory, guaranteeing fast patches access.
*
* The GenerateData() implements a loop over batches, that call the ProcessBatch()
* method for each one.
* The ProcessBatch() function is a pure virtual method that must be implemented in
* child classes.
*
* The PopulateInputTensors() method converts input patches images into placeholders
* that will be fed to the model. It is a common method to learning filters, and
* is intended to be used in child classes, as a kind of helper.
*
* \ingroup OTBTensorflow
*/
template <class TInputImage>
class ITK_EXPORT TensorflowMultisourceModelLearningBase :
public TensorflowMultisourceModelBase<TInputImage>
{
public:
/** Standard class typedefs. */
typedef TensorflowMultisourceModelLearningBase Self;
typedef TensorflowMultisourceModelBase<TInputImage> Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Run-time type information (and related methods). */
itkTypeMacro(TensorflowMultisourceModelLearningBase, TensorflowMultisourceModelBase);
/** Images typedefs */
typedef typename Superclass::ImageType ImageType;
typedef typename Superclass::ImagePointerType ImagePointerType;
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
typedef typename Superclass::RegionType RegionType;
typedef typename Superclass::SizeType SizeType;
typedef typename Superclass::IndexType IndexType;
/* Typedefs for parameters */
typedef typename Superclass::DictType DictType;
typedef typename Superclass::DictElementType DictElementType;
typedef typename Superclass::StringList StringList;
typedef typename Superclass::SizeListType SizeListType;
typedef typename Superclass::TensorListType TensorListType;
/* Typedefs for index */
typedef typename ImageType::IndexValueType IndexValueType;
typedef std::vector<IndexValueType> IndexListType;
// Batch size
itkSetMacro(BatchSize, IndexValueType);
itkGetMacro(BatchSize, IndexValueType);
// Use streaming
itkSetMacro(UseStreaming, bool);
itkGetMacro(UseStreaming, bool);
// Get number of samples
itkGetMacro(NumberOfSamples, IndexValueType);
protected:
TensorflowMultisourceModelLearningBase();
virtual ~TensorflowMultisourceModelLearningBase() {};
virtual void GenerateOutputInformation(void);
virtual void GenerateInputRequestedRegion();
virtual void GenerateData();
virtual void PopulateInputTensors(DictType & inputs, const IndexValueType & sampleStart,
const IndexValueType & batchSize, const IndexListType & order);
virtual void ProcessBatch(DictType & inputs, const IndexValueType & sampleStart,
const IndexValueType & batchSize) = 0;
private:
TensorflowMultisourceModelLearningBase(const Self&); //purposely not implemented
void operator=(const Self&); //purposely not implemented
unsigned int m_BatchSize; // Batch size
bool m_UseStreaming; // Use streaming on/off
// Read only
IndexValueType m_NumberOfSamples; // Number of samples
}; // end class
} // end namespace otb
#include "otbTensorflowMultisourceModelLearningBase.hxx"
#endif