import numpy as np import sys from BaseModels import TwoBranchModel from sklearn.metrics import f1_score, r2_score import tensorflow as tf from sklearn.utils import shuffle import time from sklearn.utils.extmath import softmax gpu_options = tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=0.45) sess = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(gpu_options=gpu_options)) def getBatch(X, i, batch_size): start_id = i*batch_size t = (i+1) * batch_size end_id = min( (i+1) * batch_size, X.shape[0]) batch_x = X[start_id:end_id] return batch_x def trainClassifS2(model, x_train_S2_pixel, x_train_S2_obj, y_train, loss_object, optimizer, BATCH_SIZE): loss_object2 = tf.keras.losses.CategoricalCrossentropy() loss_iteration = 0 tot_loss = 0 iterations = x_train_S2_pixel.shape[0] / BATCH_SIZE if x_train_S2_pixel.shape[0] % BATCH_SIZE != 0: iterations += 1 print("n iterations %d" % iterations) for ibatch in range(int(iterations)): batch_x_S2_p = getBatch(x_train_S2_pixel, ibatch, BATCH_SIZE) batch_x_S2_obj = getBatch(x_train_S2_obj, ibatch, BATCH_SIZE) batch_y = getBatch(y_train, ibatch, BATCH_SIZE) with tf.GradientTape() as gen_tape: mainEstim, pixEstim, objEstim = model((batch_x_S2_p, batch_x_S2_obj), training=True) loss = loss_object(batch_y, mainEstim) #loss += .5 * loss_object(batch_y, pixEstim) #loss += .5 * loss_object(batch_y, objEstim) loss += .5 * loss_object(batch_y, pixEstim) loss += .5 * loss_object(batch_y, objEstim) grad_of_G = gen_tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(grad_of_G, model.trainable_variables)) tot_loss+=loss return (tot_loss / iterations) ts_train_S2_pixel = np.load(sys.argv[1]) ts_train_S2_pixel = ts_train_S2_pixel[:,8,8,:,:] ts_train_S2_pixel = np.squeeze(ts_train_S2_pixel) ts_train_S2_obj = np.load(sys.argv[2]) #print(ts_train_S2_pixel) #print(ts_train_S2_obj) #exit() label_train = np.load(sys.argv[3]) label_train = label_train[:,1] label_train = label_train-1 #print(label_train.shape) #print(np.amin(label_train)) ts_valid_S2_pixel = np.load(sys.argv[4]) ts_valid_S2_pixel = ts_valid_S2_pixel[:,8,8,:,:] ts_valid_S2_pixel = np.squeeze(ts_valid_S2_pixel) ts_valid_S2_obj = np.load(sys.argv[5]) label_valid = np.load(sys.argv[6]) label_valid = label_valid[:,1] label_valid = label_valid-1 ts_test_S2_pixel = np.load(sys.argv[7]) ts_test_S2_pixel = ts_test_S2_pixel[:,8,8,:,:] ts_test_S2_pixel = np.squeeze(ts_test_S2_pixel) ts_test_S2_obj = np.load(sys.argv[8]) label_test = np.load(sys.argv[9]) label_test = label_test[:,1] label_test = label_test-1 output_dir_models = sys.argv[10] split_id = sys.argv[11] #exit() n_classes = len(np.unique(label_train)) model = TwoBranchModel(128, "model", n_classes, dropout_rate=0.4) #label_train = tf.keras.utils.to_categorical(label_train) print("model created") #DI = Discr() """ defining loss function and the optimizer to use in the training phase """ #loss_object = tf.keras.losses.Huber() loss_function = tf.keras.losses.SparseCategoricalCrossentropy() optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001) BATCH_SIZE = 256 n_epochs = 300 best_valid_fMeasure = 0 for e in range(n_epochs): ts_train_S2_pixel, ts_train_S2_obj, label_train = shuffle(ts_train_S2_pixel, ts_train_S2_obj, label_train) start = time.time() trainLoss = trainClassifS2(model, ts_train_S2_pixel, ts_train_S2_obj, label_train, loss_function, optimizer, BATCH_SIZE) end = time.time() elapsed = end - start pred = model.predict((ts_valid_S2_pixel, ts_valid_S2_obj)) pred = pred[0] fscore = f1_score(label_valid, np.argmax(pred,axis=1), average="weighted") print("epoch %d with loss %f and F-Measure in %f seconds" % (e, trainLoss, elapsed)) if fscore > best_valid_fMeasure: best_valid_fMeasure = fscore model.save_weights(output_dir_models+"/model_"+split_id) print("\tBEST current results on validation %f" % fscore) pred_test = model.predict((ts_test_S2_pixel, ts_test_S2_obj)) pred_test = pred_test[0] fscore_test = f1_score(label_test, np.argmax(pred_test,axis=1), average="weighted") print("\t\tresults on TEST %f" % fscore_test) #print(f1_score(label_valid, np.argmax(pred,axis=1), average=None)) sys.stdout.flush()