From 31d828a043dc21ffb0775a783a723e6c29966c39 Mon Sep 17 00:00:00 2001
From: SPeillet <peillet.seb@protonmail.com>
Date: Tue, 3 Sep 2019 13:38:31 +0200
Subject: [PATCH] ENH: add M3Fusion samples classification

---
 classificationWorkflow.py | 96 +++++++++++++++++++++++++++++++--------
 1 file changed, 78 insertions(+), 18 deletions(-)

diff --git a/classificationWorkflow.py b/classificationWorkflow.py
index 8e41e1a..5672ab1 100644
--- a/classificationWorkflow.py
+++ b/classificationWorkflow.py
@@ -339,15 +339,23 @@ def deepTraining(data,code,model_fld,params,feat,feat_mode = 'list', epochs=20):
 
     return model_file, scaler_files, csv_classes
 
-def deepClassify(shp_list,code,scaler,model_file, csv_classes,out_fld,out_ext,feat,feat_mode = 'list',Nproc=1,compute_confidence=False):
+def deepClassify(shp_list,code,scalers,model_file, csv_classes,out_fld,out_ext,feat,feat_mode = 'list',Nproc=1,compute_confidence=False,feat_folder='.',seg_src='.'):
     import geopandas
     import tensorflow as tf
     from sklearn.preprocessing import StandardScaler
+    import TFmodels
     import joblib
     import csv
     
-    scaler = joblib.load(scaler)
-    model = tf.keras.models.load_model(model_file)
+    if len(scalers)== 1:
+        method = 'standard'
+        scaler = joblib.load(scalers[0])
+        model = tf.keras.models.load_model(model_file)
+    elif len(scalers)== 2:
+        method = 'M3Fusion'
+        scaler_rnn = joblib.load(scalers[1])
+        scaler_cnn = joblib.load(scalers[0])
+
     dict_classes = {}
     with open(csv_classes,'r') as file:
         rdr = csv.reader(file)
@@ -355,21 +363,73 @@ def deepClassify(shp_list,code,scaler,model_file, csv_classes,out_fld,out_ext,fe
             dict_classes[int(row[0])]=int(row[1])
 
     out_file_list = []
-    for shp in shp_list:
-        out_file = os.path.join(out_fld, os.path.basename(shp).replace('.shp', out_ext + '.shp'))
-        ds = geopandas.read_file(shp)
-        feats = ds.filter(feat,axis=1)
-
-        cfeats = scaler.transform(feats)
-        predict = model.predict(cfeats)
-        class_predict = [dict_classes[np.where(p==np.max(p))[0][0]] for p in predict]
-        out = ds.filter(items=['Segment_ID',code[1:],'geometry'])
-        out.insert(out.shape[1]-1,code,class_predict)
-        if compute_confidence:
-            confmap = [np.max(p) for p in predict]
-            out.insert(out.shape[1]-1,'confidence',confmap)
-        out.to_file(out_file)
-        out_file_list.append(out_file)
+    if method == 'standard' :
+        for shp in shp_list:
+            out_file = os.path.join(out_fld, os.path.basename(shp).replace('.shp', out_ext + '.shp'))
+            ds = geopandas.read_file(shp)
+            feats = ds.filter(feat,axis=1)
+            cfeats = scaler.transform(feats)
+            predict = model.predict(cfeats)
+            class_predict = [dict_classes[np.where(p==np.max(p))[0][0]] for p in predict]
+            out = ds.filter(items=['Segment_ID',code[1:],'geometry'])
+            out.insert(out.shape[1]-1,code,class_predict)
+            if compute_confidence:
+                confmap = [np.max(p) for p in predict]
+                out.insert(out.shape[1]-1,'confidence',confmap)
+            out.to_file(out_file)
+            out_file_list.append(out_file)
+    elif method == 'M3Fusion':
+        import Moringa2DL
+        ts_size, n_timestamps = TFmodels.retrieveTSinfo(feat)
+        patch_size = 25
+        n_bands = 4
+        model = TFmodels.M3FusionModel(len(dict_classes),n_timestamps,ts_size,patch_size,n_bands)
+        opt = tf.keras.optimizers.Adam(learning_rate=0.0002)
+        x1= np.ndarray([1,n_timestamps,ts_size])
+        x2= np.ndarray([1,patch_size,patch_size,n_bands])
+        model.build([x1.shape,x2.shape])
+        model.load_weights(model_file)
+        model.compile(optimizer=opt,
+              loss="sparse_categorical_crossentropy",
+              loss_weights=[1,0.3,0.3],
+              metrics=['accuracy']
+              )
+        for shp in shp_list:
+            out_file = os.path.join(out_fld, os.path.basename(shp).replace('.shp', out_ext + '.shp'))
+            rnn, cnn, id_geom = Moringa2DL.generateAllForClassify(feat_folder,shp,seg_src,feat,cfield='Segment_ID',ofld=os.path.dirname(shp))
+            rnn_feats = np.load(rnn)
+            rnn_feats = scaler_rnn.transform(rnn_feats)
+            rnn_feats = np.reshape(rnn_feats, (rnn_feats.shape[0], n_timestamps, -1))
+            cnn_feats = np.load(cnn)
+            cnn_feats = scaler_cnn.transform(cnn_feats.reshape(-1,patch_size*patch_size*n_bands)).reshape(-1,patch_size,patch_size,n_bands)
+            predict = model.predict([rnn_feats,cnn_feats])
+            class_predict = np.array([dict_classes[np.where(p==np.max(p))[0][0]] for p in predict[0]])
+            id_join = np.load(id_geom)
+            stack = np.column_stack([id_geom,class_predict])
+            predict_dict={}
+            for i in np.unique(id_join):
+                predict_dict[int(i)]=[]
+            for i,j in stack:
+                predict_dict[i].append(j)
+            classif = []
+            for i,j in predict_dict.items():
+                classif.append(np.bincount(j).argmax())
+            ds = geopandas.read_file(shp)
+            out = ds.filter(items=['Segment_ID','geometry'])
+            out.sort_values(by='Segment_ID')
+            out.insert(out.shape[1]-1,code,classif)
+            if compute_confidence:
+                confmap = np.array([np.max(p) for p in predict[0]])
+                conf_stack = np.column_stack([id_geom,confmap])
+                cf_dict={}
+                for i,j in conf_stack:
+                    cf_dict[int(i)].append(j)
+                confidence = []
+                for i,j in cf_dict.items():
+                    confidence.append(np.mean(j))
+                out.insert(out.shape[1]-1,'confidence',confidence)
+            out.to_file(out_file)
+            out_file_list.append(out_file)
     return out_file_list
 
 def training(shp,code,model_fld,params,feat,feat_mode = 'list'):
-- 
GitLab