Commit 26580acd authored by Cédric Traizet's avatar Cédric Traizet
Browse files

REFAC: TrainVectorBase is now template

No related merge requests found
Showing with 109 additions and 84 deletions
+109 -84
...@@ -29,11 +29,11 @@ namespace otb ...@@ -29,11 +29,11 @@ namespace otb
namespace Wrapper namespace Wrapper
{ {
class TrainVectorClassifier : public TrainVectorBase class TrainVectorClassifier : public TrainVectorBase<float, int>
{ {
public: public:
typedef TrainVectorClassifier Self; typedef TrainVectorClassifier Self;
typedef TrainVectorBase Superclass; typedef TrainVectorBase<float, int> Superclass;
typedef itk::SmartPointer<Self> Pointer; typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer; typedef itk::SmartPointer<const Self> ConstPointer;
itkNewMacro( Self ) itkNewMacro( Self )
......
...@@ -25,11 +25,11 @@ namespace otb ...@@ -25,11 +25,11 @@ namespace otb
namespace Wrapper namespace Wrapper
{ {
class TrainVectorRegression : public TrainVectorBase class TrainVectorRegression : public TrainVectorBase<float, int>
{ {
public: public:
typedef TrainVectorRegression Self; typedef TrainVectorRegression Self;
typedef TrainVectorBase Superclass; typedef TrainVectorBase<float, int> Superclass;
typedef itk::SmartPointer<Self> Pointer; typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer; typedef itk::SmartPointer<const Self> ConstPointer;
...@@ -83,6 +83,13 @@ protected: ...@@ -83,6 +83,13 @@ protected:
} }
Superclass::DoExecute(); Superclass::DoExecute();
/*
std::cout << m_PredictedList << std::endl;
std::cout << m_ClassificationSamplesWithLabel.labeledListSample << std::endl;
*/
} }
private: private:
......
...@@ -49,21 +49,22 @@ bool IsNotAlphaNum(char c) ...@@ -49,21 +49,22 @@ bool IsNotAlphaNum(char c)
return !std::isalnum( c ); return !std::isalnum( c );
} }
class TrainVectorBase : public LearningApplicationBase<float, int> template <class TInputValue, class TOutputValue>
class TrainVectorBase : public LearningApplicationBase<TInputValue, TOutputValue>
{ {
public: public:
/** Standard class typedefs. */ /** Standard class typedefs. */
typedef TrainVectorBase Self; typedef TrainVectorBase Self;
typedef LearningApplicationBase<float, int> Superclass; typedef LearningApplicationBase<TInputValue, TOutputValue> Superclass;
typedef itk::SmartPointer <Self> Pointer; typedef itk::SmartPointer <Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer; typedef itk::SmartPointer<const Self> ConstPointer;
/** Standard macro */ /** Standard macro */
itkTypeMacro(Self, Superclass); itkTypeMacro(Self, Superclass);
typedef Superclass::SampleType SampleType; typedef typename Superclass::SampleType SampleType;
typedef Superclass::ListSampleType ListSampleType; typedef typename Superclass::ListSampleType ListSampleType;
typedef Superclass::TargetListSampleType TargetListSampleType; typedef typename Superclass::TargetListSampleType TargetListSampleType;
typedef double ValueType; typedef double ValueType;
typedef itk::VariableLengthVector <ValueType> MeasurementType; typedef itk::VariableLengthVector <ValueType> MeasurementType;
...@@ -86,8 +87,8 @@ protected: ...@@ -86,8 +87,8 @@ protected:
class SamplesWithLabel class SamplesWithLabel
{ {
public: public:
ListSampleType::Pointer listSample; typename ListSampleType::Pointer listSample;
TargetListSampleType::Pointer labeledListSample; typename TargetListSampleType::Pointer labeledListSample;
SamplesWithLabel() SamplesWithLabel()
{ {
listSample = ListSampleType::New(); listSample = ListSampleType::New();
...@@ -178,7 +179,7 @@ protected: ...@@ -178,7 +179,7 @@ protected:
SamplesWithLabel m_TrainingSamplesWithLabel; SamplesWithLabel m_TrainingSamplesWithLabel;
SamplesWithLabel m_ClassificationSamplesWithLabel; SamplesWithLabel m_ClassificationSamplesWithLabel;
TargetListSampleType::Pointer m_PredictedList; typename TargetListSampleType::Pointer m_PredictedList;
FeaturesInfo m_FeaturesInfo; FeaturesInfo m_FeaturesInfo;
void DoInit() override; void DoInit() override;
......
...@@ -27,100 +27,106 @@ namespace otb ...@@ -27,100 +27,106 @@ namespace otb
namespace Wrapper namespace Wrapper
{ {
void TrainVectorBase::DoInit() template <class TInputValue, class TOutputValue>
void
TrainVectorBase<TInputValue, TOutputValue>
::DoInit()
{ {
// Common Parameters for all Learning Application // Common Parameters for all Learning Application
AddParameter( ParameterType_Group, "io", "Input and output data" ); this->AddParameter( ParameterType_Group, "io", "Input and output data" );
SetParameterDescription( "io", this->SetParameterDescription( "io",
"This group of parameters allows setting input and output data." ); "This group of parameters allows setting input and output data." );
AddParameter( ParameterType_InputVectorDataList, "io.vd", "Input Vector Data" ); this->AddParameter( ParameterType_InputVectorDataList, "io.vd", "Input Vector Data" );
SetParameterDescription( "io.vd", this->SetParameterDescription( "io.vd",
"Input geometries used for training (note: all geometries from the layer will be used)" ); "Input geometries used for training (note: all geometries from the layer will be used)" );
AddParameter( ParameterType_InputFilename, "io.stats", "Input XML image statistics file" ); this->AddParameter( ParameterType_InputFilename, "io.stats", "Input XML image statistics file" );
MandatoryOff( "io.stats" ); this->MandatoryOff( "io.stats" );
SetParameterDescription( "io.stats", this->SetParameterDescription( "io.stats",
"XML file containing mean and variance of each feature." ); "XML file containing mean and variance of each feature." );
AddParameter( ParameterType_OutputFilename, "io.out", "Output model" ); this->AddParameter( ParameterType_OutputFilename, "io.out", "Output model" );
SetParameterDescription( "io.out", this->SetParameterDescription( "io.out",
"Output file containing the model estimated (.txt format)." ); "Output file containing the model estimated (.txt format)." );
AddParameter( ParameterType_Int, "layer", "Layer Index" ); this->AddParameter( ParameterType_Int, "layer", "Layer Index" );
SetParameterDescription( "layer", this->SetParameterDescription( "layer",
"Index of the layer to use in the input vector file." ); "Index of the layer to use in the input vector file." );
MandatoryOff( "layer" ); this->MandatoryOff( "layer" );
SetDefaultParameterInt( "layer", 0 ); this->SetDefaultParameterInt( "layer", 0 );
AddParameter(ParameterType_ListView, "feat", "Field names for training features"); this->AddParameter(ParameterType_ListView, "feat", "Field names for training features");
SetParameterDescription("feat", this->SetParameterDescription("feat",
"List of field names in the input vector data to be used as features for training."); "List of field names in the input vector data to be used as features for training.");
// Add validation data used to compute confusion matrix or contingency table // Add validation data used to compute confusion matrix or contingency table
AddParameter( ParameterType_Group, "valid", "Validation data" ); this->AddParameter( ParameterType_Group, "valid", "Validation data" );
SetParameterDescription( "valid", this->SetParameterDescription( "valid",
"This group of parameters defines validation data." ); "This group of parameters defines validation data." );
AddParameter( ParameterType_InputVectorDataList, "valid.vd", this->AddParameter( ParameterType_InputVectorDataList, "valid.vd",
"Validation Vector Data" ); "Validation Vector Data" );
SetParameterDescription( "valid.vd", "Geometries used for validation " this->SetParameterDescription( "valid.vd", "Geometries used for validation "
"(must contain the same fields used for training, all geometries from the layer will be used)" ); "(must contain the same fields used for training, all geometries from the layer will be used)" );
MandatoryOff( "valid.vd" ); this->MandatoryOff( "valid.vd" );
AddParameter( ParameterType_Int, "valid.layer", "Layer Index" ); this->AddParameter( ParameterType_Int, "valid.layer", "Layer Index" );
SetParameterDescription( "valid.layer", this->SetParameterDescription( "valid.layer",
"Index of the layer to use in the validation vector file." ); "Index of the layer to use in the validation vector file." );
MandatoryOff( "valid.layer" ); this->MandatoryOff( "valid.layer" );
SetDefaultParameterInt( "valid.layer", 0 ); this->SetDefaultParameterInt( "valid.layer", 0 );
// Add class field if we used validation // Add class field if we used validation
AddParameter( ParameterType_ListView, "cfield", this->AddParameter( ParameterType_ListView, "cfield",
"Field containing the class integer label for supervision" ); "Field containing the class integer label for supervision" );
SetParameterDescription( "cfield", this->SetParameterDescription( "cfield",
"Field containing the class id for supervision. " "Field containing the class id for supervision. "
"The values in this field shall be cast into integers. " "The values in this field shall be cast into integers. "
"Only geometries with this field available will be taken into account." ); "Only geometries with this field available will be taken into account." );
SetListViewSingleSelectionMode( "cfield", true ); this->SetListViewSingleSelectionMode( "cfield", true );
// Add a new parameter to compute confusion matrix / contingency table // Add a new parameter to compute confusion matrix / contingency table
AddParameter( ParameterType_OutputFilename, "io.confmatout", this->AddParameter( ParameterType_OutputFilename, "io.confmatout",
"Output confusion matrix or contingency table" ); "Output confusion matrix or contingency table" );
SetParameterDescription( "io.confmatout", this->SetParameterDescription( "io.confmatout",
"Output file containing the confusion matrix or contingency table (.csv format)." "Output file containing the confusion matrix or contingency table (.csv format)."
"The contingency table is output when we unsupervised algorithms is used otherwise the confusion matrix is output." ); "The contingency table is output when we unsupervised algorithms is used otherwise the confusion matrix is output." );
MandatoryOff( "io.confmatout" ); this->MandatoryOff( "io.confmatout" );
AddParameter(ParameterType_Bool, "v", "Verbose mode"); this->AddParameter(ParameterType_Bool, "v", "Verbose mode");
SetParameterDescription("v", "Verbose mode, display the contingency table result."); this->SetParameterDescription("v", "Verbose mode, display the contingency table result.");
SetParameterInt("v", 1); this->SetParameterInt("v", 1);
// Doc example parameter settings // Doc example parameter settings
SetDocExampleParameterValue( "io.vd", "vectorData.shp" ); this->SetDocExampleParameterValue( "io.vd", "vectorData.shp" );
SetDocExampleParameterValue( "io.stats", "meanVar.xml" ); this->SetDocExampleParameterValue( "io.stats", "meanVar.xml" );
SetDocExampleParameterValue( "io.out", "svmModel.svm" ); this->SetDocExampleParameterValue( "io.out", "svmModel.svm" );
SetDocExampleParameterValue( "feat", "perimeter area width" ); this->SetDocExampleParameterValue( "feat", "perimeter area width" );
SetDocExampleParameterValue( "cfield", "predicted" ); this->SetDocExampleParameterValue( "cfield", "predicted" );
// Add parameters for the classifier choice // Add parameters for the classifier choice
Superclass::DoInit(); Superclass::DoInit();
AddRANDParameter(); this->AddRANDParameter();
} }
void TrainVectorBase::DoUpdateParameters() template <class TInputValue, class TOutputValue>
void
TrainVectorBase<TInputValue, TOutputValue>
::DoUpdateParameters()
{ {
// if vector data is present and updated then reload fields // if vector data is present and updated then reload fields
if( HasValue( "io.vd" ) ) if( this->HasValue( "io.vd" ) )
{ {
std::vector<std::string> vectorFileList = GetParameterStringList( "io.vd" ); std::vector<std::string> vectorFileList = this->GetParameterStringList( "io.vd" );
ogr::DataSource::Pointer ogrDS = ogr::DataSource::New( vectorFileList[0], ogr::DataSource::Modes::Read ); ogr::DataSource::Pointer ogrDS = ogr::DataSource::New( vectorFileList[0], ogr::DataSource::Modes::Read );
ogr::Layer layer = ogrDS->GetLayer( static_cast<size_t>( this->GetParameterInt( "layer" ) ) ); ogr::Layer layer = ogrDS->GetLayer( static_cast<size_t>( this->GetParameterInt( "layer" ) ) );
ogr::Feature feature = layer.ogr().GetNextFeature(); ogr::Feature feature = layer.ogr().GetNextFeature();
ClearChoices( "feat" ); this->ClearChoices( "feat" );
ClearChoices( "cfield" ); this->ClearChoices( "cfield" );
for( int iField = 0; iField < feature.ogr().GetFieldCount(); iField++ ) for( int iField = 0; iField < feature.ogr().GetFieldCount(); iField++ )
{ {
...@@ -134,20 +140,23 @@ void TrainVectorBase::DoUpdateParameters() ...@@ -134,20 +140,23 @@ void TrainVectorBase::DoUpdateParameters()
if( fieldType == OFTInteger || fieldType == OFTInteger64 || fieldType == OFTReal ) if( fieldType == OFTInteger || fieldType == OFTInteger64 || fieldType == OFTReal )
{ {
std::string tmpKey = "feat." + key.substr( 0, static_cast<unsigned long>( end - key.begin() ) ); std::string tmpKey = "feat." + key.substr( 0, static_cast<unsigned long>( end - key.begin() ) );
AddChoice( tmpKey, item ); this->AddChoice( tmpKey, item );
} }
if( fieldType == OFTString || fieldType == OFTInteger || fieldType == OFTInteger64 ) if( fieldType == OFTString || fieldType == OFTInteger || fieldType == OFTInteger64 )
{ {
std::string tmpKey = "cfield." + key.substr( 0, static_cast<unsigned long>( end - key.begin() ) ); std::string tmpKey = "cfield." + key.substr( 0, static_cast<unsigned long>( end - key.begin() ) );
AddChoice( tmpKey, item ); this->AddChoice( tmpKey, item );
} }
} }
} }
} }
void TrainVectorBase::DoExecute() template <class TInputValue, class TOutputValue>
void
TrainVectorBase<TInputValue, TOutputValue>
::DoExecute()
{ {
m_FeaturesInfo.SetFieldNames( GetChoiceNames( "feat" ), GetSelectedItems( "feat" )); m_FeaturesInfo.SetFieldNames( this->GetChoiceNames( "feat" ), this->GetSelectedItems( "feat" ));
// Check input parameters // Check input parameters
if( m_FeaturesInfo.m_SelectedIdx.empty() ) if( m_FeaturesInfo.m_SelectedIdx.empty() )
...@@ -158,29 +167,35 @@ void TrainVectorBase::DoExecute() ...@@ -158,29 +167,35 @@ void TrainVectorBase::DoExecute()
ShiftScaleParameters measurement = GetStatistics( m_FeaturesInfo.m_NbFeatures ); ShiftScaleParameters measurement = GetStatistics( m_FeaturesInfo.m_NbFeatures );
ExtractAllSamples( measurement ); ExtractAllSamples( measurement );
this->Train( m_TrainingSamplesWithLabel.listSample, m_TrainingSamplesWithLabel.labeledListSample, GetParameterString( "io.out" ) ); this->Train( m_TrainingSamplesWithLabel.listSample, m_TrainingSamplesWithLabel.labeledListSample, this->GetParameterString( "io.out" ) );
m_PredictedList = m_PredictedList =
this->Classify( m_ClassificationSamplesWithLabel.listSample, GetParameterString( "io.out" ) ); this->Classify( m_ClassificationSamplesWithLabel.listSample, this->GetParameterString( "io.out" ) );
} }
template <class TInputValue, class TOutputValue>
void TrainVectorBase::ExtractAllSamples(const ShiftScaleParameters &measurement) void
TrainVectorBase<TInputValue, TOutputValue>
::ExtractAllSamples(const ShiftScaleParameters &measurement)
{ {
m_TrainingSamplesWithLabel = ExtractTrainingSamplesWithLabel(measurement); m_TrainingSamplesWithLabel = ExtractTrainingSamplesWithLabel(measurement);
m_ClassificationSamplesWithLabel = ExtractClassificationSamplesWithLabel(measurement); m_ClassificationSamplesWithLabel = ExtractClassificationSamplesWithLabel(measurement);
} }
TrainVectorBase::SamplesWithLabel template <class TInputValue, class TOutputValue>
TrainVectorBase::ExtractTrainingSamplesWithLabel(const ShiftScaleParameters &measurement) typename TrainVectorBase<TInputValue, TOutputValue>::SamplesWithLabel
TrainVectorBase<TInputValue, TOutputValue>
::ExtractTrainingSamplesWithLabel(const ShiftScaleParameters &measurement)
{ {
return ExtractSamplesWithLabel( "io.vd", "layer", measurement); return ExtractSamplesWithLabel( "io.vd", "layer", measurement);
} }
TrainVectorBase::SamplesWithLabel template <class TInputValue, class TOutputValue>
TrainVectorBase::ExtractClassificationSamplesWithLabel(const ShiftScaleParameters &measurement) typename TrainVectorBase<TInputValue, TOutputValue>::SamplesWithLabel
TrainVectorBase<TInputValue, TOutputValue>
::ExtractClassificationSamplesWithLabel(const ShiftScaleParameters &measurement)
{ {
if(GetClassifierCategory() == Supervised) if(this->GetClassifierCategory() == Superclass::Supervised)
{ {
SamplesWithLabel tmpSamplesWithLabel; SamplesWithLabel tmpSamplesWithLabel;
SamplesWithLabel validationSamplesWithLabel = ExtractSamplesWithLabel( "valid.vd", "valid.layer", measurement ); SamplesWithLabel validationSamplesWithLabel = ExtractSamplesWithLabel( "valid.vd", "valid.layer", measurement );
...@@ -206,15 +221,16 @@ TrainVectorBase::ExtractClassificationSamplesWithLabel(const ShiftScaleParameter ...@@ -206,15 +221,16 @@ TrainVectorBase::ExtractClassificationSamplesWithLabel(const ShiftScaleParameter
} }
} }
template <class TInputValue, class TOutputValue>
TrainVectorBase::ShiftScaleParameters typename TrainVectorBase<TInputValue, TOutputValue>::ShiftScaleParameters
TrainVectorBase::GetStatistics(unsigned int nbFeatures) TrainVectorBase<TInputValue, TOutputValue>
::GetStatistics(unsigned int nbFeatures)
{ {
ShiftScaleParameters measurement = ShiftScaleParameters(); ShiftScaleParameters measurement = ShiftScaleParameters();
if( HasValue( "io.stats" ) && IsParameterEnabled( "io.stats" ) ) if( this->HasValue( "io.stats" ) && this->IsParameterEnabled( "io.stats" ) )
{ {
StatisticsReader::Pointer statisticsReader = StatisticsReader::New(); typename StatisticsReader::Pointer statisticsReader = StatisticsReader::New();
std::string XMLfile = GetParameterString( "io.stats" ); std::string XMLfile = this->GetParameterString( "io.stats" );
statisticsReader->SetFileName( XMLfile ); statisticsReader->SetFileName( XMLfile );
measurement.meanMeasurementVector = statisticsReader->GetStatisticVectorByName( "mean" ); measurement.meanMeasurementVector = statisticsReader->GetStatisticVectorByName( "mean" );
measurement.stddevMeasurementVector = statisticsReader->GetStatisticVectorByName( "stddev" ); measurement.stddevMeasurementVector = statisticsReader->GetStatisticVectorByName( "stddev" );
...@@ -229,16 +245,17 @@ TrainVectorBase::GetStatistics(unsigned int nbFeatures) ...@@ -229,16 +245,17 @@ TrainVectorBase::GetStatistics(unsigned int nbFeatures)
return measurement; return measurement;
} }
template <class TInputValue, class TOutputValue>
TrainVectorBase::SamplesWithLabel typename TrainVectorBase<TInputValue, TOutputValue>::SamplesWithLabel
TrainVectorBase::ExtractSamplesWithLabel(std::string parameterName, std::string parameterLayer, TrainVectorBase<TInputValue, TOutputValue>
::ExtractSamplesWithLabel(std::string parameterName, std::string parameterLayer,
const ShiftScaleParameters &measurement) const ShiftScaleParameters &measurement)
{ {
SamplesWithLabel samplesWithLabel; SamplesWithLabel samplesWithLabel;
if( HasValue( parameterName ) && IsParameterEnabled( parameterName ) ) if( this->HasValue( parameterName ) && this->IsParameterEnabled( parameterName ) )
{ {
ListSampleType::Pointer input = ListSampleType::New(); typename ListSampleType::Pointer input = ListSampleType::New();
TargetListSampleType::Pointer target = TargetListSampleType::New(); typename TargetListSampleType::Pointer target = TargetListSampleType::New();
input->SetMeasurementVectorSize( m_FeaturesInfo.m_NbFeatures ); input->SetMeasurementVectorSize( m_FeaturesInfo.m_NbFeatures );
std::vector<std::string> fileList = this->GetParameterStringList( parameterName ); std::vector<std::string> fileList = this->GetParameterStringList( parameterName );
...@@ -251,7 +268,7 @@ TrainVectorBase::ExtractSamplesWithLabel(std::string parameterName, std::string ...@@ -251,7 +268,7 @@ TrainVectorBase::ExtractSamplesWithLabel(std::string parameterName, std::string
bool goesOn = feature.addr() != 0; bool goesOn = feature.addr() != 0;
if( !goesOn ) if( !goesOn )
{ {
otbAppLogWARNING( "The layer " << GetParameterInt( parameterLayer ) << " of " << fileList[k] otbAppLogWARNING( "The layer " << this->GetParameterInt( parameterLayer ) << " of " << fileList[k]
<< " is empty, input is skipped." ); << " is empty, input is skipped." );
continue; continue;
} }
...@@ -300,7 +317,7 @@ TrainVectorBase::ExtractSamplesWithLabel(std::string parameterName, std::string ...@@ -300,7 +317,7 @@ TrainVectorBase::ExtractSamplesWithLabel(std::string parameterName, std::string
ShiftScaleFilterType::Pointer shiftScaleFilter = ShiftScaleFilterType::New(); typename ShiftScaleFilterType::Pointer shiftScaleFilter = ShiftScaleFilterType::New();
shiftScaleFilter->SetInput( input ); shiftScaleFilter->SetInput( input );
shiftScaleFilter->SetShifts( measurement.meanMeasurementVector ); shiftScaleFilter->SetShifts( measurement.meanMeasurementVector );
shiftScaleFilter->SetScales( measurement.stddevMeasurementVector ); shiftScaleFilter->SetScales( measurement.stddevMeasurementVector );
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment