Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Lozac'h Loic
otbtf
Commits
ddf221b3
Commit
ddf221b3
authored
Aug 30, 2018
by
Cresson Remi
Browse files
REFAC: wip#2
parent
e42a2d60
Changes
10
Hide whitespace changes
Inline
Side-by-side
app/otbTensorflowModelServe.cxx
View file @
ddf221b3
...
...
@@ -248,16 +248,16 @@ public:
m_TFFilter
=
TFModelFilterType
::
New
();
m_TFFilter
->
SetGraph
(
m_SavedModel
.
meta_graph_def
.
graph_def
());
m_TFFilter
->
SetSession
(
m_SavedModel
.
session
.
get
());
m_TFFilter
->
SetOutputTensors
Names
(
GetParameterStringList
(
"output.names"
));
m_TFFilter
->
SetOutputTensors
(
GetParameterStringList
(
"output.names"
));
m_TFFilter
->
SetOutputSpacingScale
(
GetParameterFloat
(
"output.spcscale"
));
otbAppLogINFO
(
"Output spacing ratio: "
<<
m_TFFilter
->
GetOutputSpacingScale
());
// Get user placeholders
TFModelFilterType
::
DictListType
dict
;
TFModelFilterType
::
StringList
expressions
=
GetParameterStringList
(
"model.userplaceholders"
);
TFModelFilterType
::
DictType
dict
;
for
(
auto
&
exp
:
expressions
)
{
TFModelFilterType
::
DictType
entry
=
tf
::
ExpressionToTensor
(
exp
);
TFModelFilterType
::
Dict
Element
Type
entry
=
tf
::
ExpressionToTensor
(
exp
);
dict
.
push_back
(
entry
);
otbAppLogINFO
(
"Using placeholder "
<<
entry
.
first
<<
" with "
<<
tf
::
PrintTensorInfos
(
entry
.
second
));
...
...
@@ -267,7 +267,7 @@ public:
// Input sources
for
(
auto
&
bundle
:
m_Bundles
)
{
m_TFFilter
->
PushBackInputBundle
(
bundle
.
m_Placeholder
,
bundle
.
m_PatchSize
,
bundle
.
m_ImageSource
.
Get
());
m_TFFilter
->
PushBackInput
Tensor
Bundle
(
bundle
.
m_Placeholder
,
bundle
.
m_PatchSize
,
bundle
.
m_ImageSource
.
Get
());
}
// Fully convolutional mode on/off
...
...
@@ -281,9 +281,9 @@ public:
FloatVectorImageType
::
SizeType
foe
;
foe
[
0
]
=
GetParameterInt
(
"output.foex"
);
foe
[
1
]
=
GetParameterInt
(
"output.foey"
);
m_TFFilter
->
SetOutput
FOESize
(
foe
);
m_TFFilter
->
SetOutput
ExpressionFields
({
foe
}
);
otbAppLogINFO
(
"Output field of expression: "
<<
m_TFFilter
->
GetOutput
FOESize
()
);
otbAppLogINFO
(
"Output field of expression: "
<<
m_TFFilter
->
GetOutput
ExpressionFields
()[
0
]
);
// Streaming
if
(
GetParameterInt
(
"finetuning.disabletiling"
)
!=
1
)
...
...
app/otbTensorflowModelTrain.cxx
View file @
ddf221b3
...
...
@@ -53,7 +53,7 @@ public:
itkNewMacro
(
Self
);
itkTypeMacro
(
TensorflowModelTrain
,
Application
);
/** Typedefs for
t
ensor
f
low */
/** Typedefs for
T
ensor
F
low */
typedef
otb
::
TensorflowMultisourceModelTrain
<
FloatVectorImageType
>
TrainModelFilterType
;
typedef
otb
::
TensorflowMultisourceModelValidate
<
FloatVectorImageType
>
ValidateModelFilterType
;
typedef
otb
::
TensorflowSource
<
FloatVectorImageType
>
TFSource
;
...
...
@@ -75,8 +75,8 @@ public:
// Parameters keys
std
::
string
m_KeyInForTrain
;
// Key of input image list (training)
std
::
string
m_KeyInForValid
;
// Key of input image list (validation)
std
::
string
m_KeyPHNameForTrain
;
// Key for placeholder name in the
t
ensor
f
low model (training)
std
::
string
m_KeyPHNameForValid
;
// Key for placeholder name in the
t
ensor
f
low model (validation)
std
::
string
m_KeyPHNameForTrain
;
// Key for placeholder name in the
T
ensor
F
low model (training)
std
::
string
m_KeyPHNameForValid
;
// Key for placeholder name in the
T
ensor
F
low model (validation)
std
::
string
m_KeyPszX
;
// Key for samples sizes X
std
::
string
m_KeyPszY
;
// Key for samples sizes Y
};
...
...
@@ -194,10 +194,10 @@ public:
AddParameter
(
ParameterType_StringList
,
"training.userplaceholders"
,
"Additional single-valued placeholders for training. Supported types: int, float, bool."
);
MandatoryOff
(
"training.userplaceholders"
);
AddParameter
(
ParameterType_StringList
,
"training.targetnodes
names
"
,
"Names of the target nodes"
);
MandatoryOn
(
"training.targetnodes
names
"
);
AddParameter
(
ParameterType_StringList
,
"training.outputtensors
names
"
,
"Names of the output tensors to display"
);
MandatoryOff
(
"training.outputtensors
names
"
);
AddParameter
(
ParameterType_StringList
,
"training.targetnodes"
,
"Names of the target nodes"
);
MandatoryOn
(
"training.targetnodes"
);
AddParameter
(
ParameterType_StringList
,
"training.outputtensors"
,
"Names of the output tensors to display"
);
MandatoryOff
(
"training.outputtensors"
);
// Metrics
AddParameter
(
ParameterType_Group
,
"validation"
,
"Validation parameters"
);
...
...
@@ -228,7 +228,7 @@ public:
SetDocExampleParameterValue
(
"source2.fovy"
,
"1"
);
SetDocExampleParameterValue
(
"model.dir"
,
"/tmp/my_saved_model/"
);
SetDocExampleParameterValue
(
"training.userplaceholders"
,
"is_training=true dropout=0.2"
);
SetDocExampleParameterValue
(
"training.targetnode
names"
,
"optimizer"
);
SetDocExampleParameterValue
(
"training.targetnode
s"
,
"optimizer"
);
SetDocExampleParameterValue
(
"model.saveto"
,
"/tmp/my_saved_model_vars1"
);
}
...
...
@@ -350,13 +350,13 @@ public:
//
// Get user placeholders
//
TrainModelFilterType
::
Dict
List
Type
GetUserPlaceholders
(
const
std
::
string
key
)
TrainModelFilterType
::
DictType
GetUserPlaceholders
(
const
std
::
string
key
)
{
TrainModelFilterType
::
Dict
List
Type
dict
;
TrainModelFilterType
::
DictType
dict
;
TrainModelFilterType
::
StringList
expressions
=
GetParameterStringList
(
key
);
for
(
auto
&
exp
:
expressions
)
{
TrainModelFilterType
::
DictType
entry
=
tf
::
ExpressionToTensor
(
exp
);
TrainModelFilterType
::
Dict
Element
Type
entry
=
tf
::
ExpressionToTensor
(
exp
);
dict
.
push_back
(
entry
);
otbAppLogINFO
(
"Using placeholder "
<<
entry
.
first
<<
" with "
<<
tf
::
PrintTensorInfos
(
entry
.
second
));
...
...
@@ -414,15 +414,15 @@ public:
m_TrainModelFilter
=
TrainModelFilterType
::
New
();
m_TrainModelFilter
->
SetGraph
(
m_SavedModel
.
meta_graph_def
.
graph_def
());
m_TrainModelFilter
->
SetSession
(
m_SavedModel
.
session
.
get
());
m_TrainModelFilter
->
SetOutputTensors
Names
(
GetParameterStringList
(
"training.outputtensors
names
"
));
m_TrainModelFilter
->
SetTargetNodesNames
(
GetParameterStringList
(
"training.targetnodes
names
"
));
m_TrainModelFilter
->
SetOutputTensors
(
GetParameterStringList
(
"training.outputtensors"
));
m_TrainModelFilter
->
SetTargetNodesNames
(
GetParameterStringList
(
"training.targetnodes"
));
m_TrainModelFilter
->
SetBatchSize
(
GetParameterInt
(
"training.batchsize"
));
m_TrainModelFilter
->
SetUserPlaceholders
(
GetUserPlaceholders
(
"training.userplaceholders"
));
// Set inputs
for
(
unsigned
int
i
=
0
;
i
<
m_InputSourcesForTraining
.
size
()
;
i
++
)
{
m_TrainModelFilter
->
PushBackInputBundle
(
m_TrainModelFilter
->
PushBackInput
Tensor
Bundle
(
m_InputPlaceholdersForTraining
[
i
],
m_InputPatchesSizeForTraining
[
i
],
m_InputSourcesForTraining
[
i
]);
...
...
@@ -458,14 +458,14 @@ public:
for
(
unsigned
int
i
=
0
;
i
<
m_InputSourcesForEvaluationAgainstLearningData
.
size
()
;
i
++
)
{
m_ValidateModelFilter
->
PushBackInputBundle
(
m_ValidateModelFilter
->
PushBackInput
Tensor
Bundle
(
m_InputPlaceholdersForValidation
[
i
],
m_InputPatchesSizeForValidation
[
i
],
m_InputSourcesForEvaluationAgainstLearningData
[
i
]);
}
m_ValidateModelFilter
->
SetOutputTensors
Names
(
m_TargetTensorsNames
);
m_ValidateModelFilter
->
SetOutputTensors
(
m_TargetTensorsNames
);
m_ValidateModelFilter
->
SetInputReferences
(
m_InputTargetsForEvaluationAgainstLearningData
);
m_ValidateModelFilter
->
SetOutput
FOESize
s
(
m_TargetPatchesSize
);
m_ValidateModelFilter
->
SetOutput
ExpressionField
s
(
m_TargetPatchesSize
);
// Update
AddProcess
(
m_ValidateModelFilter
,
"Evaluate model (Learning data)"
);
...
...
@@ -509,10 +509,13 @@ public:
private:
tensorflow
::
SavedModelBundle
m_SavedModel
;
// must be alive during all the execution of the application !
// Filters
TrainModelFilterType
::
Pointer
m_TrainModelFilter
;
ValidateModelFilterType
::
Pointer
m_ValidateModelFilter
;
tensorflow
::
SavedModelBundle
m_SavedModel
;
// must be alive during all the execution of the application !
// Inputs
BundleList
m_Bundles
;
// Patches size
...
...
include/otbTensorflowMultisourceModelFilter.h
View file @
ddf221b3
...
...
@@ -97,7 +97,6 @@ public:
typedef
typename
Superclass
::
DictType
DictType
;
typedef
typename
Superclass
::
StringList
StringList
;
typedef
typename
Superclass
::
SizeListType
SizeListType
;
typedef
typename
Superclass
::
DictListType
DictListType
;
typedef
typename
Superclass
::
TensorListType
TensorListType
;
typedef
std
::
vector
<
float
>
ScaleListType
;
...
...
@@ -139,6 +138,7 @@ private:
SpacingType
m_OutputSpacing
;
// Output image spacing
PointType
m_OutputOrigin
;
// Output image origin
SizeType
m_OutputSize
;
// Output image size
PixelType
m_NullPixel
;
// Pixel filled with zeros
};
// end class
...
...
include/otbTensorflowMultisourceModelFilter.hxx
View file @
ddf221b3
...
...
@@ -284,6 +284,10 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
outputPtr
->
SetSignedSpacing
(
m_OutputSpacing
);
outputPtr
->
SetLargestPossibleRegion
(
largestPossibleRegion
);
// Set null pixel
m_NullPixel
.
SetSize
(
outputPtr
->
GetNumberOfComponentsPerPixel
());
m_NullPixel
.
Fill
(
0
);
}
template
<
class
TInputImage
,
class
TOutputImage
>
...
...
@@ -395,13 +399,12 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
tf
::
RecopyImageRegionToTensorWithCast
<
TInputImage
>
(
inputPtr
,
reqRegion
,
inputTensor
,
0
);
// Input #1 : the tensor of patches (aka the batch)
DictElementType
input
1
=
{
this
->
GetInputPlaceholders
Names
()[
i
],
inputTensor
};
inputs
.
push_back
(
input
1
);
DictElementType
input
=
{
this
->
GetInputPlaceholders
()[
i
],
inputTensor
};
inputs
.
push_back
(
input
);
}
else
{
// Preparing patches (not very optimized ! )
// It would be better to perform the loop inside the TF session using TF operators
// Preparing patches
// Shape of input tensor #i
tensorflow
::
int64
sz_n
=
outputReqRegion
.
GetNumberOfPixels
();
tensorflow
::
int64
sz_y
=
inputPatchSize
[
1
];
...
...
@@ -429,8 +432,9 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
}
// Input #1 : the tensor of patches (aka the batch)
DictElementType
input1
=
{
this
->
GetInputPlaceholdersNames
()[
i
],
inputTensor
};
inputs
.
push_back
(
input1
);
DictElementType
input
=
{
this
->
GetInputPlaceholders
()[
i
],
inputTensor
};
inputs
.
push_back
(
input
);
}
// mode is not full convolutional
}
// next input tensor
...
...
@@ -442,10 +446,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
// Fill the output buffer with zero value
outputPtr
->
SetBufferedRegion
(
outputReqRegion
);
outputPtr
->
Allocate
();
OutputPixelType
nullpix
;
nullpix
.
SetSize
(
outputPtr
->
GetNumberOfComponentsPerPixel
());
nullpix
.
Fill
(
0
);
outputPtr
->
FillBuffer
(
nullpix
);
outputPtr
->
FillBuffer
(
m_NullPixel
);
// Get output tensors
int
bandOffset
=
0
;
...
...
@@ -453,7 +454,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
{
// The offset (i.e. the starting index of the channel for the output tensor) is updated
// during this call
// TODO: implement a generic strategy enabling
FOE
copy in patch-based mode (see tf::CopyTensorToImageRegion)
// TODO: implement a generic strategy enabling
expression field
copy in patch-based mode (see tf::CopyTensorToImageRegion)
try
{
tf
::
CopyTensorToImageRegion
<
TOutputImage
>
(
outputs
[
i
],
...
...
@@ -461,7 +462,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
}
catch
(
itk
::
ExceptionObject
&
err
)
{
std
::
stringstream
debugMsg
=
this
->
GenerateDebugReport
(
inputs
,
outputs
);
std
::
stringstream
debugMsg
=
this
->
GenerateDebugReport
(
inputs
);
itkExceptionMacro
(
"Error occured during tensor to image conversion.
\n
"
<<
"Context: "
<<
debugMsg
.
str
()
<<
"Error:"
<<
err
);
...
...
include/otbTensorflowMultisourceModelLearningBase.h
View file @
ddf221b3
...
...
@@ -18,11 +18,6 @@
// Base
#include "otbTensorflowMultisourceModelBase.h"
// Shuffle
#include <random>
#include <algorithm>
#include <iterator>
namespace
otb
{
...
...
@@ -86,10 +81,10 @@ protected:
virtual
void
GenerateData
();
virtual
void
PopulateInputTensors
(
TensorLis
tType
&
inputs
,
const
IndexValueType
&
sampleStart
,
const
IndexValueType
&
batchSize
,
const
IndexListType
&
order
=
IndexListType
()
);
virtual
void
PopulateInputTensors
(
Dic
tType
&
inputs
,
const
IndexValueType
&
sampleStart
,
const
IndexValueType
&
batchSize
,
const
IndexListType
&
order
);
virtual
void
ProcessBatch
(
TensorLis
tType
&
inputs
,
const
IndexValueType
&
sampleStart
,
virtual
void
ProcessBatch
(
Dic
tType
&
inputs
,
const
IndexValueType
&
sampleStart
,
const
IndexValueType
&
batchSize
)
=
0
;
private:
...
...
include/otbTensorflowMultisourceModelLearningBase.hxx
View file @
ddf221b3
...
...
@@ -19,7 +19,7 @@ namespace otb
template
<
class
TInputImage
>
TensorflowMultisourceModelLearningBase
<
TInputImage
>
::
TensorflowMultisourceModelLearningBase
()
:
m_BatchSize
(
100
),
m_NumberOfSamples
(
0
),
m_UseStreaming
(
false
)
m_UseStreaming
(
false
),
m_NumberOfSamples
(
0
)
{
}
...
...
@@ -31,6 +31,7 @@ TensorflowMultisourceModelLearningBase<TInputImage>
{
Superclass
::
GenerateOutputInformation
();
// Set an empty output buffered region
ImageType
*
outputPtr
=
this
->
GetOutput
();
RegionType
nullRegion
;
nullRegion
.
GetModifiableSize
().
Fill
(
1
);
...
...
@@ -72,7 +73,7 @@ TensorflowMultisourceModelLearningBase<TInputImage>
<<
" rows but patch size Y is "
<<
inputPatchSize
[
1
]
<<
" for input "
<<
i
);
// Get the batch size
const
tensorflow
::
uint64
currNumberOfSamples
=
reqRegion
.
GetSize
(
1
)
/
inputPatchSize
[
1
];
const
IndexValueType
currNumberOfSamples
=
reqRegion
.
GetSize
(
1
)
/
inputPatchSize
[
1
];
// Check the consistency with other inputs
if
(
m_NumberOfSamples
==
0
)
...
...
@@ -95,21 +96,21 @@ TensorflowMultisourceModelLearningBase<TInputImage>
{
Superclass
::
GenerateInputRequestedRegion
();
// For each image, set no image region
// For each image, set the requested region
RegionType
nullRegion
;
for
(
unsigned
int
i
=
0
;
i
<
this
->
GetNumberOfInputs
();
++
i
)
{
RegionType
nullRegion
;
ImageType
*
inputImage
=
static_cast
<
ImageType
*
>
(
Superclass
::
ProcessObject
::
GetInput
(
i
)
);
// If the streaming is enabled, we don't read the full image
if
(
m_UseStreaming
)
{
{
inputImage
->
SetRequestedRegion
(
nullRegion
);
}
}
else
{
{
inputImage
->
SetRequestedRegion
(
inputImage
->
GetLargestPossibleRegion
());
}
}
}
// next image
}
...
...
@@ -131,8 +132,8 @@ TensorflowMultisourceModelLearningBase<TInputImage>
for
(
IndexValueType
batch
=
0
;
batch
<
nBatches
;
batch
++
)
{
//
Create input tensors lis
t
TensorLis
tType
inputs
;
//
Feed dic
t
Dic
tType
inputs
;
// Batch start and size
const
IndexValueType
sampleStart
=
batch
*
m_BatchSize
;
...
...
@@ -143,7 +144,7 @@ TensorflowMultisourceModelLearningBase<TInputImage>
}
// Process the batch
ProcessBatch
(
inputs
,
sampleStart
,
batchSize
);
this
->
ProcessBatch
(
inputs
,
sampleStart
,
batchSize
);
progress
.
CompletedPixel
();
}
// Next batch
...
...
@@ -153,7 +154,7 @@ TensorflowMultisourceModelLearningBase<TInputImage>
template
<
class
TInputImage
>
void
TensorflowMultisourceModelLearningBase
<
TInputImage
>
::
PopulateInputTensors
(
TensorLis
tType
&
inputs
,
const
IndexValueType
&
sampleStart
,
::
PopulateInputTensors
(
Dic
tType
&
inputs
,
const
IndexValueType
&
sampleStart
,
const
IndexValueType
&
batchSize
,
const
IndexListType
&
order
)
{
const
bool
reorder
=
order
.
size
();
...
...
@@ -176,7 +177,7 @@ TensorflowMultisourceModelLearningBase<TInputImage>
tensorflow
::
Tensor
inputTensor
(
this
->
GetInputTensorsDataTypes
()[
i
],
inputTensorShape
);
// Populate the tensor
for
(
tensorflow
::
uint64
elem
=
0
;
elem
<
batchSize
;
elem
++
)
for
(
IndexValueType
elem
=
0
;
elem
<
batchSize
;
elem
++
)
{
const
tensorflow
::
uint64
samplePos
=
sampleStart
+
elem
;
IndexType
start
;
...
...
@@ -199,8 +200,8 @@ TensorflowMultisourceModelLearningBase<TInputImage>
}
// Input #i : the tensor of patches (aka the batch)
DictElementType
input
1
=
{
this
->
GetInputPlaceholders
Names
()[
i
],
inputTensor
};
inputs
.
push_back
(
input
1
);
DictElementType
input
=
{
this
->
GetInputPlaceholders
()[
i
],
inputTensor
};
inputs
.
push_back
(
input
);
}
// next input tensor
}
...
...
include/otbTensorflowMultisourceModelTrain.h
View file @
ddf221b3
...
...
@@ -54,8 +54,9 @@ public:
itkTypeMacro
(
TensorflowMultisourceModelTrain
,
TensorflowMultisourceModelLearningBase
);
/** Superclass typedefs */
typedef
typename
Superclass
::
IndexValueType
IndexValue
Type
;
typedef
typename
Superclass
::
DictType
Dict
Type
;
typedef
typename
Superclass
::
TensorListType
TensorListType
;
typedef
typename
Superclass
::
IndexValueType
IndexValueType
;
typedef
typename
Superclass
::
IndexListType
IndexListType
;
...
...
@@ -63,8 +64,8 @@ protected:
TensorflowMultisourceModelTrain
();
virtual
~
TensorflowMultisourceModelTrain
()
{};
void
GenerateData
();
void
ProcessBatch
(
TensorLis
tType
&
inputs
,
const
IndexValueType
&
sampleStart
,
virtual
void
GenerateData
();
virtual
void
ProcessBatch
(
Dic
tType
&
inputs
,
const
IndexValueType
&
sampleStart
,
const
IndexValueType
&
batchSize
);
private:
...
...
include/otbTensorflowMultisourceModelTrain.hxx
View file @
ddf221b3
...
...
@@ -45,11 +45,11 @@ TensorflowMultisourceModelTrain<TInputImage>
template
<
class
TInputImage
>
void
TensorflowMultisourceModelTrain
<
TInputImage
>
::
ProcessBatch
(
TensorLis
tType
&
inputs
,
const
IndexValueType
&
sampleStart
,
::
ProcessBatch
(
Dic
tType
&
inputs
,
const
IndexValueType
&
sampleStart
,
const
IndexValueType
&
batchSize
)
{
// Populate input tensors
PopulateInputTensor
(
inputs
,
sampleStart
,
batchSize
,
m_RandomIndices
);
this
->
PopulateInputTensor
s
(
inputs
,
sampleStart
,
batchSize
,
m_RandomIndices
);
// Run the TF session here
TensorListType
outputs
;
...
...
include/otbTensorflowMultisourceModelValidate.h
View file @
ddf221b3
...
...
@@ -43,10 +43,10 @@ public TensorflowMultisourceModelLearningBase<TInputImage>
public:
/** Standard class typedefs. */
typedef
TensorflowMultisourceModelValidate
Self
;
typedef
TensorflowMultisourceModelBase
<
TInputImage
>
Superclass
;
typedef
itk
::
SmartPointer
<
Self
>
Pointer
;
typedef
itk
::
SmartPointer
<
const
Self
>
ConstPointer
;
typedef
TensorflowMultisourceModelValidate
Self
;
typedef
TensorflowMultisourceModel
Learning
Base
<
TInputImage
>
Superclass
;
typedef
itk
::
SmartPointer
<
Self
>
Pointer
;
typedef
itk
::
SmartPointer
<
const
Self
>
ConstPointer
;
/** Method for creation through the object factory. */
itkNewMacro
(
Self
);
...
...
@@ -68,6 +68,7 @@ public:
typedef
typename
Superclass
::
SizeListType
SizeListType
;
typedef
typename
Superclass
::
TensorListType
TensorListType
;
typedef
typename
Superclass
::
IndexValueType
IndexValueType
;
typedef
typename
Superclass
::
IndexListType
IndexListType
;
/* Typedefs for validation */
typedef
unsigned
long
CountValueType
;
...
...
@@ -97,7 +98,7 @@ protected:
void
GenerateOutputInformation
(
void
);
void
GenerateData
();
void
ProcessBatch
(
TensorLis
tType
&
inputs
,
const
IndexValueType
&
sampleStart
,
void
ProcessBatch
(
Dic
tType
&
inputs
,
const
IndexValueType
&
sampleStart
,
const
IndexValueType
&
batchSize
);
private:
...
...
@@ -110,6 +111,9 @@ private:
ConfMatListType
m_ConfusionMatrices
;
// Confusion matrix
MapOfClassesListType
m_MapsOfClasses
;
// Maps of classes
// Internal
std
::
vector
<
MatMapType
>
m_ConfMatMaps
;
// Accumulators
};
// end class
...
...
include/otbTensorflowMultisourceModelValidate.hxx
View file @
ddf221b3
...
...
@@ -110,24 +110,24 @@ TensorflowMultisourceModelValidate<TInputImage>
// Temporary images for outputs
m_ConfusionMatrices
.
clear
();
m_MapsOfClasses
.
clear
();
std
::
vector
<
MatMapType
>
c
onfMatMaps
;
m_C
onfMatMaps
.
clear
()
;
for
(
auto
const
&
ref
:
m_References
)
{
(
void
)
ref
;
// New confusion matrix
MatMapType
mat
;
c
onfMatMaps
.
push_back
(
mat
);
m_C
onfMatMaps
.
push_back
(
mat
);
}
// Run all the batches
Superclass
::
GenerateData
();
// Compute confusion matrices
for
(
unsigned
int
i
=
0
;
i
<
c
onfMatMaps
.
size
()
;
i
++
)
for
(
unsigned
int
i
=
0
;
i
<
m_C
onfMatMaps
.
size
()
;
i
++
)
{
// Confusion matrix (map) for current target
MatMapType
mat
=
c
onfMatMaps
[
i
];
MatMapType
mat
=
m_C
onfMatMaps
[
i
];
// List all values
MapOfClassesType
values
;
...
...
@@ -159,10 +159,6 @@ TensorflowMultisourceModelValidate<TInputImage>
m_ConfusionMatrices
.
push_back
(
matrix
);
m_MapsOfClasses
.
push_back
(
values
);
}
}
...
...
@@ -171,11 +167,12 @@ TensorflowMultisourceModelValidate<TInputImage>
template
<
class
TInputImage
>
void
TensorflowMultisourceModelValidate
<
TInputImage
>
::
ProcessBatch
(
TensorLis
tType
&
inputs
,
const
IndexValueType
&
sampleStart
,
::
ProcessBatch
(
Dic
tType
&
inputs
,
const
IndexValueType
&
sampleStart
,
const
IndexValueType
&
batchSize
)
{
// Populate input tensors
PopulateInputTensor
(
inputs
,
sampleStart
,
batchSize
);
IndexListType
empty
;
this
->
PopulateInputTensors
(
inputs
,
sampleStart
,
batchSize
,
empty
);
// Run the TF session here
TensorListType
outputs
;
...
...
@@ -223,21 +220,21 @@ TensorflowMultisourceModelValidate<TInputImage>
const
int
classIn
=
static_cast
<
LabelValueType
>
(
inIt
.
Get
()[
0
]);
const
int
classRef
=
static_cast
<
LabelValueType
>
(
refIt
.
Get
()[
0
]);
if
(
c
onfMatMaps
[
refIdx
].
count
(
classRef
)
==
0
)
if
(
m_C
onfMatMaps
[
refIdx
].
count
(
classRef
)
==
0
)
{
MapType
newMap
;
newMap
[
classIn
]
=
1
;
c
onfMatMaps
[
refIdx
][
classRef
]
=
newMap
;
m_C
onfMatMaps
[
refIdx
][
classRef
]
=
newMap
;
}
else
{
if
(
c
onfMatMaps
[
refIdx
][
classRef
].
count
(
classIn
)
==
0
)
if
(
m_C
onfMatMaps
[
refIdx
][
classRef
].
count
(
classIn
)
==
0
)
{
c
onfMatMaps
[
refIdx
][
classRef
][
classIn
]
=
1
;
m_C
onfMatMaps
[
refIdx
][
classRef
][
classIn
]
=
1
;
}
else
{
c
onfMatMaps
[
refIdx
][
classRef
][
classIn
]
++
;
m_C
onfMatMaps
[
refIdx
][
classRef
][
classIn
]
++
;
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment