From d1403a135e20bacbfa8e997152e1404c3b6b0cee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Traizet?= <cedric.traizet@c-s.fr> Date: Fri, 12 Apr 2019 16:47:01 +0200 Subject: [PATCH] ENH: added centroid input file to the parameter of the applications --- .../AppClassification/app/otbKMeansClassification.cxx | 6 ++++++ .../AppClassification/include/otbTrainSharkKMeans.hxx | 7 +++++++ 2 files changed, 13 insertions(+) diff --git a/Modules/Applications/AppClassification/app/otbKMeansClassification.cxx b/Modules/Applications/AppClassification/app/otbKMeansClassification.cxx index bcb546e03d..eeb30c2499 100644 --- a/Modules/Applications/AppClassification/app/otbKMeansClassification.cxx +++ b/Modules/Applications/AppClassification/app/otbKMeansClassification.cxx @@ -80,6 +80,10 @@ protected: SetParameterDescription("maxit", "Maximum number of iterations for the learning step."); SetDefaultParameterInt("maxit", 1000); MandatoryOff("maxit"); + + AddParameter(ParameterType_String, "incentroid", "Maximum number of iterations"); + SetParameterDescription("incentroid", "Maximum number of iterations for the learning step."); + MandatoryOff("incentroid"); AddParameter(ParameterType_OutputFilename, "outmeans", "Centroid filename"); SetParameterDescription("outmeans", "Output text file containing centroid positions"); @@ -248,6 +252,8 @@ protected: GetParameterInt("maxit")); GetInternalApplication("training")->SetParameterInt("classifier.sharkkm.k", GetParameterInt("nc")); + GetInternalApplication("training")->SetParameterString("classifier.sharkkm.incentroid", + GetParameterString("incentroid")); if( IsParameterEnabled("rand")) GetInternalApplication("training")->SetParameterInt("rand", GetParameterInt("rand")); diff --git a/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.hxx b/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.hxx index a3c43741c7..d3d8d75ddc 100644 --- a/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.hxx +++ b/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.hxx @@ -44,6 +44,12 @@ void LearningApplicationBase<TInputValue, TOutputValue>::InitSharkKMeansParams() SetParameterInt("classifier.sharkkm.k", 2); SetParameterDescription("classifier.sharkkm.k", "The number of classes used for the kmeans algorithm. Default set to 2 class"); SetMinimumParameterIntValue("classifier.sharkkm.k", 2); + + + // Number of classes + AddParameter(ParameterType_String, "classifier.sharkkm.incentroid", "Number of classes for the kmeans algorithm"); + SetParameterDescription("classifier.sharkkm.incentroid", "The number of classes used for the kmeans algorithm. Default set to 2 class"); + MandatoryOff("classifier.sharkkm.incentroid"); } template<class TInputValue, class TOutputValue> @@ -60,6 +66,7 @@ void LearningApplicationBase<TInputValue, TOutputValue>::TrainSharkKMeans( classifier->SetInputListSample( trainingListSample ); classifier->SetTargetListSample( trainingLabeledListSample ); classifier->SetK( k ); + classifier->SetCentroidFilename( GetParameterString( "classifier.sharkkm.incentroid") ); classifier->SetMaximumNumberOfIterations( nbMaxIter ); classifier->Train(); classifier->Save( modelPath ); -- GitLab