Commit 03f39ab2 authored by Ienco Dino's avatar Ienco Dino
Browse files

add restore code

parent fef47b7b
No related merge requests found
Showing with 86 additions and 0 deletions
+86 -0
Restore.py 0 → 100644
import tensorflow as tf
import numpy as np
import sys
import os
from functions import getSL
def checkTest(test_data, test_sl, batchsz, bins): # , classes_test):
alphas_values = None
tot_pred = None
iterations = test_data.shape[0] / batchsz
if test_data.shape[0] % batchsz != 0:
iterations+=1
for ibatch in range(iterations):
batch_limit, batch_x = getBatch(test_sl, test_data, ibatch, batchsz)
batch_mask = np.zeros((batch_limit.shape[0],bins))
for idx, val in enumerate(batch_limit):
batch_mask[idx,0:val] = 1.0
pred_temp = sess.run(testPrediction,feed_dict={
x_data:batch_x,
dropOut:0.,
is_training_ph:False,
seq_length:batch_limit,
mask:batch_mask
})
if tot_pred is None:
tot_pred = pred_temp
else:
tot_pred = np.concatenate((tot_pred, pred_temp),axis=0)
del batch_limit
del batch_x
del batch_mask
return tot_pred
def getBatch(X, Y, i, batch_size):
start_id = i*batch_size
end_id = min( (i+1) * batch_size, X.shape[0])
batch_x = X[start_id:end_id]
batch_y = Y[start_id:end_id]
return batch_x, batch_y
testDataFile = sys.argv[1]
ckpt_path = sys.argv[2]
outputFileName = sys.argv[3]
test_data = np.load(testDataFile)
test_sl = getSL(test_data)
bins = test_data.shape[1]
tf.reset_default_graph()
results_path = "results/"
if not os.path.exists(results_path):
os.makedirs(results_path)
with tf.Session() as sess:
# Restore variables from disk.
model_saver = tf.train.import_meta_graph(ckpt_path+".meta")
model_saver.restore(sess, ckpt_path)
graph = tf.get_default_graph()
x_data = graph.get_tensor_by_name("x_data:0")
mask = graph.get_tensor_by_name("mask:0")
seq_length = graph.get_tensor_by_name("limits:0")
y = graph.get_tensor_by_name("y:0")
is_training_ph = graph.get_tensor_by_name("is_training:0")
dropOut = graph.get_tensor_by_name("drop_rate:0")
alphas_b = graph.get_tensor_by_name("alphas:0")
testPrediction = graph.get_tensor_by_name("pred_env/prediction:0")
print "Model restored. "+ckpt_path
test_prediction = checkTest(test_data, test_sl, 1024, bins)
np.save(results_path+outputFileName, test_prediction)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment