diff --git a/Data/Baseline/OTB-Applications/Images/apTvClKMeansImageClassificationInputCentroids.tif b/Data/Baseline/OTB-Applications/Images/apTvClKMeansImageClassificationInputCentroids.tif new file mode 100644 index 0000000000000000000000000000000000000000..ce0ecbe5e6e8ea678c4ce3a6e7701d725de1658a --- /dev/null +++ b/Data/Baseline/OTB-Applications/Images/apTvClKMeansImageClassificationInputCentroids.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:28427a2e951d1b56636568b966922a2d99f68c8e3f477aab99904fa5971fc42d +size 66540 diff --git a/Data/Input/Classification/KMeansInputCentroids.txt b/Data/Input/Classification/KMeansInputCentroids.txt new file mode 100644 index 0000000000000000000000000000000000000000..d7302a51166ea6468ac049ad970582370605c10a --- /dev/null +++ b/Data/Input/Classification/KMeansInputCentroids.txt @@ -0,0 +1,5 @@ +148.1360412249 176.9065574064 79.2367424483 275.6865470422 +180.3646315623 255.4157568188 138.2565634726 656.5357728603 +187.5074713392 256.7055784897 121.8671939978 115.8660938389 +220.0887858502 326.8933399989 229.672560688 434.3589597278 +515.191687488 834.8626368509 642.6102022528 814.8945435557 diff --git a/Modules/Applications/AppClassification/app/otbKMeansClassification.cxx b/Modules/Applications/AppClassification/app/otbKMeansClassification.cxx index bcb546e03d11118b8e5c98a29d1cc7dbd824873e..d1c7fcd52f0e88e315536ec6afc7343f1ca9d5bd 100644 --- a/Modules/Applications/AppClassification/app/otbKMeansClassification.cxx +++ b/Modules/Applications/AppClassification/app/otbKMeansClassification.cxx @@ -77,13 +77,22 @@ protected: MandatoryOff("ts"); AddParameter(ParameterType_Int, "maxit", "Maximum number of iterations"); - SetParameterDescription("maxit", "Maximum number of iterations for the learning step."); + SetParameterDescription("maxit", + "Maximum number of iterations for the learning step." + " If this parameter is set to 0, the KMeans algorithm will not stop until convergence"); SetDefaultParameterInt("maxit", 1000); MandatoryOff("maxit"); - AddParameter(ParameterType_OutputFilename, "outmeans", "Centroid filename"); - SetParameterDescription("outmeans", "Output text file containing centroid positions"); - MandatoryOff("outmeans"); + AddParameter(ParameterType_Group, "centroids", "Centroids IO parameters"); + SetParameterDescription("centroids", "Group of parameters for centroids IO."); + + AddParameter(ParameterType_InputFilename, "centroids.in", "input centroids text file"); + SetParameterDescription("centroids.in", + "Input text file containing centroid positions used to initialize the algorithm. " + "Each centroid must be described by p parameters, p being the number of bands in " + "the input image, and the number of centroids must be equal to the number of classes " + "(one centroid per line with values separated by spaces)."); + MandatoryOff("centroids.in"); ShareKMSamplingParameters(); ConnectKMSamplingParams(); @@ -99,6 +108,7 @@ protected: { ShareParameter("ram", "polystats.ram"); ShareParameter("sampler", "select.sampler"); + ShareParameter("centroids.out", "training.classifier.sharkkm.centroids.out"); ShareParameter("vm", "polystats.mask", "Validity Mask", "Validity mask, only non-zero pixels will be used to estimate KMeans modes."); } @@ -248,6 +258,14 @@ protected: GetParameterInt("maxit")); GetInternalApplication("training")->SetParameterInt("classifier.sharkkm.k", GetParameterInt("nc")); + if (IsParameterEnabled("centroids.in") && HasValue("centroids.in")) + { + GetInternalApplication("training")->SetParameterString("classifier.sharkkm.centroids.in", GetParameterString("centroids.in")); + + GetInternalApplication("training") + ->SetParameterString("classifier.sharkkm.centroids.stats", GetInternalApplication("imgstats")->GetParameterString("out")); + } + if( IsParameterEnabled("rand")) GetInternalApplication("training")->SetParameterInt("rand", GetParameterInt("rand")); @@ -276,55 +294,6 @@ protected: ExecuteInternal( "classif" ); } - void CreateOutMeansFile(FloatVectorImageType *image, - const std::string &modelFileName, - unsigned int nbClasses) - { - if (IsParameterEnabled("outmeans")) - { - unsigned int nbBands = image->GetNumberOfComponentsPerPixel(); - unsigned int nbElements = nbClasses * nbBands; - // get the line in model file that contains the centroids positions - std::ifstream infile(modelFileName); - if(!infile) - { - itkExceptionMacro(<< "File: " << modelFileName << " couldn't be opened"); - } - - // get the line with the centroids (starts with "2 ") - std::string line, centroidLine; - while(std::getline(infile,line)) - { - if (line.size() > 2 && line[0] == '2' && line[1] == ' ') - { - centroidLine = line; - break; - } - } - - std::vector<std::string> centroidElm; - boost::split(centroidElm,centroidLine,boost::is_any_of(" ")); - - // remove the first elements, not the centroids positions - int nbWord = centroidElm.size(); - int beginCentroid = nbWord-nbElements; - centroidElm.erase(centroidElm.begin(), centroidElm.begin()+beginCentroid); - - // write in the output file - std::ofstream outfile; - outfile.open(GetParameterString("outmeans")); - - for (unsigned int i = 0; i < nbClasses; i++) - { - for (unsigned int j = 0; j < nbBands; j++) - { - outfile << std::setw(8) << centroidElm[i * nbBands + j] << " "; - } - outfile << std::endl; - } - } - } - class KMeansFileNamesHandler { public: @@ -495,9 +464,6 @@ private: // Compute a classification of the input image according to a model file Superclass::KMeansClassif(); - // Create the output text file containing centroids positions - Superclass::CreateOutMeansFile(GetParameterImage("in"), fileNames.modelFile, GetParameterInt("nc")); - // Remove all tempory files if( GetParameterInt( "cleanup" ) ) { diff --git a/Modules/Applications/AppClassification/include/otbLearningApplicationBase.hxx b/Modules/Applications/AppClassification/include/otbLearningApplicationBase.hxx index 8d76ee548aff28eb61f5febe1b7b248b96ec47da..0e56b0c6d14ea0c6665461d76ed24b5c2170dc83 100644 --- a/Modules/Applications/AppClassification/include/otbLearningApplicationBase.hxx +++ b/Modules/Applications/AppClassification/include/otbLearningApplicationBase.hxx @@ -122,7 +122,10 @@ LearningApplicationBase<TInputValue,TOutputValue> ::InitUnsupervisedClassifierParams() { #ifdef OTB_USE_SHARK - InitSharkKMeansParams(); + if (!m_RegressionFlag) + { + InitSharkKMeansParams(); // Regression not supported + } #endif } diff --git a/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.hxx b/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.hxx index a3c43741c70ace5c453ae45212abedb313b30b32..87ff3a948edc75658b5223780aec8e60788e0afa 100644 --- a/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.hxx +++ b/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.hxx @@ -22,6 +22,7 @@ #include "otbLearningApplicationBase.h" #include "otbSharkKMeansMachineLearningModel.h" +#include "otbStatisticsXMLFileReader.h" namespace otb { @@ -44,6 +45,30 @@ void LearningApplicationBase<TInputValue, TOutputValue>::InitSharkKMeansParams() SetParameterInt("classifier.sharkkm.k", 2); SetParameterDescription("classifier.sharkkm.k", "The number of classes used for the kmeans algorithm. Default set to 2 class"); SetMinimumParameterIntValue("classifier.sharkkm.k", 2); + + // Centroid IO + AddParameter( ParameterType_Group, "classifier.sharkkm.centroids", "Centroids IO parameters" ); + SetParameterDescription( "classifier.sharkkm.centroids", "Group of parameters for centroids IO." ); + + // Input centroids + AddParameter(ParameterType_InputFilename, "classifier.sharkkm.centroids.in", "User definied input centroids"); + SetParameterDescription("classifier.sharkkm.centroids.in", "Input text file containing centroid posistions used to initialize the algorithm. " + "Each centroid must be described by p parameters, p being the number of features in " + "the input vector data, and the number of centroids must be equal to the number of classes " + "(one centroid per line with values separated by spaces)."); + MandatoryOff("classifier.sharkkm.centroids"); + + // Centroid statistics + AddParameter(ParameterType_InputFilename, "classifier.sharkkm.centroids.stats", "Statistics file"); + SetParameterDescription("classifier.sharkkm.centroids.stats", "A XML file containing mean and standard deviation to center" + "and reduce the centroids before the KMeans algorithm, produced by ComputeImagesStatistics application."); + MandatoryOff("classifier.sharkkm.centroids.stats"); + + // Output centroids + AddParameter(ParameterType_OutputFilename, "classifier.sharkkm.centroids.out", "Output centroids text file"); + SetParameterDescription("classifier.sharkkm.centroids.out", "Output text file containing centroids after the kmean algorithm."); + MandatoryOff("classifier.sharkkm.centroids.out"); + } template<class TInputValue, class TOutputValue> @@ -60,9 +85,48 @@ void LearningApplicationBase<TInputValue, TOutputValue>::TrainSharkKMeans( classifier->SetInputListSample( trainingListSample ); classifier->SetTargetListSample( trainingLabeledListSample ); classifier->SetK( k ); + + // Initialize centroids from file + if(IsParameterEnabled("classifier.sharkkm.centroids.in") && HasValue("classifier.sharkkm.centroids.in")) + { + shark::Data<shark::RealVector> centroidData; + shark::importCSV(centroidData, GetParameterString( "classifier.sharkkm.centroids.in"), ' '); + if( HasValue( "classifier.sharkkm.centroids.stats" ) ) + { + auto statisticsReader = otb::StatisticsXMLFileReader< itk::VariableLengthVector<float> >::New(); + statisticsReader->SetFileName(GetParameterString( "classifier.sharkkm.centroids.stats" )); + auto meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean"); + auto stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev"); + + // Convert itk Variable Length Vector to shark Real Vector + shark::RealVector offsetRV(meanMeasurementVector.Size()); + shark::RealVector scaleRV(stddevMeasurementVector.Size()); + + assert(meanMeasurementVector.Size()==stddevMeasurementVector.Size()); + for (unsigned int i = 0; i<meanMeasurementVector.Size(); ++i) + { + scaleRV[i] = 1/stddevMeasurementVector[i]; + // Substract the normalized mean + offsetRV[i] = - meanMeasurementVector[i]/stddevMeasurementVector[i]; + } + + shark::Normalizer<> normalizer(scaleRV, offsetRV); + centroidData = normalizer(centroidData); + } + + if (centroidData.numberOfElements() != k) + otbAppLogWARNING( "The input centroid file will not be used because it contains " << centroidData.numberOfElements() << + " points, which is different than from the requested number of class: " << k <<"."); + + classifier->SetCentroidsFromData( centroidData); + } + classifier->SetMaximumNumberOfIterations( nbMaxIter ); classifier->Train(); classifier->Save( modelPath ); + + if( HasValue( "classifier.sharkkm.centroids.out")) + classifier->ExportCentroids( GetParameterString( "classifier.sharkkm.centroids.out" )); } } //end namespace wrapper diff --git a/Modules/Applications/AppClassification/test/CMakeLists.txt b/Modules/Applications/AppClassification/test/CMakeLists.txt index 6a05178e64311574862ceaf555fcff1310111815..556cc57987594132598622a258bc9019e23037f7 100644 --- a/Modules/Applications/AppClassification/test/CMakeLists.txt +++ b/Modules/Applications/AppClassification/test/CMakeLists.txt @@ -673,7 +673,7 @@ if(OTB_USE_SHARK) -sampler periodic -rand 121212 -nodatalabel 255 - -outmeans ${TEMP}/apTvClKMeansImageClassificationFilterOutMeans.txt + -centroids.out ${TEMP}/apTvClKMeansImageClassificationFilterOutMeans.txt -out ${TEMP}/apTvClKMeansImageClassificationFilterOutput.tif uint8 -cleanup 0 VALID --compare-image ${NOTOL} @@ -681,6 +681,25 @@ if(OTB_USE_SHARK) ${TEMP}/apTvClKMeansImageClassificationFilterOutput.tif ) endif() +if(OTB_USE_SHARK) + otb_test_application(NAME apTvClKMeansImageClassification_inputCentroids + APP KMeansClassification + OPTIONS -in ${INPUTDATA}/qb_RoadExtract.img + -ts 30000 + -nc 5 + -maxit 10000 + -sampler periodic + -nodatalabel 255 + -rand 121212 + -centroids.in ${INPUTDATA}/Classification/KMeansInputCentroids.txt + -out ${TEMP}/apTvClKMeansImageClassificationInputCentroids.tif uint8 + -cleanup 0 + + VALID --compare-image ${NOTOL} + ${OTBAPP_BASELINE}/apTvClKMeansImageClassificationInputCentroids.tif + ${TEMP}/apTvClKMeansImageClassificationInputCentroids.tif ) +endif() + #----------- TrainImagesClassifier TESTS ---------------- if(OTB_USE_LIBSVM) otb_test_application(NAME apTvClTrainSVMImagesClassifierQB1_allOpt_InXML diff --git a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h index 77e8691cf6eb6298639156171113bec261d0e6de..69574043779cf5b6144a34f5b6e446721047ec2e 100644 --- a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h +++ b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h @@ -48,6 +48,7 @@ #include "shark/Models/Clustering/Centroids.h" #include "shark/Models/Clustering/ClusteringModel.h" #include "shark/Algorithms/KMeans.h" +#include "shark/Models/Normalizer.h" #if defined(__GNUC__) || defined(__clang__) #pragma GCC diagnostic pop @@ -124,9 +125,14 @@ public: /** Set the number of class for the kMeans algorithm.*/ itkSetMacro( K, unsigned ); - /** If true, normalized input data sample list */ - itkGetMacro( Normalized, bool ); - itkSetMacro( Normalized, bool ); + /** Initialize the centroids for the kmeans algorithm */ + void SetCentroidsFromData(const shark::Data<shark::RealVector>& data) + { + m_Centroids.setCentroids(data); + this->Modified(); + } + + void ExportCentroids(const std::string& filename); protected: /** Constructor */ @@ -142,9 +148,6 @@ protected: virtual void DoPredictBatch(const InputListSampleType *, const unsigned int &startIndex, const unsigned int &size, TargetListSampleType *, ConfidenceListSampleType * = nullptr, ProbaListSampleType * = nullptr) const override; - template<typename DataType> - DataType NormalizeData(const DataType &data) const; - /** PrintSelf method */ void PrintSelf(std::ostream &os, itk::Indent indent) const override; @@ -153,16 +156,13 @@ private: void operator=(const Self &) = delete; // Parameters set by the user - bool m_Normalized; unsigned int m_K; unsigned int m_MaximumNumberOfIterations; bool m_CanRead; - /** Centroids results form kMeans */ shark::Centroids m_Centroids; - /** shark Model could be SoftClusteringModel or HardClusteringModel */ boost::shared_ptr<ClusteringModelType> m_ClusteringModel; diff --git a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.hxx b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.hxx index 3838d38aaa5e56bd4f69531711a7d6e1f45cc880..6554c67d318439e7f3d2f1560f9aa1eda5661957 100644 --- a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.hxx +++ b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.hxx @@ -35,11 +35,10 @@ #include "otb_shark.h" #include "otbSharkUtils.h" -#include "shark/Algorithms/Trainers/NormalizeComponentsUnitVariance.h" //normalize #include "shark/Algorithms/KMeans.h" //k-means algorithm #include "shark/Models/Clustering/HardClusteringModel.h" #include "shark/Models/Clustering/SoftClusteringModel.h" -#include "shark/Algorithms/Trainers/NormalizeComponentsUnitVariance.h" +#include <shark/Data/Csv.h> //load the csv file #if defined(__GNUC__) || defined(__clang__) #pragma GCC diagnostic pop @@ -52,7 +51,7 @@ namespace otb template<class TInputValue, class TOutputValue> SharkKMeansMachineLearningModel<TInputValue, TOutputValue> ::SharkKMeansMachineLearningModel() : - m_Normalized( false ), m_K(2), m_MaximumNumberOfIterations( 10 ) + m_K(2), m_MaximumNumberOfIterations( 10 ) { // Default set HardClusteringModel this->m_ConfidenceIndex = true; @@ -77,27 +76,11 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue> otb::Shark::ListSampleToSharkVector( this->GetInputListSample(), vector_data ); shark::Data<shark::RealVector> data = shark::createDataFromRange( vector_data ); - // Normalized input value if necessary - if( m_Normalized ) - data = NormalizeData( data ); - // Use a Hard Clustering Model for classification shark::kMeans( data, m_K, m_Centroids, m_MaximumNumberOfIterations ); m_ClusteringModel = boost::make_shared<ClusteringModelType>( &m_Centroids ); } -template<class TInputValue, class TOutputValue> -template<typename DataType> -DataType -SharkKMeansMachineLearningModel<TInputValue, TOutputValue> -::NormalizeData(const DataType &data) const -{ - shark::Normalizer<> normalizer; - shark::NormalizeComponentsUnitVariance<> normalizingTrainer( true );//zero mean - normalizingTrainer.train( normalizer, data ); - return normalizer( data ); -} - template<class TInputValue, class TOutputValue> typename SharkKMeansMachineLearningModel<TInputValue, TOutputValue> ::TargetSampleType @@ -258,6 +241,14 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue> return true; } +template<class TInputValue, class TOutputValue> +void +SharkKMeansMachineLearningModel<TInputValue, TOutputValue> +::ExportCentroids(const std::string & filename) +{ + shark::exportCSV(m_Centroids.centroids(), filename, ' '); +} + template<class TInputValue, class TOutputValue> void SharkKMeansMachineLearningModel<TInputValue, TOutputValue>