Commit d7ca6021 authored by remi cresson's avatar remi cresson

ENH: add function to export sess+graph in a SavedModel

parent 1e8bf930
# -*- coding: utf-8 -*-
#==========================================================================
#
# Copyright Remi Cresson (IRSTEA)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#==========================================================================*/
import sys
import os
import numpy as np
import math
import time
import otbApplication
import tensorflow as tf
import shutil
def flatten_nparray(np_arr):
""" Returns a 1D numpy array retulting from the flatten of the input
......@@ -35,10 +55,13 @@ def print_tensor_info(name, tensor):
print(name+" : "+str(tensor.shape)+" (dtype="+str(tensor.dtype)+")")
def read_samples(fn):
def read_samples(fn, single=False):
""" Read an image of patches and return a 4D numpy array
TODO: Add an optional argument for the y-patchsize
Args:
fn: file name
single: a boolean telling if there is only 1 image in the batch.
In this case, the image can be rectangular (not squared)
"""
# Get input image size
......@@ -63,12 +86,15 @@ def read_samples(fn):
print("Quick stats: min="+str(np.amin(outimg))+", max="+str(np.amax(outimg)) )
# reshape
if (single):
return np.copy(outimg.reshape((1, size_y, size_x, nbands)))
n_samples = int(size_y / size_x)
outimg = outimg.reshape((n_samples, size_x, size_x, nbands))
print("Returned numpy array shape: "+str(outimg.shape))
return np.copy(outimg)
def getBatch(X, Y, i, batch_size):
start_id = i*batch_size
end_id = min( (i+1) * batch_size, X.shape[0])
......@@ -76,3 +102,59 @@ def getBatch(X, Y, i, batch_size):
batch_y = Y[start_id:end_id]
return batch_x, batch_y
def CreateSavedModel(sess, inputs, outputs, directory):
"""
Create a SavedModel
Args:
sess: the session
inputs: the list of input names
outputs: the list of output names
directory: the output path for the SavedModel
"""
directory += "/SavedModel"
print("Create a SavedModel in " + directory)
# Delete the directory if it already exists
if os.path.exists(directory):
shutil.rmtree(directory)
# Get graph
graph = tf.get_default_graph()
# Get inputs
input_dict = { i : graph.get_tensor_by_name(i) for i in inputs }
output_dict = { o : graph.get_tensor_by_name(o) for o in outputs }
# Build the SavedModel
builder = tf.saved_model.builder.SavedModelBuilder(directory)
signature_def_map= {
"model": tf.saved_model.signature_def_utils.predict_signature_def(
input_dict,
output_dict)
}
builder.add_meta_graph_and_variables(sess,[tf.saved_model.tag_constants.TRAINING],signature_def_map)
builder.add_meta_graph([tf.saved_model.tag_constants.SERVING])
builder.save()
def CheckpointToSavedModel(ckpt_path, inputs, outputs, savedmodel_path):
"""
Read a Checkpoint and build a SavedModel
Args:
ckpt_path: path to the checkpoint file (without the ".meta" extension)
inputs: input list of placeholders names (e.g. ["x_cnn_1:0", "x_cnn_2:0"])
outputs: output list of tensor outputs names (e.g. ["prediction:0", "features:0"])
savedmodel_path: path to the SavedModel
"""
tf.reset_default_graph()
with tf.Session() as sess:
# Restore variables from disk.
model_saver = tf.train.import_meta_graph(ckpt_path+".meta")
model_saver.restore(sess, ckpt_path)
# Create a SavedModel
CreateSavedModel(sess, inputs, outputs, savedmodel_path)
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment