diff --git a/Modules/Applications/AppClassification/app/otbKMeansClassification.cxx b/Modules/Applications/AppClassification/app/otbKMeansClassification.cxx index bcb546e03d11118b8e5c98a29d1cc7dbd824873e..eeb30c249965e3a17697a1a34e5e528199515152 100644 --- a/Modules/Applications/AppClassification/app/otbKMeansClassification.cxx +++ b/Modules/Applications/AppClassification/app/otbKMeansClassification.cxx @@ -80,6 +80,10 @@ protected: SetParameterDescription("maxit", "Maximum number of iterations for the learning step."); SetDefaultParameterInt("maxit", 1000); MandatoryOff("maxit"); + + AddParameter(ParameterType_String, "incentroid", "Maximum number of iterations"); + SetParameterDescription("incentroid", "Maximum number of iterations for the learning step."); + MandatoryOff("incentroid"); AddParameter(ParameterType_OutputFilename, "outmeans", "Centroid filename"); SetParameterDescription("outmeans", "Output text file containing centroid positions"); @@ -248,6 +252,8 @@ protected: GetParameterInt("maxit")); GetInternalApplication("training")->SetParameterInt("classifier.sharkkm.k", GetParameterInt("nc")); + GetInternalApplication("training")->SetParameterString("classifier.sharkkm.incentroid", + GetParameterString("incentroid")); if( IsParameterEnabled("rand")) GetInternalApplication("training")->SetParameterInt("rand", GetParameterInt("rand")); diff --git a/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.hxx b/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.hxx index a3c43741c70ace5c453ae45212abedb313b30b32..d3d8d75ddc4fc9b272ddb547f3d19a0302afc8e6 100644 --- a/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.hxx +++ b/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.hxx @@ -44,6 +44,12 @@ void LearningApplicationBase<TInputValue, TOutputValue>::InitSharkKMeansParams() SetParameterInt("classifier.sharkkm.k", 2); SetParameterDescription("classifier.sharkkm.k", "The number of classes used for the kmeans algorithm. Default set to 2 class"); SetMinimumParameterIntValue("classifier.sharkkm.k", 2); + + + // Number of classes + AddParameter(ParameterType_String, "classifier.sharkkm.incentroid", "Number of classes for the kmeans algorithm"); + SetParameterDescription("classifier.sharkkm.incentroid", "The number of classes used for the kmeans algorithm. Default set to 2 class"); + MandatoryOff("classifier.sharkkm.incentroid"); } template<class TInputValue, class TOutputValue> @@ -60,6 +66,7 @@ void LearningApplicationBase<TInputValue, TOutputValue>::TrainSharkKMeans( classifier->SetInputListSample( trainingListSample ); classifier->SetTargetListSample( trainingLabeledListSample ); classifier->SetK( k ); + classifier->SetCentroidFilename( GetParameterString( "classifier.sharkkm.incentroid") ); classifier->SetMaximumNumberOfIterations( nbMaxIter ); classifier->Train(); classifier->Save( modelPath );