Commit 1045ff23 authored by Gaetano Raffaele's avatar Gaetano Raffaele
Browse files

Removed VectorClassifier app, now in official OTB release.

parent 8d442202
No related merge requests found
Showing with 2 additions and 288 deletions
+2 -288
OTB_CREATE_APPLICATION(NAME VectorClassifier
SOURCES otbVectorClassifier.cxx
LINK_LIBRARIES ${OTBAppClassification_LIBRARIES}
)
OTB_CREATE_APPLICATION(NAME ComputeVectorFeaturesStatistics
SOURCES otbComputeVectorFeaturesStatistics.cxx
LINK_LIBRARIES ${OTBAppClassification_LIBRARIES}
......
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#include "otbWrapperApplication.h"
#include "otbWrapperApplicationFactory.h"
#include "otbLearningApplicationBase.h"
#include "otbOGRDataSourceWrapper.h"
#include "otbOGRFeatureWrapper.h"
#include "otbStatisticsXMLFileWriter.h"
#include "otbStatisticsXMLFileReader.h"
#include "otbShiftScaleSampleListFilter.h"
#include <time.h>
namespace otb
{
namespace Wrapper
{
/** Utility function to negate std::isalnum */
bool IsNotAlphaNum(char c)
{
return !std::isalnum(c);
}
class VectorClassifier : public LearningApplicationBase<float,int>
{
public:
typedef VectorClassifier Self;
typedef LearningApplicationBase<float,int> Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
itkNewMacro(Self);
itkTypeMacro(VectorClassifier, Superclass);
typedef Superclass::SampleType SampleType;
typedef Superclass::ListSampleType ListSampleType;
typedef Superclass::TargetListSampleType TargetListSampleType;
typedef double ValueType;
typedef itk::VariableLengthVector<ValueType> MeasurementType;
typedef otb::StatisticsXMLFileReader<SampleType> StatisticsReader;
typedef otb::Statistics::ShiftScaleSampleListFilter<ListSampleType, ListSampleType> ShiftScaleFilterType;
private:
void DoInit() ITK_OVERRIDE
{
SetName("VectorClassifier");
SetDescription("Classify a vector layer based on a machine learning model and a list of features to consider.");
SetDocName("VectorClassifier");
SetDocLongDescription("This application will apply a trained machine learning model on the selected feature to get a classification of each geometry contained in an OGR layer. The list of feature must match the list used for training. The predicted label is written in the user defined field for each geometry.");
SetDocLimitations("Experimental. Only shapefiles are supported for now.");
SetDocAuthors("Raffaele Gaetano (CIRAD) based on TrainVectorClassifier.");
SetDocSeeAlso("ComputeOGRLayersFeaturesStatistics,TrainVectorClassifier");
AddDocTag(Tags::Segmentation);
//Group IO
AddParameter(ParameterType_Group, "io", "Input and output data.");
SetParameterDescription("io", "This group of parameters allows setting input and output data.");
AddParameter(ParameterType_InputVectorData, "io.vd","Name of the input vector file.");
SetParameterDescription("io.vd","Name of the input vector file to be updated");
AddParameter(ParameterType_InputFilename, "io.stats", "XML file containing mean and variance of each feature.");
MandatoryOff("io.stats");
SetParameterDescription("io.stats", "XML file containing mean and variance of each feature");
AddParameter(ParameterType_InputFilename, "io.model", "Input model filename.");
SetParameterDescription("io.model", "Input model filename");
AddParameter(ParameterType_ListView, "feat", "Features");
SetParameterDescription("feat","Features to be classified");
AddParameter(ParameterType_String,"cfield","Name of the output field containing the predicted class.");
SetParameterDescription("cfield","Output field containing the predicted class");
SetParameterString("cfield","predicted", false);
AddParameter(ParameterType_Int, "layer", "Layer Index");
SetParameterDescription("layer", "Index of the layer to use in the input vector file.");
MandatoryOff("layer");
SetDefaultParameterInt("layer",0);
// Doc example parameter settings
SetDocExampleParameterValue("io.vd", "vectorData.shp");
SetDocExampleParameterValue("io.stats", "meanVar.xml");
SetDocExampleParameterValue("io.model", "svm.model");
SetDocExampleParameterValue("feat", "perimeter");
SetDocExampleParameterValue("cfield", "predicted");
}
void DoUpdateParameters() ITK_OVERRIDE
{
if ( HasValue("io.vd") )
{
std::string vectorFile = GetParameterString("io.vd");
ogr::DataSource::Pointer ogrDS =
ogr::DataSource::New(vectorFile, ogr::DataSource::Modes::Read);
ogr::Layer layer = ogrDS->GetLayer(this->GetParameterInt("layer"));
ogr::Feature feature = layer.ogr().GetNextFeature();
ClearChoices("feat");
//ClearChoices("cfield");
for(int iField=0; iField<feature.ogr().GetFieldCount(); iField++)
{
std::string key, item = feature.ogr().GetFieldDefnRef(iField)->GetNameRef();
key = item;
std::string::iterator end = std::remove_if(key.begin(),key.end(),IsNotAlphaNum);
std::transform(key.begin(), end, key.begin(), tolower);
OGRFieldType fieldType = feature.ogr().GetFieldDefnRef(iField)->GetType();
if(fieldType == OFTInteger || ogr::version_proxy::IsOFTInteger64(fieldType) || fieldType == OFTReal)
{
std::string tmpKey="feat."+key.substr(0, end - key.begin());
AddChoice(tmpKey,item);
}
}
}
}
void DoExecute() ITK_OVERRIDE
{
clock_t tic = clock();
std::string shapefile = GetParameterString("io.vd").c_str();
std::string XMLfile = GetParameterString("io.stats").c_str();
std::string modelfile = GetParameterString("io.model").c_str();
// Prepare selected field names (their position may change between two inputs)
std::vector<int> selectedIdx = GetSelectedItems("feat");
if(selectedIdx.empty())
{
otbAppLogFATAL(<<"No features have been selected to classify!");
}
const unsigned int nbFeatures = selectedIdx.size();
std::vector<std::string> fieldNames = GetChoiceNames("feat");
std::vector<std::string> selectedNames(nbFeatures);
for (unsigned int i=0 ; i<nbFeatures ; i++)
{
selectedNames[i] = fieldNames[selectedIdx[i]];
}
std::vector<int> featureFieldIndex(nbFeatures, -1);
// Statistics for shift/scale
MeasurementType meanMeasurementVector;
MeasurementType stddevMeasurementVector;
if (HasValue("io.stats") && IsParameterEnabled("io.stats"))
{
StatisticsReader::Pointer statisticsReader = StatisticsReader::New();
std::string XMLfile = GetParameterString("io.stats");
statisticsReader->SetFileName(XMLfile);
meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean");
stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev");
}
else
{
meanMeasurementVector.SetSize(nbFeatures);
meanMeasurementVector.Fill(0.);
stddevMeasurementVector.SetSize(nbFeatures);
stddevMeasurementVector.Fill(1.);
}
ListSampleType::Pointer input = ListSampleType::New();
input->SetMeasurementVectorSize(nbFeatures);
std::string vectorFile = GetParameterString("io.vd");
otbAppLogINFO("Reading input vector file...");
ogr::DataSource::Pointer source = ogr::DataSource::New(vectorFile, ogr::DataSource::Modes::Read);
ogr::Layer layer = source->GetLayer(this->GetParameterInt("layer"));
ogr::Feature feature = layer.ogr().GetNextFeature();
bool goesOn = feature.addr() != 0;
if (!goesOn)
{
otbAppLogFATAL("The layer "<<GetParameterInt("layer")<<" of "
<<vectorFile<<" is empty!");
}
for (unsigned int i=0 ; i<nbFeatures ; i++)
{
featureFieldIndex[i] = feature.ogr().GetFieldIndex(selectedNames[i].c_str());
if (featureFieldIndex[i] < 0)
otbAppLogFATAL("The field name for feature "<<selectedNames[i]
<<" has not been found in the input vector file "<<vectorFile);
}
while(goesOn)
{
MeasurementType mv;
mv.SetSize(nbFeatures);
for(unsigned int idx=0; idx < nbFeatures; ++idx)
mv[idx] = feature.ogr().GetFieldAsDouble(featureFieldIndex[idx]);
input->PushBack(mv);
feature = layer.ogr().GetNextFeature();
goesOn = feature.addr() != 0;
}
ShiftScaleFilterType::Pointer shiftScaleFilter = ShiftScaleFilterType::New();
shiftScaleFilter->SetInput(input);
shiftScaleFilter->SetShifts(meanMeasurementVector);
shiftScaleFilter->SetScales(stddevMeasurementVector);
shiftScaleFilter->Update();
ListSampleType::Pointer listSample = shiftScaleFilter->GetOutput();
TargetListSampleType::Pointer target = TargetListSampleType::New();
this->Classify(listSample,target,GetParameterString("io.model"));
ogr::DataSource::Pointer source2 = ogr::DataSource::New(shapefile, ogr::DataSource::Modes::Update_LayerUpdate);
ogr::Layer layer2 = source2->GetLayer(0);
OGRFieldDefn predictedField(GetParameterString("cfield").c_str(), OFTInteger);
layer2.CreateField(predictedField, true);
bool goesOn2 = true;
layer2.ogr().ResetReading();
ogr::Feature feature2 = layer2.ogr().GetNextFeature();
unsigned int count=0;
if(feature2.addr())
while(goesOn2)
{
feature2.ogr().SetField(GetParameterString("cfield").c_str(),(int)target->GetMeasurementVector(count)[0]);
layer2.SetFeature(feature2);
feature2 = layer2.ogr().GetNextFeature();
goesOn2 = feature2.addr() != 0;
count++;
}
const OGRErr err = layer2.ogr().CommitTransaction();
if (err != OGRERR_NONE)
{
itkExceptionMacro(<< "Unable to commit transaction for OGR layer " << layer2.ogr().GetName() << ".");
}
source2->SyncToDisk();
clock_t toc = clock();
otbAppLogINFO( "Elapsed: "<< ((double)(toc - tic) / CLOCKS_PER_SEC)<<" seconds.");
}
};
}
}
OTB_APPLICATION_EXPORT(otb::Wrapper::VectorClassifier)
set(DOCUMENTATION "OTB Vector Classification.")
set(DOCUMENTATION "OTB Compute Vector Features Statistics.")
otb_module(OTBAppVectorClassification
otb_module(OTBAppVectorFeaturesStatistics
DEPENDS
OTBAppClassification
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment