From 88f49451b4ed90c2e4cf4b737d16c0d01cebf171 Mon Sep 17 00:00:00 2001 From: "raffaele.gaetano" <raffaele.gaetano@cirad.fr> Date: Fri, 4 Oct 2019 17:34:22 +0200 Subject: [PATCH] ENH: Added TanH to Attention (Dino). --- keras_CNN_RNN.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/keras_CNN_RNN.py b/keras_CNN_RNN.py index ac601df..94500a1 100644 --- a/keras_CNN_RNN.py +++ b/keras_CNN_RNN.py @@ -4,6 +4,7 @@ import keras from keras.callbacks import CSVLogger import keras.backend as K import csv +import sys #from sklearn.ensemble import RandomForestClassifier from sklearn.metrics import accuracy_score,f1_score,confusion_matrix,precision_recall_fscore_support @@ -14,6 +15,8 @@ ts_size = 16 patch_size = 25 n_bands = 4 +n_epochs = int(sys.argv[1]) +attTanh = int(sys.argv[2]) == 1 # KOUMBIA #Load training inputs @@ -92,7 +95,7 @@ resh = keras.layers.Reshape(input_shape=(n_timestamps*ts_size,),target_shape=(n_ rnn_out = keras.layers.GRU(256,return_sequences=True,name='gru_base')(resh) #rnn_out = keras.layers.GRU(512,name='gru_base')(input_ts) rnn_out = keras.layers.Dropout(rate=0.5,name='gru_dropout')(rnn_out) -rnn_out = BasicAttention(name='gru_attention', with_tanh=False)(rnn_out) +rnn_out = BasicAttention(name='gru_attention', with_tanh=attTanh)(rnn_out) rnn_aux = keras.layers.Dense(n_classes,activation='softmax',name='rnn_dense_layer_'+str(n_classes))(rnn_out) #CNN branch -- GitLab