import sys
import os
import gdal
import numpy as np
import math
from operator import itemgetter, attrgetter, methodcaller
import tensorflow as tf
from tensorflow.contrib import rnn
import random
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_recall_fscore_support




#(n,img*timestep)->(n,img,timestep)
def split(arr, size):

     arrs = []

     while len(arr) > size:

         pice = arr[:size]
         arrs.append(pice)
         arr = arr[size:]

     arrs.append(arr)

     return arrs

#Format values matrix from (n,22*13) shape to (n,22,13)
def getRNNFormat(X):

    new_X = []
    for row in X:
        new_X.append( np.split(row, 22) )
    return np.array(new_X)

#Format labels matrix from (n,1) shape to (n,3)
def getRNNFormatLabel(Y):

    vals = np.unique(np.array(Y))
    sorted(vals)
    hash_val = {}

    for el in vals:

        hash_val[el] = len(hash_val.keys())

    new_Y = []

    for el in Y:

        t = np.zeros(len(vals))
        t[hash_val[el]] = 1.0
        new_Y.append(t)

    return np.array(new_Y)


#Get i-th batches of values set X and labels set Y
def getBatch2(X, Y, i, batch_size):
    start_id = i*batch_size
    end_id = min( (i+1) * batch_size, X.shape[0])

    batch_x = X[start_id:end_id]
    batch_y = Y[start_id:end_id]

    return batch_x, batch_y



def B_Rnn(x,nlayer):


    #PLACEHOLDERS + WEIGHT & BIAS DEF
    outw = tf.Variable(tf.random_normal([nunits*2,nclasses]))
    outb = tf.Variable(tf.random_normal([nclasses]))

    #Reverting input tensor for backward LSTM
    x_b = tf.reverse(x, [2])

    #Processing input tensors
    x= tf.unstack(x,timesteps,1)
    x_bw = tf.unstack(x_b,timesteps,1)
    x_fw=x

    #NETWORK DEF
    #MORE THEN ONE LAYER: list of LSTMcell,nunits hidden units each, for each layer
    if nlayer>1:
        cells_fw=[]
        cells_bw=[]
        for _ in range(nlayer):

            with tf.variable_scope('forward'):
                cell_fw_temp = rnn.LSTMCell(nunits,forget_bias=1)
                cells_fw.append(cell_fw_temp)

            with tf.variable_scope('backward'):

               cell_bw_temp = rnn.LSTMCell(nunits,forget_bias=1)
               cells_bw.append(cell_bw_temp)


        cell_fw = tf.contrib.rnn.MultiRNNCell(cells_fw)
        cell_bw = tf.contrib.rnn.MultiRNNCell(cells_bw)

    #SIGNLE LAYER: single GRUCell, nunits hidden units each
    else:

        with tf.variable_scope('forward'):
            cell_fw= rnn.LSTMCell(nunits,forget_bias=1)

        with tf.variable_scope('backward'):
            cell_bw=rnn.LSTMCell(nunits,forget_bias=1)


    outputs_fw,_=rnn.static_rnn(cell_fw, x_fw, dtype="float32", scope='forward')
    outputs_bw,_=rnn.static_rnn(cell_bw, x_bw, dtype="float32",scope='backward')

    #concatenate forward and backward last output
    output=tf.concat([outputs_fw[-1],outputs_bw[-1]],1)

    #CLASSIFICATION: Fully connected layer
    prediction = tf.matmul(output,outw)+outb
    #converting output [batchsz,nunits]->[batchsz,nclasses]

    return prediction

