Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Cresson Remi
otbtf
Commits
d7ca6021
Commit
d7ca6021
authored
Sep 12, 2018
by
remi cresson
Browse files
ENH: add function to export sess+graph in a SavedModel
parent
1e8bf930
Changes
1
Hide whitespace changes
Inline
Side-by-side
python/tricks.py
View file @
d7ca6021
# -*- coding: utf-8 -*-
#==========================================================================
#
# Copyright Remi Cresson (IRSTEA)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#==========================================================================*/
import
sys
import
os
import
numpy
as
np
import
math
import
time
import
otbApplication
import
tensorflow
as
tf
import
shutil
def
flatten_nparray
(
np_arr
):
""" Returns a 1D numpy array retulting from the flatten of the input
...
...
@@ -35,10 +55,13 @@ def print_tensor_info(name, tensor):
print
(
name
+
" : "
+
str
(
tensor
.
shape
)
+
" (dtype="
+
str
(
tensor
.
dtype
)
+
")"
)
def
read_samples
(
fn
):
def
read_samples
(
fn
,
single
=
False
):
""" Read an image of patches and return a 4D numpy array
TODO: Add an optional argument for the y-patchsize
Args:
fn: file name
single: a boolean telling if there is only 1 image in the batch.
In this case, the image can be rectangular (not squared)
"""
# Get input image size
...
...
@@ -63,12 +86,15 @@ def read_samples(fn):
print
(
"Quick stats: min="
+
str
(
np
.
amin
(
outimg
))
+
", max="
+
str
(
np
.
amax
(
outimg
))
)
# reshape
if
(
single
):
return
np
.
copy
(
outimg
.
reshape
((
1
,
size_y
,
size_x
,
nbands
)))
n_samples
=
int
(
size_y
/
size_x
)
outimg
=
outimg
.
reshape
((
n_samples
,
size_x
,
size_x
,
nbands
))
print
(
"Returned numpy array shape: "
+
str
(
outimg
.
shape
))
return
np
.
copy
(
outimg
)
def
getBatch
(
X
,
Y
,
i
,
batch_size
):
start_id
=
i
*
batch_size
end_id
=
min
(
(
i
+
1
)
*
batch_size
,
X
.
shape
[
0
])
...
...
@@ -76,3 +102,59 @@ def getBatch(X, Y, i, batch_size):
batch_y
=
Y
[
start_id
:
end_id
]
return
batch_x
,
batch_y
def
CreateSavedModel
(
sess
,
inputs
,
outputs
,
directory
):
"""
Create a SavedModel
Args:
sess: the session
inputs: the list of input names
outputs: the list of output names
directory: the output path for the SavedModel
"""
directory
+=
"/SavedModel"
print
(
"Create a SavedModel in "
+
directory
)
# Delete the directory if it already exists
if
os
.
path
.
exists
(
directory
):
shutil
.
rmtree
(
directory
)
# Get graph
graph
=
tf
.
get_default_graph
()
# Get inputs
input_dict
=
{
i
:
graph
.
get_tensor_by_name
(
i
)
for
i
in
inputs
}
output_dict
=
{
o
:
graph
.
get_tensor_by_name
(
o
)
for
o
in
outputs
}
# Build the SavedModel
builder
=
tf
.
saved_model
.
builder
.
SavedModelBuilder
(
directory
)
signature_def_map
=
{
"model"
:
tf
.
saved_model
.
signature_def_utils
.
predict_signature_def
(
input_dict
,
output_dict
)
}
builder
.
add_meta_graph_and_variables
(
sess
,[
tf
.
saved_model
.
tag_constants
.
TRAINING
],
signature_def_map
)
builder
.
add_meta_graph
([
tf
.
saved_model
.
tag_constants
.
SERVING
])
builder
.
save
()
def
CheckpointToSavedModel
(
ckpt_path
,
inputs
,
outputs
,
savedmodel_path
):
"""
Read a Checkpoint and build a SavedModel
Args:
ckpt_path: path to the checkpoint file (without the ".meta" extension)
inputs: input list of placeholders names (e.g. ["x_cnn_1:0", "x_cnn_2:0"])
outputs: output list of tensor outputs names (e.g. ["prediction:0", "features:0"])
savedmodel_path: path to the SavedModel
"""
tf
.
reset_default_graph
()
with
tf
.
Session
()
as
sess
:
# Restore variables from disk.
model_saver
=
tf
.
train
.
import_meta_graph
(
ckpt_path
+
".meta"
)
model_saver
.
restore
(
sess
,
ckpt_path
)
# Create a SavedModel
CreateSavedModel
(
sess
,
inputs
,
outputs
,
savedmodel_path
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment