Commit ef393f13 authored by Cédric Traizet's avatar Cédric Traizet
Browse files

Training app for SOM working, empty som factory header created

No related merge requests found
Showing with 490 additions and 11 deletions
+490 -11
...@@ -28,7 +28,6 @@ ...@@ -28,7 +28,6 @@
#include "otbImageToVectorImageCastFilter.h" #include "otbImageToVectorImageCastFilter.h"
#include "DimensionalityReductionModelFactory.h" #include "DimensionalityReductionModelFactory.h"
namespace otb namespace otb
{ {
namespace Functor namespace Functor
...@@ -273,7 +272,6 @@ private: ...@@ -273,7 +272,6 @@ private:
ModelPointerType m_Model; ModelPointerType m_Model;
RescalerType::Pointer m_Rescaler; RescalerType::Pointer m_Rescaler;
OutputRescalerType::Pointer m_OutRescaler; OutputRescalerType::Pointer m_OutRescaler;
}; };
......
...@@ -49,14 +49,14 @@ private: ...@@ -49,14 +49,14 @@ private:
/*
template <class TInputValue, class TTargetValue> template <class TInputValue, class TTargetValue>
class ITK_EXPORT AutoencoderModelFactory : public AutoencoderModelFactoryBase<TInputValue, TTargetValue, shark::Autoencoder< shark::TanhNeuron, shark::LinearNeuron>> {}; using AutoencoderModelFactory = AutoencoderModelFactoryBase<TInputValue, TTargetValue, shark::Autoencoder< shark::TanhNeuron, shark::LinearNeuron>> ;
template <class TInputValue, class TTargetValue> template <class TInputValue, class TTargetValue>
class ITK_EXPORT TiedAutoencoderModelFactory : public AutoencoderModelFactoryBase<TInputValue, TTargetValue, shark::TiedAutoencoder< shark::TanhNeuron, shark::LinearNeuron>> {}; using TiedAutoencoderModelFactory = AutoencoderModelFactoryBase<TInputValue, TTargetValue, shark::TiedAutoencoder< shark::TanhNeuron, shark::LinearNeuron>> ;
*/
} //namespace otb } //namespace otb
......
...@@ -50,6 +50,7 @@ public: ...@@ -50,6 +50,7 @@ public:
/** Mode in which the files is intended to be used */ /** Mode in which the files is intended to be used */
typedef enum { ReadMode, WriteMode } FileModeType; typedef enum { ReadMode, WriteMode } FileModeType;
/** Create the appropriate MachineLearningModel depending on the particulars of the file. */ /** Create the appropriate MachineLearningModel depending on the particulars of the file. */
static DimensionalityReductionModelTypePointer CreateDimensionalityReductionModel(const std::string& path, FileModeType mode); static DimensionalityReductionModelTypePointer CreateDimensionalityReductionModel(const std::string& path, FileModeType mode);
......
...@@ -32,6 +32,15 @@ ...@@ -32,6 +32,15 @@
namespace otb namespace otb
{ {
template <class TInputValue, class TTargetValue>
using AutoencoderModelFactory = AutoencoderModelFactoryBase<TInputValue, TTargetValue, shark::Autoencoder< shark::TanhNeuron, shark::LinearNeuron>> ;
template <class TInputValue, class TTargetValue>
using TiedAutoencoderModelFactory = AutoencoderModelFactoryBase<TInputValue, TTargetValue, shark::TiedAutoencoder< shark::TanhNeuron, shark::LinearNeuron>> ;
template <class TInputValue, class TOutputValue> template <class TInputValue, class TOutputValue>
typename DimensionalityReductionModel<TInputValue,TOutputValue>::Pointer typename DimensionalityReductionModel<TInputValue,TOutputValue>::Pointer
DimensionalityReductionModelFactory<TInputValue,TOutputValue> DimensionalityReductionModelFactory<TInputValue,TOutputValue>
...@@ -88,6 +97,13 @@ DimensionalityReductionModelFactory<TInputValue,TOutputValue> ...@@ -88,6 +97,13 @@ DimensionalityReductionModelFactory<TInputValue,TOutputValue>
#ifdef OTB_USE_SHARK #ifdef OTB_USE_SHARK
// using AutoencoderModelFactory = AutoencoderModelFactoryBase<TInputValue, TTargetValue, shark::Autoencoder< shark::TanhNeuron, shark::LinearNeuron>> {};
//using TiedAutoencoderModelFactory = public AutoencoderModelFactoryBase<TInputValue, TTargetValue, shark::TiedAutoencoder< shark::TanhNeuron, shark::LinearNeuron>> {};
RegisterFactory(PCAModelFactory<TInputValue,TOutputValue>::New()); RegisterFactory(PCAModelFactory<TInputValue,TOutputValue>::New());
RegisterFactory(AutoencoderModelFactory<TInputValue,TOutputValue>::New()); RegisterFactory(AutoencoderModelFactory<TInputValue,TOutputValue>::New());
RegisterFactory(TiedAutoencoderModelFactory<TInputValue,TOutputValue>::New()); RegisterFactory(TiedAutoencoderModelFactory<TInputValue,TOutputValue>::New());
......
...@@ -28,7 +28,7 @@ public: ...@@ -28,7 +28,7 @@ public:
typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType; typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType;
itkNewMacro(Self); itkNewMacro(Self);
itkTypeMacro(AutoencoderModel, DimensionalityReductionModel); itkTypeMacro(PCAModel, DimensionalityReductionModel);
unsigned int GetDimension() {return m_Dimension;}; unsigned int GetDimension() {return m_Dimension;};
itkSetMacro(Dimension,unsigned int); itkSetMacro(Dimension,unsigned int);
......
...@@ -38,7 +38,6 @@ void PCAModel<TInputValue>::Train() ...@@ -38,7 +38,6 @@ void PCAModel<TInputValue>::Train()
Shark::ListSampleToSharkVector(this->GetInputListSample(), features); Shark::ListSampleToSharkVector(this->GetInputListSample(), features);
shark::Data<shark::RealVector> inputSamples = shark::createDataFromRange( features ); shark::Data<shark::RealVector> inputSamples = shark::createDataFromRange( features );
//m_pca.train(m_encoder,inputSamples);
m_pca.setData(inputSamples); m_pca.setData(inputSamples);
m_pca.encoder(m_encoder, m_Dimension); m_pca.encoder(m_encoder, m_Dimension);
m_pca.decoder(m_decoder, m_Dimension); m_pca.decoder(m_decoder, m_Dimension);
......
#ifndef SOMModel_h
#define SOMModel_h
#include "DimensionalityReductionModel.h"
#include "otbSOMMap.h"
#include "otbSOM.h"
#include "itkEuclideanDistanceMetric.h" // the distance function
#include "otbCzihoSOMLearningBehaviorFunctor.h"
#include "otbCzihoSOMNeighborhoodBehaviorFunctor.h"
namespace otb
{
template <class TInputValue>
class ITK_EXPORT SOMModel: public DimensionalityReductionModel<TInputValue,TInputValue>
{
public:
typedef SOMModel Self;
typedef DimensionalityReductionModel<TInputValue,TInputValue> Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
typedef typename Superclass::InputValueType InputValueType;
typedef typename Superclass::InputSampleType InputSampleType;
typedef typename Superclass::InputListSampleType InputListSampleType;
typedef typename InputListSampleType::Pointer ListSamplePointerType;
typedef typename Superclass::TargetValueType TargetValueType;
typedef typename Superclass::TargetSampleType TargetSampleType;
typedef typename Superclass::TargetListSampleType TargetListSampleType;
typedef typename Superclass::ConfidenceValueType ConfidenceValueType;
typedef typename Superclass::ConfidenceSampleType ConfidenceSampleType;
typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType;
typedef SOMMap<itk::VariableLengthVector<TInputValue>,itk::Statistics::EuclideanDistanceMetric<itk::VariableLengthVector<TInputValue>>, 2> MapType;
typedef typename MapType::SizeType SizeType;
typedef otb::SOM<InputListSampleType, MapType> EstimatorType;
typedef Functor::CzihoSOMLearningBehaviorFunctor SOMLearningBehaviorFunctorType;
typedef Functor::CzihoSOMNeighborhoodBehaviorFunctor SOMNeighborhoodBehaviorFunctorType;
itkNewMacro(Self);
itkTypeMacro(SOMModel, DimensionalityReductionModel);
/** Accessors */
itkSetMacro(NumberOfIterations, unsigned int);
itkGetMacro(NumberOfIterations, unsigned int);
itkSetMacro(BetaInit, double);
itkGetMacro(BetaInit, double);
itkSetMacro(BetaEnd, double);
itkGetMacro(BetaEnd, double);
itkSetMacro(MinWeight, InputValueType);
itkGetMacro(MinWeight, InputValueType);
itkSetMacro(MaxWeight, InputValueType);
itkGetMacro(MaxWeight, InputValueType);
itkSetMacro(MapSize, SizeType);
itkGetMacro(MapSize, SizeType);
itkSetMacro(NeighborhoodSizeInit, SizeType);
itkGetMacro(NeighborhoodSizeInit, SizeType);
itkSetMacro(RandomInit, bool);
itkGetMacro(RandomInit, bool);
itkSetMacro(Seed, unsigned int);
itkGetMacro(Seed, unsigned int);
itkGetObjectMacro(ListSample, InputListSampleType);
itkSetObjectMacro(ListSample, InputListSampleType);
bool CanReadFile(const std::string & filename);
bool CanWriteFile(const std::string & filename);
void Save(const std::string & filename, const std::string & name="") ITK_OVERRIDE;
void Load(const std::string & filename, const std::string & name="") ITK_OVERRIDE;
void Train() ITK_OVERRIDE;
//void Dimensionality_reduction() {}; // Dimensionality reduction is done by DoPredict
unsigned int GetDimension() { return MapType::ImageDimension;};
protected:
SOMModel();
~SOMModel() ITK_OVERRIDE;
virtual TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=ITK_NULLPTR) const ITK_OVERRIDE;
private:
typename MapType::Pointer m_SOMMap;
/** Map Parameters used for training */
SizeType m_MapSize;
/** Number of iterations */
unsigned int m_NumberOfIterations;
/** Initial learning coefficient */
double m_BetaInit;
/** Final learning coefficient */
double m_BetaEnd;
/** Initial neighborhood size */
SizeType m_NeighborhoodSizeInit;
/** Minimum initial neuron weights */
InputValueType m_MinWeight;
/** Maximum initial neuron weights */
InputValueType m_MaxWeight;
/** Random initialization bool */
bool m_RandomInit;
/** Seed for random initialization */
unsigned int m_Seed;
/** The input list sample */
ListSamplePointerType m_ListSample;
/** Behavior of the Learning weightening (link to the beta coefficient) */
SOMLearningBehaviorFunctorType m_BetaFunctor;
/** Behavior of the Neighborhood extent */
SOMNeighborhoodBehaviorFunctorType m_NeighborhoodSizeFunctor;
};
} // end namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
#include "SOMModel.txx"
#endif
#endif
#ifndef SOMModel_txx
#define SOMModel_txx
#include "otbImageFileReader.h"
#include "otbImageFileWriter.h"
#include "itkMacro.h"
namespace otb
{
template <class TInputValue>
SOMModel<TInputValue>::SOMModel()
{
}
template <class TInputValue>
SOMModel<TInputValue>::~SOMModel()
{
}
template <class TInputValue>
void SOMModel<TInputValue>::Train()
{
typename EstimatorType::Pointer estimator = EstimatorType::New();
estimator->SetListSample(m_ListSample);
estimator->SetMapSize(m_MapSize);
estimator->SetNeighborhoodSizeInit(m_NeighborhoodSizeInit);
estimator->SetNumberOfIterations(m_NumberOfIterations);
estimator->SetBetaInit(m_BetaInit);
estimator->SetBetaEnd(m_BetaEnd);
estimator->SetMaxWeight(m_MaxWeight);
//AddProcess(estimator,"Learning");
std::cout << "list = " << m_ListSample << std::endl;
std::cout << "size = " << m_MapSize << std::endl;
std::cout << "neigsize = " << m_NeighborhoodSizeInit << std::endl;
std::cout << "n iter = " << m_NumberOfIterations << std::endl;
std::cout << "bi = " << m_BetaInit << std::endl;
std::cout << "be = " << m_BetaEnd << std::endl;
std::cout << "mw = " << m_MaxWeight << std::endl;
estimator->Update();
m_SOMMap = estimator->GetOutput();
}
template <class TInputValue>
bool SOMModel<TInputValue>::CanReadFile(const std::string & filename)
{
return true;
}
template <class TInputValue>
bool SOMModel<TInputValue>::CanWriteFile(const std::string & filename)
{
return true;
}
template <class TInputValue>
void SOMModel<TInputValue>::Save(const std::string & filename, const std::string & name)
{
std::cout << m_SOMMap->GetNumberOfComponentsPerPixel() << std::endl;
//Ecriture
auto kwl = m_SOMMap->GetImageKeywordlist();
//kwl.AddKey("MachineLearningModelType", "SOM");
//m_SOMMap->SetImageKeywordList(kwl);
auto writer = otb::ImageFileWriter<MapType>::New();
writer->SetInput(m_SOMMap);
writer->SetFileName(filename);
writer->Update();
}
template <class TInputValue>
void SOMModel<TInputValue>::Load(const std::string & filename, const std::string & name)
{
auto reader = otb::ImageFileReader<MapType>::New();
reader->SetFileName(filename);
reader->Update();
std::cout << reader->GetOutput()->GetImageKeywordlist().GetMetadataByKey("MachineLearningModelType") << '\n';
m_SOMMap = reader->GetOutput();
}
template <class TInputValue>
typename SOMModel<TInputValue>::TargetSampleType
SOMModel<TInputValue>::DoPredict(const InputSampleType & value, ConfidenceValueType *quality) const
{
}
} // namespace otb
#endif
#ifndef PCAModelFactory_h
#define PCAModelFactory_h
#include "itkObjectFactoryBase.h"
#include "itkImageIOBase.h"
namespace otb
{
template <class TInputValue, class TTargetValue>
class ITK_EXPORT PCAModelFactory : public itk::ObjectFactoryBase
{
public:
/** Standard class typedefs. */
typedef PCAModelFactory Self;
typedef itk::ObjectFactoryBase Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Class methods used to interface with the registered factories. */
const char* GetITKSourceVersion(void) const ITK_OVERRIDE;
const char* GetDescription(void) const ITK_OVERRIDE;
/** Method for class instantiation. */
itkFactorylessNewMacro(Self);
/** Run-time type information (and related methods). */
itkTypeMacro(PCAModelFactory, itk::ObjectFactoryBase);
/** Register one factory of this type */
static void RegisterOneFactory(void)
{
Pointer PCAFactory = PCAModelFactory::New();
itk::ObjectFactoryBase::RegisterFactory(PCAFactory);
}
protected:
PCAModelFactory();
~PCAModelFactory() ITK_OVERRIDE;
private:
PCAModelFactory(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
};
} //namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
#include "PCAModelFactory.txx"
#endif
#endif
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef PCAFactory_txx
#define PCAFactory_txx
#include "PCAModelFactory.h"
#include "itkCreateObjectFunction.h"
#include "PCAModel.h"
//#include <shark/Algorithms/Trainers/PCA.h>
#include "itkVersion.h"
namespace otb
{
template <class TInputValue, class TOutputValue>
PCAModelFactory<TInputValue,TOutputValue>::PCAModelFactory()
{
std::string classOverride = std::string("DimensionalityReductionModel");
std::string subclass = std::string("PCAModel");
this->RegisterOverride(classOverride.c_str(),
subclass.c_str(),
"Shark PCA ML Model",
1,
// itk::CreateObjectFunction<AutoencoderModel<TInputValue,TOutputValue> >::New());
itk::CreateObjectFunction<PCAModel<TInputValue>>::New());
}
template <class TInputValue, class TOutputValue>
PCAModelFactory<TInputValue,TOutputValue>::~PCAModelFactory()
{
}
template <class TInputValue, class TOutputValue>
const char* PCAModelFactory<TInputValue,TOutputValue>::GetITKSourceVersion(void) const
{
return ITK_SOURCE_VERSION;
}
template <class TInputValue, class TOutputValue>
const char* PCAModelFactory<TInputValue,TOutputValue>::GetDescription() const
{
return "PCA model factory";
}
} // end namespace otb
#endif
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
//Estimator //Estimator
#include "DimensionalityReductionModelFactory.h" #include "DimensionalityReductionModelFactory.h"
#include "SOMModel.h"
#ifdef OTB_USE_SHARK #ifdef OTB_USE_SHARK
#include "AutoencoderModel.h" #include "AutoencoderModel.h"
#include "PCAModel.h" #include "PCAModel.h"
...@@ -75,7 +77,6 @@ public: ...@@ -75,7 +77,6 @@ public:
typedef otb::VectorImage<InputValueType> SampleImageType; typedef otb::VectorImage<InputValueType> SampleImageType;
typedef typename SampleImageType::PixelType PixelType; typedef typename SampleImageType::PixelType PixelType;
// Machine Learning models
typedef otb::DimensionalityReductionModelFactory< typedef otb::DimensionalityReductionModelFactory<
InputValueType, OutputValueType> ModelFactoryType; InputValueType, OutputValueType> ModelFactoryType;
typedef typename ModelFactoryType::DimensionalityReductionModelTypePointer ModelPointerType; typedef typename ModelFactoryType::DimensionalityReductionModelTypePointer ModelPointerType;
...@@ -84,6 +85,11 @@ public: ...@@ -84,6 +85,11 @@ public:
typedef typename ModelType::InputSampleType SampleType; typedef typename ModelType::InputSampleType SampleType;
typedef typename ModelType::InputListSampleType ListSampleType; typedef typename ModelType::InputListSampleType ListSampleType;
// Dimensionality reduction models
typedef SOMMap<itk::VariableLengthVector<TInputValue>,itk::Statistics::EuclideanDistanceMetric<itk::VariableLengthVector<TInputValue>>, 2> MapType;
typedef otb::SOM<ListSampleType, MapType> EstimatorType;
typedef otb::SOMModel<InputValueType> SOMModelType;
#ifdef OTB_USE_SHARK #ifdef OTB_USE_SHARK
typedef shark::Autoencoder< shark::TanhNeuron, shark::LinearNeuron> AutoencoderType; typedef shark::Autoencoder< shark::TanhNeuron, shark::LinearNeuron> AutoencoderType;
...@@ -120,9 +126,11 @@ private: ...@@ -120,9 +126,11 @@ private:
#ifdef OTB_USE_SHARK #ifdef OTB_USE_SHARK
void InitAutoencoderParams(); void InitAutoencoderParams();
void InitPCAParams(); void InitPCAParams();
void InitSOMParams();
template <class autoencoderchoice> template <class autoencoderchoice>
void TrainAutoencoder(typename ListSampleType::Pointer trainingListSample, std::string modelPath); void TrainAutoencoder(typename ListSampleType::Pointer trainingListSample, std::string modelPath);
void TrainPCA(typename ListSampleType::Pointer trainingListSample, std::string modelPath); void TrainPCA(typename ListSampleType::Pointer trainingListSample, std::string modelPath);
void TrainSOM(typename ListSampleType::Pointer trainingListSample, std::string modelPath);
#endif #endif
//@} //@}
}; };
...@@ -132,6 +140,7 @@ private: ...@@ -132,6 +140,7 @@ private:
#ifndef OTB_MANUAL_INSTANTIATION #ifndef OTB_MANUAL_INSTANTIATION
#include "cbLearningApplicationBaseDR.txx" #include "cbLearningApplicationBaseDR.txx"
#include "cbTrainSOM.txx"
#ifdef OTB_USE_SHARK #ifdef OTB_USE_SHARK
#include "cbTrainAutoencoder.txx" #include "cbTrainAutoencoder.txx"
#include "cbTrainPCA.txx" #include "cbTrainPCA.txx"
......
...@@ -48,7 +48,7 @@ cbLearningApplicationBaseDR<TInputValue,TOutputValue> ...@@ -48,7 +48,7 @@ cbLearningApplicationBaseDR<TInputValue,TOutputValue>
AddParameter(ParameterType_Choice, "model", "moddel to use for the training"); AddParameter(ParameterType_Choice, "model", "moddel to use for the training");
SetParameterDescription("model", "Choice of the dimensionality reduction model to use for the training."); SetParameterDescription("model", "Choice of the dimensionality reduction model to use for the training.");
InitSOMParams();
#ifdef OTB_USE_SHARK #ifdef OTB_USE_SHARK
InitAutoencoderParams(); InitAutoencoderParams();
InitPCAParams(); InitPCAParams();
...@@ -98,7 +98,11 @@ cbLearningApplicationBaseDR<TInputValue,TOutputValue> ...@@ -98,7 +98,11 @@ cbLearningApplicationBaseDR<TInputValue,TOutputValue>
// get the name of the chosen machine learning model // get the name of the chosen machine learning model
const std::string modelName = GetParameterString("model"); const std::string modelName = GetParameterString("model");
// call specific train function // call specific train function
if(modelName == "som")
{
TrainSOM(trainingListSample,modelPath);
}
if(modelName == "autoencoder") if(modelName == "autoencoder")
{ {
#ifdef OTB_USE_SHARK #ifdef OTB_USE_SHARK
......
#ifndef cbTrainSOM_txx
#define cbTrainSOM_txx
#include "cbLearningApplicationBaseDR.h"
namespace otb
{
namespace Wrapper
{
template <class TInputValue, class TOutputValue>
void
cbLearningApplicationBaseDR<TInputValue,TOutputValue>
::InitSOMParams()
{
AddChoice("model.som", "OTB SOM");
SetParameterDescription("model.som",
"This group of parameters allows setting SOM parameters. "
);
AddParameter(ParameterType_Int, "model.som.sx", "SizeX");
SetParameterDescription("model.som.sx", "X size of the SOM map");
MandatoryOff("model.som.sx");
AddParameter(ParameterType_Int, "model.som.sy", "SizeY");
SetParameterDescription("model.som.sy", "Y size of the SOM map");
MandatoryOff("model.som.sy");
AddParameter(ParameterType_Int, "model.som.nx", "NeighborhoodX");
SetParameterDescription("model.som.nx", "X size of the initial neighborhood in the SOM map");
MandatoryOff("model.som.nx");
AddParameter(ParameterType_Int, "model.som.ny", "NeighborhoodY");
SetParameterDescription("model.som.ny", "Y size of the initial neighborhood in the SOM map");
MandatoryOff("model.som.nx");
AddParameter(ParameterType_Int, "model.som.ni", "NumberIteration");
SetParameterDescription("model.som.ni", "Number of iterations for SOM learning");
MandatoryOff("model.som.ni");
AddParameter(ParameterType_Float, "model.som.bi", "BetaInit");
SetParameterDescription("model.som.bi", "Initial learning coefficient");
MandatoryOff("model.som.bi");
AddParameter(ParameterType_Float, "model.som.bf", "BetaFinal");
SetParameterDescription("model.som.bf", "Final learning coefficient");
MandatoryOff("model.som.bf");
AddParameter(ParameterType_Float, "model.som.iv", "InitialValue");
SetParameterDescription("model.som.iv", "Maximum initial neuron weight");
MandatoryOff("model.som.iv");
SetDefaultParameterInt("model.som.sx", 32);
SetDefaultParameterInt("model.som.sy", 32);
SetDefaultParameterInt("model.som.nx", 10);
SetDefaultParameterInt("model.som.ny", 10);
SetDefaultParameterInt("model.som.ni", 5);
SetDefaultParameterFloat("model.som.bi", 1.0);
SetDefaultParameterFloat("model.som.bf", 0.1);
SetDefaultParameterFloat("model.som.iv", 10.0);
}
template <class TInputValue, class TOutputValue>
void cbLearningApplicationBaseDR<TInputValue,TOutputValue>
::TrainSOM(typename ListSampleType::Pointer trainingListSample,std::string modelPath)
{
typename SOMModelType::Pointer dimredTrainer = SOMModelType::New();
dimredTrainer->SetNumberOfIterations(GetParameterInt("model.som.ni"));
dimredTrainer->SetBetaInit(GetParameterFloat("model.som.bi"));
dimredTrainer->SetBetaEnd(GetParameterFloat("model.som.bf"));
dimredTrainer->SetMaxWeight(GetParameterFloat("model.som.iv"));
typename EstimatorType::SizeType size;
size[0]=GetParameterInt("model.som.sx");
size[1]=GetParameterInt("model.som.sy");
dimredTrainer->SetMapSize(size);
typename EstimatorType::SizeType radius;
radius[0] = GetParameterInt("model.som.nx");
radius[1] = GetParameterInt("model.som.ny");
dimredTrainer->SetNeighborhoodSizeInit(radius);
std::cout << trainingListSample << std::endl;
dimredTrainer->SetListSample(trainingListSample);
dimredTrainer->Train();
dimredTrainer->Save(modelPath);
}
} //end namespace wrapper
} //end namespace otb
#endif
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment