Skip to content
GitLab
Projects Groups Snippets
  • /
  • Help
    • Help
    • Support
    • Community forum
    • Submit feedback
    • Contribute to GitLab
  • Register
  • Sign in
  • otbtf otbtf
  • Project information
    • Project information
    • Activity
    • Labels
    • Members
  • Repository
    • Repository
    • Files
    • Commits
    • Branches
    • Tags
    • Contributors
    • Graph
    • Compare
  • Issues 7
    • Issues 7
    • List
    • Boards
    • Service Desk
    • Milestones
  • Merge requests 0
    • Merge requests 0
  • CI/CD
    • CI/CD
    • Pipelines
    • Jobs
    • Schedules
  • Deployments
    • Deployments
    • Environments
    • Releases
  • Packages and registries
    • Packages and registries
    • Container Registry
  • Monitor
    • Monitor
    • Incidents
  • Analytics
    • Analytics
    • Value stream
    • CI/CD
    • Repository
  • Snippets
    • Snippets
  • Activity
  • Graph
  • Create a new issue
  • Jobs
  • Commits
  • Issue Boards
Collapse sidebar
  • Cresson Remi
  • otbtfotbtf
  • Issues
  • #27
Closed
Open
Issue created Jun 10, 2022 by Cresson Remi@remi.cressonOwner

Notes on how create quantized models

from test_graph_transform branch

python /opt/otbtf/lib/python3.8/site-packages/tensorflow/python/tools/freeze_graph.py  --input_saved_model_dir sr4rs_sentinel2_bands4328_france2020_savedmodel --output_node_names "output_64" --output_graph sr4rs_graphdef.pb
/opt/otbtf/bin/graph_transforms/transform_graph --in_graph=sr4rs_graphdef.pb --out_graph=sr4rs_opt.pb --inputs="lr_input" --outputs="output_64" --transforms='add_default_attributes remove_nodes(op=CheckNumerics) fold_constants(ignore_errors=true) fold_batch_norms fold_old_batch_norms quantize_nodes(input_min=-1,input_max=1) strip_unused_nodes sort_by_execution_order'
rm -rf new_sm
python gd2sm.py
python sr4rs/code/sr.py --input image.tif --savedmodel new_sm/ --output hr.tif

With file gd2sm.py:

import tensorflow as tf
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants

version_str = tf.__version__
if int(version_str.split('.')[0]) == 2:
    # tensorflow 2.x
    del tf
    import tensorflow.compat.v1 as tf
    tf.disable_v2_behavior()
# else tensorflow 1.x

export_dir = 'new_sm'
graph_pb = 'sr4rs_opt.pb'

builder = tf.saved_model.builder.SavedModelBuilder(export_dir)

with tf.gfile.GFile(graph_pb, "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

sigs = {}

with tf.Session(graph=tf.Graph()) as sess:
    tf.import_graph_def(graph_def, name="")
    g = tf.get_default_graph()
    values = g.get_tensor_by_name("lr_input:0")
    predictions = g.get_tensor_by_name("output_64:0")

    sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
        tf.saved_model.signature_def_utils.predict_signature_def(
            {"lr_input": values}, {"output_64": predictions})

    builder.add_meta_graph_and_variables(sess, [tag_constants.SERVING], signature_def_map=sigs)

builder.save()
Assignee
Assign to
Time tracking