Commit 63376264 authored by Cédric Traizet's avatar Cédric Traizet
Browse files

DOC: rename parameters in TrainVectorClassifier

No related merge requests found
Showing with 24 additions and 19 deletions
+24 -19
......@@ -103,7 +103,7 @@ protected:
{
ShareParameter("ram", "polystats.ram");
ShareParameter("sampler", "select.sampler");
ShareParameter("centroids.out", "training.classifier.sharkkm.outcentroids");
ShareParameter("centroids.out", "training.classifier.sharkkm.centroids.out");
ShareParameter("vm", "polystats.mask", "Validity Mask",
"Validity mask, only non-zero pixels will be used to estimate KMeans modes.");
}
......@@ -255,10 +255,10 @@ protected:
GetParameterInt("nc"));
if(IsParameterEnabled("centroids.in") && HasValue("centroids.in"))
{
GetInternalApplication("training")->SetParameterString("classifier.sharkkm.centroids",
GetInternalApplication("training")->SetParameterString("classifier.sharkkm.centroids.in",
GetParameterString("centroids.in"));
GetInternalApplication("training")->SetParameterString("classifier.sharkkm.centroidstats",
GetInternalApplication("training")->SetParameterString("classifier.sharkkm.centroids.stats",
GetInternalApplication("imgstats")->GetParameterString("out"));
}
......
......@@ -47,21 +47,26 @@ void LearningApplicationBase<TInputValue, TOutputValue>::InitSharkKMeansParams()
SetParameterDescription("classifier.sharkkm.k", "The number of classes used for the kmeans algorithm. Default set to 2 class");
SetMinimumParameterIntValue("classifier.sharkkm.k", 2);
AddParameter( ParameterType_Group, "classifier.sharkkm.centroids", "Centroids IO parameters" );
SetParameterDescription( "classifier.sharkkm.centroids", "Group of parameters for centroids IO." );
// Input centroids
AddParameter(ParameterType_InputFilename, "classifier.sharkkm.centroids", "User definied input centroids");
AddParameter(ParameterType_InputFilename, "classifier.sharkkm.centroids.in", "User definied input centroids");
SetParameterDescription("classifier.sharkkm.centroids", "Text file containing input centroids.");
MandatoryOff("classifier.sharkkm.centroids");
// Centroid statistics
AddParameter(ParameterType_InputFilename, "classifier.sharkkm.centroidstats", "Statistics file");
SetParameterDescription("classifier.sharkkm.centroidstats", "A XML file containing mean and standard deviation to center"
AddParameter(ParameterType_InputFilename, "classifier.sharkkm.centroids.stats", "Statistics file");
SetParameterDescription("classifier.sharkkm.centroids.stats", "A XML file containing mean and standard deviation to center"
"and reduce the centroids before the KMeans algorithm, produced by ComputeImagesStatistics application.");
MandatoryOff("classifier.sharkkm.centroidstats");
MandatoryOff("classifier.sharkkm.centroids.stats");
// output centroids
AddParameter(ParameterType_OutputFilename, "classifier.sharkkm.outcentroids", "Output centroids text file");
SetParameterDescription("classifier.sharkkm.outcentroids", "Output text file containing centroids after the kmean algorithm.");
MandatoryOff("classifier.sharkkm.outcentroids");
AddParameter(ParameterType_OutputFilename, "classifier.sharkkm.centroids.out", "Output centroids text file");
SetParameterDescription("classifier.sharkkm.centroids.out", "Output text file containing centroids after the kmean algorithm.");
MandatoryOff("classifier.sharkkm.centroids.out");
}
......@@ -81,30 +86,30 @@ void LearningApplicationBase<TInputValue, TOutputValue>::TrainSharkKMeans(
classifier->SetK( k );
// Initialize centroids from file
if(IsParameterEnabled("classifier.sharkkm.centroids") && HasValue("classifier.sharkkm.centroids"))
if(IsParameterEnabled("classifier.sharkkm.centroids.in") && HasValue("classifier.sharkkm.centroids.in"))
{
shark::Data<shark::RealVector> centroidData;
shark::importCSV(centroidData, GetParameterString( "classifier.sharkkm.centroids"), ' ');
if( HasValue( "classifier.sharkkm.centroidstats" ) )
shark::importCSV(centroidData, GetParameterString( "classifier.sharkkm.centroids.in"), ' ');
if( HasValue( "classifier.sharkkm.centroids.stats" ) )
{
auto statisticsReader = otb::StatisticsXMLFileReader< itk::VariableLengthVector<float> >::New();
statisticsReader->SetFileName(GetParameterString( "classifier.sharkkm.centroidstats" ));
statisticsReader->SetFileName(GetParameterString( "classifier.sharkkm.centroids.stats" ));
auto meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean");
auto stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev");
// Convert itk Variable Length Vector to shark Real Vector
shark::RealVector offsetRV(meanMeasurementVector.Size());
shark::RealVector stddevMeasurementRV(stddevMeasurementVector.Size());
shark::RealVector scaleRV(stddevMeasurementVector.Size());
assert(meanMeasurementVector.Size()==stddevMeasurementVector.Size());
for (unsigned int i = 0; i<meanMeasurementVector.Size(); ++i)
{
stddevMeasurementRV[i] = 1/stddevMeasurementVector[i];
scaleRV[i] = 1/stddevMeasurementVector[i];
// Substract the normalized mean
offsetRV[i] = - meanMeasurementVector[i]/stddevMeasurementVector[i];
}
shark::Normalizer<> normalizer(stddevMeasurementRV, offsetRV);
shark::Normalizer<> normalizer(scaleRV, offsetRV);
centroidData = normalizer(centroidData);
}
......@@ -115,8 +120,8 @@ void LearningApplicationBase<TInputValue, TOutputValue>::TrainSharkKMeans(
classifier->Train();
classifier->Save( modelPath );
if( HasValue( "classifier.sharkkm.outcentroids"))
classifier->ExportCentroids( GetParameterString( "classifier.sharkkm.outcentroids" ));
if( HasValue( "classifier.sharkkm.centroids.out"))
classifier->ExportCentroids( GetParameterString( "classifier.sharkkm.centroids.out" ));
}
} //end namespace wrapper
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment