Commit 7c1fe8e1 authored by Cédric Traizet's avatar Cédric Traizet
Browse files

ENH: added the possibility to initialize the centroids from a file

No related merge requests found
Showing with 25 additions and 0 deletions
+25 -0
...@@ -128,6 +128,10 @@ public: ...@@ -128,6 +128,10 @@ public:
itkGetMacro( Normalized, bool ); itkGetMacro( Normalized, bool );
itkSetMacro( Normalized, bool ); itkSetMacro( Normalized, bool );
/** If true, normalized input data sample list */
itkGetMacro( CentroidFilename, std::string );
itkSetMacro( CentroidFilename, std::string );
protected: protected:
/** Constructor */ /** Constructor */
SharkKMeansMachineLearningModel(); SharkKMeansMachineLearningModel();
...@@ -152,6 +156,9 @@ private: ...@@ -152,6 +156,9 @@ private:
SharkKMeansMachineLearningModel(const Self &) = delete; SharkKMeansMachineLearningModel(const Self &) = delete;
void operator=(const Self &) = delete; void operator=(const Self &) = delete;
bool InitializeCentroids();
// Parameters set by the user // Parameters set by the user
bool m_Normalized; bool m_Normalized;
unsigned int m_K; unsigned int m_K;
...@@ -162,6 +169,8 @@ private: ...@@ -162,6 +169,8 @@ private:
/** Centroids results form kMeans */ /** Centroids results form kMeans */
shark::Centroids m_Centroids; shark::Centroids m_Centroids;
/** Input centroid filename */
std::string m_CentroidFilename;
/** shark Model could be SoftClusteringModel or HardClusteringModel */ /** shark Model could be SoftClusteringModel or HardClusteringModel */
boost::shared_ptr<ClusteringModelType> m_ClusteringModel; boost::shared_ptr<ClusteringModelType> m_ClusteringModel;
......
...@@ -40,6 +40,7 @@ ...@@ -40,6 +40,7 @@
#include "shark/Models/Clustering/HardClusteringModel.h" #include "shark/Models/Clustering/HardClusteringModel.h"
#include "shark/Models/Clustering/SoftClusteringModel.h" #include "shark/Models/Clustering/SoftClusteringModel.h"
#include "shark/Algorithms/Trainers/NormalizeComponentsUnitVariance.h" #include "shark/Algorithms/Trainers/NormalizeComponentsUnitVariance.h"
#include <shark/Data/Csv.h> //load the csv file
#if defined(__GNUC__) || defined(__clang__) #if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic pop #pragma GCC diagnostic pop
...@@ -66,6 +67,18 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue> ...@@ -66,6 +67,18 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
{ {
} }
template<class TInputValue, class TOutputValue>
bool
SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
::InitializeCentroids()
{
shark::Data<shark::RealVector> data;
shark::importCSV(data, m_CentroidFilename, ' ');
m_Centroids.setCentroids(data);
std::cout <<m_Centroids.centroids() << std::endl;
return 1;
}
/** Train the machine learning model */ /** Train the machine learning model */
template<class TInputValue, class TOutputValue> template<class TInputValue, class TOutputValue>
void void
...@@ -77,6 +90,9 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue> ...@@ -77,6 +90,9 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
otb::Shark::ListSampleToSharkVector( this->GetInputListSample(), vector_data ); otb::Shark::ListSampleToSharkVector( this->GetInputListSample(), vector_data );
shark::Data<shark::RealVector> data = shark::createDataFromRange( vector_data ); shark::Data<shark::RealVector> data = shark::createDataFromRange( vector_data );
if (!m_CentroidFilename.empty())
InitializeCentroids();
// Normalized input value if necessary // Normalized input value if necessary
if( m_Normalized ) if( m_Normalized )
data = NormalizeData( data ); data = NormalizeData( data );
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment