Commit b18d93a8 authored by Cédric Traizet's avatar Cédric Traizet
Browse files

ENH: remove normalization from Shark Kmeans machine learning model

No related merge requests found
Showing with 1 addition and 28 deletions
+1 -28
...@@ -125,10 +125,6 @@ public: ...@@ -125,10 +125,6 @@ public:
/** Set the number of class for the kMeans algorithm.*/ /** Set the number of class for the kMeans algorithm.*/
itkSetMacro( K, unsigned ); itkSetMacro( K, unsigned );
/** If true, normalized input data sample list */
itkGetMacro( Normalized, bool );
itkSetMacro( Normalized, bool );
/** Initialize the centroids for the kmeans algorithm */ /** Initialize the centroids for the kmeans algorithm */
void SetCentroidsFromData(const shark::Data<shark::RealVector> & data) void SetCentroidsFromData(const shark::Data<shark::RealVector> & data)
{ {
...@@ -150,9 +146,6 @@ protected: ...@@ -150,9 +146,6 @@ protected:
virtual void DoPredictBatch(const InputListSampleType *, const unsigned int &startIndex, const unsigned int &size, virtual void DoPredictBatch(const InputListSampleType *, const unsigned int &startIndex, const unsigned int &size,
TargetListSampleType *, ConfidenceListSampleType * = nullptr, ProbaListSampleType * = nullptr) const override; TargetListSampleType *, ConfidenceListSampleType * = nullptr, ProbaListSampleType * = nullptr) const override;
template<typename DataType>
DataType NormalizeData(const DataType &data) const;
template<typename DataType> template<typename DataType>
shark::Normalizer<> TrainNormalizer(const DataType &data) const; shark::Normalizer<> TrainNormalizer(const DataType &data) const;
...@@ -164,7 +157,6 @@ private: ...@@ -164,7 +157,6 @@ private:
void operator=(const Self &) = delete; void operator=(const Self &) = delete;
// Parameters set by the user // Parameters set by the user
bool m_Normalized;
unsigned int m_K; unsigned int m_K;
unsigned int m_MaximumNumberOfIterations; unsigned int m_MaximumNumberOfIterations;
bool m_CanRead; bool m_CanRead;
......
...@@ -52,7 +52,7 @@ namespace otb ...@@ -52,7 +52,7 @@ namespace otb
template<class TInputValue, class TOutputValue> template<class TInputValue, class TOutputValue>
SharkKMeansMachineLearningModel<TInputValue, TOutputValue> SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
::SharkKMeansMachineLearningModel() : ::SharkKMeansMachineLearningModel() :
m_Normalized( false ), m_K(2), m_MaximumNumberOfIterations( 10 ) m_K(2), m_MaximumNumberOfIterations( 10 )
{ {
// Default set HardClusteringModel // Default set HardClusteringModel
this->m_ConfidenceIndex = true; this->m_ConfidenceIndex = true;
...@@ -77,30 +77,11 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue> ...@@ -77,30 +77,11 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
otb::Shark::ListSampleToSharkVector( this->GetInputListSample(), vector_data ); otb::Shark::ListSampleToSharkVector( this->GetInputListSample(), vector_data );
shark::Data<shark::RealVector> data = shark::createDataFromRange( vector_data ); shark::Data<shark::RealVector> data = shark::createDataFromRange( vector_data );
// Normalized input value if necessary
if( m_Normalized )
{
auto normalizer = TrainNormalizer(data);
data = normalizer(data);
}
// Use a Hard Clustering Model for classification // Use a Hard Clustering Model for classification
shark::kMeans( data, m_K, m_Centroids, m_MaximumNumberOfIterations ); shark::kMeans( data, m_K, m_Centroids, m_MaximumNumberOfIterations );
m_ClusteringModel = boost::make_shared<ClusteringModelType>( &m_Centroids ); m_ClusteringModel = boost::make_shared<ClusteringModelType>( &m_Centroids );
} }
template<class TInputValue, class TOutputValue>
template<typename DataType>
DataType
SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
::NormalizeData(const DataType &data) const
{
shark::Normalizer<> normalizer;
shark::NormalizeComponentsUnitVariance<> normalizingTrainer( true );//zero mean
normalizingTrainer.train( normalizer, data );
return normalizer( data );
}
template<class TInputValue, class TOutputValue> template<class TInputValue, class TOutputValue>
template<typename DataType> template<typename DataType>
shark::Normalizer<> shark::Normalizer<>
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment