Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Commandre Benjamin
OTB
Commits
78d7f56c
Commit
78d7f56c
authored
May 09, 2019
by
Cédric Traizet
Browse files
Merge branch 'regression_refactoring' into 'develop'
Regression refactoring : TrainVectorRegression See merge request orfeotoolbox/otb!488
parents
b340ad5d
5e79e868
Changes
9
Hide whitespace changes
Inline
Side-by-side
Data/Baseline/OTB-Applications/Files/apTvClTrainVectorRegressionModel.1.txt
0 → 100644
View file @
78d7f56c
io.mse: 0.003289417131
Data/Baseline/OTB-Applications/Files/apTvClTrainVectorRegressionModel.txt
0 → 100644
View file @
78d7f56c
io.mse: 0.001359587419
Modules/Applications/AppClassification/app/CMakeLists.txt
View file @
78d7f56c
...
...
@@ -65,6 +65,11 @@ otb_create_application(
SOURCES otbTrainRegression.cxx
LINK_LIBRARIES
${${
otb-module
}
_LIBRARIES
}
)
otb_create_application
(
NAME TrainVectorRegression
SOURCES otbTrainVectorRegression.cxx
LINK_LIBRARIES
${${
otb-module
}
_LIBRARIES
}
)
otb_create_application
(
NAME PredictRegression
SOURCES otbPredictRegression.cxx
...
...
Modules/Applications/AppClassification/app/otbTrainRegression.cxx
View file @
78d7f56c
...
...
@@ -271,8 +271,7 @@ void ParseCSVPredictors(std::string path, ListSampleType* outputList)
elem
.
Fill
(
0.0
);
for
(
unsigned
int
i
=
0
;
i
<
nbCols
;
++
i
)
{
iss
.
str
(
words
[
i
]);
iss
>>
elem
[
i
];
elem
[
i
]
=
std
::
stod
(
words
[
i
]);
}
outputList
->
PushBack
(
elem
);
}
...
...
Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx
View file @
78d7f56c
...
...
@@ -29,11 +29,11 @@ namespace otb
namespace
Wrapper
{
class
TrainVectorClassifier
:
public
TrainVectorBase
class
TrainVectorClassifier
:
public
TrainVectorBase
<
float
,
int
>
{
public:
typedef
TrainVectorClassifier
Self
;
typedef
TrainVectorBase
Superclass
;
typedef
TrainVectorBase
<
float
,
int
>
Superclass
;
typedef
itk
::
SmartPointer
<
Self
>
Pointer
;
typedef
itk
::
SmartPointer
<
const
Self
>
ConstPointer
;
itkNewMacro
(
Self
)
...
...
@@ -66,13 +66,20 @@ protected:
"Learning (2.3.1 and later), and Shark ML The output of this application "
"is a text model file, whose format corresponds to the ML model type "
"chosen. There is no image nor vector data output."
);
SetDocLimitations
(
""
);
SetDocLimitations
(
"
None
"
);
SetDocAuthors
(
"OTB Team"
);
SetDocSeeAlso
(
" "
);
SetOfficialDocLink
();
Superclass
::
DoInit
();
// Add a new parameter to compute confusion matrix / contingency table
this
->
AddParameter
(
ParameterType_OutputFilename
,
"io.confmatout"
,
"Output confusion matrix or contingency table"
);
this
->
SetParameterDescription
(
"io.confmatout"
,
"Output file containing the confusion matrix or contingency table (.csv format)."
"The contingency table is output when we unsupervised algorithms is used otherwise the confusion matrix is output."
);
this
->
MandatoryOff
(
"io.confmatout"
);
}
void
DoUpdateParameters
()
override
...
...
Modules/Applications/AppClassification/app/otbTrainVectorRegression.cxx
0 → 100644
View file @
78d7f56c
/*
* Copyright (C) 2005-2019 Centre National d'Etudes Spatiales (CNES)
*
* This file is part of Orfeo Toolbox
*
* https://www.orfeo-toolbox.org/
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "otbTrainVectorBase.h"
namespace
otb
{
namespace
Wrapper
{
class
TrainVectorRegression
:
public
TrainVectorBase
<
float
,
float
>
{
public:
typedef
TrainVectorRegression
Self
;
typedef
TrainVectorBase
<
float
,
float
>
Superclass
;
typedef
itk
::
SmartPointer
<
Self
>
Pointer
;
typedef
itk
::
SmartPointer
<
const
Self
>
ConstPointer
;
itkNewMacro
(
Self
)
itkTypeMacro
(
Self
,
Superclass
)
typedef
Superclass
::
SampleType
SampleType
;
typedef
Superclass
::
ListSampleType
ListSampleType
;
typedef
Superclass
::
TargetListSampleType
TargetListSampleType
;
protected:
TrainVectorRegression
()
{
this
->
m_RegressionFlag
=
true
;
}
void
DoInit
()
override
{
SetName
(
"TrainVectorRegression"
);
SetDescription
(
"Train a regression algorithm based on geometries with "
"list of features to consider and a predictor."
);
SetDocLongDescription
(
"This application trains a regression algorithm based on "
"a predictor geometries and a list of features to consider for "
"regression.
\n
This application is based on LibSVM, OpenCV Machine "
"Learning (2.3.1 and later), and Shark ML The output of this application "
"is a text model file, whose format corresponds to the ML model type "
"chosen. There is no image or vector data output."
);
SetDocLimitations
(
"None"
);
SetDocAuthors
(
"OTB Team"
);
SetDocSeeAlso
(
"TrainVectorClassifier"
);
SetOfficialDocLink
();
Superclass
::
DoInit
();
AddParameter
(
ParameterType_Float
,
"io.mse"
,
"Mean Square Error"
);
SetParameterDescription
(
"io.mse"
,
"Mean square error computed with the validation predictors"
);
SetParameterRole
(
"io.mse"
,
Role_Output
);
this
->
MandatoryOff
(
"io.mse"
);
}
void
DoUpdateParameters
()
override
{
Superclass
::
DoUpdateParameters
();
}
double
ComputeMSE
(
const
TargetListSampleType
&
list1
,
const
TargetListSampleType
&
list2
)
{
assert
(
list1
.
Size
()
==
list2
.
Size
());
double
mse
=
0.
;
for
(
TargetListSampleType
::
InstanceIdentifier
i
=
0
;
i
<
list1
.
Size
();
++
i
)
{
auto
elem1
=
list1
.
GetMeasurementVector
(
i
);
auto
elem2
=
list2
.
GetMeasurementVector
(
i
);
mse
+=
(
elem1
[
0
]
-
elem2
[
0
])
*
(
elem1
[
0
]
-
elem2
[
0
]);
}
mse
/=
static_cast
<
double
>
(
list1
.
Size
());
return
mse
;
}
void
DoExecute
()
override
{
m_FeaturesInfo
.
SetClassFieldNames
(
GetChoiceNames
(
"cfield"
),
GetSelectedItems
(
"cfield"
));
if
(
m_FeaturesInfo
.
m_SelectedCFieldIdx
.
empty
()
&&
GetClassifierCategory
()
==
Supervised
)
{
otbAppLogFATAL
(
<<
"No field has been selected for data labelling!"
);
}
Superclass
::
DoExecute
();
otbAppLogINFO
(
"Computing training performances"
);
auto
mse
=
ComputeMSE
(
*
m_ClassificationSamplesWithLabel
.
labeledListSample
,
*
m_PredictedList
);
otbAppLogINFO
(
"Mean Square Error = "
<<
mse
);
this
->
SetParameterFloat
(
"io.mse"
,
mse
);
}
private:
};
}
}
OTB_APPLICATION_EXPORT
(
otb
::
Wrapper
::
TrainVectorRegression
)
Modules/Applications/AppClassification/include/otbTrainVectorBase.h
View file @
78d7f56c
...
...
@@ -49,21 +49,22 @@ bool IsNotAlphaNum(char c)
return
!
std
::
isalnum
(
c
);
}
class
TrainVectorBase
:
public
LearningApplicationBase
<
float
,
int
>
template
<
class
TInputValue
,
class
TOutputValue
>
class
TrainVectorBase
:
public
LearningApplicationBase
<
TInputValue
,
TOutputValue
>
{
public:
/** Standard class typedefs. */
typedef
TrainVectorBase
Self
;
typedef
LearningApplicationBase
<
float
,
int
>
Superclass
;
typedef
LearningApplicationBase
<
TInputValue
,
TOutputValue
>
Superclass
;
typedef
itk
::
SmartPointer
<
Self
>
Pointer
;
typedef
itk
::
SmartPointer
<
const
Self
>
ConstPointer
;
/** Standard macro */
itkTypeMacro
(
Self
,
Superclass
);
typedef
Superclass
::
SampleType
SampleType
;
typedef
Superclass
::
ListSampleType
ListSampleType
;
typedef
Superclass
::
TargetListSampleType
TargetListSampleType
;
typedef
typename
Superclass
::
SampleType
SampleType
;
typedef
typename
Superclass
::
ListSampleType
ListSampleType
;
typedef
typename
Superclass
::
TargetListSampleType
TargetListSampleType
;
typedef
double
ValueType
;
typedef
itk
::
VariableLengthVector
<
ValueType
>
MeasurementType
;
...
...
@@ -86,8 +87,8 @@ protected:
class
SamplesWithLabel
{
public:
ListSampleType
::
Pointer
listSample
;
TargetListSampleType
::
Pointer
labeledListSample
;
typename
ListSampleType
::
Pointer
listSample
;
typename
TargetListSampleType
::
Pointer
labeledListSample
;
SamplesWithLabel
()
{
listSample
=
ListSampleType
::
New
();
...
...
@@ -178,13 +179,18 @@ protected:
SamplesWithLabel
m_TrainingSamplesWithLabel
;
SamplesWithLabel
m_ClassificationSamplesWithLabel
;
TargetListSampleType
::
Pointer
m_PredictedList
;
typename
TargetListSampleType
::
Pointer
m_PredictedList
;
FeaturesInfo
m_FeaturesInfo
;
void
DoInit
()
override
;
void
DoUpdateParameters
()
override
;
void
DoExecute
()
override
;
private:
/**
* Get the field of the input feature corresponding to the input field
*/
inline
TOutputValue
GetFeatureField
(
const
ogr
::
Feature
&
feature
,
int
field
);
};
}
...
...
Modules/Applications/AppClassification/include/otbTrainVectorBase.hxx
View file @
78d7f56c
...
...
@@ -27,100 +27,98 @@ namespace otb
namespace
Wrapper
{
void
TrainVectorBase
::
DoInit
()
template
<
class
TInputValue
,
class
TOutputValue
>
void
TrainVectorBase
<
TInputValue
,
TOutputValue
>
::
DoInit
()
{
// Common Parameters for all Learning Application
AddParameter
(
ParameterType_Group
,
"io"
,
"Input and output data"
);
SetParameterDescription
(
"io"
,
this
->
AddParameter
(
ParameterType_Group
,
"io"
,
"Input and output data"
);
this
->
SetParameterDescription
(
"io"
,
"This group of parameters allows setting input and output data."
);
AddParameter
(
ParameterType_InputVectorDataList
,
"io.vd"
,
"Input Vector Data"
);
SetParameterDescription
(
"io.vd"
,
this
->
AddParameter
(
ParameterType_InputVectorDataList
,
"io.vd"
,
"Input Vector Data"
);
this
->
SetParameterDescription
(
"io.vd"
,
"Input geometries used for training (note: all geometries from the layer will be used)"
);
AddParameter
(
ParameterType_InputFilename
,
"io.stats"
,
"Input XML image statistics file"
);
MandatoryOff
(
"io.stats"
);
SetParameterDescription
(
"io.stats"
,
this
->
AddParameter
(
ParameterType_InputFilename
,
"io.stats"
,
"Input XML image statistics file"
);
this
->
MandatoryOff
(
"io.stats"
);
this
->
SetParameterDescription
(
"io.stats"
,
"XML file containing mean and variance of each feature."
);
AddParameter
(
ParameterType_OutputFilename
,
"io.out"
,
"Output model"
);
SetParameterDescription
(
"io.out"
,
this
->
AddParameter
(
ParameterType_OutputFilename
,
"io.out"
,
"Output model"
);
this
->
SetParameterDescription
(
"io.out"
,
"Output file containing the model estimated (.txt format)."
);
AddParameter
(
ParameterType_Int
,
"layer"
,
"Layer Index"
);
SetParameterDescription
(
"layer"
,
this
->
AddParameter
(
ParameterType_Int
,
"layer"
,
"Layer Index"
);
this
->
SetParameterDescription
(
"layer"
,
"Index of the layer to use in the input vector file."
);
MandatoryOff
(
"layer"
);
SetDefaultParameterInt
(
"layer"
,
0
);
this
->
MandatoryOff
(
"layer"
);
this
->
SetDefaultParameterInt
(
"layer"
,
0
);
AddParameter
(
ParameterType_ListView
,
"feat"
,
"Field names for training features"
);
SetParameterDescription
(
"feat"
,
this
->
AddParameter
(
ParameterType_ListView
,
"feat"
,
"Field names for training features"
);
this
->
SetParameterDescription
(
"feat"
,
"List of field names in the input vector data to be used as features for training."
);
// Add validation data used to compute confusion matrix or contingency table
AddParameter
(
ParameterType_Group
,
"valid"
,
"Validation data"
);
SetParameterDescription
(
"valid"
,
this
->
AddParameter
(
ParameterType_Group
,
"valid"
,
"Validation data"
);
this
->
SetParameterDescription
(
"valid"
,
"This group of parameters defines validation data."
);
AddParameter
(
ParameterType_InputVectorDataList
,
"valid.vd"
,
this
->
AddParameter
(
ParameterType_InputVectorDataList
,
"valid.vd"
,
"Validation Vector Data"
);
SetParameterDescription
(
"valid.vd"
,
"Geometries used for validation "
this
->
SetParameterDescription
(
"valid.vd"
,
"Geometries used for validation "
"(must contain the same fields used for training, all geometries from the layer will be used)"
);
MandatoryOff
(
"valid.vd"
);
this
->
MandatoryOff
(
"valid.vd"
);
AddParameter
(
ParameterType_Int
,
"valid.layer"
,
"Layer Index"
);
SetParameterDescription
(
"valid.layer"
,
this
->
AddParameter
(
ParameterType_Int
,
"valid.layer"
,
"Layer Index"
);
this
->
SetParameterDescription
(
"valid.layer"
,
"Index of the layer to use in the validation vector file."
);
MandatoryOff
(
"valid.layer"
);
SetDefaultParameterInt
(
"valid.layer"
,
0
);
this
->
MandatoryOff
(
"valid.layer"
);
this
->
SetDefaultParameterInt
(
"valid.layer"
,
0
);
// Add class field if we used validation
AddParameter
(
ParameterType_ListView
,
"cfield"
,
this
->
AddParameter
(
ParameterType_ListView
,
"cfield"
,
"Field containing the class integer label for supervision"
);
SetParameterDescription
(
"cfield"
,
this
->
SetParameterDescription
(
"cfield"
,
"Field containing the class id for supervision. "
"The values in this field shall be cast into integers. "
"Only geometries with this field available will be taken into account."
);
SetListViewSingleSelectionMode
(
"cfield"
,
true
);
this
->
SetListViewSingleSelectionMode
(
"cfield"
,
true
);
// Add a new parameter to compute confusion matrix / contingency table
AddParameter
(
ParameterType_OutputFilename
,
"io.confmatout"
,
"Output confusion matrix or contingency table"
);
SetParameterDescription
(
"io.confmatout"
,
"Output file containing the confusion matrix or contingency table (.csv format)."
"The contingency table is output when we unsupervised algorithms is used otherwise the confusion matrix is output."
);
MandatoryOff
(
"io.confmatout"
);
AddParameter
(
ParameterType_Bool
,
"v"
,
"Verbose mode"
);
SetParameterDescription
(
"v"
,
"Verbose mode, display the contingency table result."
);
SetParameterInt
(
"v"
,
1
);
this
->
AddParameter
(
ParameterType_Bool
,
"v"
,
"Verbose mode"
);
this
->
SetParameterDescription
(
"v"
,
"Verbose mode, display the contingency table result."
);
this
->
SetParameterInt
(
"v"
,
1
);
// Doc example parameter settings
SetDocExampleParameterValue
(
"io.vd"
,
"vectorData.shp"
);
SetDocExampleParameterValue
(
"io.stats"
,
"meanVar.xml"
);
SetDocExampleParameterValue
(
"io.out"
,
"svmModel.svm"
);
SetDocExampleParameterValue
(
"feat"
,
"perimeter area width"
);
SetDocExampleParameterValue
(
"cfield"
,
"predicted"
);
this
->
SetDocExampleParameterValue
(
"io.vd"
,
"vectorData.shp"
);
this
->
SetDocExampleParameterValue
(
"io.stats"
,
"meanVar.xml"
);
this
->
SetDocExampleParameterValue
(
"io.out"
,
"svmModel.svm"
);
this
->
SetDocExampleParameterValue
(
"feat"
,
"perimeter area width"
);
this
->
SetDocExampleParameterValue
(
"cfield"
,
"predicted"
);
// Add parameters for the classifier choice
Superclass
::
DoInit
();
AddRANDParameter
();
this
->
AddRANDParameter
();
}
void
TrainVectorBase
::
DoUpdateParameters
()
template
<
class
TInputValue
,
class
TOutputValue
>
void
TrainVectorBase
<
TInputValue
,
TOutputValue
>
::
DoUpdateParameters
()
{
// if vector data is present and updated then reload fields
if
(
HasValue
(
"io.vd"
)
)
if
(
this
->
HasValue
(
"io.vd"
)
)
{
std
::
vector
<
std
::
string
>
vectorFileList
=
GetParameterStringList
(
"io.vd"
);
std
::
vector
<
std
::
string
>
vectorFileList
=
this
->
GetParameterStringList
(
"io.vd"
);
ogr
::
DataSource
::
Pointer
ogrDS
=
ogr
::
DataSource
::
New
(
vectorFileList
[
0
],
ogr
::
DataSource
::
Modes
::
Read
);
ogr
::
Layer
layer
=
ogrDS
->
GetLayer
(
static_cast
<
size_t
>
(
this
->
GetParameterInt
(
"layer"
)
)
);
ogr
::
Feature
feature
=
layer
.
ogr
().
GetNextFeature
();
ClearChoices
(
"feat"
);
ClearChoices
(
"cfield"
);
this
->
ClearChoices
(
"feat"
);
this
->
ClearChoices
(
"cfield"
);
for
(
int
iField
=
0
;
iField
<
feature
.
ogr
().
GetFieldCount
();
iField
++
)
{
...
...
@@ -134,20 +132,23 @@ void TrainVectorBase::DoUpdateParameters()
if
(
fieldType
==
OFTInteger
||
fieldType
==
OFTInteger64
||
fieldType
==
OFTReal
)
{
std
::
string
tmpKey
=
"feat."
+
key
.
substr
(
0
,
static_cast
<
unsigned
long
>
(
end
-
key
.
begin
()
)
);
AddChoice
(
tmpKey
,
item
);
this
->
AddChoice
(
tmpKey
,
item
);
}
if
(
fieldType
==
OFTString
||
fieldType
==
OFTInteger
||
fieldType
==
OFTInteger64
)
if
(
fieldType
==
OFTString
||
fieldType
==
OFTInteger
||
fieldType
==
OFTInteger64
||
fieldType
==
OFTReal
)
{
std
::
string
tmpKey
=
"cfield."
+
key
.
substr
(
0
,
static_cast
<
unsigned
long
>
(
end
-
key
.
begin
()
)
);
AddChoice
(
tmpKey
,
item
);
this
->
AddChoice
(
tmpKey
,
item
);
}
}
}
}
void
TrainVectorBase
::
DoExecute
()
template
<
class
TInputValue
,
class
TOutputValue
>
void
TrainVectorBase
<
TInputValue
,
TOutputValue
>
::
DoExecute
()
{
m_FeaturesInfo
.
SetFieldNames
(
GetChoiceNames
(
"feat"
),
GetSelectedItems
(
"feat"
));
m_FeaturesInfo
.
SetFieldNames
(
this
->
GetChoiceNames
(
"feat"
),
this
->
GetSelectedItems
(
"feat"
));
// Check input parameters
if
(
m_FeaturesInfo
.
m_SelectedIdx
.
empty
()
)
...
...
@@ -158,29 +159,35 @@ void TrainVectorBase::DoExecute()
ShiftScaleParameters
measurement
=
GetStatistics
(
m_FeaturesInfo
.
m_NbFeatures
);
ExtractAllSamples
(
measurement
);
this
->
Train
(
m_TrainingSamplesWithLabel
.
listSample
,
m_TrainingSamplesWithLabel
.
labeledListSample
,
GetParameterString
(
"io.out"
)
);
this
->
Train
(
m_TrainingSamplesWithLabel
.
listSample
,
m_TrainingSamplesWithLabel
.
labeledListSample
,
this
->
GetParameterString
(
"io.out"
)
);
m_PredictedList
=
this
->
Classify
(
m_ClassificationSamplesWithLabel
.
listSample
,
GetParameterString
(
"io.out"
)
);
this
->
Classify
(
m_ClassificationSamplesWithLabel
.
listSample
,
this
->
GetParameterString
(
"io.out"
)
);
}
void
TrainVectorBase
::
ExtractAllSamples
(
const
ShiftScaleParameters
&
measurement
)
template
<
class
TInputValue
,
class
TOutputValue
>
void
TrainVectorBase
<
TInputValue
,
TOutputValue
>
::
ExtractAllSamples
(
const
ShiftScaleParameters
&
measurement
)
{
m_TrainingSamplesWithLabel
=
ExtractTrainingSamplesWithLabel
(
measurement
);
m_ClassificationSamplesWithLabel
=
ExtractClassificationSamplesWithLabel
(
measurement
);
}
TrainVectorBase
::
SamplesWithLabel
TrainVectorBase
::
ExtractTrainingSamplesWithLabel
(
const
ShiftScaleParameters
&
measurement
)
template
<
class
TInputValue
,
class
TOutputValue
>
typename
TrainVectorBase
<
TInputValue
,
TOutputValue
>::
SamplesWithLabel
TrainVectorBase
<
TInputValue
,
TOutputValue
>
::
ExtractTrainingSamplesWithLabel
(
const
ShiftScaleParameters
&
measurement
)
{
return
ExtractSamplesWithLabel
(
"io.vd"
,
"layer"
,
measurement
);
}
TrainVectorBase
::
SamplesWithLabel
TrainVectorBase
::
ExtractClassificationSamplesWithLabel
(
const
ShiftScaleParameters
&
measurement
)
template
<
class
TInputValue
,
class
TOutputValue
>
typename
TrainVectorBase
<
TInputValue
,
TOutputValue
>::
SamplesWithLabel
TrainVectorBase
<
TInputValue
,
TOutputValue
>
::
ExtractClassificationSamplesWithLabel
(
const
ShiftScaleParameters
&
measurement
)
{
if
(
GetClassifierCategory
()
==
Supervised
)
if
(
this
->
GetClassifierCategory
()
==
Superclass
::
Supervised
)
{
SamplesWithLabel
tmpSamplesWithLabel
;
SamplesWithLabel
validationSamplesWithLabel
=
ExtractSamplesWithLabel
(
"valid.vd"
,
"valid.layer"
,
measurement
);
...
...
@@ -206,15 +213,16 @@ TrainVectorBase::ExtractClassificationSamplesWithLabel(const ShiftScaleParameter
}
}
TrainVectorBase
::
ShiftScaleParameters
TrainVectorBase
::
GetStatistics
(
unsigned
int
nbFeatures
)
template
<
class
TInputValue
,
class
TOutputValue
>
typename
TrainVectorBase
<
TInputValue
,
TOutputValue
>::
ShiftScaleParameters
TrainVectorBase
<
TInputValue
,
TOutputValue
>
::
GetStatistics
(
unsigned
int
nbFeatures
)
{
ShiftScaleParameters
measurement
=
ShiftScaleParameters
();
if
(
HasValue
(
"io.stats"
)
&&
IsParameterEnabled
(
"io.stats"
)
)
if
(
this
->
HasValue
(
"io.stats"
)
&&
this
->
IsParameterEnabled
(
"io.stats"
)
)
{
StatisticsReader
::
Pointer
statisticsReader
=
StatisticsReader
::
New
();
std
::
string
XMLfile
=
GetParameterString
(
"io.stats"
);
typename
StatisticsReader
::
Pointer
statisticsReader
=
StatisticsReader
::
New
();
std
::
string
XMLfile
=
this
->
GetParameterString
(
"io.stats"
);
statisticsReader
->
SetFileName
(
XMLfile
);
measurement
.
meanMeasurementVector
=
statisticsReader
->
GetStatisticVectorByName
(
"mean"
);
measurement
.
stddevMeasurementVector
=
statisticsReader
->
GetStatisticVectorByName
(
"stddev"
);
...
...
@@ -229,16 +237,34 @@ TrainVectorBase::GetStatistics(unsigned int nbFeatures)
return
measurement
;
}
// Template specialization for the integer case (i.e.classification), to avoid a cast from double to integer
template
<
>
inline
int
TrainVectorBase
<
float
,
int
>
::
GetFeatureField
(
const
ogr
::
Feature
&
feature
,
int
fieldIndex
)
{
return
(
feature
[
fieldIndex
].
GetValue
<
int
>
());
}
template
<
class
TInputValue
,
class
TOutputValue
>
inline
TOutputValue
TrainVectorBase
<
TInputValue
,
TOutputValue
>
::
GetFeatureField
(
const
ogr
::
Feature
&
feature
,
int
fieldIndex
)
{
return
(
feature
[
fieldIndex
].
GetValue
<
double
>
());
}
TrainVectorBase
::
SamplesWithLabel
TrainVectorBase
::
ExtractSamplesWithLabel
(
std
::
string
parameterName
,
std
::
string
parameterLayer
,
template
<
class
TInputValue
,
class
TOutputValue
>
typename
TrainVectorBase
<
TInputValue
,
TOutputValue
>::
SamplesWithLabel
TrainVectorBase
<
TInputValue
,
TOutputValue
>
::
ExtractSamplesWithLabel
(
std
::
string
parameterName
,
std
::
string
parameterLayer
,
const
ShiftScaleParameters
&
measurement
)
{
SamplesWithLabel
samplesWithLabel
;
if
(
HasValue
(
parameterName
)
&&
IsParameterEnabled
(
parameterName
)
)
if
(
this
->
HasValue
(
parameterName
)
&&
this
->
IsParameterEnabled
(
parameterName
)
)
{
ListSampleType
::
Pointer
input
=
ListSampleType
::
New
();
TargetListSampleType
::
Pointer
target
=
TargetListSampleType
::
New
();
typename
ListSampleType
::
Pointer
input
=
ListSampleType
::
New
();
typename
TargetListSampleType
::
Pointer
target
=
TargetListSampleType
::
New
();
input
->
SetMeasurementVectorSize
(
m_FeaturesInfo
.
m_NbFeatures
);
std
::
vector
<
std
::
string
>
fileList
=
this
->
GetParameterStringList
(
parameterName
);
...
...
@@ -251,7 +277,7 @@ TrainVectorBase::ExtractSamplesWithLabel(std::string parameterName, std::string
bool
goesOn
=
feature
.
addr
()
!=
0
;
if
(
!
goesOn
)
{
otbAppLogWARNING
(
"The layer "
<<
GetParameterInt
(
parameterLayer
)
<<
" of "
<<
fileList
[
k
]
otbAppLogWARNING
(
"The layer "
<<
this
->
GetParameterInt
(
parameterLayer
)
<<
" of "
<<
fileList
[
k
]
<<
" is empty, input is skipped."
);
continue
;
}
...
...
@@ -284,14 +310,14 @@ TrainVectorBase::ExtractSamplesWithLabel(std::string parameterName, std::string
MeasurementType
mv
;
mv
.
SetSize
(
m_FeaturesInfo
.
m_NbFeatures
);
for
(
unsigned
int
idx
=
0
;
idx
<
m_FeaturesInfo
.
m_NbFeatures
;
++
idx
)
mv
[
idx
]
=
feature
.
ogr
().
GetFieldAsDouble
(
featureFieldIndex
[
idx
]
);
mv
[
idx
]
=
feature
[
featureFieldIndex
[
idx
]
].
GetValue
<
double
>
(
);
input
->
PushBack
(
mv
);
if
(
cFieldIndex
>=
0
&&
ogr
::
Field
(
feature
,
cFieldIndex
).
HasBeenSet
())
target
->
PushBack
(
f
eature
.
ogr
().
GetFieldAsInteger
(
cFieldIndex
)
);
target
->
PushBack
(
GetF
eature
Field
(
feature
,
cFieldIndex
)
);
else
target
->
PushBack
(
0
);
target
->
PushBack
(
0
.
);
feature
=
layer
.
ogr
().
GetNextFeature
();
goesOn
=
feature
.
addr
()
!=
0
;
...
...
@@ -300,7 +326,7 @@ TrainVectorBase::ExtractSamplesWithLabel(std::string parameterName, std::string
ShiftScaleFilterType
::
Pointer
shiftScaleFilter
=
ShiftScaleFilterType
::
New
();
typename
ShiftScaleFilterType
::
Pointer
shiftScaleFilter
=
ShiftScaleFilterType
::
New
();
shiftScaleFilter
->
SetInput
(
input
);
shiftScaleFilter
->
SetShifts
(
measurement
.
meanMeasurementVector
);
shiftScaleFilter
->
SetScales
(
measurement
.
stddevMeasurementVector
);
...
...
Modules/Applications/AppClassification/test/CMakeLists.txt
View file @
78d7f56c
...
...
@@ -837,6 +837,22 @@ if(OTB_USE_OPENCV)
${
TEMP
}
/apTvClTrainVectorClassifierModel.rf
)
endif
()
#----------- TrainVectorRegression TESTS ----------------
if
(
OTB_USE_OPENCV
)
otb_test_application
(
NAME apTvClTrainVectorRegression
APP TrainVectorRegression
OPTIONS -io.vd
${
INPUTDATA
}
/Classification/apTvClSampleExtractionOut.sqlite