Commit d5ed9ebb authored by Guillaume Pasero's avatar Guillaume Pasero
Browse files

STYLE: use OTB coding style

No related merge requests found
Showing with 41 additions and 104 deletions
+41 -104
...@@ -20,15 +20,12 @@ ...@@ -20,15 +20,12 @@
#ifndef otbAutoencoderModelFactory_h #ifndef otbAutoencoderModelFactory_h
#define otbAutoencoderModelFactory_h #define otbAutoencoderModelFactory_h
//#include <shark/Models/TiedAutoencoder.h>
//#include <shark/Models/Autoencoder.h>
#include "itkObjectFactoryBase.h" #include "itkObjectFactoryBase.h"
#include "itkImageIOBase.h" #include "itkImageIOBase.h"
namespace otb namespace otb
{ {
template <class TInputValue, class TTargetValue, class NeuronType> template <class TInputValue, class TTargetValue, class NeuronType>
class ITK_EXPORT AutoencoderModelFactory : public itk::ObjectFactoryBase class ITK_EXPORT AutoencoderModelFactory : public itk::ObjectFactoryBase
{ {
...@@ -63,27 +60,12 @@ protected: ...@@ -63,27 +60,12 @@ protected:
private: private:
AutoencoderModelFactory(const Self &); //purposely not implemented AutoencoderModelFactory(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented void operator =(const Self&); //purposely not implemented
}; };
/*
template <class TInputValue, class TTargetValue>
using AutoencoderModelFactory = AutoencoderModelFactoryBase<TInputValue, TTargetValue, shark::Autoencoder< shark::TanhNeuron, shark::LinearNeuron>> ;
template <class TInputValue, class TTargetValue>
using TiedAutoencoderModelFactory = AutoencoderModelFactoryBase<TInputValue, TTargetValue, shark::TiedAutoencoder< shark::TanhNeuron, shark::LinearNeuron>> ;
*/
} //namespace otb } //namespace otb
#ifndef OTB_MANUAL_INSTANTIATION #ifndef OTB_MANUAL_INSTANTIATION
#include "otbAutoencoderModelFactory.txx" #include "otbAutoencoderModelFactory.txx"
#endif #endif
#endif #endif
...@@ -20,7 +20,6 @@ ...@@ -20,7 +20,6 @@
#ifndef otbAutoencoderModelFactory_txx #ifndef otbAutoencoderModelFactory_txx
#define otbAutoencoderModelFactory_txx #define otbAutoencoderModelFactory_txx
#include "otbAutoencoderModelFactory.h" #include "otbAutoencoderModelFactory.h"
#include "otbAutoencoderModel.h" #include "otbAutoencoderModel.h"
...@@ -32,16 +31,15 @@ namespace otb ...@@ -32,16 +31,15 @@ namespace otb
template <class TInputValue, class TOutputValue, class NeuronType> template <class TInputValue, class TOutputValue, class NeuronType>
AutoencoderModelFactory<TInputValue,TOutputValue, NeuronType>::AutoencoderModelFactory() AutoencoderModelFactory<TInputValue,TOutputValue, NeuronType>::AutoencoderModelFactory()
{ {
std::string classOverride = std::string("DimensionalityReductionModel"); std::string classOverride = std::string("DimensionalityReductionModel");
std::string subclass = std::string("AutoencoderModel"); std::string subclass = std::string("AutoencoderModel");
this->RegisterOverride(classOverride.c_str(), this->RegisterOverride(
subclass.c_str(), classOverride.c_str(),
"Shark AE ML Model", subclass.c_str(),
1, "Shark AE ML Model",
// itk::CreateObjectFunction<AutoencoderModel<TInputValue,TOutputValue> >::New()); 1,
itk::CreateObjectFunction<AutoencoderModel<TInputValue,NeuronType > >::New()); itk::CreateObjectFunction<AutoencoderModel<TInputValue,NeuronType > >::New());
} }
template <class TInputValue, class TOutputValue, class NeuronType> template <class TInputValue, class TOutputValue, class NeuronType>
...@@ -50,13 +48,15 @@ AutoencoderModelFactory<TInputValue,TOutputValue, NeuronType>::~AutoencoderModel ...@@ -50,13 +48,15 @@ AutoencoderModelFactory<TInputValue,TOutputValue, NeuronType>::~AutoencoderModel
} }
template <class TInputValue, class TOutputValue, class NeuronType> template <class TInputValue, class TOutputValue, class NeuronType>
const char* AutoencoderModelFactory<TInputValue,TOutputValue, NeuronType>::GetITKSourceVersion(void) const const char*
AutoencoderModelFactory<TInputValue,TOutputValue, NeuronType>::GetITKSourceVersion(void) const
{ {
return ITK_SOURCE_VERSION; return ITK_SOURCE_VERSION;
} }
template <class TInputValue, class TOutputValue, class NeuronType> template <class TInputValue, class TOutputValue, class NeuronType>
const char* AutoencoderModelFactory<TInputValue,TOutputValue, NeuronType>::GetDescription() const const char*
AutoencoderModelFactory<TInputValue,TOutputValue, NeuronType>::GetDescription() const
{ {
return "Autoencoder model factory"; return "Autoencoder model factory";
} }
......
...@@ -20,9 +20,7 @@ ...@@ -20,9 +20,7 @@
#ifndef otbDimensionalityReductionModelFactory_h #ifndef otbDimensionalityReductionModelFactory_h
#define otbDimensionalityReductionModelFactory_h #define otbDimensionalityReductionModelFactory_h
//#include "DimensionalityReductionModel.h"
#include "otbMachineLearningModelFactoryBase.h" #include "otbMachineLearningModelFactoryBase.h"
#include "otbMachineLearningModel.h" #include "otbMachineLearningModel.h"
namespace otb namespace otb
...@@ -54,7 +52,6 @@ public: ...@@ -54,7 +52,6 @@ public:
/** Mode in which the files is intended to be used */ /** Mode in which the files is intended to be used */
typedef enum { ReadMode, WriteMode } FileModeType; typedef enum { ReadMode, WriteMode } FileModeType;
/** Create the appropriate MachineLearningModel depending on the particulars of the file. */ /** Create the appropriate MachineLearningModel depending on the particulars of the file. */
static DimensionalityReductionModelTypePointer CreateDimensionalityReductionModel(const std::string& path, FileModeType mode); static DimensionalityReductionModelTypePointer CreateDimensionalityReductionModel(const std::string& path, FileModeType mode);
...@@ -74,7 +71,6 @@ private: ...@@ -74,7 +71,6 @@ private:
/** Register a single factory, ensuring it has not been registered /** Register a single factory, ensuring it has not been registered
* twice */ * twice */
static void RegisterFactory(itk::ObjectFactoryBase * factory); static void RegisterFactory(itk::ObjectFactoryBase * factory);
}; };
} // end namespace otb } // end namespace otb
......
...@@ -32,15 +32,12 @@ ...@@ -32,15 +32,12 @@
#include "itkMutexLockHolder.h" #include "itkMutexLockHolder.h"
namespace otb namespace otb
{ {
template <class TInputValue, class TTargetValue> template <class TInputValue, class TTargetValue>
using LogAutoencoderModelFactory = AutoencoderModelFactory<TInputValue, TTargetValue, shark::LogisticNeuron> ; using LogAutoencoderModelFactory = AutoencoderModelFactory<TInputValue, TTargetValue, shark::LogisticNeuron> ;
template <class TInputValue, class TTargetValue> template <class TInputValue, class TTargetValue>
using SOM2DModelFactory = SOMModelFactory<TInputValue, TTargetValue, 2> ; using SOM2DModelFactory = SOMModelFactory<TInputValue, TTargetValue, 2> ;
...@@ -55,7 +52,7 @@ using SOM5DModelFactory = SOMModelFactory<TInputValue, TTargetValue, 5> ; ...@@ -55,7 +52,7 @@ using SOM5DModelFactory = SOMModelFactory<TInputValue, TTargetValue, 5> ;
template <class TInputValue, class TOutputValue> template <class TInputValue, class TOutputValue>
typename MachineLearningModel<itk::VariableLengthVector< TInputValue> , itk::VariableLengthVector< TOutputValue>>::Pointer typename MachineLearningModel<itk::VariableLengthVector< TInputValue>, itk::VariableLengthVector< TOutputValue> >::Pointer
DimensionalityReductionModelFactory<TInputValue,TOutputValue> DimensionalityReductionModelFactory<TInputValue,TOutputValue>
::CreateDimensionalityReductionModel(const std::string& path, FileModeType mode) ::CreateDimensionalityReductionModel(const std::string& path, FileModeType mode)
{ {
...@@ -64,7 +61,7 @@ DimensionalityReductionModelFactory<TInputValue,TOutputValue> ...@@ -64,7 +61,7 @@ DimensionalityReductionModelFactory<TInputValue,TOutputValue>
std::list<DimensionalityReductionModelTypePointer> possibleDimensionalityReductionModel; std::list<DimensionalityReductionModelTypePointer> possibleDimensionalityReductionModel;
std::list<LightObject::Pointer> allobjects = std::list<LightObject::Pointer> allobjects =
itk::ObjectFactoryBase::CreateAllInstance("DimensionalityReductionModel"); itk::ObjectFactoryBase::CreateAllInstance("DimensionalityReductionModel");
for(std::list<LightObject::Pointer>::iterator i = allobjects.begin(); for(std::list<LightObject::Pointer>::iterator i = allobjects.begin();
i != allobjects.end(); ++i) i != allobjects.end(); ++i)
{ {
...@@ -75,19 +72,17 @@ DimensionalityReductionModelFactory<TInputValue,TOutputValue> ...@@ -75,19 +72,17 @@ DimensionalityReductionModelFactory<TInputValue,TOutputValue>
} }
else else
{ {
std::cerr << "Error DimensionalityReductionModel Factory did not return an DimensionalityReductionModel: " std::cerr << "Error DimensionalityReductionModel Factory did not return an DimensionalityReductionModel: "
<< (*i)->GetNameOfClass() << (*i)->GetNameOfClass()
<< std::endl; << std::endl;
} }
} }
for(typename std::list<DimensionalityReductionModelTypePointer>::iterator k = possibleDimensionalityReductionModel.begin(); for(typename std::list<DimensionalityReductionModelTypePointer>::iterator k = possibleDimensionalityReductionModel.begin();
k != possibleDimensionalityReductionModel.end(); ++k) k != possibleDimensionalityReductionModel.end(); ++k)
{ {
if( mode == ReadMode ) if( mode == ReadMode )
{ {
if((*k)->CanReadFile(path)) if((*k)->CanReadFile(path))
{ {
return *k; return *k;
...@@ -99,7 +94,6 @@ for(typename std::list<DimensionalityReductionModelTypePointer>::iterator k = po ...@@ -99,7 +94,6 @@ for(typename std::list<DimensionalityReductionModelTypePointer>::iterator k = po
{ {
return *k; return *k;
} }
} }
} }
return ITK_NULLPTR; return ITK_NULLPTR;
...@@ -111,20 +105,16 @@ DimensionalityReductionModelFactory<TInputValue,TOutputValue> ...@@ -111,20 +105,16 @@ DimensionalityReductionModelFactory<TInputValue,TOutputValue>
::RegisterBuiltInFactories() ::RegisterBuiltInFactories()
{ {
itk::MutexLockHolder<itk::SimpleMutexLock> lockHolder(mutex); itk::MutexLockHolder<itk::SimpleMutexLock> lockHolder(mutex);
RegisterFactory(SOM2DModelFactory<TInputValue,TOutputValue>::New()); RegisterFactory(SOM2DModelFactory<TInputValue,TOutputValue>::New());
RegisterFactory(SOM3DModelFactory<TInputValue,TOutputValue>::New()); RegisterFactory(SOM3DModelFactory<TInputValue,TOutputValue>::New());
RegisterFactory(SOM4DModelFactory<TInputValue,TOutputValue>::New()); RegisterFactory(SOM4DModelFactory<TInputValue,TOutputValue>::New());
RegisterFactory(SOM5DModelFactory<TInputValue,TOutputValue>::New()); RegisterFactory(SOM5DModelFactory<TInputValue,TOutputValue>::New());
#ifdef OTB_USE_SHARK #ifdef OTB_USE_SHARK
RegisterFactory(PCAModelFactory<TInputValue,TOutputValue>::New()); RegisterFactory(PCAModelFactory<TInputValue,TOutputValue>::New());
RegisterFactory(LogAutoencoderModelFactory<TInputValue,TOutputValue>::New()); RegisterFactory(LogAutoencoderModelFactory<TInputValue,TOutputValue>::New());
// RegisterFactory(TiedAutoencoderModelFactory<TInputValue,TOutputValue>::New());
#endif #endif
} }
template <class TInputValue, class TOutputValue> template <class TInputValue, class TOutputValue>
...@@ -151,17 +141,15 @@ DimensionalityReductionModelFactory<TInputValue,TOutputValue> ...@@ -151,17 +141,15 @@ DimensionalityReductionModelFactory<TInputValue,TOutputValue>
for (itFac = factories.begin(); itFac != factories.end() ; ++itFac) for (itFac = factories.begin(); itFac != factories.end() ; ++itFac)
{ {
// SOM 5D
// SOM SOM5DModelFactory<TInputValue,TOutputValue> *som5dFactory =
SOM5DModelFactory<TInputValue,TOutputValue> *som5dFactory =
dynamic_cast<SOM5DModelFactory<TInputValue,TOutputValue> *>(*itFac); dynamic_cast<SOM5DModelFactory<TInputValue,TOutputValue> *>(*itFac);
if (som5dFactory) if (som5dFactory)
{ {
itk::ObjectFactoryBase::UnRegisterFactory(som5dFactory); itk::ObjectFactoryBase::UnRegisterFactory(som5dFactory);
continue; continue;
} }
// SOM 4D
SOM4DModelFactory<TInputValue,TOutputValue> *som4dFactory = SOM4DModelFactory<TInputValue,TOutputValue> *som4dFactory =
dynamic_cast<SOM4DModelFactory<TInputValue,TOutputValue> *>(*itFac); dynamic_cast<SOM4DModelFactory<TInputValue,TOutputValue> *>(*itFac);
if (som4dFactory) if (som4dFactory)
...@@ -169,7 +157,7 @@ DimensionalityReductionModelFactory<TInputValue,TOutputValue> ...@@ -169,7 +157,7 @@ DimensionalityReductionModelFactory<TInputValue,TOutputValue>
itk::ObjectFactoryBase::UnRegisterFactory(som4dFactory); itk::ObjectFactoryBase::UnRegisterFactory(som4dFactory);
continue; continue;
} }
// SOM 3D
SOM3DModelFactory<TInputValue,TOutputValue> *som3dFactory = SOM3DModelFactory<TInputValue,TOutputValue> *som3dFactory =
dynamic_cast<SOM3DModelFactory<TInputValue,TOutputValue> *>(*itFac); dynamic_cast<SOM3DModelFactory<TInputValue,TOutputValue> *>(*itFac);
if (som3dFactory) if (som3dFactory)
...@@ -177,7 +165,7 @@ DimensionalityReductionModelFactory<TInputValue,TOutputValue> ...@@ -177,7 +165,7 @@ DimensionalityReductionModelFactory<TInputValue,TOutputValue>
itk::ObjectFactoryBase::UnRegisterFactory(som3dFactory); itk::ObjectFactoryBase::UnRegisterFactory(som3dFactory);
continue; continue;
} }
// SOM 2D
SOM2DModelFactory<TInputValue,TOutputValue> *som2dFactory = SOM2DModelFactory<TInputValue,TOutputValue> *som2dFactory =
dynamic_cast<SOM2DModelFactory<TInputValue,TOutputValue> *>(*itFac); dynamic_cast<SOM2DModelFactory<TInputValue,TOutputValue> *>(*itFac);
if (som2dFactory) if (som2dFactory)
...@@ -185,9 +173,8 @@ DimensionalityReductionModelFactory<TInputValue,TOutputValue> ...@@ -185,9 +173,8 @@ DimensionalityReductionModelFactory<TInputValue,TOutputValue>
itk::ObjectFactoryBase::UnRegisterFactory(som2dFactory); itk::ObjectFactoryBase::UnRegisterFactory(som2dFactory);
continue; continue;
} }
#ifdef OTB_USE_SHARK #ifdef OTB_USE_SHARK
// Autoencoder
LogAutoencoderModelFactory<TInputValue,TOutputValue> *aeFactory = LogAutoencoderModelFactory<TInputValue,TOutputValue> *aeFactory =
dynamic_cast<LogAutoencoderModelFactory<TInputValue,TOutputValue> *>(*itFac); dynamic_cast<LogAutoencoderModelFactory<TInputValue,TOutputValue> *>(*itFac);
if (aeFactory) if (aeFactory)
...@@ -195,17 +182,7 @@ DimensionalityReductionModelFactory<TInputValue,TOutputValue> ...@@ -195,17 +182,7 @@ DimensionalityReductionModelFactory<TInputValue,TOutputValue>
itk::ObjectFactoryBase::UnRegisterFactory(aeFactory); itk::ObjectFactoryBase::UnRegisterFactory(aeFactory);
continue; continue;
} }
// PCA
/*
TiedAutoencoderModelFactory<TInputValue,TOutputValue> *taeFactory =
dynamic_cast<TiedAutoencoderModelFactory<TInputValue,TOutputValue> *>(*itFac);
if (taeFactory)
{
itk::ObjectFactoryBase::UnRegisterFactory(taeFactory);
continue;
}
*/
// PCA
PCAModelFactory<TInputValue,TOutputValue> *pcaFactory = PCAModelFactory<TInputValue,TOutputValue> *pcaFactory =
dynamic_cast<PCAModelFactory<TInputValue,TOutputValue> *>(*itFac); dynamic_cast<PCAModelFactory<TInputValue,TOutputValue> *>(*itFac);
if (pcaFactory) if (pcaFactory)
...@@ -214,9 +191,7 @@ DimensionalityReductionModelFactory<TInputValue,TOutputValue> ...@@ -214,9 +191,7 @@ DimensionalityReductionModelFactory<TInputValue,TOutputValue>
continue; continue;
} }
#endif #endif
} }
} }
} // end namespace otb } // end namespace otb
......
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
#define otbImageDimensionalityReduction_h #define otbImageDimensionalityReduction_h
#include "itkImageToImageFilter.h" #include "itkImageToImageFilter.h"
//#include "DimensionalityReductionModel.h"
#include "otbMachineLearningModel.h" #include "otbMachineLearningModel.h"
#include "otbImage.h" #include "otbImage.h"
......
...@@ -36,8 +36,6 @@ ImageDimensionalityReductionFilter<TInputImage, TOutputImage, TMaskImage> ...@@ -36,8 +36,6 @@ ImageDimensionalityReductionFilter<TInputImage, TOutputImage, TMaskImage>
this->SetNumberOfIndexedInputs(2); this->SetNumberOfIndexedInputs(2);
this->SetNumberOfRequiredInputs(1); this->SetNumberOfRequiredInputs(1);
//m_DefaultLabel = itk::NumericTraits<LabelType>::ZeroValue();
this->SetNumberOfRequiredOutputs(2); this->SetNumberOfRequiredOutputs(2);
this->SetNthOutput(0,TOutputImage::New()); this->SetNthOutput(0,TOutputImage::New());
this->SetNthOutput(1,ConfidenceImageType::New()); this->SetNthOutput(1,ConfidenceImageType::New());
...@@ -113,9 +111,7 @@ ImageDimensionalityReductionFilter<TInputImage, TOutputImage, TMaskImage> ...@@ -113,9 +111,7 @@ ImageDimensionalityReductionFilter<TInputImage, TOutputImage, TMaskImage>
// Define iterators // Define iterators
typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType; typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType;
//typedef itk::ImageRegionConstIterator<MaskImageType> MaskIteratorType;
typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType; typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
//typedef itk::ImageRegionIterator<ConfidenceImageType> ConfidenceMapIteratorType;
InputIteratorType inIt(inputPtr, outputRegionForThread); InputIteratorType inIt(inputPtr, outputRegionForThread);
OutputIteratorType outIt(outputPtr, outputRegionForThread); OutputIteratorType outIt(outputPtr, outputRegionForThread);
...@@ -123,43 +119,36 @@ ImageDimensionalityReductionFilter<TInputImage, TOutputImage, TMaskImage> ...@@ -123,43 +119,36 @@ ImageDimensionalityReductionFilter<TInputImage, TOutputImage, TMaskImage>
// Walk the part of the image // Walk the part of the image
for (inIt.GoToBegin(), outIt.GoToBegin(); !inIt.IsAtEnd() && !outIt.IsAtEnd(); ++inIt, ++outIt) for (inIt.GoToBegin(), outIt.GoToBegin(); !inIt.IsAtEnd() && !outIt.IsAtEnd(); ++inIt, ++outIt)
{ {
// Classifify // Classifify
outIt.Set(m_Model->Predict(inIt.Get()));
outIt.Set(m_Model->Predict(inIt.Get()));
progress.CompletedPixel(); progress.CompletedPixel();
} }
} }
template <class TInputImage, class TOutputImage, class TMaskImage> template <class TInputImage, class TOutputImage, class TMaskImage>
void ImageDimensionalityReductionFilter<TInputImage, TOutputImage, TMaskImage>::GenerateOutputInformation() void ImageDimensionalityReductionFilter<TInputImage, TOutputImage, TMaskImage>::GenerateOutputInformation()
{ {
Superclass::GenerateOutputInformation(); Superclass::GenerateOutputInformation();
this->GetOutput()->SetNumberOfComponentsPerPixel( m_Model->GetDimension() ); this->GetOutput()->SetNumberOfComponentsPerPixel( m_Model->GetDimension() );
} }
template <class TInputImage, class TOutputImage, class TMaskImage> template <class TInputImage, class TOutputImage, class TMaskImage>
void void
ImageDimensionalityReductionFilter<TInputImage, TOutputImage, TMaskImage> ImageDimensionalityReductionFilter<TInputImage, TOutputImage, TMaskImage>
::BatchThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, itk::ThreadIdType threadId) ::BatchThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, itk::ThreadIdType threadId)
{ {
// Get the input pointers // Get the input pointers
InputImageConstPointerType inputPtr = this->GetInput(); InputImageConstPointerType inputPtr = this->GetInput();
MaskImageConstPointerType inputMaskPtr = this->GetInputMask(); MaskImageConstPointerType inputMaskPtr = this->GetInputMask();
OutputImagePointerType outputPtr = this->GetOutput(); OutputImagePointerType outputPtr = this->GetOutput();
ConfidenceImagePointerType confidencePtr = this->GetOutputConfidence(); ConfidenceImagePointerType confidencePtr = this->GetOutputConfidence();
// Progress reporting // Progress reporting
itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels()); itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels());
// Define iterators // Define iterators
typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType; typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType;
//typedef itk::ImageRegionConstIterator<MaskImageType> MaskIteratorType;
typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType; typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
//typedef itk::ImageRegionIterator<ConfidenceImageType> ConfidenceMapIteratorType;
InputIteratorType inIt(inputPtr, outputRegionForThread); InputIteratorType inIt(inputPtr, outputRegionForThread);
OutputIteratorType outIt(outputPtr, outputRegionForThread); OutputIteratorType outIt(outputPtr, outputRegionForThread);
...@@ -168,45 +157,40 @@ ImageDimensionalityReductionFilter<TInputImage, TOutputImage, TMaskImage> ...@@ -168,45 +157,40 @@ ImageDimensionalityReductionFilter<TInputImage, TOutputImage, TMaskImage>
typedef typename ModelType::InputListSampleType InputListSampleType; typedef typename ModelType::InputListSampleType InputListSampleType;
typedef typename ModelType::TargetValueType TargetValueType; typedef typename ModelType::TargetValueType TargetValueType;
typedef typename ModelType::TargetListSampleType TargetListSampleType; typedef typename ModelType::TargetListSampleType TargetListSampleType;
typename InputListSampleType::Pointer samples = InputListSampleType::New(); typename InputListSampleType::Pointer samples = InputListSampleType::New();
unsigned int num_features = inputPtr->GetNumberOfComponentsPerPixel(); unsigned int num_features = inputPtr->GetNumberOfComponentsPerPixel();
samples->SetMeasurementVectorSize(num_features); samples->SetMeasurementVectorSize(num_features);
InputSampleType sample(num_features); InputSampleType sample(num_features);
// Fill the samples // Fill the samples
for (inIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt) for (inIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt)
{ {
typename InputImageType::PixelType pix = inIt.Get();
typename InputImageType::PixelType pix = inIt.Get(); for(size_t feat=0; feat<num_features; ++feat)
for(size_t feat=0; feat<num_features; ++feat) {
{ sample[feat]=pix[feat];
sample[feat]=pix[feat]; }
} samples->PushBack(sample);
samples->PushBack(sample);
} }
//Make the batch prediction //Make the batch prediction
typename TargetListSampleType::Pointer labels; typename TargetListSampleType::Pointer labels;
// This call is threadsafe // This call is threadsafe
labels = m_Model->PredictBatch(samples); labels = m_Model->PredictBatch(samples);
// Set the output values // Set the output values
typename TargetListSampleType::ConstIterator labIt = labels->Begin(); typename TargetListSampleType::ConstIterator labIt = labels->Begin();
for (outIt.GoToBegin(); !outIt.IsAtEnd(); ++outIt) for (outIt.GoToBegin(); !outIt.IsAtEnd(); ++outIt)
{ {
itk::VariableLengthVector<TargetValueType> labelValue;
itk::VariableLengthVector<TargetValueType> labelValue;
labelValue = labIt.GetMeasurementVector(); labelValue = labIt.GetMeasurementVector();
++labIt; ++labIt;
outIt.Set(labelValue); outIt.Set(labelValue);
progress.CompletedPixel(); progress.CompletedPixel();
} }
} }
template <class TInputImage, class TOutputImage, class TMaskImage> template <class TInputImage, class TOutputImage, class TMaskImage>
void void
ImageDimensionalityReductionFilter<TInputImage, TOutputImage, TMaskImage> ImageDimensionalityReductionFilter<TInputImage, TOutputImage, TMaskImage>
...@@ -220,8 +204,8 @@ ImageDimensionalityReductionFilter<TInputImage, TOutputImage, TMaskImage> ...@@ -220,8 +204,8 @@ ImageDimensionalityReductionFilter<TInputImage, TOutputImage, TMaskImage>
{ {
this->ClassicThreadedGenerateData(outputRegionForThread, threadId); this->ClassicThreadedGenerateData(outputRegionForThread, threadId);
} }
} }
/** /**
* PrintSelf Method * PrintSelf Method
*/ */
...@@ -232,5 +216,6 @@ ImageDimensionalityReductionFilter<TInputImage, TOutputImage, TMaskImage> ...@@ -232,5 +216,6 @@ ImageDimensionalityReductionFilter<TInputImage, TOutputImage, TMaskImage>
{ {
Superclass::PrintSelf(os, indent); Superclass::PrintSelf(os, indent);
} }
} // End namespace otb } // End namespace otb
#endif #endif
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment