diff --git a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h
index 66c0df76222e9e928886639566f8cc78faef755d..2127439fb7af284361117428e12826e27149b444 100644
--- a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h
+++ b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h
@@ -48,6 +48,7 @@
 #include "shark/Models/Clustering/Centroids.h"
 #include "shark/Models/Clustering/ClusteringModel.h"
 #include "shark/Algorithms/KMeans.h"
+#include "shark/Algorithms/Trainers/NormalizeComponentsUnitVariance.h"
 
 #if defined(__GNUC__) || defined(__clang__)
 #pragma GCC diagnostic pop
@@ -132,6 +133,13 @@ public:
   itkGetMacro( CentroidFilename, std::string );
   itkSetMacro( CentroidFilename, std::string );
 
+  /** Initialize the centroids for the kmeans algorithm */
+  void SetCentroidsFromData(const shark::Data<shark::RealVector> & data)
+  {
+    m_Centroids.setCentroids(data);
+    this->Modified();
+  }
+  
 protected:
   /** Constructor */
   SharkKMeansMachineLearningModel();
@@ -148,6 +156,9 @@ protected:
 
   template<typename DataType>
   DataType NormalizeData(const DataType &data) const;
+  
+  template<typename DataType>
+  shark::Normalizer<> TrainNormalizer(const DataType &data) const;
 
   /** PrintSelf method */
   void PrintSelf(std::ostream &os, itk::Indent indent) const override;
diff --git a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.hxx b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.hxx
index 7392c6c15a106860b5e5fcdcd0e4e3024e72c95c..1d27ce519c63c86216167fa10e3df2cc79625eec 100644
--- a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.hxx
+++ b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.hxx
@@ -39,7 +39,6 @@
 #include "shark/Algorithms/KMeans.h" //k-means algorithm
 #include "shark/Models/Clustering/HardClusteringModel.h"
 #include "shark/Models/Clustering/SoftClusteringModel.h"
-#include "shark/Algorithms/Trainers/NormalizeComponentsUnitVariance.h"
 #include <shark/Data/Csv.h> //load the csv file
 
 #if defined(__GNUC__) || defined(__clang__)
@@ -67,18 +66,6 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
 {
 }
 
-template<class TInputValue, class TOutputValue>
-bool 
-SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
-::InitializeCentroids()
-{
-  shark::Data<shark::RealVector> data;
-  shark::importCSV(data, m_CentroidFilename, ' ');
-  m_Centroids.setCentroids(data);
-  std::cout <<m_Centroids.centroids() << std::endl;
-  return 1;
-}
-
 /** Train the machine learning model */
 template<class TInputValue, class TOutputValue>
 void
@@ -90,12 +77,12 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
   otb::Shark::ListSampleToSharkVector( this->GetInputListSample(), vector_data );
   shark::Data<shark::RealVector> data = shark::createDataFromRange( vector_data );
 
-  if (!m_CentroidFilename.empty())
-    InitializeCentroids();
-  
   // Normalized input value if necessary
   if( m_Normalized )
-    data = NormalizeData( data );
+  {
+    auto normalizer = TrainNormalizer(data);
+    data = normalizer(data);
+  }
 
   // Use a Hard Clustering Model for classification
   shark::kMeans( data, m_K, m_Centroids, m_MaximumNumberOfIterations );
@@ -114,6 +101,18 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
   return normalizer( data );
 }
 
+template<class TInputValue, class TOutputValue>
+template<typename DataType>
+shark::Normalizer<>
+SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
+::TrainNormalizer(const DataType &data) const
+{
+  shark::Normalizer<> normalizer;
+  shark::NormalizeComponentsUnitVariance<> normalizingTrainer( true );//zero mean
+  normalizingTrainer.train( normalizer, data );
+  return normalizer;
+}
+
 template<class TInputValue, class TOutputValue>
 typename SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
 ::TargetSampleType