An error occurred while loading the file. Please try again.
-
Cédric Traizet authored7607af28
/*
* Copyright (C) 2005-2019 Centre National d'Etudes Spatiales (CNES)
*
* This file is part of Orfeo Toolbox
*
* https://www.orfeo-toolbox.org/
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef otbSharkKMeansMachineLearningModel_h
#define otbSharkKMeansMachineLearningModel_h
#include "boost/shared_ptr.hpp"
#include "itkLightObject.h"
#include "otbMachineLearningModel.h"
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wshadow"
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Woverloaded-virtual"
#pragma GCC diagnostic ignored "-Wignored-qualifiers"
#pragma GCC diagnostic ignored "-Wsign-compare"
#pragma GCC diagnostic ignored "-Wcast-align"
#pragma GCC diagnostic ignored "-Wunknown-pragmas"
#pragma GCC diagnostic ignored "-Wunused-local-typedefs"
#if defined(__clang__)
#pragma clang diagnostic ignored "-Wheader-guard"
#pragma clang diagnostic ignored "-Wexpansion-to-defined"
#else
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#endif
#endif
#include "otb_shark.h"
#include "shark/Models/Clustering/HardClusteringModel.h"
#include "shark/Models/Clustering/SoftClusteringModel.h"
#include "shark/Models/Clustering/Centroids.h"
#include "shark/Models/Clustering/ClusteringModel.h"
#include "shark/Algorithms/KMeans.h"
#include "shark/Algorithms/Trainers/NormalizeComponentsUnitVariance.h"
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic pop
#endif
/** \class SharkKMeansMachineLearningModel
* \brief Shark version of Random Forests algorithm
*
* This is a specialization of MachineLearningModel class allowing to
* use Shark implementation of the Random Forests algorithm.
*
* It is noteworthy that training step is parallel.
*
* For more information, see
* http://image.diku.dk/shark/sphinx_pages/build/html/rest_sources/tutorials/algorithms/kmeans.html
*
* \ingroup OTBUnsupervised
*/
namespace otb
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
{
template<class TInputValue, class TTargetValue>
class ITK_EXPORT SharkKMeansMachineLearningModel : public MachineLearningModel<TInputValue, TTargetValue>
{
public:
/** Standard class typedefs. */
typedef SharkKMeansMachineLearningModel Self;
typedef MachineLearningModel<TInputValue, TTargetValue> Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
typedef typename Superclass::InputValueType InputValueType;
typedef typename Superclass::InputSampleType InputSampleType;
typedef typename Superclass::InputListSampleType InputListSampleType;
typedef typename Superclass::TargetValueType TargetValueType;
typedef typename Superclass::TargetSampleType TargetSampleType;
typedef typename Superclass::TargetListSampleType TargetListSampleType;
typedef typename Superclass::ConfidenceValueType ConfidenceValueType;
typedef typename Superclass::ConfidenceSampleType ConfidenceSampleType;
typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType;
typedef typename Superclass::ProbaSampleType ProbaSampleType;
typedef typename Superclass::ProbaListSampleType ProbaListSampleType;
typedef shark::HardClusteringModel<shark::RealVector> ClusteringModelType;
typedef ClusteringModelType::OutputType ClusteringOutputType;
/** Run-time type information (and related methods). */
itkNewMacro( Self );
itkTypeMacro( SharkKMeansMachineLearningModel, MachineLearningModel );
/** Train the machine learning model */
virtual void Train() override;
/** Save the model to file */
virtual void Save(const std::string &filename, const std::string &name = "") override;
/** Load the model from file */
virtual void Load(const std::string &filename, const std::string &name = "") override;
/**\name Classification model file compatibility tests */
//@{
/** Is the input model file readable and compatible with the corresponding classifier ? */
virtual bool CanReadFile(const std::string &) override;
/** Is the input model file writable and compatible with the corresponding classifier ? */
virtual bool CanWriteFile(const std::string &) override;
//@}
/** Get the maximum number of iteration for the kMeans algorithm.*/
itkGetMacro( MaximumNumberOfIterations, unsigned );
/** Set the maximum number of iteration for the kMeans algorithm.*/
itkSetMacro( MaximumNumberOfIterations, unsigned );
/** Get the number of class for the kMeans algorithm.*/
itkGetMacro( K, unsigned );
/** Set the number of class for the kMeans algorithm.*/
itkSetMacro( K, unsigned );
/** If true, normalized input data sample list */
itkGetMacro( Normalized, bool );
itkSetMacro( Normalized, bool );
/** If true, normalized input data sample list */
itkGetMacro( CentroidFilename, std::string );
itkSetMacro( CentroidFilename, std::string );
/** Initialize the centroids for the kmeans algorithm */
void SetCentroidsFromData(const shark::Data<shark::RealVector> & data)
{
m_Centroids.setCentroids(data);
this->Modified();
141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
}
protected:
/** Constructor */
SharkKMeansMachineLearningModel();
/** Destructor */
virtual ~SharkKMeansMachineLearningModel();
/** Predict values using the model */
virtual TargetSampleType
DoPredict(const InputSampleType &input, ConfidenceValueType *quality = nullptr, ProbaSampleType *proba=nullptr) const override;
virtual void DoPredictBatch(const InputListSampleType *, const unsigned int &startIndex, const unsigned int &size,
TargetListSampleType *, ConfidenceListSampleType * = nullptr, ProbaListSampleType * = nullptr) const override;
template<typename DataType>
DataType NormalizeData(const DataType &data) const;
template<typename DataType>
shark::Normalizer<> TrainNormalizer(const DataType &data) const;
/** PrintSelf method */
void PrintSelf(std::ostream &os, itk::Indent indent) const override;
private:
SharkKMeansMachineLearningModel(const Self &) = delete;
void operator=(const Self &) = delete;
bool InitializeCentroids();
// Parameters set by the user
bool m_Normalized;
unsigned int m_K;
unsigned int m_MaximumNumberOfIterations;
bool m_CanRead;
/** Centroids results form kMeans */
shark::Centroids m_Centroids;
/** Input centroid filename */
std::string m_CentroidFilename;
/** shark Model could be SoftClusteringModel or HardClusteringModel */
boost::shared_ptr<ClusteringModelType> m_ClusteringModel;
};
} // end namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
#include "otbSharkKMeansMachineLearningModel.hxx"
#endif
#endif