From 58e508bdfffb8121f24868422548162f8d3621bd Mon Sep 17 00:00:00 2001 From: remi cresson Date: Wed, 12 Sep 2018 15:18:33 +0200 Subject: [PATCH] DOC: add copyright header + a bit of refactoring --- python/create_model_ienco-m3_patchbased.py | 68 ++++++++------------- python/create_model_maggiori17_fullyconv.py | 38 ++++++------ 2 files changed, 45 insertions(+), 61 deletions(-) diff --git a/python/create_model_ienco-m3_patchbased.py b/python/create_model_ienco-m3_patchbased.py index 0f1b8c4..be1bac4 100644 --- a/python/create_model_ienco-m3_patchbased.py +++ b/python/create_model_ienco-m3_patchbased.py @@ -1,3 +1,21 @@ +# -*- coding: utf-8 -*- +#========================================================================== +# +# Copyright Remi Cresson, Dino Ienco (IRSTEA) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +#==========================================================================*/ import sys import os import numpy as np @@ -13,35 +31,6 @@ from sklearn.ensemble import RandomForestClassifier from sklearn.utils import shuffle from sklearn.metrics import confusion_matrix from tricks import * - -def export_model(sess, - export_dir, - x_cnn_placeholder, - x_rnn_placeholder, - is_training_placeholder, - testPrediction): - """ export a SavedModel - """ - - # Update the export dir - model_dir = export_dir + "/saved_model/" - if os.path.exists(model_dir): - shutil.rmtree(model_dir) - - print("Export model in " + model_dir) - - # Add a builder (for LoadSavedModel) - builder = tf.saved_model.builder.SavedModelBuilder(model_dir) - signature_def_map= { - "model": tf.saved_model.signature_def_utils.predict_signature_def( - inputs = {"x_cnn" : x_cnn_placeholder, - "x_rnn" : x_rnn_placeholder, - "is_training" : is_training_placeholder}, - outputs = {"prediction" : testPrediction}) - } - builder.add_meta_graph_and_variables(sess,[tf.saved_model.tag_constants.TRAINING],signature_def_map) - builder.add_meta_graph([tf.saved_model.tag_constants.SERVING]) - builder.save() def checkTest(ts_data, vhsr_data, batchsz, label_test): tot_pred = [] @@ -60,11 +49,7 @@ def checkTest(ts_data, vhsr_data, batchsz, label_test): is_training_ph:True, dropout:0.0, x_cnn:batch_cnn_x}) - - del batch_rnn_x - del batch_cnn_x - del batch_y - + for el in pred_temp: tot_pred.append( el ) @@ -241,8 +226,8 @@ n_channels = 4 nclasses = 8 # check number of arguments -if len(sys.argv) != 7: - print("Usage : ") +if len(sys.argv) != 8: + print("Usage : ") sys.exit(1) ts_train = read_samples(sys.argv[1]) @@ -257,6 +242,8 @@ label_test = read_samples(sys.argv[6]) label_test = np.int32(label_test) print_histo(label_test, "label_test") +export_dir = read_samples(sys.argv[7]) + x_rnn = tf.placeholder(tf.float32,[None, 1, 1, n_dims*n_timestamps],name="x_rnn") x_cnn = tf.placeholder(tf.float32,[None, patch_window, patch_window, n_channels],name="x_cnn") y = tf.placeholder(tf.int32,[None, 1, 1, 1],name="y") @@ -324,19 +311,12 @@ for e in range(hm_epochs): lossi+=loss accS+=acc - del batch_rnn_x - del batch_cnn_x - del batch_y - - print "Epoch:",e,"Train loss:",lossi/iterations,"| accuracy:",accS/iterations c_loss = lossi/iterations if c_loss < best_loss: - save_path = saver.save(sess, "models/model") - print("Model saved in path: %s" % save_path) best_loss = c_loss - export_model(sess, "/tmp/m3_export", x_cnn, x_rnn, is_training_ph, testPrediction) + CreateSavedModel(sess, ["x_cnn:0","x_rnn:0","is_training:0"], ["prediction:0"], export_dir) test_acc = checkTest(ts_test, vhsr_test, 1024, label_test) \ No newline at end of file diff --git a/python/create_model_maggiori17_fullyconv.py b/python/create_model_maggiori17_fullyconv.py index efd1dd5..b4ba770 100644 --- a/python/create_model_maggiori17_fullyconv.py +++ b/python/create_model_maggiori17_fullyconv.py @@ -1,3 +1,21 @@ +# -*- coding: utf-8 -*- +#========================================================================== +# +# Copyright Remi Cresson (IRSTEA) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +#==========================================================================*/ from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -89,12 +107,11 @@ def main(unused_argv): # check number of arguments if len(sys.argv) != 4: - print("Usage : ") + print("Usage : ") sys.exit(1) # Export dir - log_dir = sys.argv[3] + '/model_checkpoints/' - export_dir = sys.argv[3] + '/model_export/' + export_dir = sys.argv[3] print("loading dataset") @@ -272,13 +289,9 @@ def main(unused_argv): if step % 10 == 0: # Print status to stdout. print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration)) - #print('Step %d: (%.3f sec)' % (step, duration)) - # Save a checkpoint and evaluate the model periodically. if (curr_epoch + 1) % 1 == 0: - checkpoint_file = os.path.join(log_dir, 'model.ckpt') - saver.save(sess, checkpoint_file, global_step=step) # Evaluate against the training set. print('Training Data Eval:') do_eval2(sess, @@ -301,16 +314,7 @@ def main(unused_argv): batch_size) # Let's export a SavedModel - shutil.rmtree(export_dir) - builder = tf.saved_model.builder.SavedModelBuilder(export_dir) - signature_def_map= { - "model": tf.saved_model.signature_def_utils.predict_signature_def( - inputs= {"x1": xs_placeholder}, - outputs= {"prediction": testPrediction}) - } - builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.TRAINING], signature_def_map) - builder.add_meta_graph([tf.saved_model.tag_constants.SERVING]) - builder.save() + CreateSavedModel(sess, ["x1:0"], ["prediction:0"], export_dir) quit() -- GitLab