diff --git a/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.hxx b/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.hxx index d3d8d75ddc4fc9b272ddb547f3d19a0302afc8e6..4c307e22e2ffa59abe3a0f15d966de8fc3dbd143 100644 --- a/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.hxx +++ b/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.hxx @@ -22,6 +22,7 @@ #include "otbLearningApplicationBase.h" #include "otbSharkKMeansMachineLearningModel.h" +#include "otbStatisticsXMLFileReader.h" namespace otb { @@ -45,9 +46,13 @@ void LearningApplicationBase<TInputValue, TOutputValue>::InitSharkKMeansParams() SetParameterDescription("classifier.sharkkm.k", "The number of classes used for the kmeans algorithm. Default set to 2 class"); SetMinimumParameterIntValue("classifier.sharkkm.k", 2); + AddParameter(ParameterType_InputFilename, "classifier.sharkkm.centroidstats", "Statistics file"); + SetParameterDescription("classifier.sharkkm.centroidstats", "A XML file containing mean and standard deviation to center" + "and reduce the centroids before classification, produced by ComputeImagesStatistics application."); + MandatoryOff("classifier.sharkkm.centroidstats"); // Number of classes - AddParameter(ParameterType_String, "classifier.sharkkm.incentroid", "Number of classes for the kmeans algorithm"); + AddParameter(ParameterType_String, "classifier.sharkkm.centroids", "Number of classes for the kmeans algorithm"); SetParameterDescription("classifier.sharkkm.incentroid", "The number of classes used for the kmeans algorithm. Default set to 2 class"); MandatoryOff("classifier.sharkkm.incentroid"); } @@ -66,7 +71,41 @@ void LearningApplicationBase<TInputValue, TOutputValue>::TrainSharkKMeans( classifier->SetInputListSample( trainingListSample ); classifier->SetTargetListSample( trainingLabeledListSample ); classifier->SetK( k ); - classifier->SetCentroidFilename( GetParameterString( "classifier.sharkkm.incentroid") ); + + // Initialize centroids from file + if(HasValue("classifier.sharkkm.centroids")) + { + shark::Data<shark::RealVector> centroidData; + shark::importCSV(centroidData, GetParameterString( "classifier.sharkkm.centroidstats"), ' '); + if( HasValue( "classifier.sharkkm.centroids" ) ) + { + auto statisticsReader = otb::StatisticsXMLFileReader< itk::VariableLengthVector<float> >::New(); + statisticsReader->SetFileName(GetParameterString( "classifier.sharkkm.centroidstats" )); + auto meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean"); + auto stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev"); + + // Convert itk Variable Length Vector to shark Real Vector + shark::RealVector meanMeasurementRV(meanMeasurementVector.Size()); + for (unsigned int i = 0; i<meanMeasurementVector.Size(); ++i) + { + // Substract the mean + meanMeasurementRV[i] = - meanMeasurementVector[i]; + } + + shark::RealVector stddevMeasurementRV(stddevMeasurementVector.Size()); + for (unsigned int i = 0; i<stddevMeasurementVector.Size(); ++i) + { + stddevMeasurementRV[i] = stddevMeasurementVector[i]; + } + + shark::Normalizer<> normalizer(stddevMeasurementRV, meanMeasurementRV); + + centroidData = normalizer(centroidData); + } + + classifier->SetCentroidsFromData( centroidData); + } + classifier->SetMaximumNumberOfIterations( nbMaxIter ); classifier->Train(); classifier->Save( modelPath );