diff --git a/python/ckpt2savedmodel.py b/python/ckpt2savedmodel.py index 462d133a6989e4f2eb5bd46b22c1c0179e636a79..caf5b2fe097996fd0f2b813f202f3d5fc59ca178 100644 --- a/python/ckpt2savedmodel.py +++ b/python/ckpt2savedmodel.py @@ -21,10 +21,7 @@ from __future__ import division from __future__ import print_function import argparse -from tricks import * - -# Logging -tf.logging.set_verbosity(tf.logging.INFO) +from tricks import CheckpointToSavedModel # Parser parser = argparse.ArgumentParser() @@ -32,10 +29,12 @@ parser.add_argument("--ckpt", help="checkpoint file prefix", required=True) parser.add_argument("--inputs", help="input placeholder names", required=True, nargs='+') parser.add_argument("--outputs", help="output placeholder names", required=True, nargs='+') parser.add_argument("--model", help="output SavedModel", required=True) +parser.add_argument('--clear_devices', dest='clear_devices', action='store_true') +parser.set_defaults(clear_devices=False) params = parser.parse_args() if __name__ == "__main__": - CheckpointToSavedModel(params.ckpt, params.inputs, params.outputs, params.model) + CheckpointToSavedModel(params.ckpt, params.inputs, params.outputs, params.model, params.clear_devices) quit() diff --git a/python/tricks.py b/python/tricks.py index 0b1ae614c7947fec9f99b08f78d8674c92ee35c9..0e2f90a9b9370702697790dce7d4a20a6bf59340 100644 --- a/python/tricks.py +++ b/python/tricks.py @@ -135,7 +135,7 @@ def CreateSavedModel(sess, inputs, outputs, directory): builder.add_meta_graph([tf.saved_model.tag_constants.SERVING]) builder.save() -def CheckpointToSavedModel(ckpt_path, inputs, outputs, savedmodel_path): +def CheckpointToSavedModel(ckpt_path, inputs, outputs, savedmodel_path, clear_devices=False): """ Read a Checkpoint and build a SavedModel @@ -149,7 +149,7 @@ def CheckpointToSavedModel(ckpt_path, inputs, outputs, savedmodel_path): with tf.Session() as sess: # Restore variables from disk. - model_saver = tf.train.import_meta_graph(ckpt_path+".meta") + model_saver = tf.train.import_meta_graph(ckpt_path+".meta", clear_devices=clear_devices) model_saver.restore(sess, ckpt_path) # Create a SavedModel