diff --git a/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.hxx b/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.hxx index f703503935223509a16118c643c0d523b7dbed27..f5293ef4de1c0f896f11971038004a0f331cdb7b 100644 --- a/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.hxx +++ b/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.hxx @@ -48,7 +48,7 @@ void LearningApplicationBase<TInputValue, TOutputValue>::InitSharkKMeansParams() SetMinimumParameterIntValue("classifier.sharkkm.k", 2); // Input centroids - AddParameter(ParameterType_InputFilename, "classifier.sharkkm.centroids", "input centroids"); + AddParameter(ParameterType_InputFilename, "classifier.sharkkm.centroids", "User definied input centroids"); SetParameterDescription("classifier.sharkkm.centroids", "Text file containing input centroids."); MandatoryOff("classifier.sharkkm.centroids"); @@ -58,6 +58,11 @@ void LearningApplicationBase<TInputValue, TOutputValue>::InitSharkKMeansParams() "and reduce the centroids before the KMeans algorithm, produced by ComputeImagesStatistics application."); MandatoryOff("classifier.sharkkm.centroidstats"); + // output centroids + AddParameter(ParameterType_OutputFilename, "classifier.sharkkm.outcentroids", "Output centroids text file"); + SetParameterDescription("classifier.sharkkm.outcentroids", "Output text file containing centroids after the kmean algorithm."); + MandatoryOff("classifier.sharkkm.outcentroids"); + } template<class TInputValue, class TOutputValue> @@ -109,6 +114,9 @@ void LearningApplicationBase<TInputValue, TOutputValue>::TrainSharkKMeans( classifier->SetMaximumNumberOfIterations( nbMaxIter ); classifier->Train(); classifier->Save( modelPath ); + + if( HasValue( "classifier.sharkkm.outcentroids")) + classifier->ExportCentroids( GetParameterString( "classifier.sharkkm.outcentroids" )); } } //end namespace wrapper diff --git a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h index 4c877a5d50e4d703f8018f8819af3459d0141a04..bf0b89391cd5b205ace3cebc3879f80ff0f1cca2 100644 --- a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h +++ b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h @@ -131,6 +131,8 @@ public: this->Modified(); } + void ExportCentroids(const std::string & filename); + protected: /** Constructor */ SharkKMeansMachineLearningModel(); diff --git a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.hxx b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.hxx index f4b69b45927efad0f51b0c5d00b75165177a0938..6554c67d318439e7f3d2f1560f9aa1eda5661957 100644 --- a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.hxx +++ b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.hxx @@ -241,6 +241,14 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue> return true; } +template<class TInputValue, class TOutputValue> +void +SharkKMeansMachineLearningModel<TInputValue, TOutputValue> +::ExportCentroids(const std::string & filename) +{ + shark::exportCSV(m_Centroids.centroids(), filename, ' '); +} + template<class TInputValue, class TOutputValue> void SharkKMeansMachineLearningModel<TInputValue, TOutputValue>