Commit 610c65b2 authored by Cédric Traizet's avatar Cédric Traizet
Browse files

ENH: add a method to export centroids in a text file

No related merge requests found
Showing with 19 additions and 1 deletion
+19 -1
...@@ -48,7 +48,7 @@ void LearningApplicationBase<TInputValue, TOutputValue>::InitSharkKMeansParams() ...@@ -48,7 +48,7 @@ void LearningApplicationBase<TInputValue, TOutputValue>::InitSharkKMeansParams()
SetMinimumParameterIntValue("classifier.sharkkm.k", 2); SetMinimumParameterIntValue("classifier.sharkkm.k", 2);
// Input centroids // Input centroids
AddParameter(ParameterType_InputFilename, "classifier.sharkkm.centroids", "input centroids"); AddParameter(ParameterType_InputFilename, "classifier.sharkkm.centroids", "User definied input centroids");
SetParameterDescription("classifier.sharkkm.centroids", "Text file containing input centroids."); SetParameterDescription("classifier.sharkkm.centroids", "Text file containing input centroids.");
MandatoryOff("classifier.sharkkm.centroids"); MandatoryOff("classifier.sharkkm.centroids");
...@@ -58,6 +58,11 @@ void LearningApplicationBase<TInputValue, TOutputValue>::InitSharkKMeansParams() ...@@ -58,6 +58,11 @@ void LearningApplicationBase<TInputValue, TOutputValue>::InitSharkKMeansParams()
"and reduce the centroids before the KMeans algorithm, produced by ComputeImagesStatistics application."); "and reduce the centroids before the KMeans algorithm, produced by ComputeImagesStatistics application.");
MandatoryOff("classifier.sharkkm.centroidstats"); MandatoryOff("classifier.sharkkm.centroidstats");
// output centroids
AddParameter(ParameterType_OutputFilename, "classifier.sharkkm.outcentroids", "Output centroids text file");
SetParameterDescription("classifier.sharkkm.outcentroids", "Output text file containing centroids after the kmean algorithm.");
MandatoryOff("classifier.sharkkm.outcentroids");
} }
template<class TInputValue, class TOutputValue> template<class TInputValue, class TOutputValue>
...@@ -109,6 +114,9 @@ void LearningApplicationBase<TInputValue, TOutputValue>::TrainSharkKMeans( ...@@ -109,6 +114,9 @@ void LearningApplicationBase<TInputValue, TOutputValue>::TrainSharkKMeans(
classifier->SetMaximumNumberOfIterations( nbMaxIter ); classifier->SetMaximumNumberOfIterations( nbMaxIter );
classifier->Train(); classifier->Train();
classifier->Save( modelPath ); classifier->Save( modelPath );
if( HasValue( "classifier.sharkkm.outcentroids"))
classifier->ExportCentroids( GetParameterString( "classifier.sharkkm.outcentroids" ));
} }
} //end namespace wrapper } //end namespace wrapper
......
...@@ -131,6 +131,8 @@ public: ...@@ -131,6 +131,8 @@ public:
this->Modified(); this->Modified();
} }
void ExportCentroids(const std::string & filename);
protected: protected:
/** Constructor */ /** Constructor */
SharkKMeansMachineLearningModel(); SharkKMeansMachineLearningModel();
......
...@@ -241,6 +241,14 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue> ...@@ -241,6 +241,14 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
return true; return true;
} }
template<class TInputValue, class TOutputValue>
void
SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
::ExportCentroids(const std::string & filename)
{
shark::exportCSV(m_Centroids.centroids(), filename, ' ');
}
template<class TInputValue, class TOutputValue> template<class TInputValue, class TOutputValue>
void void
SharkKMeansMachineLearningModel<TInputValue, TOutputValue> SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment