From 999d3c693ea69a351838eafb168df3f9f2375bd6 Mon Sep 17 00:00:00 2001 From: remi Date: Wed, 6 Mar 2019 08:36:39 +0000 Subject: [PATCH] ENH: add the clear_devices option, useful to wipe device assignment in graphs --- python/ckpt2savedmodel.py | 9 ++++----- python/tricks.py | 4 ++-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/python/ckpt2savedmodel.py b/python/ckpt2savedmodel.py index 462d133..caf5b2f 100644 --- a/python/ckpt2savedmodel.py +++ b/python/ckpt2savedmodel.py @@ -21,10 +21,7 @@ from __future__ import division from __future__ import print_function import argparse -from tricks import * - -# Logging -tf.logging.set_verbosity(tf.logging.INFO) +from tricks import CheckpointToSavedModel # Parser parser = argparse.ArgumentParser() @@ -32,10 +29,12 @@ parser.add_argument("--ckpt", help="checkpoint file prefix", required=True) parser.add_argument("--inputs", help="input placeholder names", required=True, nargs='+') parser.add_argument("--outputs", help="output placeholder names", required=True, nargs='+') parser.add_argument("--model", help="output SavedModel", required=True) +parser.add_argument('--clear_devices', dest='clear_devices', action='store_true') +parser.set_defaults(clear_devices=False) params = parser.parse_args() if __name__ == "__main__": - CheckpointToSavedModel(params.ckpt, params.inputs, params.outputs, params.model) + CheckpointToSavedModel(params.ckpt, params.inputs, params.outputs, params.model, params.clear_devices) quit() diff --git a/python/tricks.py b/python/tricks.py index 0b1ae61..0e2f90a 100644 --- a/python/tricks.py +++ b/python/tricks.py @@ -135,7 +135,7 @@ def CreateSavedModel(sess, inputs, outputs, directory): builder.add_meta_graph([tf.saved_model.tag_constants.SERVING]) builder.save() -def CheckpointToSavedModel(ckpt_path, inputs, outputs, savedmodel_path): +def CheckpointToSavedModel(ckpt_path, inputs, outputs, savedmodel_path, clear_devices=False): """ Read a Checkpoint and build a SavedModel @@ -149,7 +149,7 @@ def CheckpointToSavedModel(ckpt_path, inputs, outputs, savedmodel_path): with tf.Session() as sess: # Restore variables from disk. - model_saver = tf.train.import_meta_graph(ckpt_path+".meta") + model_saver = tf.train.import_meta_graph(ckpt_path+".meta", clear_devices=clear_devices) model_saver.restore(sess, ckpt_path) # Create a SavedModel -- GitLab