/* * Copyright (C) 2005-2019 Centre National d'Etudes Spatiales (CNES) * * This file is part of Orfeo Toolbox * * https://www.orfeo-toolbox.org/ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef otbSharkKMeansMachineLearningModel_hxx #define otbSharkKMeansMachineLearningModel_hxx #include <fstream> #include "boost/make_shared.hpp" #include "itkMacro.h" #include "otbSharkKMeansMachineLearningModel.h" #if defined(__GNUC__) || defined(__clang__) #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wshadow" #pragma GCC diagnostic ignored "-Wunused-parameter" #pragma GCC diagnostic ignored "-Woverloaded-virtual" #pragma GCC diagnostic ignored "-Wignored-qualifiers" #endif #include "otb_shark.h" #include "otbSharkUtils.h" #include "shark/Algorithms/KMeans.h" //k-means algorithm #include "shark/Models/Clustering/HardClusteringModel.h" #include "shark/Models/Clustering/SoftClusteringModel.h" #include <shark/Data/Csv.h> //load the csv file #if defined(__GNUC__) || defined(__clang__) #pragma GCC diagnostic pop #endif namespace otb { template<class TInputValue, class TOutputValue> SharkKMeansMachineLearningModel<TInputValue, TOutputValue> ::SharkKMeansMachineLearningModel() : m_K(2), m_MaximumNumberOfIterations( 10 ) { // Default set HardClusteringModel this->m_ConfidenceIndex = true; m_ClusteringModel = boost::make_shared<ClusteringModelType>( &m_Centroids ); } template<class TInputValue, class TOutputValue> SharkKMeansMachineLearningModel<TInputValue, TOutputValue> ::~SharkKMeansMachineLearningModel() { } /** Train the machine learning model */ template<class TInputValue, class TOutputValue> void SharkKMeansMachineLearningModel<TInputValue, TOutputValue> ::Train() { // Parse input data and convert to Shark Data std::vector<shark::RealVector> vector_data; otb::Shark::ListSampleToSharkVector( this->GetInputListSample(), vector_data ); shark::Data<shark::RealVector> data = shark::createDataFromRange( vector_data ); // Use a Hard Clustering Model for classification shark::kMeans( data, m_K, m_Centroids, m_MaximumNumberOfIterations ); m_ClusteringModel = boost::make_shared<ClusteringModelType>( &m_Centroids ); } template<class TInputValue, class TOutputValue> typename SharkKMeansMachineLearningModel<TInputValue, TOutputValue> ::TargetSampleType SharkKMeansMachineLearningModel<TInputValue, TOutputValue> ::DoPredict(const InputSampleType &value, ConfidenceValueType *quality, ProbaSampleType *proba) const { shark::RealVector data( value.Size()); for( size_t i = 0; i < value.Size(); i++ ) { data.push_back( value[i] ); } // Change quality measurement only if SoftClustering or other clustering method is used. if( quality != nullptr ) { //unsigned int probas = (*m_ClusteringModel)( data ); ( *quality ) = ConfidenceValueType( 1.); } if (proba != nullptr) { if (!this->m_ProbaIndex) { itkExceptionMacro("Probability per class not available for this classifier !"); } } TargetSampleType target; ClusteringOutputType predictedValue = (*m_ClusteringModel)( data ); target[0] = static_cast<TOutputValue>(predictedValue); return target; } template<class TInputValue, class TOutputValue> void SharkKMeansMachineLearningModel<TInputValue, TOutputValue> ::DoPredictBatch(const InputListSampleType *input, const unsigned int &startIndex, const unsigned int &size, TargetListSampleType *targets, ConfidenceListSampleType *quality, ProbaListSampleType * proba) const { // Perform check on input values assert( input != nullptr ); assert( targets != nullptr ); // input list sample and target list sample should be initialized and without assert( input->Size() == targets->Size() && "Input sample list and target label list do not have the same size." ); assert( ( ( quality == nullptr ) || ( quality->Size() == input->Size() ) ) && "Quality samples list is not null and does not have the same size as input samples list" ); if( startIndex + size > input->Size() ) { itkExceptionMacro( <<"requested range ["<<startIndex<<", "<<startIndex+size<<"[ partially outside input sample list range.[0,"<<input->Size()<<"[" ); } // Convert input list of features to shark data format std::vector<shark::RealVector> features; otb::Shark::ListSampleRangeToSharkVector( input, features, startIndex, size ); shark::Data<shark::RealVector> inputSamples = shark::createDataFromRange( features ); shark::Data<ClusteringOutputType> clusters; try { clusters = ( *m_ClusteringModel )( inputSamples ); } catch( ... ) { itkExceptionMacro( "Failed to run clustering classification. " "The number of features of input samples and the model could differ."); } unsigned int id = startIndex; for( const auto &p : clusters.elements() ) { TargetSampleType target; target[0] = static_cast<TOutputValue>(p); targets->SetMeasurementVector( id, target ); ++id; } // Change quality measurement only if SoftClustering or other clustering method is used. if( quality != nullptr ) { for( unsigned int qid = startIndex; qid < startIndex+size; ++qid ) { quality->SetMeasurementVector( qid, static_cast<ConfidenceValueType>(1.) ); } } if (proba !=nullptr && !this->m_ProbaIndex) { itkExceptionMacro("Probability per class not available for this classifier !"); } } template<class TInputValue, class TOutputValue> void SharkKMeansMachineLearningModel<TInputValue, TOutputValue> ::Save(const std::string &filename, const std::string & itkNotUsed( name )) { std::ofstream ofs( filename); if( !ofs ) { itkExceptionMacro( << "Error opening " << filename.c_str()); } ofs << "#" << m_ClusteringModel->name() << std::endl; shark::TextOutArchive oa( ofs ); m_ClusteringModel->save( oa, 1 ); } template<class TInputValue, class TOutputValue> void SharkKMeansMachineLearningModel<TInputValue, TOutputValue> ::Load(const std::string &filename, const std::string & itkNotUsed( name )) { m_CanRead = false; std::ifstream ifs( filename); if(ifs.good()) { // Check if first line contains model name std::string line; std::getline(ifs, line); m_CanRead = line.find(m_ClusteringModel->name()) != std::string::npos; } if(!m_CanRead) return; shark::TextInArchive ia( ifs ); m_ClusteringModel->load( ia, 0 ); ifs.close(); } template<class TInputValue, class TOutputValue> bool SharkKMeansMachineLearningModel<TInputValue, TOutputValue> ::CanReadFile(const std::string &file) { try { m_CanRead = true; this->Load( file ); } catch( ... ) { return false; } return m_CanRead; } template<class TInputValue, class TOutputValue> bool SharkKMeansMachineLearningModel<TInputValue, TOutputValue> ::CanWriteFile(const std::string & itkNotUsed( file )) { return true; } template<class TInputValue, class TOutputValue> void SharkKMeansMachineLearningModel<TInputValue, TOutputValue> ::ExportCentroids(const std::string & filename) { shark::exportCSV(m_Centroids.centroids(), filename, ' '); } template<class TInputValue, class TOutputValue> void SharkKMeansMachineLearningModel<TInputValue, TOutputValue> ::PrintSelf(std::ostream &os, itk::Indent indent) const { // Call superclass implementation Superclass::PrintSelf( os, indent ); } } //end namespace otb #endif