From 88f49451b4ed90c2e4cf4b737d16c0d01cebf171 Mon Sep 17 00:00:00 2001
From: "raffaele.gaetano" <raffaele.gaetano@cirad.fr>
Date: Fri, 4 Oct 2019 17:34:22 +0200
Subject: [PATCH] ENH: Added TanH to Attention (Dino).

---
 keras_CNN_RNN.py | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/keras_CNN_RNN.py b/keras_CNN_RNN.py
index ac601df..94500a1 100644
--- a/keras_CNN_RNN.py
+++ b/keras_CNN_RNN.py
@@ -4,6 +4,7 @@ import keras
 from keras.callbacks import CSVLogger
 import keras.backend as K
 import csv
+import sys
 
 #from sklearn.ensemble import RandomForestClassifier
 from sklearn.metrics import accuracy_score,f1_score,confusion_matrix,precision_recall_fscore_support
@@ -14,6 +15,8 @@ ts_size = 16
 patch_size = 25
 n_bands = 4
 
+n_epochs = int(sys.argv[1])
+attTanh = int(sys.argv[2]) == 1
 
 # KOUMBIA
 #Load training inputs
@@ -92,7 +95,7 @@ resh = keras.layers.Reshape(input_shape=(n_timestamps*ts_size,),target_shape=(n_
 rnn_out = keras.layers.GRU(256,return_sequences=True,name='gru_base')(resh)
 #rnn_out = keras.layers.GRU(512,name='gru_base')(input_ts)
 rnn_out = keras.layers.Dropout(rate=0.5,name='gru_dropout')(rnn_out)
-rnn_out = BasicAttention(name='gru_attention', with_tanh=False)(rnn_out)
+rnn_out = BasicAttention(name='gru_attention', with_tanh=attTanh)(rnn_out)
 rnn_aux = keras.layers.Dense(n_classes,activation='softmax',name='rnn_dense_layer_'+str(n_classes))(rnn_out)
 
 #CNN branch
-- 
GitLab