diff --git a/keras_CNN_RNN.py b/keras_CNN_RNN.py index ac601dffbedeed99cf9b334c5b44b256603d7d0d..94500a1c5b66a36abc2bc1c9b08c2484f6b6748b 100644 --- a/keras_CNN_RNN.py +++ b/keras_CNN_RNN.py @@ -4,6 +4,7 @@ import keras from keras.callbacks import CSVLogger import keras.backend as K import csv +import sys #from sklearn.ensemble import RandomForestClassifier from sklearn.metrics import accuracy_score,f1_score,confusion_matrix,precision_recall_fscore_support @@ -14,6 +15,8 @@ ts_size = 16 patch_size = 25 n_bands = 4 +n_epochs = int(sys.argv[1]) +attTanh = int(sys.argv[2]) == 1 # KOUMBIA #Load training inputs @@ -92,7 +95,7 @@ resh = keras.layers.Reshape(input_shape=(n_timestamps*ts_size,),target_shape=(n_ rnn_out = keras.layers.GRU(256,return_sequences=True,name='gru_base')(resh) #rnn_out = keras.layers.GRU(512,name='gru_base')(input_ts) rnn_out = keras.layers.Dropout(rate=0.5,name='gru_dropout')(rnn_out) -rnn_out = BasicAttention(name='gru_attention', with_tanh=False)(rnn_out) +rnn_out = BasicAttention(name='gru_attention', with_tanh=attTanh)(rnn_out) rnn_aux = keras.layers.Dense(n_classes,activation='softmax',name='rnn_dense_layer_'+str(n_classes))(rnn_out) #CNN branch