Commit 610c65b2 authored by Cédric Traizet's avatar Cédric Traizet
Browse files

ENH: add a method to export centroids in a text file

No related merge requests found
Showing with 19 additions and 1 deletion
+19 -1
......@@ -48,7 +48,7 @@ void LearningApplicationBase<TInputValue, TOutputValue>::InitSharkKMeansParams()
SetMinimumParameterIntValue("classifier.sharkkm.k", 2);
// Input centroids
AddParameter(ParameterType_InputFilename, "classifier.sharkkm.centroids", "input centroids");
AddParameter(ParameterType_InputFilename, "classifier.sharkkm.centroids", "User definied input centroids");
SetParameterDescription("classifier.sharkkm.centroids", "Text file containing input centroids.");
MandatoryOff("classifier.sharkkm.centroids");
......@@ -58,6 +58,11 @@ void LearningApplicationBase<TInputValue, TOutputValue>::InitSharkKMeansParams()
"and reduce the centroids before the KMeans algorithm, produced by ComputeImagesStatistics application.");
MandatoryOff("classifier.sharkkm.centroidstats");
// output centroids
AddParameter(ParameterType_OutputFilename, "classifier.sharkkm.outcentroids", "Output centroids text file");
SetParameterDescription("classifier.sharkkm.outcentroids", "Output text file containing centroids after the kmean algorithm.");
MandatoryOff("classifier.sharkkm.outcentroids");
}
template<class TInputValue, class TOutputValue>
......@@ -109,6 +114,9 @@ void LearningApplicationBase<TInputValue, TOutputValue>::TrainSharkKMeans(
classifier->SetMaximumNumberOfIterations( nbMaxIter );
classifier->Train();
classifier->Save( modelPath );
if( HasValue( "classifier.sharkkm.outcentroids"))
classifier->ExportCentroids( GetParameterString( "classifier.sharkkm.outcentroids" ));
}
} //end namespace wrapper
......
......@@ -131,6 +131,8 @@ public:
this->Modified();
}
void ExportCentroids(const std::string & filename);
protected:
/** Constructor */
SharkKMeansMachineLearningModel();
......
......@@ -241,6 +241,14 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
return true;
}
template<class TInputValue, class TOutputValue>
void
SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
::ExportCentroids(const std::string & filename)
{
shark::exportCSV(m_Centroids.centroids(), filename, ' ');
}
template<class TInputValue, class TOutputValue>
void
SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment