Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Lozac'h Loic
otbtf
Commits
5a1d2322
Commit
5a1d2322
authored
Sep 29, 2018
by
Cresson Remi
Browse files
Merge branch 'develop'
parents
1e8bf930
d2195d91
Changes
11
Hide whitespace changes
Inline
Side-by-side
app/otbTensorflowModelServe.cxx
View file @
5a1d2322
...
...
@@ -30,8 +30,7 @@
#include "otbTensorflowSource.h"
// Streaming
#include "otbImageRegionSquareTileSplitter.h"
#include "itkStreamingImageFilter.h"
#include "otbTensorflowStreamerFilter.h"
namespace
otb
{
...
...
@@ -58,7 +57,7 @@ public:
/** Typedef for streaming */
typedef
otb
::
ImageRegionSquareTileSplitter
<
FloatVectorImageType
::
ImageDimension
>
TileSplitterType
;
typedef
itk
::
StreamingImage
Filter
<
FloatVectorImageType
,
FloatVectorImageType
>
StreamingFilterType
;
typedef
otb
::
TensorflowStreamer
Filter
<
FloatVectorImageType
,
FloatVectorImageType
>
StreamingFilterType
;
/** Typedefs for images */
typedef
FloatVectorImageType
::
SizeType
SizeType
;
...
...
@@ -198,9 +197,12 @@ public:
AddParameter
(
ParameterType_Bool
,
"optim.disabletiling"
,
"Disable tiling"
);
MandatoryOff
(
"optim.disabletiling"
);
SetParameterDescription
(
"optim.disabletiling"
,
"Tiling avoids to process a too large subset of image, but sometimes it can be useful to disable it"
);
AddParameter
(
ParameterType_Int
,
"optim.tilesize"
,
"Tile width used to stream the filter output"
);
SetMinimumParameterIntValue
(
"optim.tilesize"
,
1
);
SetDefaultParameterInt
(
"optim.tilesize"
,
16
);
AddParameter
(
ParameterType_Int
,
"optim.tilesizex"
,
"Tile width used to stream the filter output"
);
SetMinimumParameterIntValue
(
"optim.tilesizex"
,
1
);
SetDefaultParameterInt
(
"optim.tilesizex"
,
16
);
AddParameter
(
ParameterType_Int
,
"optim.tilesizey"
,
"Tile height used to stream the filter output"
);
SetMinimumParameterIntValue
(
"optim.tilesizey"
,
1
);
SetDefaultParameterInt
(
"optim.tilesizey"
,
16
);
// Output image
AddParameter
(
ParameterType_OutputImage
,
"out"
,
"output image"
);
...
...
@@ -292,22 +294,28 @@ public:
if
(
GetParameterInt
(
"optim.disabletiling"
)
!=
1
)
{
// Get the tile size
const
unsigned
int
tileSize
=
GetParameterInt
(
"optim.tilesize"
);
otbAppLogINFO
(
"Force tiling with squared tiles of "
<<
tileSize
)
SizeType
tileSize
;
tileSize
[
0
]
=
GetParameterInt
(
"optim.tilesizex"
);
tileSize
[
1
]
=
GetParameterInt
(
"optim.tilesizey"
);
// Check that the tile size is aligned to the field of expression
for
(
unsigned
int
i
=
0
;
i
<
FloatVectorImageType
::
ImageDimension
;
i
++
)
if
(
tileSize
[
i
]
%
foe
[
i
]
!=
0
)
{
SizeType
::
SizeValueType
newSize
=
1
+
std
::
floor
(
tileSize
[
i
]
/
foe
[
i
]);
newSize
*=
foe
[
i
];
// Update the TensorFlow filter output information to get the output image size
m_TFFilter
->
UpdateOutputInformation
();
otbAppLogWARNING
(
"Aligning the tiling to the output expression field "
<<
"for better performances (dim "
<<
i
<<
"). New value set to "
<<
newSize
)
// Splitting using square tiles
TileSplitterType
::
Pointer
splitter
=
TileSplitterType
::
New
();
splitter
->
SetTileSizeAlignment
(
tileSize
);
unsigned
int
nbDesiredTiles
=
itk
::
Math
::
Ceil
<
unsigned
int
>
(
double
(
m_TFFilter
->
GetOutput
()
->
GetLargestPossibleRegion
().
GetNumberOfPixels
()
)
/
(
tileSize
*
tileSize
)
);
tileSize
[
i
]
=
newSize
;
}
otbAppLogINFO
(
"Force tiling with squared tiles of "
<<
tileSize
)
//
Use an itk::StreamingImageFilter to f
orce the computation tile by tile
//
F
orce the computation tile by tile
m_StreamFilter
=
StreamingFilterType
::
New
();
m_StreamFilter
->
SetRegionSplitter
(
splitter
);
m_StreamFilter
->
SetNumberOfStreamDivisions
(
nbDesiredTiles
);
m_StreamFilter
->
SetOutputGridSize
(
tileSize
);
m_StreamFilter
->
SetInput
(
m_TFFilter
->
GetOutput
());
SetParameterOutputImage
(
"out"
,
m_StreamFilter
->
GetOutput
());
...
...
app/otbTensorflowModelTrain.cxx
View file @
5a1d2322
...
...
@@ -438,6 +438,7 @@ public:
}
// Setup the validation filter
const
bool
do_validation
=
HasUserValue
(
"validation.mode"
);
if
(
GetParameterInt
(
"validation.mode"
)
==
1
)
// class
{
otbAppLogINFO
(
"Set validation mode to classification validation"
);
...
...
@@ -467,50 +468,53 @@ public:
AddProcess
(
m_TrainModelFilter
,
"Training epoch #"
+
std
::
to_string
(
epoch
));
m_TrainModelFilter
->
Update
();
// Validate the model
if
(
epoch
%
GetParameterInt
(
"validation.step"
)
==
0
)
if
(
do_validation
)
{
// Validate the model
if
(
epoch
%
GetParameterInt
(
"validation.step"
)
==
0
)
{
// 1. Evaluate the metrics against the learning data
// 1. Evaluate the metrics against the learning data
for
(
unsigned
int
i
=
0
;
i
<
m_InputSourcesForEvaluationAgainstLearningData
.
size
()
;
i
++
)
for
(
unsigned
int
i
=
0
;
i
<
m_InputSourcesForEvaluationAgainstLearningData
.
size
()
;
i
++
)
{
m_ValidateModelFilter
->
SetInput
(
i
,
m_InputSourcesForEvaluationAgainstLearningData
[
i
]);
m_ValidateModelFilter
->
SetInput
(
i
,
m_InputSourcesForEvaluationAgainstLearningData
[
i
]);
}
m_ValidateModelFilter
->
SetInputReferences
(
m_InputTargetsForEvaluationAgainstLearningData
);
m_ValidateModelFilter
->
SetInputReferences
(
m_InputTargetsForEvaluationAgainstLearningData
);
// As we use the learning data here, it's rational to use the same option as streaming during training
m_ValidateModelFilter
->
SetUseStreaming
(
GetParameterInt
(
"training.usestreaming"
));
// As we use the learning data here, it's rational to use the same option as streaming during training
m_ValidateModelFilter
->
SetUseStreaming
(
GetParameterInt
(
"training.usestreaming"
));
// Update
AddProcess
(
m_ValidateModelFilter
,
"Evaluate model (Learning data)"
);
m_ValidateModelFilter
->
Update
();
// Update
AddProcess
(
m_ValidateModelFilter
,
"Evaluate model (Learning data)"
);
m_ValidateModelFilter
->
Update
();
for
(
unsigned
int
i
=
0
;
i
<
m_TargetTensorsNames
.
size
()
;
i
++
)
for
(
unsigned
int
i
=
0
;
i
<
m_TargetTensorsNames
.
size
()
;
i
++
)
{
otbAppLogINFO
(
"Metrics for target
\"
"
<<
m_TargetTensorsNames
[
i
]
<<
"
\"
:"
);
PrintClassificationMetrics
(
m_ValidateModelFilter
->
GetConfusionMatrix
(
i
),
m_ValidateModelFilter
->
GetMapOfClasses
(
i
));
otbAppLogINFO
(
"Metrics for target
\"
"
<<
m_TargetTensorsNames
[
i
]
<<
"
\"
:"
);
PrintClassificationMetrics
(
m_ValidateModelFilter
->
GetConfusionMatrix
(
i
),
m_ValidateModelFilter
->
GetMapOfClasses
(
i
));
}
// 2. Evaluate the metrics against the validation data
// 2. Evaluate the metrics against the validation data
// Here we just change the input sources and references
for
(
unsigned
int
i
=
0
;
i
<
m_InputSourcesForEvaluationAgainstValidationData
.
size
()
;
i
++
)
// Here we just change the input sources and references
for
(
unsigned
int
i
=
0
;
i
<
m_InputSourcesForEvaluationAgainstValidationData
.
size
()
;
i
++
)
{
m_ValidateModelFilter
->
SetInput
(
i
,
m_InputSourcesForEvaluationAgainstValidationData
[
i
]);
m_ValidateModelFilter
->
SetInput
(
i
,
m_InputSourcesForEvaluationAgainstValidationData
[
i
]);
}
m_ValidateModelFilter
->
SetInputReferences
(
m_InputTargetsForEvaluationAgainstValidationData
);
m_ValidateModelFilter
->
SetUseStreaming
(
GetParameterInt
(
"validation.usestreaming"
));
m_ValidateModelFilter
->
SetInputReferences
(
m_InputTargetsForEvaluationAgainstValidationData
);
m_ValidateModelFilter
->
SetUseStreaming
(
GetParameterInt
(
"validation.usestreaming"
));
// Update
AddProcess
(
m_ValidateModelFilter
,
"Evaluate model (Validation data)"
);
m_ValidateModelFilter
->
Update
();
// Update
AddProcess
(
m_ValidateModelFilter
,
"Evaluate model (Validation data)"
);
m_ValidateModelFilter
->
Update
();
for
(
unsigned
int
i
=
0
;
i
<
m_TargetTensorsNames
.
size
()
;
i
++
)
for
(
unsigned
int
i
=
0
;
i
<
m_TargetTensorsNames
.
size
()
;
i
++
)
{
otbAppLogINFO
(
"Metrics for target
\"
"
<<
m_TargetTensorsNames
[
i
]
<<
"
\"
:"
);
PrintClassificationMetrics
(
m_ValidateModelFilter
->
GetConfusionMatrix
(
i
),
m_ValidateModelFilter
->
GetMapOfClasses
(
i
));
otbAppLogINFO
(
"Metrics for target
\"
"
<<
m_TargetTensorsNames
[
i
]
<<
"
\"
:"
);
PrintClassificationMetrics
(
m_ValidateModelFilter
->
GetConfusionMatrix
(
i
),
m_ValidateModelFilter
->
GetMapOfClasses
(
i
));
}
}
// Step is OK to perform validation
}
// Do the validation against the validation data
}
// Next epoch
...
...
app/otbTrainClassifierFromDeepFeatures.cxx
View file @
5a1d2322
...
...
@@ -93,6 +93,7 @@ private:
ShareParameter
(
"optim"
,
"tfmodel.optim"
,
"Processing time optimization"
,
"This group of parameters allows optimization of processing time"
);
// Train shared parameters
ShareParameter
(
"ram"
,
"train.ram"
,
"Available RAM (Mb)"
,
"Available RAM (Mb)"
);
ShareParameter
(
"vd"
,
"train.io.vd"
,
"Vector data for training"
,
"Input vector data for training"
);
ShareParameter
(
"valid"
,
"train.io.valid"
,
"Vector data for validation"
,
"Input vector data for validation"
);
ShareParameter
(
"out"
,
"train.io.out"
,
"Output classification model"
,
"Output classification model"
);
...
...
include/otbTensorflowMultisourceModelFilter.hxx
View file @
5a1d2322
...
...
@@ -368,9 +368,6 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
RegionType
outputAlignedReqRegion
(
outputReqRegion
);
EnlargeToAlignedRegion
(
outputAlignedReqRegion
);
// Add a progress reporter
itk
::
ProgressReporter
progress
(
this
,
0
,
outputReqRegion
.
GetNumberOfPixels
());
const
unsigned
int
nInputs
=
this
->
GetNumberOfInputs
();
// Create input tensors list
...
...
include/otbTensorflowMultisourceModelTrain.hxx
View file @
5a1d2322
...
...
@@ -55,6 +55,12 @@ TensorflowMultisourceModelTrain<TInputImage>
TensorListType
outputs
;
this
->
RunSession
(
inputs
,
outputs
);
// Display outputs tensors
for
(
auto
&
o
:
outputs
)
{
tf
::
PrintTensorInfos
(
o
);
}
}
...
...
include/otbTensorflowStreamerFilter.h
0 → 100644
View file @
5a1d2322
/*=========================================================================
Copyright (c) Remi Cresson (IRSTEA). All rights reserved.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef otbTensorflowStreamerFilter_h
#define otbTensorflowStreamerFilter_h
// Image2image
#include "itkImageToImageFilter.h"
namespace
otb
{
/**
* \class TensorflowStreamerFilter
* \brief This filter generates an output image with an internal
* explicit streaming mechanism.
*
* \ingroup OTBTensorflow
*/
template
<
class
TInputImage
,
class
TOutputImage
>
class
ITK_EXPORT
TensorflowStreamerFilter
:
public
itk
::
ImageToImageFilter
<
TInputImage
,
TOutputImage
>
{
public:
/** Standard class typedefs. */
typedef
TensorflowStreamerFilter
Self
;
typedef
itk
::
ImageToImageFilter
<
TInputImage
,
TOutputImage
>
Superclass
;
typedef
itk
::
SmartPointer
<
Self
>
Pointer
;
typedef
itk
::
SmartPointer
<
const
Self
>
ConstPointer
;
/** Method for creation through the object factory. */
itkNewMacro
(
Self
);
/** Run-time type information (and related methods). */
itkTypeMacro
(
TensorflowStreamerFilter
,
itk
::
ImageToImageFilter
);
/** Images typedefs */
typedef
typename
Superclass
::
InputImageType
ImageType
;
typedef
typename
ImageType
::
IndexType
IndexType
;
typedef
typename
ImageType
::
IndexValueType
IndexValueType
;
typedef
typename
ImageType
::
SizeType
SizeType
;
typedef
typename
Superclass
::
InputImageRegionType
RegionType
;
typedef
TOutputImage
OutputImageType
;
itkSetMacro
(
OutputGridSize
,
SizeType
);
itkGetMacro
(
OutputGridSize
,
SizeType
);
protected:
TensorflowStreamerFilter
();
virtual
~
TensorflowStreamerFilter
()
{};
virtual
void
UpdateOutputData
(
itk
::
DataObject
*
output
){(
void
)
output
;
this
->
GenerateData
();}
virtual
void
GenerateData
();
private:
TensorflowStreamerFilter
(
const
Self
&
);
//purposely not implemented
void
operator
=
(
const
Self
&
);
//purposely not implemented
SizeType
m_OutputGridSize
;
// Output grid size
};
// end class
}
// end namespace otb
#include "otbTensorflowStreamerFilter.hxx"
#endif
include/otbTensorflowStreamerFilter.hxx
0 → 100644
View file @
5a1d2322
/*=========================================================================
Copyright (c) Remi Cresson (IRSTEA). All rights reserved.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef otbTensorflowStreamerFilter_txx
#define otbTensorflowStreamerFilter_txx
#include "otbTensorflowStreamerFilter.h"
#include "itkImageAlgorithm.h"
namespace
otb
{
template
<
class
TInputImage
,
class
TOutputImage
>
TensorflowStreamerFilter
<
TInputImage
,
TOutputImage
>
::
TensorflowStreamerFilter
()
{
m_OutputGridSize
.
Fill
(
1
);
}
/**
* Compute the output image
*/
template
<
class
TInputImage
,
class
TOutputImage
>
void
TensorflowStreamerFilter
<
TInputImage
,
TOutputImage
>
::
GenerateData
()
{
// Output pointer and requested region
OutputImageType
*
outputPtr
=
this
->
GetOutput
();
const
RegionType
outputReqRegion
=
outputPtr
->
GetRequestedRegion
();
outputPtr
->
SetBufferedRegion
(
outputReqRegion
);
outputPtr
->
Allocate
();
// Compute the aligned region
RegionType
region
;
for
(
unsigned
int
dim
=
0
;
dim
<
OutputImageType
::
ImageDimension
;
++
dim
)
{
// Get corners
IndexValueType
lower
=
outputReqRegion
.
GetIndex
(
dim
);
IndexValueType
upper
=
lower
+
outputReqRegion
.
GetSize
(
dim
);
// Compute deltas between corners and the grid
const
IndexValueType
deltaLo
=
lower
%
m_OutputGridSize
[
dim
];
const
IndexValueType
deltaUp
=
upper
%
m_OutputGridSize
[
dim
];
// Move corners to aligned positions
lower
-=
deltaLo
;
if
(
deltaUp
>
0
)
{
upper
+=
m_OutputGridSize
[
dim
]
-
deltaUp
;
}
// Update region
region
.
SetIndex
(
dim
,
lower
);
region
.
SetSize
(
dim
,
upper
-
lower
);
}
// Compute the number of subregions to process
const
unsigned
int
nbTilesX
=
region
.
GetSize
(
0
)
/
m_OutputGridSize
[
0
];
const
unsigned
int
nbTilesY
=
region
.
GetSize
(
1
)
/
m_OutputGridSize
[
1
];
// Progress
itk
::
ProgressReporter
progress
(
this
,
0
,
nbTilesX
*
nbTilesY
);
// For each tile, propagate the input region and recopy the output
ImageType
*
inputImage
=
static_cast
<
ImageType
*
>
(
Superclass
::
ProcessObject
::
GetInput
(
0
)
);
unsigned
int
tx
,
ty
;
RegionType
subRegion
;
subRegion
.
SetSize
(
m_OutputGridSize
);
for
(
ty
=
0
;
ty
<
nbTilesY
;
ty
++
)
{
subRegion
.
SetIndex
(
1
,
ty
*
m_OutputGridSize
[
1
]
+
region
.
GetIndex
(
1
));
for
(
tx
=
0
;
tx
<
nbTilesX
;
tx
++
)
{
// Update the input subregion
subRegion
.
SetIndex
(
0
,
tx
*
m_OutputGridSize
[
0
]
+
region
.
GetIndex
(
0
));
// The actual region to copy
RegionType
cpyRegion
(
subRegion
);
cpyRegion
.
Crop
(
outputReqRegion
);
// Propagate region
inputImage
->
SetRequestedRegion
(
cpyRegion
);
inputImage
->
PropagateRequestedRegion
();
inputImage
->
UpdateOutputData
();
// Copy the subregion to output
itk
::
ImageAlgorithm
::
Copy
(
inputImage
,
outputPtr
,
cpyRegion
,
cpyRegion
);
progress
.
CompletedPixel
();
}
}
}
}
// end namespace otb
#endif
python/ckpt2savedmodel.py
0 → 100644
View file @
5a1d2322
# -*- coding: utf-8 -*-
#==========================================================================
#
# Copyright Remi Cresson (IRSTEA)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#==========================================================================*/
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
argparse
from
tricks
import
*
# Logging
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
# Parser
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--ckpt"
,
help
=
"checkpoint file prefix"
,
required
=
True
)
parser
.
add_argument
(
"--inputs"
,
help
=
"input placeholder names"
,
required
=
True
,
nargs
=
'+'
)
parser
.
add_argument
(
"--outputs"
,
help
=
"output placeholder names"
,
required
=
True
,
nargs
=
'+'
)
parser
.
add_argument
(
"--model"
,
help
=
"output SavedModel"
,
required
=
True
)
params
=
parser
.
parse_args
()
if
__name__
==
"__main__"
:
CheckpointToSavedModel
(
params
.
ckpt
,
params
.
inputs
,
params
.
outputs
,
params
.
model
)
quit
()
python/create_model_ienco-m3_patchbased.py
View file @
5a1d2322
# -*- coding: utf-8 -*-
#==========================================================================
#
# Copyright Remi Cresson, Dino Ienco (IRSTEA)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#==========================================================================*/
import
sys
import
os
import
numpy
as
np
...
...
@@ -13,35 +31,6 @@ from sklearn.ensemble import RandomForestClassifier
from
sklearn.utils
import
shuffle
from
sklearn.metrics
import
confusion_matrix
from
tricks
import
*
def
export_model
(
sess
,
export_dir
,
x_cnn_placeholder
,
x_rnn_placeholder
,
is_training_placeholder
,
testPrediction
):
""" export a SavedModel
"""
# Update the export dir
model_dir
=
export_dir
+
"/saved_model/"
if
os
.
path
.
exists
(
model_dir
):
shutil
.
rmtree
(
model_dir
)
print
(
"Export model in "
+
model_dir
)
# Add a builder (for LoadSavedModel)
builder
=
tf
.
saved_model
.
builder
.
SavedModelBuilder
(
model_dir
)
signature_def_map
=
{
"model"
:
tf
.
saved_model
.
signature_def_utils
.
predict_signature_def
(
inputs
=
{
"x_cnn"
:
x_cnn_placeholder
,
"x_rnn"
:
x_rnn_placeholder
,
"is_training"
:
is_training_placeholder
},
outputs
=
{
"prediction"
:
testPrediction
})
}
builder
.
add_meta_graph_and_variables
(
sess
,[
tf
.
saved_model
.
tag_constants
.
TRAINING
],
signature_def_map
)
builder
.
add_meta_graph
([
tf
.
saved_model
.
tag_constants
.
SERVING
])
builder
.
save
()
def
checkTest
(
ts_data
,
vhsr_data
,
batchsz
,
label_test
):
tot_pred
=
[]
...
...
@@ -60,11 +49,7 @@ def checkTest(ts_data, vhsr_data, batchsz, label_test):
is_training_ph
:
True
,
dropout
:
0.0
,
x_cnn
:
batch_cnn_x
})
del
batch_rnn_x
del
batch_cnn_x
del
batch_y
for
el
in
pred_temp
:
tot_pred
.
append
(
el
)
...
...
@@ -241,8 +226,8 @@ n_channels = 4
nclasses
=
8
# check number of arguments
if
len
(
sys
.
argv
)
!=
7
:
print
(
"Usage : <ts_train> <vhs_train> <label_train> <ts_valid> <vhs_valid> <label_valid>"
)
if
len
(
sys
.
argv
)
!=
8
:
print
(
"Usage : <ts_train> <vhs_train> <label_train> <ts_valid> <vhs_valid> <label_valid>
<export_dir>
"
)
sys
.
exit
(
1
)
ts_train
=
read_samples
(
sys
.
argv
[
1
])
...
...
@@ -257,6 +242,8 @@ label_test = read_samples(sys.argv[6])
label_test
=
np
.
int32
(
label_test
)
print_histo
(
label_test
,
"label_test"
)
export_dir
=
read_samples
(
sys
.
argv
[
7
])
x_rnn
=
tf
.
placeholder
(
tf
.
float32
,[
None
,
1
,
1
,
n_dims
*
n_timestamps
],
name
=
"x_rnn"
)
x_cnn
=
tf
.
placeholder
(
tf
.
float32
,[
None
,
patch_window
,
patch_window
,
n_channels
],
name
=
"x_cnn"
)
y
=
tf
.
placeholder
(
tf
.
int32
,[
None
,
1
,
1
,
1
],
name
=
"y"
)
...
...
@@ -324,19 +311,12 @@ for e in range(hm_epochs):
lossi
+=
loss
accS
+=
acc
del
batch_rnn_x
del
batch_cnn_x
del
batch_y
print
"Epoch:"
,
e
,
"Train loss:"
,
lossi
/
iterations
,
"| accuracy:"
,
accS
/
iterations
c_loss
=
lossi
/
iterations
if
c_loss
<
best_loss
:
save_path
=
saver
.
save
(
sess
,
"models/model"
)
print
(
"Model saved in path: %s"
%
save_path
)
best_loss
=
c_loss
export_model
(
sess
,
"/tmp/m3_export"
,
x_cnn
,
x_rnn
,
is_training_ph
,
testPrediction
)
CreateSavedModel
(
sess
,
[
"x_cnn:0"
,
"x_rnn:0"
,
"is_training:0"
],
[
"prediction:0"
],
export_dir
)
test_acc
=
checkTest
(
ts_test
,
vhsr_test
,
1024
,
label_test
)
\ No newline at end of file
python/create_model_maggiori17_fullyconv.py
View file @
5a1d2322
# -*- coding: utf-8 -*-
#==========================================================================
#
# Copyright Remi Cresson (IRSTEA)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#==========================================================================*/
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
...
...
@@ -89,12 +107,11 @@ def main(unused_argv):
# check number of arguments
if
len
(
sys
.
argv
)
!=
4
:
print
(
"Usage : <patches> <labels> <
output_model
_dir>"
)
print
(
"Usage : <patches> <labels> <
export
_dir>"
)
sys
.
exit
(
1
)
# Export dir
log_dir
=
sys
.
argv
[
3
]
+
'/model_checkpoints/'
export_dir
=
sys
.
argv
[
3
]
+
'/model_export/'
export_dir
=
sys
.
argv
[
3
]
print
(
"loading dataset"
)
...
...
@@ -272,13 +289,9 @@ def main(unused_argv):
if
step
%
10
==
0
:
# Print status to stdout.
print
(
'Step %d: loss = %.2f (%.3f sec)'
%
(
step
,
loss_value
,
duration
))
#print('Step %d: (%.3f sec)' % (step, duration))
# Save a checkpoint and evaluate the model periodically.
if
(
curr_epoch
+
1
)
%
1
==
0
:
checkpoint_file
=
os
.
path
.
join
(
log_dir
,
'model.ckpt'
)
saver
.
save
(
sess
,
checkpoint_file
,
global_step
=
step
)
# Evaluate against the training set.
print
(
'Training Data Eval:'
)
do_eval2
(
sess
,
...
...
@@ -301,16 +314,7 @@ def main(unused_argv):
batch_size
)
# Let's export a SavedModel
shutil
.
rmtree
(
export_dir
)
builder
=
tf
.
saved_model
.
builder
.
SavedModelBuilder
(
export_dir
)
signature_def_map
=
{
"model"
:
tf
.
saved_model
.
signature_def_utils
.
predict_signature_def
(
inputs
=
{
"x1"
:
xs_placeholder
},
outputs
=
{
"prediction"
:
testPrediction
})
}
builder
.
add_meta_graph_and_variables
(
sess
,
[
tf
.
saved_model
.
tag_constants
.
TRAINING
],
signature_def_map
)
builder
.
add_meta_graph
([
tf
.
saved_model
.
tag_constants
.
SERVING
])
builder
.
save
()
CreateSavedModel
(
sess
,
[
"x1:0"
],
[
"prediction:0"
],
export_dir
)
quit
()
...
...