diff --git a/Restore.py b/Restore.py new file mode 100644 index 0000000000000000000000000000000000000000..b399217eb4ebc2b399cbe678b51960388e83bd78 --- /dev/null +++ b/Restore.py @@ -0,0 +1,86 @@ +import tensorflow as tf +import numpy as np +import sys +import os +from functions import getSL + +def checkTest(test_data, test_sl, batchsz, bins): # , classes_test): + alphas_values = None + tot_pred = None + iterations = test_data.shape[0] / batchsz + if test_data.shape[0] % batchsz != 0: + iterations+=1 + + for ibatch in range(iterations): + batch_limit, batch_x = getBatch(test_sl, test_data, ibatch, batchsz) + batch_mask = np.zeros((batch_limit.shape[0],bins)) + for idx, val in enumerate(batch_limit): + batch_mask[idx,0:val] = 1.0 + + pred_temp = sess.run(testPrediction,feed_dict={ + x_data:batch_x, + dropOut:0., + is_training_ph:False, + seq_length:batch_limit, + mask:batch_mask + }) + + if tot_pred is None: + tot_pred = pred_temp + else: + tot_pred = np.concatenate((tot_pred, pred_temp),axis=0) + + del batch_limit + del batch_x + del batch_mask + + return tot_pred + + + +def getBatch(X, Y, i, batch_size): + start_id = i*batch_size + end_id = min( (i+1) * batch_size, X.shape[0]) + batch_x = X[start_id:end_id] + batch_y = Y[start_id:end_id] + return batch_x, batch_y + + +testDataFile = sys.argv[1] +ckpt_path = sys.argv[2] +outputFileName = sys.argv[3] + +test_data = np.load(testDataFile) +test_sl = getSL(test_data) + +bins = test_data.shape[1] +tf.reset_default_graph() + +results_path = "results/" + +if not os.path.exists(results_path): + os.makedirs(results_path) + + +with tf.Session() as sess: + # Restore variables from disk. + model_saver = tf.train.import_meta_graph(ckpt_path+".meta") + model_saver.restore(sess, ckpt_path) + + graph = tf.get_default_graph() + + x_data = graph.get_tensor_by_name("x_data:0") + mask = graph.get_tensor_by_name("mask:0") + seq_length = graph.get_tensor_by_name("limits:0") + y = graph.get_tensor_by_name("y:0") + + is_training_ph = graph.get_tensor_by_name("is_training:0") + dropOut = graph.get_tensor_by_name("drop_rate:0") + alphas_b = graph.get_tensor_by_name("alphas:0") + testPrediction = graph.get_tensor_by_name("pred_env/prediction:0") + + print "Model restored. "+ckpt_path + + test_prediction = checkTest(test_data, test_sl, 1024, bins) + + np.save(results_path+outputFileName, test_prediction)