diff --git a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h index 66c0df76222e9e928886639566f8cc78faef755d..2127439fb7af284361117428e12826e27149b444 100644 --- a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h +++ b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h @@ -48,6 +48,7 @@ #include "shark/Models/Clustering/Centroids.h" #include "shark/Models/Clustering/ClusteringModel.h" #include "shark/Algorithms/KMeans.h" +#include "shark/Algorithms/Trainers/NormalizeComponentsUnitVariance.h" #if defined(__GNUC__) || defined(__clang__) #pragma GCC diagnostic pop @@ -132,6 +133,13 @@ public: itkGetMacro( CentroidFilename, std::string ); itkSetMacro( CentroidFilename, std::string ); + /** Initialize the centroids for the kmeans algorithm */ + void SetCentroidsFromData(const shark::Data<shark::RealVector> & data) + { + m_Centroids.setCentroids(data); + this->Modified(); + } + protected: /** Constructor */ SharkKMeansMachineLearningModel(); @@ -148,6 +156,9 @@ protected: template<typename DataType> DataType NormalizeData(const DataType &data) const; + + template<typename DataType> + shark::Normalizer<> TrainNormalizer(const DataType &data) const; /** PrintSelf method */ void PrintSelf(std::ostream &os, itk::Indent indent) const override; diff --git a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.hxx b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.hxx index 7392c6c15a106860b5e5fcdcd0e4e3024e72c95c..1d27ce519c63c86216167fa10e3df2cc79625eec 100644 --- a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.hxx +++ b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.hxx @@ -39,7 +39,6 @@ #include "shark/Algorithms/KMeans.h" //k-means algorithm #include "shark/Models/Clustering/HardClusteringModel.h" #include "shark/Models/Clustering/SoftClusteringModel.h" -#include "shark/Algorithms/Trainers/NormalizeComponentsUnitVariance.h" #include <shark/Data/Csv.h> //load the csv file #if defined(__GNUC__) || defined(__clang__) @@ -67,18 +66,6 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue> { } -template<class TInputValue, class TOutputValue> -bool -SharkKMeansMachineLearningModel<TInputValue, TOutputValue> -::InitializeCentroids() -{ - shark::Data<shark::RealVector> data; - shark::importCSV(data, m_CentroidFilename, ' '); - m_Centroids.setCentroids(data); - std::cout <<m_Centroids.centroids() << std::endl; - return 1; -} - /** Train the machine learning model */ template<class TInputValue, class TOutputValue> void @@ -90,12 +77,12 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue> otb::Shark::ListSampleToSharkVector( this->GetInputListSample(), vector_data ); shark::Data<shark::RealVector> data = shark::createDataFromRange( vector_data ); - if (!m_CentroidFilename.empty()) - InitializeCentroids(); - // Normalized input value if necessary if( m_Normalized ) - data = NormalizeData( data ); + { + auto normalizer = TrainNormalizer(data); + data = normalizer(data); + } // Use a Hard Clustering Model for classification shark::kMeans( data, m_K, m_Centroids, m_MaximumNumberOfIterations ); @@ -114,6 +101,18 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue> return normalizer( data ); } +template<class TInputValue, class TOutputValue> +template<typename DataType> +shark::Normalizer<> +SharkKMeansMachineLearningModel<TInputValue, TOutputValue> +::TrainNormalizer(const DataType &data) const +{ + shark::Normalizer<> normalizer; + shark::NormalizeComponentsUnitVariance<> normalizingTrainer( true );//zero mean + normalizingTrainer.train( normalizer, data ); + return normalizer; +} + template<class TInputValue, class TOutputValue> typename SharkKMeansMachineLearningModel<TInputValue, TOutputValue> ::TargetSampleType