diff --git a/Modules/Applications/AppClassification/app/otbKMeansClassification.cxx b/Modules/Applications/AppClassification/app/otbKMeansClassification.cxx index 79d3262dca8cf7b81a2664aa63854944c5ae13a0..9284616399bb38d0c01401ef5af36f9f2feac229 100644 --- a/Modules/Applications/AppClassification/app/otbKMeansClassification.cxx +++ b/Modules/Applications/AppClassification/app/otbKMeansClassification.cxx @@ -103,7 +103,7 @@ protected: { ShareParameter("ram", "polystats.ram"); ShareParameter("sampler", "select.sampler"); - ShareParameter("centroids.out", "training.classifier.sharkkm.outcentroids"); + ShareParameter("centroids.out", "training.classifier.sharkkm.centroids.out"); ShareParameter("vm", "polystats.mask", "Validity Mask", "Validity mask, only non-zero pixels will be used to estimate KMeans modes."); } @@ -255,10 +255,10 @@ protected: GetParameterInt("nc")); if(IsParameterEnabled("centroids.in") && HasValue("centroids.in")) { - GetInternalApplication("training")->SetParameterString("classifier.sharkkm.centroids", + GetInternalApplication("training")->SetParameterString("classifier.sharkkm.centroids.in", GetParameterString("centroids.in")); - GetInternalApplication("training")->SetParameterString("classifier.sharkkm.centroidstats", + GetInternalApplication("training")->SetParameterString("classifier.sharkkm.centroids.stats", GetInternalApplication("imgstats")->GetParameterString("out")); } diff --git a/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.hxx b/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.hxx index f5293ef4de1c0f896f11971038004a0f331cdb7b..44f5faf8fb7b3683604ea2b88f9a0fab6ff76bfb 100644 --- a/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.hxx +++ b/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.hxx @@ -47,21 +47,26 @@ void LearningApplicationBase<TInputValue, TOutputValue>::InitSharkKMeansParams() SetParameterDescription("classifier.sharkkm.k", "The number of classes used for the kmeans algorithm. Default set to 2 class"); SetMinimumParameterIntValue("classifier.sharkkm.k", 2); + + AddParameter( ParameterType_Group, "classifier.sharkkm.centroids", "Centroids IO parameters" ); + SetParameterDescription( "classifier.sharkkm.centroids", "Group of parameters for centroids IO." ); + + // Input centroids - AddParameter(ParameterType_InputFilename, "classifier.sharkkm.centroids", "User definied input centroids"); + AddParameter(ParameterType_InputFilename, "classifier.sharkkm.centroids.in", "User definied input centroids"); SetParameterDescription("classifier.sharkkm.centroids", "Text file containing input centroids."); MandatoryOff("classifier.sharkkm.centroids"); // Centroid statistics - AddParameter(ParameterType_InputFilename, "classifier.sharkkm.centroidstats", "Statistics file"); - SetParameterDescription("classifier.sharkkm.centroidstats", "A XML file containing mean and standard deviation to center" + AddParameter(ParameterType_InputFilename, "classifier.sharkkm.centroids.stats", "Statistics file"); + SetParameterDescription("classifier.sharkkm.centroids.stats", "A XML file containing mean and standard deviation to center" "and reduce the centroids before the KMeans algorithm, produced by ComputeImagesStatistics application."); - MandatoryOff("classifier.sharkkm.centroidstats"); + MandatoryOff("classifier.sharkkm.centroids.stats"); // output centroids - AddParameter(ParameterType_OutputFilename, "classifier.sharkkm.outcentroids", "Output centroids text file"); - SetParameterDescription("classifier.sharkkm.outcentroids", "Output text file containing centroids after the kmean algorithm."); - MandatoryOff("classifier.sharkkm.outcentroids"); + AddParameter(ParameterType_OutputFilename, "classifier.sharkkm.centroids.out", "Output centroids text file"); + SetParameterDescription("classifier.sharkkm.centroids.out", "Output text file containing centroids after the kmean algorithm."); + MandatoryOff("classifier.sharkkm.centroids.out"); } @@ -81,30 +86,30 @@ void LearningApplicationBase<TInputValue, TOutputValue>::TrainSharkKMeans( classifier->SetK( k ); // Initialize centroids from file - if(IsParameterEnabled("classifier.sharkkm.centroids") && HasValue("classifier.sharkkm.centroids")) + if(IsParameterEnabled("classifier.sharkkm.centroids.in") && HasValue("classifier.sharkkm.centroids.in")) { shark::Data<shark::RealVector> centroidData; - shark::importCSV(centroidData, GetParameterString( "classifier.sharkkm.centroids"), ' '); - if( HasValue( "classifier.sharkkm.centroidstats" ) ) + shark::importCSV(centroidData, GetParameterString( "classifier.sharkkm.centroids.in"), ' '); + if( HasValue( "classifier.sharkkm.centroids.stats" ) ) { auto statisticsReader = otb::StatisticsXMLFileReader< itk::VariableLengthVector<float> >::New(); - statisticsReader->SetFileName(GetParameterString( "classifier.sharkkm.centroidstats" )); + statisticsReader->SetFileName(GetParameterString( "classifier.sharkkm.centroids.stats" )); auto meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean"); auto stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev"); // Convert itk Variable Length Vector to shark Real Vector shark::RealVector offsetRV(meanMeasurementVector.Size()); - shark::RealVector stddevMeasurementRV(stddevMeasurementVector.Size()); + shark::RealVector scaleRV(stddevMeasurementVector.Size()); assert(meanMeasurementVector.Size()==stddevMeasurementVector.Size()); for (unsigned int i = 0; i<meanMeasurementVector.Size(); ++i) { - stddevMeasurementRV[i] = 1/stddevMeasurementVector[i]; + scaleRV[i] = 1/stddevMeasurementVector[i]; // Substract the normalized mean offsetRV[i] = - meanMeasurementVector[i]/stddevMeasurementVector[i]; } - shark::Normalizer<> normalizer(stddevMeasurementRV, offsetRV); + shark::Normalizer<> normalizer(scaleRV, offsetRV); centroidData = normalizer(centroidData); } @@ -115,8 +120,8 @@ void LearningApplicationBase<TInputValue, TOutputValue>::TrainSharkKMeans( classifier->Train(); classifier->Save( modelPath ); - if( HasValue( "classifier.sharkkm.outcentroids")) - classifier->ExportCentroids( GetParameterString( "classifier.sharkkm.outcentroids" )); + if( HasValue( "classifier.sharkkm.centroids.out")) + classifier->ExportCentroids( GetParameterString( "classifier.sharkkm.centroids.out" )); } } //end namespace wrapper