Commit 58e508bd authored by remi cresson's avatar remi cresson

DOC: add copyright header + a bit of refactoring

parent ad741859
# -*- coding: utf-8 -*-
#==========================================================================
#
# Copyright Remi Cresson, Dino Ienco (IRSTEA)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#==========================================================================*/
import sys
import os
import numpy as np
......@@ -13,35 +31,6 @@ from sklearn.ensemble import RandomForestClassifier
from sklearn.utils import shuffle
from sklearn.metrics import confusion_matrix
from tricks import *
def export_model(sess,
export_dir,
x_cnn_placeholder,
x_rnn_placeholder,
is_training_placeholder,
testPrediction):
""" export a SavedModel
"""
# Update the export dir
model_dir = export_dir + "/saved_model/"
if os.path.exists(model_dir):
shutil.rmtree(model_dir)
print("Export model in " + model_dir)
# Add a builder (for LoadSavedModel)
builder = tf.saved_model.builder.SavedModelBuilder(model_dir)
signature_def_map= {
"model": tf.saved_model.signature_def_utils.predict_signature_def(
inputs = {"x_cnn" : x_cnn_placeholder,
"x_rnn" : x_rnn_placeholder,
"is_training" : is_training_placeholder},
outputs = {"prediction" : testPrediction})
}
builder.add_meta_graph_and_variables(sess,[tf.saved_model.tag_constants.TRAINING],signature_def_map)
builder.add_meta_graph([tf.saved_model.tag_constants.SERVING])
builder.save()
def checkTest(ts_data, vhsr_data, batchsz, label_test):
tot_pred = []
......@@ -60,11 +49,7 @@ def checkTest(ts_data, vhsr_data, batchsz, label_test):
is_training_ph:True,
dropout:0.0,
x_cnn:batch_cnn_x})
del batch_rnn_x
del batch_cnn_x
del batch_y
for el in pred_temp:
tot_pred.append( el )
......@@ -241,8 +226,8 @@ n_channels = 4
nclasses = 8
# check number of arguments
if len(sys.argv) != 7:
print("Usage : <ts_train> <vhs_train> <label_train> <ts_valid> <vhs_valid> <label_valid>")
if len(sys.argv) != 8:
print("Usage : <ts_train> <vhs_train> <label_train> <ts_valid> <vhs_valid> <label_valid> <export_dir>")
sys.exit(1)
ts_train = read_samples(sys.argv[1])
......@@ -257,6 +242,8 @@ label_test = read_samples(sys.argv[6])
label_test = np.int32(label_test)
print_histo(label_test, "label_test")
export_dir = read_samples(sys.argv[7])
x_rnn = tf.placeholder(tf.float32,[None, 1, 1, n_dims*n_timestamps],name="x_rnn")
x_cnn = tf.placeholder(tf.float32,[None, patch_window, patch_window, n_channels],name="x_cnn")
y = tf.placeholder(tf.int32,[None, 1, 1, 1],name="y")
......@@ -324,19 +311,12 @@ for e in range(hm_epochs):
lossi+=loss
accS+=acc
del batch_rnn_x
del batch_cnn_x
del batch_y
print "Epoch:",e,"Train loss:",lossi/iterations,"| accuracy:",accS/iterations
c_loss = lossi/iterations
if c_loss < best_loss:
save_path = saver.save(sess, "models/model")
print("Model saved in path: %s" % save_path)
best_loss = c_loss
export_model(sess, "/tmp/m3_export", x_cnn, x_rnn, is_training_ph, testPrediction)
CreateSavedModel(sess, ["x_cnn:0","x_rnn:0","is_training:0"], ["prediction:0"], export_dir)
test_acc = checkTest(ts_test, vhsr_test, 1024, label_test)
\ No newline at end of file
# -*- coding: utf-8 -*-
#==========================================================================
#
# Copyright Remi Cresson (IRSTEA)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#==========================================================================*/
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
......@@ -89,12 +107,11 @@ def main(unused_argv):
# check number of arguments
if len(sys.argv) != 4:
print("Usage : <patches> <labels> <output_model_dir>")
print("Usage : <patches> <labels> <export_dir>")
sys.exit(1)
# Export dir
log_dir = sys.argv[3] + '/model_checkpoints/'
export_dir = sys.argv[3] + '/model_export/'
export_dir = sys.argv[3]
print("loading dataset")
......@@ -272,13 +289,9 @@ def main(unused_argv):
if step % 10 == 0:
# Print status to stdout.
print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))
#print('Step %d: (%.3f sec)' % (step, duration))
# Save a checkpoint and evaluate the model periodically.
if (curr_epoch + 1) % 1 == 0:
checkpoint_file = os.path.join(log_dir, 'model.ckpt')
saver.save(sess, checkpoint_file, global_step=step)
# Evaluate against the training set.
print('Training Data Eval:')
do_eval2(sess,
......@@ -301,16 +314,7 @@ def main(unused_argv):
batch_size)
# Let's export a SavedModel
shutil.rmtree(export_dir)
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
signature_def_map= {
"model": tf.saved_model.signature_def_utils.predict_signature_def(
inputs= {"x1": xs_placeholder},
outputs= {"prediction": testPrediction})
}
builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.TRAINING], signature_def_map)
builder.add_meta_graph([tf.saved_model.tag_constants.SERVING])
builder.save()
CreateSavedModel(sess, ["x1:0"], ["prediction:0"], export_dir)
quit()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment