Commit 1fd2ba4d authored by Cédric Traizet's avatar Cédric Traizet
Browse files

added image dr filter

No related merge requests found
Showing with 500 additions and 25 deletions
+500 -25
...@@ -23,10 +23,10 @@ ...@@ -23,10 +23,10 @@
#include "otbStandardWriterWatcher.h" #include "otbStandardWriterWatcher.h"
#include "otbStatisticsXMLFileReader.h" #include "otbStatisticsXMLFileReader.h"
#include "otbShiftScaleVectorImageFilter.h" #include "otbShiftScaleVectorImageFilter.h"
#include "otbImageClassificationFilter.h" #include "ImageDimensionalityReductionFilter.h"
#include "otbMultiToMonoChannelExtractROI.h" #include "otbMultiToMonoChannelExtractROI.h"
#include "otbImageToVectorImageCastFilter.h" #include "otbImageToVectorImageCastFilter.h"
#include "otbMachineLearningModelFactory.h" #include "DimensionalityReductionModelFactory.h"
namespace otb namespace otb
{ {
...@@ -94,19 +94,19 @@ public: ...@@ -94,19 +94,19 @@ public:
FloatImageType, FloatImageType,
FloatImageType, FloatImageType,
otb::Functor::AffineFunctor<float,float> > OutputRescalerType; otb::Functor::AffineFunctor<float,float> > OutputRescalerType;
typedef otb::ImageClassificationFilter<FloatVectorImageType, FloatImageType, MaskImageType> ClassificationFilterType; typedef otb::ImageDimensionalityReductionFilter<FloatVectorImageType, FloatVectorImageType, MaskImageType> DimensionalityReductionFilterType;
typedef ClassificationFilterType::Pointer ClassificationFilterPointerType; typedef DimensionalityReductionFilterType::Pointer DimensionalityReductionFilterPointerType;
typedef ClassificationFilterType::ModelType ModelType; typedef DimensionalityReductionFilterType::ModelType ModelType;
typedef ModelType::Pointer ModelPointerType; typedef ModelType::Pointer ModelPointerType;
typedef ClassificationFilterType::ValueType ValueType; typedef DimensionalityReductionFilterType::ValueType ValueType;
typedef ClassificationFilterType::LabelType LabelType; typedef DimensionalityReductionFilterType::LabelType LabelType;
typedef otb::MachineLearningModelFactory<ValueType, LabelType> MachineLearningModelFactoryType; typedef otb::DimensionalityReductionModelFactory<ValueType, LabelType> DimensionalityReductionModelFactoryType;
protected: protected:
~CbDimensionalityReduction() ITK_OVERRIDE ~CbDimensionalityReduction() ITK_OVERRIDE
{ {
MachineLearningModelFactoryType::CleanFactories(); DimensionalityReductionModelFactoryType::CleanFactories();
} }
private: private:
...@@ -191,8 +191,8 @@ private: ...@@ -191,8 +191,8 @@ private:
// Load svm model // Load svm model
otbAppLogINFO("Loading model"); otbAppLogINFO("Loading model");
m_Model = MachineLearningModelFactoryType::CreateMachineLearningModel(GetParameterString("model"), m_Model = DimensionalityReductionModelFactoryType::CreateDimensionalityReductionModel(GetParameterString("model"),
MachineLearningModelFactoryType::ReadMode); DimensionalityReductionModelFactoryType::ReadMode);
if (m_Model.IsNull()) if (m_Model.IsNull())
{ {
...@@ -204,10 +204,10 @@ private: ...@@ -204,10 +204,10 @@ private:
otbAppLogINFO("Model loaded"); otbAppLogINFO("Model loaded");
// Classify // Classify
m_ClassificationFilter = ClassificationFilterType::New(); m_ClassificationFilter = DimensionalityReductionFilterType::New();
m_ClassificationFilter->SetModel(m_Model); m_ClassificationFilter->SetModel(m_Model);
FloatImageType::Pointer outputImage = m_ClassificationFilter->GetOutput(); FloatVectorImageType::Pointer outputImage = m_ClassificationFilter->GetOutput();
// Normalize input image if asked // Normalize input image if asked
if(IsParameterEnabled("imstat") ) if(IsParameterEnabled("imstat") )
...@@ -224,7 +224,7 @@ private: ...@@ -224,7 +224,7 @@ private:
stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev"); stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev");
otbAppLogINFO( "mean used: " << meanMeasurementVector ); otbAppLogINFO( "mean used: " << meanMeasurementVector );
otbAppLogINFO( "standard deviation used: " << stddevMeasurementVector ); otbAppLogINFO( "standard deviation used: " << stddevMeasurementVector );
if (meanMeasurementVector.Size() == nbFeatures + 1) /*if (meanMeasurementVector.Size() == nbFeatures + 1)
{ {
double outMean = meanMeasurementVector[nbFeatures]; double outMean = meanMeasurementVector[nbFeatures];
double outStdDev = stddevMeasurementVector[nbFeatures]; double outStdDev = stddevMeasurementVector[nbFeatures];
...@@ -236,7 +236,7 @@ private: ...@@ -236,7 +236,7 @@ private:
m_OutRescaler->GetFunctor().SetB(outMean); m_OutRescaler->GetFunctor().SetB(outMean);
outputImage = m_OutRescaler->GetOutput(); outputImage = m_OutRescaler->GetOutput();
} }
else if (meanMeasurementVector.Size() != nbFeatures) else*/ if (meanMeasurementVector.Size() != nbFeatures)
{ {
otbAppLogFATAL("Wrong number of components in statistics file : "<<meanMeasurementVector.Size()); otbAppLogFATAL("Wrong number of components in statistics file : "<<meanMeasurementVector.Size());
} }
...@@ -264,11 +264,11 @@ private: ...@@ -264,11 +264,11 @@ private:
m_ClassificationFilter->SetInputMask(inMask); m_ClassificationFilter->SetInputMask(inMask);
} }
SetParameterOutputImage<FloatImageType>("out", outputImage); SetParameterOutputImage<FloatVectorImageType>("out", outputImage);
} }
ClassificationFilterType::Pointer m_ClassificationFilter; DimensionalityReductionFilterType::Pointer m_ClassificationFilter;
ModelPointerType m_Model; ModelPointerType m_Model;
RescalerType::Pointer m_Rescaler; RescalerType::Pointer m_Rescaler;
OutputRescalerType::Pointer m_OutRescaler; OutputRescalerType::Pointer m_OutRescaler;
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
#include <fstream> // write the model file #include <fstream> // write the model file
#include "otbMachineLearningModelFactory.h" #include "DimensionalityReductionModelFactory.h"
#include "cbLearningApplicationBaseDR.h" #include "cbLearningApplicationBaseDR.h"
...@@ -44,7 +44,7 @@ public: ...@@ -44,7 +44,7 @@ public:
typedef otb::Statistics::ShiftScaleSampleListFilter<ListSampleType, ListSampleType> ShiftScaleFilterType; typedef otb::Statistics::ShiftScaleSampleListFilter<ListSampleType, ListSampleType> ShiftScaleFilterType;
typedef otb::MachineLearningModelFactory<ValueType, ValueType> ModelFactoryType; typedef otb::DimensionalityReductionModelFactory<ValueType, ValueType> ModelFactoryType;
private: private:
void DoInit() void DoInit()
......
...@@ -111,6 +111,7 @@ void AutoencoderModel<TInputValue,AutoencoderType>::Load(const std::string & fil ...@@ -111,6 +111,7 @@ void AutoencoderModel<TInputValue,AutoencoderType>::Load(const std::string & fil
m_net.read(ia); m_net.read(ia);
ifs.close(); ifs.close();
m_NumberOfHiddenNeurons = m_net.numberOfHiddenNeurons(); m_NumberOfHiddenNeurons = m_net.numberOfHiddenNeurons();
//this->m_Size = m_NumberOfHiddenNeurons;
} }
...@@ -154,7 +155,7 @@ void AutoencoderModel<TInputValue,AutoencoderType> ...@@ -154,7 +155,7 @@ void AutoencoderModel<TInputValue,AutoencoderType>
//target[a]=p[a]; //target[a]=p[a];
target=p[a]; target=p[a];
} }
//std::cout << p << std::endl; std::cout << p << std::endl;
targets->SetMeasurementVector(id,target); targets->SetMeasurementVector(id,target);
++id; ++id;
} }
......
...@@ -201,7 +201,13 @@ protected: ...@@ -201,7 +201,13 @@ protected:
/** Is DoPredictBatch multi-threaded ? */ /** Is DoPredictBatch multi-threaded ? */
bool m_IsDoPredictBatchMultiThreaded; bool m_IsDoPredictBatchMultiThreaded;
/** Size of the output after dimensionality reduction */
//unsigned int m_Size;
private: private:
/** Actual implementation of BatchPredicition /** Actual implementation of BatchPredicition
* Default implementation will call DoPredict iteratively * Default implementation will call DoPredict iteratively
......
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef ImageDimensionalityReduction_h
#define ImageDimensionalityReduction_h
#include "itkImageToImageFilter.h"
#include "DimensionalityReductionModel.h"
#include "otbImage.h"
namespace otb
{
/** \class ImageClassificationFilter
* \brief This filter performs the classification of a VectorImage using a Model.
*
* This filter is streamed and threaded, allowing to classify huge images
* while fully using several core.
*
* \sa Classifier
* \ingroup Streamed
* \ingroup Threaded
*
* \ingroup OTBSupervised
*/
template <class TInputImage, class TOutputImage, class TMaskImage = TOutputImage>
class ITK_EXPORT ImageDimensionalityReductionFilter
: public itk::ImageToImageFilter<TInputImage, TOutputImage>
{
public:
/** Standard typedefs */
typedef ImageDimensionalityReductionFilter Self;
typedef itk::ImageToImageFilter<TInputImage, TOutputImage> Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Type macro */
itkNewMacro(Self);
/** Creation through object factory macro */
itkTypeMacro(ImageDimensionalityReductionFilter, ImageToImageFilter);
typedef TInputImage InputImageType;
typedef typename InputImageType::ConstPointer InputImageConstPointerType;
typedef typename InputImageType::InternalPixelType ValueType;
typedef TMaskImage MaskImageType;
typedef typename MaskImageType::ConstPointer MaskImageConstPointerType;
typedef typename MaskImageType::Pointer MaskImagePointerType;
typedef TOutputImage OutputImageType;
typedef typename OutputImageType::Pointer OutputImagePointerType;
typedef typename OutputImageType::RegionType OutputImageRegionType;
typedef typename OutputImageType::InternalPixelType LabelType;
typedef DimensionalityReductionModel<ValueType, LabelType> ModelType;
typedef typename ModelType::Pointer ModelPointerType;
typedef otb::Image<double> ConfidenceImageType;
typedef typename ConfidenceImageType::Pointer ConfidenceImagePointerType;
/** Set/Get the model */
itkSetObjectMacro(Model, ModelType);
itkGetObjectMacro(Model, ModelType);
/** Set/Get the default label */
itkSetMacro(DefaultLabel, LabelType);
itkGetMacro(DefaultLabel, LabelType);
/** Set/Get the confidence map flag */
itkSetMacro(UseConfidenceMap, bool);
itkGetMacro(UseConfidenceMap, bool);
itkSetMacro(BatchMode, bool);
itkGetMacro(BatchMode, bool);
itkBooleanMacro(BatchMode);
/**
* If set, only pixels within the mask will be classified.
* All pixels with a value greater than 0 in the mask, will be classified.
* \param mask The input mask.
*/
void SetInputMask(const MaskImageType * mask);
/**
* Get the input mask.
* \return The mask.
*/
const MaskImageType * GetInputMask(void);
/**
* Get the output confidence map
*/
ConfidenceImageType * GetOutputConfidence(void);
protected:
/** Constructor */
ImageDimensionalityReductionFilter();
/** Destructor */
~ImageDimensionalityReductionFilter() ITK_OVERRIDE {}
/** Threaded generate data */
void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, itk::ThreadIdType threadId) ITK_OVERRIDE;
void ClassicThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, itk::ThreadIdType threadId);
void BatchThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, itk::ThreadIdType threadId);
/** Before threaded generate data */
void BeforeThreadedGenerateData() ITK_OVERRIDE;
/**PrintSelf method */
void PrintSelf(std::ostream& os, itk::Indent indent) const ITK_OVERRIDE;
private:
ImageDimensionalityReductionFilter(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
/** The model used for classification */
ModelPointerType m_Model;
/** Default label for invalid pixels (when using a mask) */
LabelType m_DefaultLabel;
/** Flag to produce the confidence map (if the model supports it) */
bool m_UseConfidenceMap;
bool m_BatchMode;
};
} // End namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
#include "ImageDimensionalityReductionFilter.txx"
#endif
#endif
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef otbImageClassificationFilter_txx
#define otbImageClassificationFilter_txx
#include "ImageDimensionalityReductionFilter.h"
#include "itkImageRegionIterator.h"
#include "itkProgressReporter.h"
namespace otb
{
/**
* Constructor
*/
template <class TInputImage, class TOutputImage, class TMaskImage>
ImageDimensionalityReductionFilter<TInputImage, TOutputImage, TMaskImage>
::ImageDimensionalityReductionFilter()
{
this->SetNumberOfIndexedInputs(2);
this->SetNumberOfRequiredInputs(1);
LabelType empty_vect;
empty_vect.SetSize(1);
m_DefaultLabel = itk::NumericTraits<LabelType>::ZeroValue(empty_vect);
this->SetNumberOfRequiredOutputs(2);
this->SetNthOutput(0,TOutputImage::New());
this->SetNthOutput(1,ConfidenceImageType::New());
m_UseConfidenceMap = false;
m_BatchMode = true;
}
template <class TInputImage, class TOutputImage, class TMaskImage>
void
ImageDimensionalityReductionFilter<TInputImage, TOutputImage, TMaskImage>
::SetInputMask(const MaskImageType * mask)
{
this->itk::ProcessObject::SetNthInput(1, const_cast<MaskImageType *>(mask));
}
template <class TInputImage, class TOutputImage, class TMaskImage>
const typename ImageDimensionalityReductionFilter<TInputImage, TOutputImage, TMaskImage>
::MaskImageType *
ImageDimensionalityReductionFilter<TInputImage, TOutputImage, TMaskImage>
::GetInputMask()
{
if (this->GetNumberOfInputs() < 2)
{
return ITK_NULLPTR;
}
return static_cast<const MaskImageType *>(this->itk::ProcessObject::GetInput(1));
}
template <class TInputImage, class TOutputImage, class TMaskImage>
typename ImageDimensionalityReductionFilter<TInputImage, TOutputImage, TMaskImage>
::ConfidenceImageType *
ImageDimensionalityReductionFilter<TInputImage, TOutputImage, TMaskImage>
::GetOutputConfidence()
{
if (this->GetNumberOfOutputs() < 2)
{
return ITK_NULLPTR;
}
return static_cast<ConfidenceImageType *>(this->itk::ProcessObject::GetOutput(1));
}
template <class TInputImage, class TOutputImage, class TMaskImage>
void
ImageDimensionalityReductionFilter<TInputImage, TOutputImage, TMaskImage>
::BeforeThreadedGenerateData()
{
if (!m_Model)
{
itkGenericExceptionMacro(<< "No model for classification");
}
if(m_BatchMode)
{
#ifdef _OPENMP
// OpenMP will take care of threading
this->SetNumberOfThreads(1);
#endif
}
}
template <class TInputImage, class TOutputImage, class TMaskImage>
void
ImageDimensionalityReductionFilter<TInputImage, TOutputImage, TMaskImage>
::ClassicThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, itk::ThreadIdType threadId)
{
// Get the input pointers
InputImageConstPointerType inputPtr = this->GetInput();
MaskImageConstPointerType inputMaskPtr = this->GetInputMask();
OutputImagePointerType outputPtr = this->GetOutput();
ConfidenceImagePointerType confidencePtr = this->GetOutputConfidence();
// Progress reporting
itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels());
// Define iterators
typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType;
typedef itk::ImageRegionConstIterator<MaskImageType> MaskIteratorType;
typedef itk::ImageRegionIterator<InputImageType> OutputIteratorType;
typedef itk::ImageRegionIterator<ConfidenceImageType> ConfidenceMapIteratorType;
InputIteratorType inIt(inputPtr, outputRegionForThread);
OutputIteratorType outIt(outputPtr, outputRegionForThread);
// Eventually iterate on masks
MaskIteratorType maskIt;
if (inputMaskPtr)
{
maskIt = MaskIteratorType(inputMaskPtr, outputRegionForThread);
maskIt.GoToBegin();
}
// setup iterator for confidence map
bool computeConfidenceMap(m_UseConfidenceMap && m_Model->HasConfidenceIndex() && !m_Model->GetRegressionMode());
ConfidenceMapIteratorType confidenceIt;
if (computeConfidenceMap)
{
confidenceIt = ConfidenceMapIteratorType(confidencePtr,outputRegionForThread);
confidenceIt.GoToBegin();
}
bool validPoint = true;
double confidenceIndex = 0.0;
// Walk the part of the image
for (inIt.GoToBegin(), outIt.GoToBegin(); !inIt.IsAtEnd() && !outIt.IsAtEnd(); ++inIt, ++outIt)
{
// Check pixel validity
if (inputMaskPtr)
{
validPoint = maskIt.Get() > 0;
++maskIt;
}
// If point is valid
if (validPoint)
{
// Classifify
if (computeConfidenceMap)
{
outIt.Set(m_Model->Predict(inIt.Get(),&confidenceIndex)[0]);
}
else
{
outIt.Set(m_Model->Predict(inIt.Get())[0]);
}
}
else
{
// else, set default value
outIt.Set(m_DefaultLabel);
confidenceIndex = 0.0;
}
if (computeConfidenceMap)
{
confidenceIt.Set(confidenceIndex);
++confidenceIt;
}
progress.CompletedPixel();
}
}
template <class TInputImage, class TOutputImage, class TMaskImage>
void
ImageDimensionalityReductionFilter<TInputImage, TOutputImage, TMaskImage>
::BatchThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, itk::ThreadIdType threadId)
{
bool computeConfidenceMap(m_UseConfidenceMap && m_Model->HasConfidenceIndex()
&& !m_Model->GetRegressionMode());
// Get the input pointers
InputImageConstPointerType inputPtr = this->GetInput();
MaskImageConstPointerType inputMaskPtr = this->GetInputMask();
OutputImagePointerType outputPtr = this->GetOutput();
ConfidenceImagePointerType confidencePtr = this->GetOutputConfidence();
// Progress reporting
itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels());
// Define iterators
typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType;
typedef itk::ImageRegionConstIterator<MaskImageType> MaskIteratorType;
typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
typedef itk::ImageRegionIterator<ConfidenceImageType> ConfidenceMapIteratorType;
InputIteratorType inIt(inputPtr, outputRegionForThread);
OutputIteratorType outIt(outputPtr, outputRegionForThread);
MaskIteratorType maskIt;
if (inputMaskPtr)
{
maskIt = MaskIteratorType(inputMaskPtr, outputRegionForThread);
maskIt.GoToBegin();
}
// typedef typename ModelType::InputValueType InputValueType;
typedef typename ModelType::InputSampleType InputSampleType;
typedef typename ModelType::InputListSampleType InputListSampleType;
typedef typename ModelType::TargetValueType TargetValueType;
// typedef typename ModelType::TargetSampleType TargetSampleType;
typedef typename ModelType::TargetListSampleType TargetListSampleType;
// typedef typename ModelType::ConfidenceValueType ConfidenceValueType;
// typedef typename ModelType::ConfidenceSampleType ConfidenceSampleType;
typedef typename ModelType::ConfidenceListSampleType ConfidenceListSampleType;
typename InputListSampleType::Pointer samples = InputListSampleType::New();
unsigned int num_features = inputPtr->GetNumberOfComponentsPerPixel();
samples->SetMeasurementVectorSize(num_features);
InputSampleType sample(num_features);
// Fill the samples
bool validPoint = true;
for (inIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt)
{
// Check pixel validity
if (inputMaskPtr)
{
validPoint = maskIt.Get() > 0;
++maskIt;
}
if(validPoint)
{
typename InputImageType::PixelType pix = inIt.Get();
for(size_t feat=0; feat<num_features; ++feat)
{
sample[feat]=pix[feat];
}
samples->PushBack(sample);
}
}
//Make the batch prediction
typename TargetListSampleType::Pointer labels;
typename ConfidenceListSampleType::Pointer confidences;
if(computeConfidenceMap)
confidences = ConfidenceListSampleType::New();
// This call is threadsafe
labels = m_Model->PredictBatch(samples,confidences);
// Set the output values
ConfidenceMapIteratorType confidenceIt;
if (computeConfidenceMap)
{
confidenceIt = ConfidenceMapIteratorType(confidencePtr,outputRegionForThread);
confidenceIt.GoToBegin();
}
typename TargetListSampleType::ConstIterator labIt = labels->Begin();
maskIt.GoToBegin();
for (outIt.GoToBegin(); !outIt.IsAtEnd(); ++outIt)
{
double confidenceIndex = 0.0;
TargetValueType labelValue(m_DefaultLabel);
if (inputMaskPtr)
{
validPoint = maskIt.Get() > 0;
++maskIt;
}
if (validPoint && labIt!=labels->End())
{
labelValue = labIt.GetMeasurementVector()[0];
if(computeConfidenceMap)
{
confidenceIndex = confidences->GetMeasurementVector(labIt.GetInstanceIdentifier())[0];
}
++labIt;
}
else
{
labelValue = m_DefaultLabel;
}
outIt.Set(labelValue);
if(computeConfidenceMap)
{
confidenceIt.Set(confidenceIndex);
++confidenceIt;
}
progress.CompletedPixel();
}
}
template <class TInputImage, class TOutputImage, class TMaskImage>
void
ImageDimensionalityReductionFilter<TInputImage, TOutputImage, TMaskImage>
::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, itk::ThreadIdType threadId)
{
if(m_BatchMode)
{
this->BatchThreadedGenerateData(outputRegionForThread, threadId);
}
else
{
this->ClassicThreadedGenerateData(outputRegionForThread, threadId);
}
}
/**
* PrintSelf Method
*/
template <class TInputImage, class TOutputImage, class TMaskImage>
void
ImageDimensionalityReductionFilter<TInputImage, TOutputImage, TMaskImage>
::PrintSelf(std::ostream& os, itk::Indent indent) const
{
Superclass::PrintSelf(os, indent);
}
} // End namespace otb
#endif
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
#include "itkVariableLengthVector.h" #include "itkVariableLengthVector.h"
//Estimator //Estimator
#include "otbMachineLearningModelFactory.h" #include "DimensionalityReductionModelFactory.h"
#ifdef OTB_USE_SHARK #ifdef OTB_USE_SHARK
#include "AutoencoderModel.h" #include "AutoencoderModel.h"
...@@ -75,10 +75,10 @@ public: ...@@ -75,10 +75,10 @@ public:
typedef typename SampleImageType::PixelType PixelType; typedef typename SampleImageType::PixelType PixelType;
// Machine Learning models // Machine Learning models
typedef otb::MachineLearningModelFactory< typedef otb::DimensionalityReductionModelFactory<
InputValueType, OutputValueType> ModelFactoryType; InputValueType, OutputValueType> ModelFactoryType;
typedef typename ModelFactoryType::MachineLearningModelTypePointer ModelPointerType; typedef typename ModelFactoryType::DimensionalityReductionModelTypePointer ModelPointerType;
typedef typename ModelFactoryType::MachineLearningModelType ModelType; typedef typename ModelFactoryType::DimensionalityReductionModelType ModelType;
typedef typename ModelType::InputSampleType SampleType; typedef typename ModelType::InputSampleType SampleType;
typedef typename ModelType::InputListSampleType ListSampleType; typedef typename ModelType::InputListSampleType ListSampleType;
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment