Commit d7ca6021 authored by remi cresson's avatar remi cresson
Browse files

ENH: add function to export sess+graph in a SavedModel

parent 1e8bf930
# -*- coding: utf-8 -*-
#==========================================================================
#
# Copyright Remi Cresson (IRSTEA)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#==========================================================================*/
import sys import sys
import os import os
import numpy as np import numpy as np
import math import math
import time import time
import otbApplication import otbApplication
import tensorflow as tf
import shutil
def flatten_nparray(np_arr): def flatten_nparray(np_arr):
""" Returns a 1D numpy array retulting from the flatten of the input """ Returns a 1D numpy array retulting from the flatten of the input
...@@ -35,10 +55,13 @@ def print_tensor_info(name, tensor): ...@@ -35,10 +55,13 @@ def print_tensor_info(name, tensor):
print(name+" : "+str(tensor.shape)+" (dtype="+str(tensor.dtype)+")") print(name+" : "+str(tensor.shape)+" (dtype="+str(tensor.dtype)+")")
def read_samples(fn): def read_samples(fn, single=False):
""" Read an image of patches and return a 4D numpy array """ Read an image of patches and return a 4D numpy array
TODO: Add an optional argument for the y-patchsize
Args: Args:
fn: file name fn: file name
single: a boolean telling if there is only 1 image in the batch.
In this case, the image can be rectangular (not squared)
""" """
# Get input image size # Get input image size
...@@ -63,12 +86,15 @@ def read_samples(fn): ...@@ -63,12 +86,15 @@ def read_samples(fn):
print("Quick stats: min="+str(np.amin(outimg))+", max="+str(np.amax(outimg)) ) print("Quick stats: min="+str(np.amin(outimg))+", max="+str(np.amax(outimg)) )
# reshape # reshape
if (single):
return np.copy(outimg.reshape((1, size_y, size_x, nbands)))
n_samples = int(size_y / size_x) n_samples = int(size_y / size_x)
outimg = outimg.reshape((n_samples, size_x, size_x, nbands)) outimg = outimg.reshape((n_samples, size_x, size_x, nbands))
print("Returned numpy array shape: "+str(outimg.shape)) print("Returned numpy array shape: "+str(outimg.shape))
return np.copy(outimg) return np.copy(outimg)
def getBatch(X, Y, i, batch_size): def getBatch(X, Y, i, batch_size):
start_id = i*batch_size start_id = i*batch_size
end_id = min( (i+1) * batch_size, X.shape[0]) end_id = min( (i+1) * batch_size, X.shape[0])
...@@ -76,3 +102,59 @@ def getBatch(X, Y, i, batch_size): ...@@ -76,3 +102,59 @@ def getBatch(X, Y, i, batch_size):
batch_y = Y[start_id:end_id] batch_y = Y[start_id:end_id]
return batch_x, batch_y return batch_x, batch_y
def CreateSavedModel(sess, inputs, outputs, directory):
"""
Create a SavedModel
Args:
sess: the session
inputs: the list of input names
outputs: the list of output names
directory: the output path for the SavedModel
"""
directory += "/SavedModel"
print("Create a SavedModel in " + directory)
# Delete the directory if it already exists
if os.path.exists(directory):
shutil.rmtree(directory)
# Get graph
graph = tf.get_default_graph()
# Get inputs
input_dict = { i : graph.get_tensor_by_name(i) for i in inputs }
output_dict = { o : graph.get_tensor_by_name(o) for o in outputs }
# Build the SavedModel
builder = tf.saved_model.builder.SavedModelBuilder(directory)
signature_def_map= {
"model": tf.saved_model.signature_def_utils.predict_signature_def(
input_dict,
output_dict)
}
builder.add_meta_graph_and_variables(sess,[tf.saved_model.tag_constants.TRAINING],signature_def_map)
builder.add_meta_graph([tf.saved_model.tag_constants.SERVING])
builder.save()
def CheckpointToSavedModel(ckpt_path, inputs, outputs, savedmodel_path):
"""
Read a Checkpoint and build a SavedModel
Args:
ckpt_path: path to the checkpoint file (without the ".meta" extension)
inputs: input list of placeholders names (e.g. ["x_cnn_1:0", "x_cnn_2:0"])
outputs: output list of tensor outputs names (e.g. ["prediction:0", "features:0"])
savedmodel_path: path to the SavedModel
"""
tf.reset_default_graph()
with tf.Session() as sess:
# Restore variables from disk.
model_saver = tf.train.import_meta_graph(ckpt_path+".meta")
model_saver.restore(sess, ckpt_path)
# Create a SavedModel
CreateSavedModel(sess, inputs, outputs, savedmodel_path)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment