diff --git a/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.hxx b/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.hxx index 4b10b50533c36aca20ca169bf8f85f1515c142de..86cd23e5193ec94a2f0abec18324879a5dc241de 100644 --- a/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.hxx +++ b/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.hxx @@ -52,7 +52,7 @@ void LearningApplicationBase<TInputValue, TOutputValue>::InitSharkKMeansParams() MandatoryOff("classifier.sharkkm.centroidstats"); // Number of classes - AddParameter(ParameterType_String, "classifier.sharkkm.centroids", "Number of classes for the kmeans algorithm"); + AddParameter(ParameterType_InputFilename, "classifier.sharkkm.centroids", "Number of classes for the kmeans algorithm"); SetParameterDescription("classifier.sharkkm.centroids", "The number of classes used for the kmeans algorithm. Default set to 2 class"); MandatoryOff("classifier.sharkkm.centroids"); } @@ -73,7 +73,7 @@ void LearningApplicationBase<TInputValue, TOutputValue>::TrainSharkKMeans( classifier->SetK( k ); // Initialize centroids from file - if(HasValue("classifier.sharkkm.centroids")) + if(IsParameterEnabled("classifier.sharkkm.centroids") && HasValue("classifier.sharkkm.centroids")) { shark::Data<shark::RealVector> centroidData; shark::importCSV(centroidData, GetParameterString( "classifier.sharkkm.centroids"), ' '); @@ -91,7 +91,7 @@ void LearningApplicationBase<TInputValue, TOutputValue>::TrainSharkKMeans( assert(meanMeasurementVector.Size()==stddevMeasurementVector.Size()); for (unsigned int i = 0; i<meanMeasurementVector.Size(); ++i) { - stddevMeasurementRV[i] = stddevMeasurementVector[i]; + stddevMeasurementRV[i] = 1/stddevMeasurementVector[i]; // Substract the normalized mean offsetRV[i] = - meanMeasurementVector[i]/stddevMeasurementVector[i]; }