Commit 7607af28 authored by Cédric Traizet's avatar Cédric Traizet
Browse files

ENH: input centroids can now be set from outside of the model class

No related merge requests found
Showing with 27 additions and 17 deletions
+27 -17
...@@ -48,6 +48,7 @@ ...@@ -48,6 +48,7 @@
#include "shark/Models/Clustering/Centroids.h" #include "shark/Models/Clustering/Centroids.h"
#include "shark/Models/Clustering/ClusteringModel.h" #include "shark/Models/Clustering/ClusteringModel.h"
#include "shark/Algorithms/KMeans.h" #include "shark/Algorithms/KMeans.h"
#include "shark/Algorithms/Trainers/NormalizeComponentsUnitVariance.h"
#if defined(__GNUC__) || defined(__clang__) #if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic pop #pragma GCC diagnostic pop
...@@ -132,6 +133,13 @@ public: ...@@ -132,6 +133,13 @@ public:
itkGetMacro( CentroidFilename, std::string ); itkGetMacro( CentroidFilename, std::string );
itkSetMacro( CentroidFilename, std::string ); itkSetMacro( CentroidFilename, std::string );
/** Initialize the centroids for the kmeans algorithm */
void SetCentroidsFromData(const shark::Data<shark::RealVector> & data)
{
m_Centroids.setCentroids(data);
this->Modified();
}
protected: protected:
/** Constructor */ /** Constructor */
SharkKMeansMachineLearningModel(); SharkKMeansMachineLearningModel();
...@@ -148,6 +156,9 @@ protected: ...@@ -148,6 +156,9 @@ protected:
template<typename DataType> template<typename DataType>
DataType NormalizeData(const DataType &data) const; DataType NormalizeData(const DataType &data) const;
template<typename DataType>
shark::Normalizer<> TrainNormalizer(const DataType &data) const;
/** PrintSelf method */ /** PrintSelf method */
void PrintSelf(std::ostream &os, itk::Indent indent) const override; void PrintSelf(std::ostream &os, itk::Indent indent) const override;
......
...@@ -39,7 +39,6 @@ ...@@ -39,7 +39,6 @@
#include "shark/Algorithms/KMeans.h" //k-means algorithm #include "shark/Algorithms/KMeans.h" //k-means algorithm
#include "shark/Models/Clustering/HardClusteringModel.h" #include "shark/Models/Clustering/HardClusteringModel.h"
#include "shark/Models/Clustering/SoftClusteringModel.h" #include "shark/Models/Clustering/SoftClusteringModel.h"
#include "shark/Algorithms/Trainers/NormalizeComponentsUnitVariance.h"
#include <shark/Data/Csv.h> //load the csv file #include <shark/Data/Csv.h> //load the csv file
#if defined(__GNUC__) || defined(__clang__) #if defined(__GNUC__) || defined(__clang__)
...@@ -67,18 +66,6 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue> ...@@ -67,18 +66,6 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
{ {
} }
template<class TInputValue, class TOutputValue>
bool
SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
::InitializeCentroids()
{
shark::Data<shark::RealVector> data;
shark::importCSV(data, m_CentroidFilename, ' ');
m_Centroids.setCentroids(data);
std::cout <<m_Centroids.centroids() << std::endl;
return 1;
}
/** Train the machine learning model */ /** Train the machine learning model */
template<class TInputValue, class TOutputValue> template<class TInputValue, class TOutputValue>
void void
...@@ -90,12 +77,12 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue> ...@@ -90,12 +77,12 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
otb::Shark::ListSampleToSharkVector( this->GetInputListSample(), vector_data ); otb::Shark::ListSampleToSharkVector( this->GetInputListSample(), vector_data );
shark::Data<shark::RealVector> data = shark::createDataFromRange( vector_data ); shark::Data<shark::RealVector> data = shark::createDataFromRange( vector_data );
if (!m_CentroidFilename.empty())
InitializeCentroids();
// Normalized input value if necessary // Normalized input value if necessary
if( m_Normalized ) if( m_Normalized )
data = NormalizeData( data ); {
auto normalizer = TrainNormalizer(data);
data = normalizer(data);
}
// Use a Hard Clustering Model for classification // Use a Hard Clustering Model for classification
shark::kMeans( data, m_K, m_Centroids, m_MaximumNumberOfIterations ); shark::kMeans( data, m_K, m_Centroids, m_MaximumNumberOfIterations );
...@@ -114,6 +101,18 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue> ...@@ -114,6 +101,18 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
return normalizer( data ); return normalizer( data );
} }
template<class TInputValue, class TOutputValue>
template<typename DataType>
shark::Normalizer<>
SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
::TrainNormalizer(const DataType &data) const
{
shark::Normalizer<> normalizer;
shark::NormalizeComponentsUnitVariance<> normalizingTrainer( true );//zero mean
normalizingTrainer.train( normalizer, data );
return normalizer;
}
template<class TInputValue, class TOutputValue> template<class TInputValue, class TOutputValue>
typename SharkKMeansMachineLearningModel<TInputValue, TOutputValue> typename SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
::TargetSampleType ::TargetSampleType
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment