An error occurred while loading the file. Please try again.
-
Cédric Traizet authoreda5a1c34a
/*
* Copyright (C) 2005-2019 Centre National d'Etudes Spatiales (CNES)
*
* This file is part of Orfeo Toolbox
*
* https://www.orfeo-toolbox.org/
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "otbWrapperCompositeApplication.h"
#include "otbWrapperApplicationFactory.h"
#include "otbOGRDataToSamplePositionFilter.h"
namespace otb
{
namespace Wrapper
{
class KMeansApplicationBase : public CompositeApplication
{
public:
/** Standard class typedefs. */
typedef KMeansApplicationBase Self;
typedef CompositeApplication Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Standard macro */
itkTypeMacro( KMeansApplicationBase, Superclass )
protected:
void InitKMParams()
{
AddApplication("ImageEnvelope", "imgenvelop", "mean shift smoothing");
AddApplication("PolygonClassStatistics", "polystats", "Polygon Class Statistics");
AddApplication("SampleSelection", "select", "Sample selection");
AddApplication("SampleExtraction", "extraction", "Sample extraction");
AddApplication("TrainVectorClassifier", "training", "Model training");
AddApplication("ComputeImagesStatistics", "imgstats", "Compute Images second order statistics");
AddApplication("ImageClassifier", "classif", "Performs a classification of the input image");
ShareParameter("in", "imgenvelop.in");
ShareParameter("out", "classif.out");
InitKMSampling();
InitKMClassification();
// init at the end cleanup
AddParameter( ParameterType_Bool, "cleanup", "Temporary files cleaning" );
SetParameterDescription( "cleanup",
"If activated, the application will try to clean all temporary files it created" );
SetParameterInt("cleanup", 1);
}
void InitKMSampling()
{
AddParameter(ParameterType_Int, "nc", "Number of classes");
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
SetParameterDescription("nc", "Number of modes, which will be used to generate class membership.");
SetDefaultParameterInt("nc", 5);
AddParameter(ParameterType_Int, "ts", "Training set size");
SetParameterDescription("ts", "Size of the training set (in pixels).");
SetDefaultParameterInt("ts", 100);
MandatoryOff("ts");
AddParameter(ParameterType_Int, "maxit", "Maximum number of iterations");
SetParameterDescription("maxit",
"Maximum number of iterations for the learning step."
" If this parameter is set to 0, the KMeans algorithm will not stop until convergence");
SetDefaultParameterInt("maxit", 1000);
MandatoryOff("maxit");
AddParameter(ParameterType_Group, "centroids", "Centroids IO parameters");
SetParameterDescription("centroids", "Group of parameters for centroids IO.");
AddParameter(ParameterType_InputFilename, "centroids.in", "input centroids text file");
SetParameterDescription("centroids.in",
"Input text file containing centroid posistions used to initialize the algorithm."
" The file must contain one centroid per line, and each centroid value must be separated by a space. The number of"
" centroids in this file must match the number of classes (nc parameter).");
MandatoryOff("centroids.in");
ShareKMSamplingParameters();
ConnectKMSamplingParams();
}
void InitKMClassification()
{
ShareKMClassificationParams();
ConnectKMClassificationParams();
}
void ShareKMSamplingParameters()
{
ShareParameter("ram", "polystats.ram");
ShareParameter("sampler", "select.sampler");
ShareParameter("centroids.out", "training.classifier.sharkkm.centroids.out");
ShareParameter("vm", "polystats.mask", "Validity Mask",
"Validity mask, only non-zero pixels will be used to estimate KMeans modes.");
}
void ShareKMClassificationParams()
{
ShareParameter("nodatalabel", "classif.nodatalabel", "Label mask value",
"By default, hidden pixels will have the assigned label 0 in the output image. "
"It's possible to define the label mask by another value, "
"but be careful to not take a label from another class. "
"This application initialize the labels from 0 to N-1, "
"N is the number of class (defined by 'nc' parameter).");
}
void ConnectKMSamplingParams()
{
Connect("polystats.in", "imgenvelop.in");
Connect("select.in", "polystats.in");
Connect("select.vec", "polystats.vec");
Connect("select.ram", "polystats.ram");
Connect("extraction.in", "select.in");
Connect("extraction.field", "select.field");
Connect("extraction.vec", "select.out");
Connect("extraction.ram", "polystats.ram");
}
void ConnectKMClassificationParams()
{
141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
Connect("training.cfield", "extraction.field");
Connect("training.io.stats","imgstats.out");
Connect("classif.in", "imgenvelop.in");
Connect("classif.model", "training.io.out");
Connect("classif.ram", "polystats.ram");
Connect("classif.imstat", "imgstats.out");
}
void ConnectKMClassificationMask()
{
otbAppLogINFO("Using input mask ...");
Connect("select.mask", "polystats.mask");
Connect("classif.mask", "select.mask");
}
void ComputeImageEnvelope(const std::string &vectorFileName)
{
GetInternalApplication("imgenvelop")->SetParameterString("out", vectorFileName);
GetInternalApplication("imgenvelop")->ExecuteAndWriteOutput();
}
void ComputeAddField(const std::string &vectorFileName,
const std::string &fieldName)
{
otbAppLogINFO("add field in the layer ...");
otb::ogr::DataSource::Pointer ogrDS;
ogrDS = otb::ogr::DataSource::New(vectorFileName, otb::ogr::DataSource::Modes::Update_LayerUpdate);
otb::ogr::Layer layer = ogrDS->GetLayer(0);
OGRFieldDefn classField(fieldName.c_str(), OFTInteger);
classField.SetWidth(classField.GetWidth());
classField.SetPrecision(classField.GetPrecision());
ogr::FieldDefn classFieldDefn(classField);
layer.CreateField(classFieldDefn);
otb::ogr::Layer::const_iterator it = layer.cbegin();
otb::ogr::Layer::const_iterator itEnd = layer.cend();
for( ; it!=itEnd ; ++it)
{
ogr::Feature dstFeature(layer.GetLayerDefn());
dstFeature.SetFrom( *it, TRUE);
dstFeature.SetFID(it->GetFID());
dstFeature[fieldName].SetValue<int>(0);
layer.SetFeature(dstFeature);
}
const OGRErr err = layer.ogr().CommitTransaction();
if (err != OGRERR_NONE)
itkExceptionMacro(<< "Unable to commit transaction for OGR layer " << layer.ogr().GetName() << ".");
ogrDS->SyncToDisk();
}
void ComputePolygonStatistics(const std::string &statisticsFileName,
const std::string &fieldName)
{
std::vector<std::string> fieldList = {fieldName};
GetInternalApplication("polystats")->SetParameterStringList("field", fieldList);
GetInternalApplication("polystats")->SetParameterString("out", statisticsFileName);
ExecuteInternal("polystats");
}
void SelectAndExtractSamples(const std::string &statisticsFileName,
const std::string &fieldName,
const std::string &sampleFileName,
int NBSamples)
{
/* SampleSelection */
GetInternalApplication("select")->SetParameterString("out", sampleFileName);
211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
UpdateInternalParameters("select");
GetInternalApplication("select")->SetParameterString("instats", statisticsFileName);
GetInternalApplication("select")->SetParameterString("field", fieldName);
GetInternalApplication("select")->SetParameterString("strategy", "constant");
GetInternalApplication("select")->SetParameterInt("strategy.constant.nb", NBSamples);
if( IsParameterEnabled("rand"))
GetInternalApplication("select")->SetParameterInt("rand", GetParameterInt("rand"));
// select sample positions
ExecuteInternal("select");
/* SampleExtraction */
UpdateInternalParameters("extraction");
GetInternalApplication("extraction")->SetParameterString("outfield", "prefix");
GetInternalApplication("extraction")->SetParameterString("outfield.prefix.name", "value_");
// extract sample descriptors
GetInternalApplication("extraction")->ExecuteAndWriteOutput();
}
void TrainKMModel(FloatVectorImageType *image,
const std::string &sampleTrainFileName,
const std::string &modelFileName)
{
std::vector<std::string> extractOutputList = {sampleTrainFileName};
GetInternalApplication("training")->SetParameterStringList("io.vd", extractOutputList);
UpdateInternalParameters("training");
// set field names
std::string selectPrefix = GetInternalApplication("extraction")->GetParameterString("outfield.prefix.name");
unsigned int nbBands = image->GetNumberOfComponentsPerPixel();
std::vector<std::string> selectedNames;
for( unsigned int i = 0; i < nbBands; i++ )
{
std::ostringstream oss;
oss << i;
selectedNames.push_back( selectPrefix + oss.str() );
}
GetInternalApplication("training")->SetParameterStringList("feat", selectedNames);
GetInternalApplication("training")->SetParameterString("classifier", "sharkkm");
GetInternalApplication("training")->SetParameterInt("classifier.sharkkm.maxiter",
GetParameterInt("maxit"));
GetInternalApplication("training")->SetParameterInt("classifier.sharkkm.k",
GetParameterInt("nc"));
if (IsParameterEnabled("centroids.in") && HasValue("centroids.in"))
{
GetInternalApplication("training")->SetParameterString("classifier.sharkkm.centroids.in", GetParameterString("centroids.in"));
GetInternalApplication("training")
->SetParameterString("classifier.sharkkm.centroids.stats", GetInternalApplication("imgstats")->GetParameterString("out"));
}
if( IsParameterEnabled("rand"))
GetInternalApplication("training")->SetParameterInt("rand", GetParameterInt("rand"));
GetInternalApplication("training")->GetParameterByKey("v")->SetActive(false);
GetInternalApplication("training")->SetParameterString("io.out", modelFileName);
ExecuteInternal( "training" );
otbAppLogINFO("output model: " << GetInternalApplication("training")->GetParameterString("io.out"));
}
void ComputeImageStatistics( ImageBaseType * img,
const std::string &imagesStatsFileName)
281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
{
// std::vector<std::string> imageFileNameList = {imageFileName};
GetInternalApplication("imgstats")->SetParameterImageBase("il", img);
GetInternalApplication("imgstats")->SetParameterString("out", imagesStatsFileName);
ExecuteInternal( "imgstats" );
otbAppLogINFO("image statistics file: " << GetInternalApplication("imgstats")->GetParameterString("out"));
}
void KMeansClassif()
{
ExecuteInternal( "classif" );
}
class KMeansFileNamesHandler
{
public:
KMeansFileNamesHandler(const std::string &outPath)
{
tmpVectorFile = outPath + "_imgEnvelope.shp";
polyStatOutput = outPath + "_polyStats.xml";
sampleOutput = outPath + "_sampleSelect.shp";
modelFile = outPath + "_model.txt";
imgStatOutput = outPath + "_imgstats.xml";
}
void clear()
{
RemoveFile(tmpVectorFile);
RemoveFile(polyStatOutput);
RemoveFile(sampleOutput);
RemoveFile(modelFile);
RemoveFile(imgStatOutput);
}
std::string tmpVectorFile;
std::string polyStatOutput;
std::string sampleOutput;
std::string modelFile;
std::string imgStatOutput;
private:
bool RemoveFile(const std::string &filePath)
{
bool res = true;
if( itksys::SystemTools::FileExists( filePath ) )
{
size_t posExt = filePath.rfind( '.' );
if( posExt != std::string::npos && filePath.compare( posExt, std::string::npos, ".shp" ) == 0 )
{
std::string shxPath = filePath.substr( 0, posExt ) + std::string( ".shx" );
std::string dbfPath = filePath.substr( 0, posExt ) + std::string( ".dbf" );
std::string prjPath = filePath.substr( 0, posExt ) + std::string( ".prj" );
RemoveFile( shxPath );
RemoveFile( dbfPath );
RemoveFile( prjPath );
}
res = itksys::SystemTools::RemoveFile( filePath );
if( !res )
{
//otbAppLogINFO( <<"Unable to remove file "<<filePath );
}
}
return res;
}
};
};
351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
class KMeansClassification: public KMeansApplicationBase
{
public:
/** Standard class typedefs. */
typedef KMeansClassification Self;
typedef KMeansApplicationBase Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Standard macro */
itkNewMacro(Self);
itkTypeMacro(Self, Superclass);
private:
void DoInit() override
{
SetName("KMeansClassification");
SetDescription("Unsupervised KMeans image classification");
SetDocName("Unsupervised KMeans image classification");
SetDocLongDescription("Unsupervised KMeans image classification. "
"This is a composite application, using existing training and classification applications. "
"The SharkKMeans model is used.\n\n"
"This application is only available if OTB is compiled with Shark support"
"(CMake option :code:`OTB_USE_SHARK=ON`).\n\n"
"The steps of this composite application:\n\n"
"1) ImageEnvelope: create a shapefile (1 polygon),\n"
"2) PolygonClassStatistics: compute the statistics,\n"
"3) SampleSelection: select the samples by constant strategy in the shapefile "
"(1000000 samples max),\n"
"4) SampleExtraction: extract the samples descriptors (update of SampleSelection output file),\n"
"5) ComputeImagesStatistics: compute images second order statistics,\n"
"6) TrainVectorClassifier: train the SharkKMeans model,\n"
"7) ImageClassifier: perform the classification of the input image "
"according to a model file.\n\n"
"It's possible to choice random/periodic modes of the SampleSelection application.\n"
"If you want keep the temporary files (sample selected, model file, ...), "
"initialize cleanup parameter.\n"
"For more information on shark KMeans algorithm [1].");
SetDocLimitations("The application doesn't support NaN in the input image");
SetDocAuthors("OTB-Team");
SetDocSeeAlso("ImageEnvelope, PolygonClassStatistics, SampleSelection, SampleExtraction, "
"PolygonClassStatistics, TrainVectorClassifier, ImageClassifier.\n\n"
"[1] http://image.diku.dk/shark/sphinx_pages/build/html/rest_sources/tutorials/algorithms/kmeans.html");
AddDocTag(Tags::Learning);
AddDocTag(Tags::Segmentation);
// Perform initialization
ClearApplications();
// initialisation parameters and synchronizes parameters
Superclass::InitKMParams();
AddRANDParameter();
// Doc example parameter settings
SetDocExampleParameterValue("in", "QB_1_ortho.tif");
SetDocExampleParameterValue("ts", "1000");
SetDocExampleParameterValue("nc", "5");
SetDocExampleParameterValue("maxit", "1000");
SetDocExampleParameterValue("out", "ClassificationFilterOutput.tif uint8");
SetOfficialDocLink();
}
421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
void DoUpdateParameters() override
{
}
void DoExecute() override
{
if (IsParameterEnabled("vm") && HasValue("vm")) Superclass::ConnectKMClassificationMask();
KMeansFileNamesHandler fileNames(GetParameterString("out"));
const std::string fieldName = "field";
// Create an image envelope
Superclass::ComputeImageEnvelope(fileNames.tmpVectorFile);
// Add a new field at the ImageEnvelope output file
Superclass::ComputeAddField(fileNames.tmpVectorFile, fieldName);
// Compute PolygonStatistics app
UpdateKMPolygonClassStatisticsParameters(fileNames.tmpVectorFile);
Superclass::ComputePolygonStatistics(fileNames.polyStatOutput, fieldName);
// Compute number of sample max for KMeans
const int theoricNBSamplesForKMeans = GetParameterInt("ts");
const int upperThresholdNBSamplesForKMeans = 1000 * 1000;
const int actualNBSamplesForKMeans = std::min(theoricNBSamplesForKMeans,
upperThresholdNBSamplesForKMeans);
otbAppLogINFO(<< actualNBSamplesForKMeans << " is the maximum sample size that will be used." \
<< std::endl);
// Compute SampleSelection and SampleExtraction app
Superclass::SelectAndExtractSamples(fileNames.polyStatOutput, fieldName,
fileNames.sampleOutput,
actualNBSamplesForKMeans);
// Compute Images second order statistics
Superclass::ComputeImageStatistics(GetParameterImageBase("in"), fileNames.imgStatOutput);
// Compute a train model with TrainVectorClassifier app
Superclass::TrainKMModel(GetParameterImage("in"), fileNames.sampleOutput,
fileNames.modelFile);
// Compute a classification of the input image according to a model file
Superclass::KMeansClassif();
// Remove all tempory files
if( GetParameterInt( "cleanup" ) )
{
otbAppLogINFO( <<"Final clean-up ..." );
fileNames.clear();
}
}
void UpdateKMPolygonClassStatisticsParameters(const std::string &vectorFileName)
{
GetInternalApplication( "polystats" )->SetParameterString( "vec", vectorFileName);
UpdateInternalParameters( "polystats" );
}
};
}
}
OTB_APPLICATION_EXPORT(otb::Wrapper::KMeansClassification)