import numpy as np
import sys
from BaseModels import TwoBranchModel
from sklearn.metrics import f1_score, r2_score
import tensorflow as tf
from sklearn.utils import shuffle
import time
from sklearn.utils.extmath import softmax

gpu_options = tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=0.45)
sess = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(gpu_options=gpu_options))

def getBatch(X, i, batch_size):
    start_id = i*batch_size
    t = (i+1) * batch_size
    end_id = min( (i+1) * batch_size, X.shape[0])
    batch_x = X[start_id:end_id]
    return batch_x

def trainClassifS2(model, x_train_S2_pixel, x_train_S2_obj, y_train, loss_object, optimizer, BATCH_SIZE):
    loss_object2 = tf.keras.losses.CategoricalCrossentropy()
    loss_iteration = 0
    tot_loss = 0
    iterations = x_train_S2_pixel.shape[0] / BATCH_SIZE
    if x_train_S2_pixel.shape[0] % BATCH_SIZE != 0:
        iterations += 1
    print("n iterations %d" % iterations)
    for ibatch in range(int(iterations)):
        batch_x_S2_p = getBatch(x_train_S2_pixel, ibatch, BATCH_SIZE)
        batch_x_S2_obj = getBatch(x_train_S2_obj, ibatch, BATCH_SIZE)
        batch_y = getBatch(y_train, ibatch, BATCH_SIZE)

        with tf.GradientTape() as gen_tape:
            mainEstim, pixEstim, objEstim  = model((batch_x_S2_p, batch_x_S2_obj), training=True)
            loss = loss_object(batch_y, mainEstim)
            #loss += .5 * loss_object(batch_y, pixEstim)
            #loss += .5 * loss_object(batch_y, objEstim)
            loss += .5 * loss_object(batch_y, pixEstim)
            loss += .5 * loss_object(batch_y, objEstim)

            grad_of_G = gen_tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grad_of_G, model.trainable_variables))
            tot_loss+=loss
    return (tot_loss / iterations)



ts_train_S2_pixel = np.load(sys.argv[1])
ts_train_S2_pixel = ts_train_S2_pixel[:,8,8,:,:]
ts_train_S2_pixel = np.squeeze(ts_train_S2_pixel)
ts_train_S2_obj = np.load(sys.argv[2])


#print(ts_train_S2_pixel)
#print(ts_train_S2_obj)
#exit()

label_train = np.load(sys.argv[3])
label_train = label_train[:,1]
label_train = label_train-1

#print(label_train.shape)
#print(np.amin(label_train))

ts_valid_S2_pixel = np.load(sys.argv[4])
ts_valid_S2_pixel = ts_valid_S2_pixel[:,8,8,:,:]
ts_valid_S2_pixel = np.squeeze(ts_valid_S2_pixel)
ts_valid_S2_obj = np.load(sys.argv[5])

label_valid = np.load(sys.argv[6])
label_valid = label_valid[:,1]
label_valid = label_valid-1


ts_test_S2_pixel = np.load(sys.argv[7])
ts_test_S2_pixel = ts_test_S2_pixel[:,8,8,:,:]
ts_test_S2_pixel = np.squeeze(ts_test_S2_pixel)
ts_test_S2_obj = np.load(sys.argv[8])

label_test = np.load(sys.argv[9])
label_test = label_test[:,1]
label_test = label_test-1



output_dir_models = sys.argv[10]
split_id = sys.argv[11]

#exit()

n_classes = len(np.unique(label_train))
model = TwoBranchModel(128, "model", n_classes, dropout_rate=0.4)

#label_train = tf.keras.utils.to_categorical(label_train)

print("model created")
#DI = Discr()
""" defining loss function and the optimizer to use in the training phase """
#loss_object = tf.keras.losses.Huber()
loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)

BATCH_SIZE = 256
n_epochs = 300

best_valid_fMeasure = 0
for e in range(n_epochs):
    ts_train_S2_pixel, ts_train_S2_obj, label_train = shuffle(ts_train_S2_pixel, ts_train_S2_obj, label_train)
    start = time.time()
    trainLoss = trainClassifS2(model, ts_train_S2_pixel, ts_train_S2_obj, label_train, loss_function, optimizer, BATCH_SIZE)
    end = time.time()
    elapsed = end - start
    pred = model.predict((ts_valid_S2_pixel, ts_valid_S2_obj))
    pred = pred[0]
    fscore = f1_score(label_valid, np.argmax(pred,axis=1), average="weighted")
    print("epoch %d with loss %f and F-Measure in %f seconds" % (e, trainLoss, elapsed))
    if fscore > best_valid_fMeasure:
        best_valid_fMeasure = fscore
        model.save_weights(output_dir_models+"/model_"+split_id)
        print("\tBEST current results on validation %f" % fscore)
        pred_test = model.predict((ts_test_S2_pixel, ts_test_S2_obj))
        pred_test = pred_test[0]
        fscore_test = f1_score(label_test, np.argmax(pred_test,axis=1), average="weighted")
        print("\t\tresults on TEST %f" % fscore_test)
    #print(f1_score(label_valid, np.argmax(pred,axis=1), average=None))
    sys.stdout.flush()