Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
What's new
10
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Open sidebar
Lozac'h Loic
otbtf
Commits
108aab60
Commit
108aab60
authored
Sep 01, 2018
by
Cresson Remi
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
ENH: perform validation every Nth epochs
parent
4d9c113a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
92 additions
and
83 deletions
+92
-83
app/otbTensorflowModelServe.cxx
app/otbTensorflowModelServe.cxx
+27
-25
app/otbTensorflowModelTrain.cxx
app/otbTensorflowModelTrain.cxx
+65
-58
No files found.
app/otbTensorflowModelServe.cxx
View file @
108aab60
...
...
@@ -141,23 +141,21 @@ public:
// Documentation
SetName
(
"TensorflowModelServe"
);
SetDescription
(
"Multisource deep learning classifier using Tensorflow. Change "
"the "
+
tf
::
ENV_VAR_NAME_NSOURCES
+
" environment variable to set the number of "
"sources."
);
SetDocLongDescription
(
"The application run a Tensorflow model over multiple data sources. "
"The number of input sources can be changed at runtime by setting the "
"system environment variable "
+
tf
::
ENV_VAR_NAME_NSOURCES
+
". "
"For each source, you have to set (1) the tensor placeholder name, as named in "
"the tensorflow model, (2) the patch size and (3) the image(s) source. "
"The output is a multiband image, stacking all outputs "
"tensors together: you have to specify the names of the output tensors, as "
"named in the tensorflow model (typically, an operator's output). The output "
"tensors values will be stacked in the same order as they appear in the "
"
\"
model.output
\"
parameter (you can use a space separator between names). "
"Last but not least, consider using extended filename to bypass the automatic "
"memory footprint calculator of the otb application engine, and set a good "
"splitting strategy (I would recommend using small square tiles) or use the "
"finetuning parameter group to impose your squared tiles sizes"
);
SetDescription
(
"Multisource deep learning classifier using TensorFlow. Change the "
+
tf
::
ENV_VAR_NAME_NSOURCES
+
" environment variable to set the number of sources."
);
SetDocLongDescription
(
"The application run a TensorFlow model over multiple data sources. "
"The number of input sources can be changed at runtime by setting the system "
"environment variable "
+
tf
::
ENV_VAR_NAME_NSOURCES
+
". For each source, you have to "
"set (1) the placeholder name, as named in the TensorFlow model, (2) the receptive "
"field and (3) the image(s) source. The output is a multiband image, stacking all "
"outputs tensors together: you have to specify (1) the names of the output tensors, as "
"named in the TensorFlow model (typically, an operator's output) and (2) the expression "
"field of each output tensor. The output tensors values will be stacked in the same "
"order as they appear in the
\"
model.output
\"
parameter (you can use a space separator "
"between names). You might consider to use extended filename to bypass the automatic "
"memory footprint calculator of the otb application engine, and set a good splitting "
"strategy (Square tiles is good for convolutional networks) or use the
\"
optim
\"
"
"parameter group to impose your squared tiles sizes"
);
SetDocAuthors
(
"Remi Cresson"
);
// Input/output images
...
...
@@ -167,17 +165,21 @@ public:
// Input model
AddParameter
(
ParameterType_Group
,
"model"
,
"model parameters"
);
AddParameter
(
ParameterType_Directory
,
"model.dir"
,
"Tensor
f
low model_save directory"
);
AddParameter
(
ParameterType_Directory
,
"model.dir"
,
"Tensor
F
low model_save directory"
);
MandatoryOn
(
"model.dir"
);
SetParameterDescription
(
"model.dir"
,
"The model directory should contains the model Google Protobuf (.pb) and variables"
);
AddParameter
(
ParameterType_StringList
,
"model.userplaceholders"
,
"Additional single-valued placeholders. Supported types: int, float, bool."
);
MandatoryOff
(
"model.userplaceholders"
);
SetParameterDescription
(
"model.userplaceholders"
,
"Syntax to use is
\"
placeholder_1=value_1 ... placeholder_N=value_N
\"
"
);
AddParameter
(
ParameterType_Bool
,
"model.fullyconv"
,
"Fully convolutional"
);
MandatoryOff
(
"model.fullyconv"
);
// Output tensors parameters
AddParameter
(
ParameterType_Group
,
"output"
,
"Output tensors parameters"
);
AddParameter
(
ParameterType_Float
,
"output.spcscale"
,
"The output spacing scale"
);
AddParameter
(
ParameterType_Float
,
"output.spcscale"
,
"The output spacing scale
, related to the first input
"
);
SetDefaultParameterFloat
(
"output.spcscale"
,
1.0
);
SetParameterDescription
(
"output.spcscale"
,
"The output image size/scale and spacing*scale where size and spacing corresponds to the first input"
);
AddParameter
(
ParameterType_StringList
,
"output.names"
,
"Names of the output tensors"
);
MandatoryOn
(
"output.names"
);
...
...
@@ -195,6 +197,7 @@ public:
AddParameter
(
ParameterType_Group
,
"optim"
,
"This group of parameters allows optimization of processing time"
);
AddParameter
(
ParameterType_Bool
,
"optim.disabletiling"
,
"Disable tiling"
);
MandatoryOff
(
"optim.disabletiling"
);
SetParameterDescription
(
"optim.disabletiling"
,
"Tiling avoids to process a too large subset of image, but sometimes it can be useful to disable it"
);
AddParameter
(
ParameterType_Int
,
"optim.tilesize"
,
"Tile width used to stream the filter output"
);
SetMinimumParameterIntValue
(
"optim.tilesize"
,
1
);
SetDefaultParameterInt
(
"optim.tilesize"
,
16
);
...
...
@@ -230,8 +233,8 @@ public:
bundle
.
m_PatchSize
[
1
]
=
GetParameterInt
(
bundle
.
m_KeyPszY
);
otbAppLogINFO
(
"Source info :"
);
otbAppLogINFO
(
"
Field of view
: "
<<
bundle
.
m_PatchSize
);
otbAppLogINFO
(
"Placeholder
: "
<<
bundle
.
m_Placeholder
);
otbAppLogINFO
(
"
Receptive field
: "
<<
bundle
.
m_PatchSize
);
otbAppLogINFO
(
"Placeholder
name
: "
<<
bundle
.
m_Placeholder
);
}
}
...
...
@@ -273,7 +276,7 @@ public:
// Fully convolutional mode on/off
if
(
GetParameterInt
(
"model.fullyconv"
)
==
1
)
{
otbAppLogINFO
(
"The
t
ensor
f
low model is used in fully convolutional mode"
);
otbAppLogINFO
(
"The
T
ensor
F
low model is used in fully convolutional mode"
);
m_TFFilter
->
SetFullyConvolutional
(
true
);
}
...
...
@@ -292,7 +295,7 @@ public:
const
unsigned
int
tileSize
=
GetParameterInt
(
"optim.tilesize"
);
otbAppLogINFO
(
"Force tiling with squared tiles of "
<<
tileSize
)
// Update the T
F filter
to get the output image size
// Update the T
ensorFlow filter output information
to get the output image size
m_TFFilter
->
UpdateOutputInformation
();
// Splitting using square tiles
...
...
@@ -301,7 +304,7 @@ public:
unsigned
int
nbDesiredTiles
=
itk
::
Math
::
Ceil
<
unsigned
int
>
(
double
(
m_TFFilter
->
GetOutput
()
->
GetLargestPossibleRegion
().
GetNumberOfPixels
()
)
/
(
tileSize
*
tileSize
)
);
// Use an itk::StreamingImageFilter to force the computation
on
tile
s
// Use an itk::StreamingImageFilter to force the computation
tile by
tile
m_StreamFilter
=
StreamingFilterType
::
New
();
m_StreamFilter
->
SetRegionSplitter
(
splitter
);
m_StreamFilter
->
SetNumberOfStreamDivisions
(
nbDesiredTiles
);
...
...
@@ -313,7 +316,6 @@ public:
{
otbAppLogINFO
(
"Tiling disabled"
);
SetParameterOutputImage
(
"out"
,
m_TFFilter
->
GetOutput
());
}
}
...
...
app/otbTensorflowModelTrain.cxx
View file @
108aab60
...
...
@@ -191,7 +191,7 @@ public:
SetDefaultParameterInt
(
"training.batchsize"
,
100
);
AddParameter
(
ParameterType_Int
,
"training.epochs"
,
"Number of epochs"
);
SetMinimumParameterIntValue
(
"training.epochs"
,
1
);
SetDefaultParameterInt
(
"training.epochs"
,
10
);
SetDefaultParameterInt
(
"training.epochs"
,
10
0
);
AddParameter
(
ParameterType_StringList
,
"training.userplaceholders"
,
"Additional single-valued placeholders for training. Supported types: int, float, bool."
);
MandatoryOff
(
"training.userplaceholders"
);
...
...
@@ -205,6 +205,9 @@ public:
// Metrics
AddParameter
(
ParameterType_Group
,
"validation"
,
"Validation parameters"
);
MandatoryOff
(
"validation"
);
AddParameter
(
ParameterType_Int
,
"validation.step"
,
"Perform the validation every Nth epochs"
);
SetMinimumParameterIntValue
(
"validation.step"
,
1
);
SetDefaultParameterInt
(
"validation.step"
,
10
);
AddParameter
(
ParameterType_Choice
,
"validation.mode"
,
"Metrics to compute"
);
AddChoice
(
"validation.mode.none"
,
"No validation step"
);
AddChoice
(
"validation.mode.class"
,
"Classification metrics"
);
...
...
@@ -415,7 +418,7 @@ public:
// Prepare inputs
PrepareInputs
();
// Setup filter
// Setup
training
filter
m_TrainModelFilter
=
TrainModelFilterType
::
New
();
m_TrainModelFilter
->
SetGraph
(
m_SavedModel
.
meta_graph_def
.
graph_def
());
m_TrainModelFilter
->
SetSession
(
m_SavedModel
.
session
.
get
());
...
...
@@ -434,21 +437,6 @@ public:
m_InputSourcesForTraining
[
i
]);
}
// Train the model
for
(
int
epoch
=
0
;
epoch
<
GetParameterInt
(
"training.epochs"
)
;
epoch
++
)
{
AddProcess
(
m_TrainModelFilter
,
"Training epoch #"
+
std
::
to_string
(
epoch
+
1
));
m_TrainModelFilter
->
Update
();
}
// Check if we have to save variables to somewhere
if
(
HasValue
(
"model.saveto"
))
{
const
std
::
string
path
=
GetParameterAsString
(
"model.saveto"
);
otbAppLogINFO
(
"Saving model to "
+
path
);
tf
::
SaveModel
(
path
,
m_SavedModel
);
}
// Setup the validation filter
if
(
GetParameterInt
(
"validation.mode"
)
==
1
)
// class
{
...
...
@@ -459,60 +447,79 @@ public:
m_ValidateModelFilter
->
SetSession
(
m_SavedModel
.
session
.
get
());
m_ValidateModelFilter
->
SetBatchSize
(
GetParameterInt
(
"training.batchsize"
));
m_ValidateModelFilter
->
SetUserPlaceholders
(
GetUserPlaceholders
(
"validation.userplaceholders"
));
m_ValidateModelFilter
->
SetInputPlaceholders
(
m_InputPlaceholdersForValidation
);
m_ValidateModelFilter
->
SetInputReceptiveFields
(
m_InputPatchesSizeForValidation
);
m_ValidateModelFilter
->
SetOutputTensors
(
m_TargetTensorsNames
);
m_ValidateModelFilter
->
SetOutputExpressionFields
(
m_TargetPatchesSize
);
}
else
if
(
GetParameterInt
(
"validation.mode"
)
==
2
)
// rmse)
{
otbAppLogINFO
(
"Set validation mode to classification RMSE evaluation"
);
otbAppLogFATAL
(
"Not implemented yet !"
);
// XD
//
AS we use the learning data here, it's rational to use the same option as streaming during training
m_ValidateModelFilter
->
SetUseStreaming
(
GetParameterInt
(
"training.usestreaming"
));
//
TODO
}
// 1. Evaluate the metrics against the learning data
// Epoch
for
(
int
epoch
=
1
;
epoch
<=
GetParameterInt
(
"training.epochs"
)
;
epoch
++
)
{
// Train the model
AddProcess
(
m_TrainModelFilter
,
"Training epoch #"
+
std
::
to_string
(
epoch
));
m_TrainModelFilter
->
Update
();
for
(
unsigned
int
i
=
0
;
i
<
m_InputSourcesForEvaluationAgainstLearningData
.
size
()
;
i
++
)
// Validate the model
if
(
epoch
%
GetParameterInt
(
"validation.step"
)
==
0
)
{
m_ValidateModelFilter
->
PushBackInputTensorBundle
(
m_InputPlaceholdersForValidation
[
i
],
m_InputPatchesSizeForValidation
[
i
],
m_InputSourcesForEvaluationAgainstLearningData
[
i
]);
}
m_ValidateModelFilter
->
SetOutputTensors
(
m_TargetTensorsNames
);
m_ValidateModelFilter
->
SetInputReferences
(
m_InputTargetsForEvaluationAgainstLearningData
);
m_ValidateModelFilter
->
SetOutputExpressionFields
(
m_TargetPatchesSize
);
// 1. Evaluate the metrics against the learning data
// Update
AddProcess
(
m_ValidateModelFilter
,
"Evaluate model (Learning data)"
);
m_ValidateModelFilter
->
Update
();
for
(
unsigned
int
i
=
0
;
i
<
m_InputSourcesForEvaluationAgainstLearningData
.
size
()
;
i
++
)
{
m_ValidateModelFilter
->
SetInput
(
i
,
m_InputSourcesForEvaluationAgainstLearningData
[
i
]);
}
m_ValidateModelFilter
->
SetInputReferences
(
m_InputTargetsForEvaluationAgainstLearningData
);
for
(
unsigned
int
i
=
0
;
i
<
m_TargetTensorsNames
.
size
()
;
i
++
)
{
otbAppLogINFO
(
"Metrics for target
\"
"
<<
m_TargetTensorsNames
[
i
]
<<
"
\"
:"
);
PrintClassificationMetrics
(
m_ValidateModelFilter
->
GetConfusionMatrix
(
i
),
m_ValidateModelFilter
->
GetMapOfClasses
(
i
));
}
// As we use the learning data here, it's rational to use the same option as streaming during training
m_ValidateModelFilter
->
SetUseStreaming
(
GetParameterInt
(
"training.usestreaming"
));
// 2. Evaluate the metrics against the validation data
// Update
AddProcess
(
m_ValidateModelFilter
,
"Evaluate model (Learning data)"
);
m_ValidateModelFilter
->
Update
();
// Here we just change the input sources and references
for
(
unsigned
int
i
=
0
;
i
<
m_InputSourcesForEvaluationAgainstValidationData
.
size
()
;
i
++
)
{
m_ValidateModelFilter
->
SetInput
(
i
,
m_InputSourcesForEvaluationAgainstValidationData
[
i
]);
}
m_ValidateModelFilter
->
SetInputReferences
(
m_InputTargetsForEvaluationAgainstValidationData
);
m_ValidateModelFilter
->
SetUseStreaming
(
GetParameterInt
(
"validation.usestreaming"
));
for
(
unsigned
int
i
=
0
;
i
<
m_TargetTensorsNames
.
size
()
;
i
++
)
{
otbAppLogINFO
(
"Metrics for target
\"
"
<<
m_TargetTensorsNames
[
i
]
<<
"
\"
:"
);
PrintClassificationMetrics
(
m_ValidateModelFilter
->
GetConfusionMatrix
(
i
),
m_ValidateModelFilter
->
GetMapOfClasses
(
i
));
}
// Update
AddProcess
(
m_ValidateModelFilter
,
"Evaluate model (Validation data)"
);
m_ValidateModelFilter
->
Update
();
// 2. Evaluate the metrics against the validation data
for
(
unsigned
int
i
=
0
;
i
<
m_TargetTensorsNames
.
size
()
;
i
++
)
{
otbAppLogINFO
(
"Metrics for target
\"
"
<<
m_TargetTensorsNames
[
i
]
<<
"
\"
:"
);
PrintClassificationMetrics
(
m_ValidateModelFilter
->
GetConfusionMatrix
(
i
),
m_ValidateModelFilter
->
GetMapOfClasses
(
i
));
}
// Here we just change the input sources and references
for
(
unsigned
int
i
=
0
;
i
<
m_InputSourcesForEvaluationAgainstValidationData
.
size
()
;
i
++
)
{
m_ValidateModelFilter
->
SetInput
(
i
,
m_InputSourcesForEvaluationAgainstValidationData
[
i
]);
}
m_ValidateModelFilter
->
SetInputReferences
(
m_InputTargetsForEvaluationAgainstValidationData
);
m_ValidateModelFilter
->
SetUseStreaming
(
GetParameterInt
(
"validation.usestreaming"
));
}
else
if
(
GetParameterInt
(
"validation.mode"
)
==
2
)
// rmse)
{
otbAppLogINFO
(
"Set validation mode to classification RMSE evaluation"
);
// Update
AddProcess
(
m_ValidateModelFilter
,
"Evaluate model (Validation data)"
);
m_ValidateModelFilter
->
Update
();
// TODO
for
(
unsigned
int
i
=
0
;
i
<
m_TargetTensorsNames
.
size
()
;
i
++
)
{
otbAppLogINFO
(
"Metrics for target
\"
"
<<
m_TargetTensorsNames
[
i
]
<<
"
\"
:"
);
PrintClassificationMetrics
(
m_ValidateModelFilter
->
GetConfusionMatrix
(
i
),
m_ValidateModelFilter
->
GetMapOfClasses
(
i
));
}
}
// Step is OK to perform validation
}
// Next epoch
// Check if we have to save variables to somewhere
if
(
HasValue
(
"model.saveto"
))
{
const
std
::
string
path
=
GetParameterAsString
(
"model.saveto"
);
otbAppLogINFO
(
"Saving model to "
+
path
);
tf
::
SaveModel
(
path
,
m_SavedModel
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment