Commit 999d3c69 authored by Cresson Remi's avatar Cresson Remi

ENH: add the clear_devices option, useful to wipe device assignment in graphs

parent de2c2b11
...@@ -21,10 +21,7 @@ from __future__ import division ...@@ -21,10 +21,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse import argparse
from tricks import * from tricks import CheckpointToSavedModel
# Logging
tf.logging.set_verbosity(tf.logging.INFO)
# Parser # Parser
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -32,10 +29,12 @@ parser.add_argument("--ckpt", help="checkpoint file prefix", required=True) ...@@ -32,10 +29,12 @@ parser.add_argument("--ckpt", help="checkpoint file prefix", required=True)
parser.add_argument("--inputs", help="input placeholder names", required=True, nargs='+') parser.add_argument("--inputs", help="input placeholder names", required=True, nargs='+')
parser.add_argument("--outputs", help="output placeholder names", required=True, nargs='+') parser.add_argument("--outputs", help="output placeholder names", required=True, nargs='+')
parser.add_argument("--model", help="output SavedModel", required=True) parser.add_argument("--model", help="output SavedModel", required=True)
parser.add_argument('--clear_devices', dest='clear_devices', action='store_true')
parser.set_defaults(clear_devices=False)
params = parser.parse_args() params = parser.parse_args()
if __name__ == "__main__": if __name__ == "__main__":
CheckpointToSavedModel(params.ckpt, params.inputs, params.outputs, params.model) CheckpointToSavedModel(params.ckpt, params.inputs, params.outputs, params.model, params.clear_devices)
quit() quit()
...@@ -135,7 +135,7 @@ def CreateSavedModel(sess, inputs, outputs, directory): ...@@ -135,7 +135,7 @@ def CreateSavedModel(sess, inputs, outputs, directory):
builder.add_meta_graph([tf.saved_model.tag_constants.SERVING]) builder.add_meta_graph([tf.saved_model.tag_constants.SERVING])
builder.save() builder.save()
def CheckpointToSavedModel(ckpt_path, inputs, outputs, savedmodel_path): def CheckpointToSavedModel(ckpt_path, inputs, outputs, savedmodel_path, clear_devices=False):
""" """
Read a Checkpoint and build a SavedModel Read a Checkpoint and build a SavedModel
...@@ -149,7 +149,7 @@ def CheckpointToSavedModel(ckpt_path, inputs, outputs, savedmodel_path): ...@@ -149,7 +149,7 @@ def CheckpointToSavedModel(ckpt_path, inputs, outputs, savedmodel_path):
with tf.Session() as sess: with tf.Session() as sess:
# Restore variables from disk. # Restore variables from disk.
model_saver = tf.train.import_meta_graph(ckpt_path+".meta") model_saver = tf.train.import_meta_graph(ckpt_path+".meta", clear_devices=clear_devices)
model_saver.restore(sess, ckpt_path) model_saver.restore(sess, ckpt_path)
# Create a SavedModel # Create a SavedModel
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment