otbSharkKMeansMachineLearningModel.hxx 7.86 KiB
/*
 * Copyright (C) 2005-2019 Centre National d'Etudes Spatiales (CNES)
 * This file is part of Orfeo Toolbox
 *     https://www.orfeo-toolbox.org/
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *     http://www.apache.org/licenses/LICENSE-2.0
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
#ifndef otbSharkKMeansMachineLearningModel_hxx
#define otbSharkKMeansMachineLearningModel_hxx
#include <fstream>
#include "boost/make_shared.hpp"
#include "itkMacro.h"
#include "otbSharkKMeansMachineLearningModel.h"
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wshadow"
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Woverloaded-virtual"
#pragma GCC diagnostic ignored "-Wignored-qualifiers"
#endif
#include "otb_shark.h"
#include "otbSharkUtils.h"
#include "shark/Algorithms/KMeans.h" //k-means algorithm
#include "shark/Models/Clustering/HardClusteringModel.h"
#include "shark/Models/Clustering/SoftClusteringModel.h"
#include <shark/Data/Csv.h> //load the csv file
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic pop
#endif
namespace otb
template<class TInputValue, class TOutputValue>
SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
::SharkKMeansMachineLearningModel() :
        m_K(2), m_MaximumNumberOfIterations( 10 )
  // Default set HardClusteringModel
  this->m_ConfidenceIndex = true;
  m_ClusteringModel = boost::make_shared<ClusteringModelType>( &m_Centroids );
template<class TInputValue, class TOutputValue>
SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
::~SharkKMeansMachineLearningModel()
/** Train the machine learning model */
template<class TInputValue, class TOutputValue>
void
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
SharkKMeansMachineLearningModel<TInputValue, TOutputValue> ::Train() { // Parse input data and convert to Shark Data std::vector<shark::RealVector> vector_data; otb::Shark::ListSampleToSharkVector( this->GetInputListSample(), vector_data ); shark::Data<shark::RealVector> data = shark::createDataFromRange( vector_data ); // Use a Hard Clustering Model for classification shark::kMeans( data, m_K, m_Centroids, m_MaximumNumberOfIterations ); m_ClusteringModel = boost::make_shared<ClusteringModelType>( &m_Centroids ); } template<class TInputValue, class TOutputValue> typename SharkKMeansMachineLearningModel<TInputValue, TOutputValue> ::TargetSampleType SharkKMeansMachineLearningModel<TInputValue, TOutputValue> ::DoPredict(const InputSampleType &value, ConfidenceValueType *quality, ProbaSampleType *proba) const { shark::RealVector data( value.Size()); for( size_t i = 0; i < value.Size(); i++ ) { data.push_back( value[i] ); } // Change quality measurement only if SoftClustering or other clustering method is used. if( quality != nullptr ) { //unsigned int probas = (*m_ClusteringModel)( data ); ( *quality ) = ConfidenceValueType( 1.); } if (proba != nullptr) { if (!this->m_ProbaIndex) { itkExceptionMacro("Probability per class not available for this classifier !"); } } TargetSampleType target; ClusteringOutputType predictedValue = (*m_ClusteringModel)( data ); target[0] = static_cast<TOutputValue>(predictedValue); return target; } template<class TInputValue, class TOutputValue> void SharkKMeansMachineLearningModel<TInputValue, TOutputValue> ::DoPredictBatch(const InputListSampleType *input, const unsigned int &startIndex, const unsigned int &size, TargetListSampleType *targets, ConfidenceListSampleType *quality, ProbaListSampleType * proba) const { // Perform check on input values assert( input != nullptr ); assert( targets != nullptr ); // input list sample and target list sample should be initialized and without assert( input->Size() == targets->Size() && "Input sample list and target label list do not have the same size." ); assert( ( ( quality == nullptr ) || ( quality->Size() == input->Size() ) ) && "Quality samples list is not null and does not have the same size as input samples list" ); if( startIndex + size > input->Size() ) { itkExceptionMacro( <<"requested range ["<<startIndex<<", "<<startIndex+size<<"[ partially outside input sample list range.[0,"<<input->Size()<<"[" ); }