def train_RNN(test_x,test_y,train_x,train_y,hm_epochs,itr,nlayer):
    tf.reset_default_graph()


    #DEF PLACEHOLDERS
    x = tf.placeholder("float",[None,timesteps,ninput],name="x")
    y = tf.placeholder("float",[None,nclasses],name="y")
    learning_rate = tf.placeholder(tf.float32, shape=[])


    sess = tf.InteractiveSession()


    #Call RNN specifying number of layers
    prediction = B_Rnn(x,nlayer)

    #LOSS-OPTIMIZER-ACCURACY DEF
    tensor1d = tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction)
    cost = tf.reduce_mean(tensor1d)

    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)


    correct = tf.equal(tf.argmax(prediction,1),tf.argmax(y,1))
    accuracy = tf.reduce_mean(tf.cast(correct,tf.float64))

    tf.global_variables_initializer().run()

    #START SESSION
    train_x = getRNNFormat(train_x)
    #(n,22*13) -> (n,22,13)
    train_y = getRNNFormatLabel(train_y)
    # (n,1) -> (n,3)

    iterations = len(train_x) / batchsz

    if len(train_x) % batchsz != 0:
        iterations+=1

    for e in range(hm_epochs):
        lossi = 0
        accS = 0

        for ibatch in range(iterations):
            #BATCH_X BATCH_Y: i-th batches of train_x and train_y
            batch_x, batch_y = getBatch2(train_x, train_y, ibatch, batchsz)
            acc,_,loss = sess.run([accuracy,optimizer,cost],feed_dict={x:batch_x,y:batch_y, learning_rate:0.001})

            lossi+=loss
            accS+=acc

        #media accuracy su 35 batches
        #print TAG,":Epoch",e,"/",hm_epochs,",avg",iterations,"batches-->loss:", lossi/iterations,"| accuracy:",accS/iterations
    print "Train loss:",lossi/iterations,"| accuracy:",accS/iterations

    test_x = getRNNFormat(test_x)
    test_y = getRNNFormatLabel(test_y)

    accS=0
    tot_pred = []

    iterations = len(test_x) / batchsz
    if len(test_x) % batchsz != 0:
        iterations+=1

    for ibatch in range(iterations):

        #BATCH_X BATCH_Y: i-th batches of test_x and test_y
        batch_x,batch_y=getBatch2(test_x, test_y, ibatch, batchsz)

        #PRED_TEMP: (n,3)
        pred_temp = sess.run(prediction,feed_dict={x:batch_x})

        #TOT_PRED: accumulate max argument of each test prediciton evaluated by batch
        #TOT_PRED: (n,1)
        for el in pred_temp:
            tot_pred.append( np.argmax(el) )

    #GT: contains max argument of each test ground truth
    # test_y (n,3) -> gt (n,1)
    gt = []
    for el in test_y:
        gt.append( np.argmax(el))


    tot_pred = tot_pred[0:len(gt)]

    print "Accuracy ", accuracy_score(gt, tot_pred)

    #SAVE GROUD TRUTH gt E PREDICTION tot_pred
    var_totpred = './dataset/N%d%s%d%s%d%s%d%s%d%s%d%s'%(norm,'/B_LSTM',nlayer,'l_truthpred_',p_split,'p',nunits,'u',batchsz,'b/totpred',itr,'.npy')
    var_gt='./dataset/N%d%s%d%s%d%s%d%s%d%s%d%s'%(norm,'/B_LSTM',nlayer,'l_truthpred_',p_split,'p',nunits,'u',batchsz,'b/gt',itr,'.npy')
    np.save(var_totpred, tot_pred)
    np.save(var_gt, gt)



#Load dataset and get the number of elements in the label column
def getClasses(norm):

    aux='./dataset/N%d/ds.npy'%norm
    ds= np.load(aux)
    r_,c_= ds.shape

    #Label column
    aux=ds[:,286]

    aux=aux.astype(int)
    classes = np.argwhere(np.bincount(aux)).shape

    return classes[0]



print "Blstm"

timesteps = 22
ninput = 13
p_split = 70
n_split=10

batchsz=int(sys.argv[1])
nunits=int(sys.argv[2])
hm_epochs=int(sys.argv[3])
nlayer=int(sys.argv[4])
norm=int(sys.argv[5])

nclasses=getClasses(norm)
print "Split percentage:\t",p_split
print "Number of split:\t",n_split
print "nclasses:",nclasses
print "batch size:",batchsz
print "n_units:",nunits
print "epoche:",hm_epochs
print "layers:",nlayer

directory = './dataset/N%d%s%d%s%d%s%d%s%d%s'%(norm,'/B_LSTM',nlayer,'l_truthpred_',p_split,'p',nunits,'u',batchsz,'b')
if not os.path.exists(directory):
    os.makedirs(directory)

for i in range(n_split):
	print "iter:",i
	var_train_x = './dataset/N%d/train_x%d%s%d%s'%(norm,i,'_',p_split,'.npy')
	var_train_y = './dataset/N%d/train_y%d%s%d%s'%(norm,i,'_',p_split,'.npy')
	var_test_x = './dataset/N%d/test_x%d%s%d%s'%(norm,i,'_',p_split,'.npy')
	var_test_y = './dataset/N%d/test_y%d%s%d%s'%(norm,i,'_',p_split,'.npy')

	train_x = np.load(var_train_x)
	train_y = np.load(var_train_y)
	test_x = np.load(var_test_x)
	test_y = np.load(var_test_y)
    #Start training
	train_RNN(test_x, test_y, train_x, train_y,hm_epochs,i,nlayer)