diff --git a/main.py b/main.py
index 9a4bc42931f8acebd42a4f49015887d8b9fbc73e..dd7b3cb76c3b78c8cef36b987655f1b040e2a81e 100644
--- a/main.py
+++ b/main.py
@@ -7,7 +7,8 @@ from sklearn.utils import shuffle
 import time
 from sklearn.utils.extmath import softmax
 
-
+gpu_options = tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=0.45)
+sess = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(gpu_options=gpu_options))
 
 def getBatch(X, i, batch_size):
     start_id = i*batch_size
@@ -16,24 +17,27 @@ def getBatch(X, i, batch_size):
     batch_x = X[start_id:end_id]
     return batch_x
 
-#@tf.function
-#trainClassifS1(model, discr, ts_train_S2, featTHRS, label_train, loss_function, optimizer2, optimizerDI, BATCH_SIZE, e, n_classes)
-#trainClassifS2(model, ts_train_S2_pixel, ts_train_S2_obj, label_train, loss_function, optimizer, BATCH_SIZE)
-
 def trainClassifS2(model, x_train_S2_pixel, x_train_S2_obj, y_train, loss_object, optimizer, BATCH_SIZE):
+    loss_object2 = tf.keras.losses.CategoricalCrossentropy()
     loss_iteration = 0
+    tot_loss = 0
     iterations = x_train_S2_pixel.shape[0] / BATCH_SIZE
     if x_train_S2_pixel.shape[0] % BATCH_SIZE != 0:
         iterations += 1
-
+    print("n iterations %d" % iterations)
     for ibatch in range(int(iterations)):
         batch_x_S2_p = getBatch(x_train_S2_pixel, ibatch, BATCH_SIZE)
         batch_x_S2_obj = getBatch(x_train_S2_obj, ibatch, BATCH_SIZE)
         batch_y = getBatch(y_train, ibatch, BATCH_SIZE)
 
         with tf.GradientTape() as gen_tape:
-            mainEstim  = model((batch_x_S2_p, batch_x_S2_obj), training=True)
+            mainEstim, pixEstim, objEstim  = model((batch_x_S2_p, batch_x_S2_obj), training=True)
             loss = loss_object(batch_y, mainEstim)
+            #loss += .5 * loss_object(batch_y, pixEstim)
+            #loss += .5 * loss_object(batch_y, objEstim)
+            loss += .5 * loss_object(batch_y, pixEstim)
+            loss += .5 * loss_object(batch_y, objEstim)
+
             grad_of_G = gen_tape.gradient(loss, model.trainable_variables)
             optimizer.apply_gradients(zip(grad_of_G, model.trainable_variables))
             tot_loss+=loss
@@ -42,29 +46,52 @@ def trainClassifS2(model, x_train_S2_pixel, x_train_S2_obj, y_train, loss_object
 
 
 ts_train_S2_pixel = np.load(sys.argv[1])
+ts_train_S2_pixel = ts_train_S2_pixel[:,8,8,:,:]
+ts_train_S2_pixel = np.squeeze(ts_train_S2_pixel)
 ts_train_S2_obj = np.load(sys.argv[2])
 
 
-print(ts_train_S2_pixel)
-print(ts_train_S2_obj)
-exit()
+#print(ts_train_S2_pixel)
+#print(ts_train_S2_obj)
+#exit()
 
 label_train = np.load(sys.argv[3])
+label_train = label_train[:,1]
 label_train = label_train-1
 
+#print(label_train.shape)
+#print(np.amin(label_train))
+
 ts_valid_S2_pixel = np.load(sys.argv[4])
+ts_valid_S2_pixel = ts_valid_S2_pixel[:,8,8,:,:]
+ts_valid_S2_pixel = np.squeeze(ts_valid_S2_pixel)
 ts_valid_S2_obj = np.load(sys.argv[5])
 
 label_valid = np.load(sys.argv[6])
+label_valid = label_valid[:,1]
 label_valid = label_valid-1
 
-output_dir_models = sys.argv[7]
-split_id = sys.argv[8]
+
+ts_test_S2_pixel = np.load(sys.argv[7])
+ts_test_S2_pixel = ts_test_S2_pixel[:,8,8,:,:]
+ts_test_S2_pixel = np.squeeze(ts_test_S2_pixel)
+ts_test_S2_obj = np.load(sys.argv[8])
+
+label_test = np.load(sys.argv[9])
+label_test = label_test[:,1]
+label_test = label_test-1
+
+
+
+output_dir_models = sys.argv[10]
+split_id = sys.argv[11]
+
+#exit()
 
 n_classes = len(np.unique(label_train))
-model = TwoBranchModel(128, "model", n_classes, dropout_rate=0.2)
+model = TwoBranchModel(128, "model", n_classes, dropout_rate=0.4)
 
-label_train = tf.keras.utils.to_categorical(label_train)
+#label_train = tf.keras.utils.to_categorical(label_train)
 
 print("model created")
 #DI = Discr()
@@ -74,8 +101,7 @@ loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
 optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)
 
 BATCH_SIZE = 256
-n_epochs = 1000
-
+n_epochs = 300
 
 best_valid_fMeasure = 0
 for e in range(n_epochs):
@@ -85,10 +111,16 @@ for e in range(n_epochs):
     end = time.time()
     elapsed = end - start
     pred = model.predict((ts_valid_S2_pixel, ts_valid_S2_obj))
+    pred = pred[0]
     fscore = f1_score(label_valid, np.argmax(pred,axis=1), average="weighted")
+    print("epoch %d with loss %f and F-Measure in %f seconds" % (e, trainLoss, elapsed))
     if fscore > best_valid_fMeasure:
-    	best_valid_fMeasure = fscore
-    	model.save_weights(output_dir_models+"/model_"+split_id)
-    print("epoch %d with loss %f and F-Measure on validation %f in %f seconds" % (e, trainLoss, fscore, elapsed))
-    print(f1_score(label_valid, np.argmax(pred,axis=1), average=None))
+        best_valid_fMeasure = fscore
+        model.save_weights(output_dir_models+"/model_"+split_id)
+        print("\tBEST current results on validation %f" % fscore)
+        pred_test = model.predict((ts_test_S2_pixel, ts_test_S2_obj))
+        pred_test = pred_test[0]
+        fscore_test = f1_score(label_test, np.argmax(pred_test,axis=1), average="weighted")
+        print("\t\tresults on TEST %f" % fscore_test)
+    #print(f1_score(label_valid, np.argmax(pred,axis=1), average=None))
     sys.stdout.flush()