From 6ede0bc555d2b1a3fe8985580acba300028ad83f Mon Sep 17 00:00:00 2001
From: mlang <marc.lang@teledetection.fr>
Date: Wed, 12 Dec 2018 12:13:10 +0100
Subject: [PATCH] Add of a second recipe

update recipe
---
 .gitignore                                  |   4 +
 scripts/classification_from_segmentation.py | 127 +++++++++++++++++++-
 2 files changed, 130 insertions(+), 1 deletion(-)

diff --git a/.gitignore b/.gitignore
index c6c5a74..b1a12f3 100755
--- a/.gitignore
+++ b/.gitignore
@@ -60,3 +60,7 @@ data/l8_20130707\.tif\.aux\.xml
 data/l8_20130831\.tif\.aux\.xml
 
 qgis-auth\.db
+
+\.spyderproject
+
+\.spyderworkspace
diff --git a/scripts/classification_from_segmentation.py b/scripts/classification_from_segmentation.py
index a5b7ea4..ba8d170 100644
--- a/scripts/classification_from_segmentation.py
+++ b/scripts/classification_from_segmentation.py
@@ -20,6 +20,132 @@ except:
     print 'pandas not imported'
 
 
+
+def my_recipe2(**kwargs):
+
+    stats = kwargs.get('stats', ['mean', 'std'])
+    field_id = kwargs.get('field_id', 'DN')
+    field_class = kwargs.get('field_class', 'class')
+    train_size = kwargs.get('train_size', 0.5)
+    nodata = kwargs.get('nodata', None)
+    target_names = kwargs.get('target_names')
+    debug = kwargs.get('debug', False)
+    field_pred = kwargs.get('field_pred', 'pred_class')
+
+    n_jobs = kwargs.get('n_jobs', 1)
+    save = kwargs.get('save', False)
+    X_train = kwargs.get('X_train')
+    Y_train = kwargs.get('Y_train')
+    X_test = kwargs.get('X_test')
+    Y_test = kwargs.get('Y_test')
+    X = Y = FID = None
+    FID_train = None
+    FID_test = None
+    i=0
+    base = kwargs.get('base')
+    X = np.load(base + '_stats_allsegment_X{}.npy'.format(i))
+    Y = np.load(base + '_stats_allsegment_Y{}.npy'.format(i))
+    FID = np.load(base + '_stats_allsegment_FID{}.npy'.format(i))
+
+    #X_train = np.load(base + '_X_train.npy')
+    #Y_train = np.load(base + '_Y_train.npy')
+    #X_test = np.load(base + '_X_test.npy')
+    #Y_test = np.load(base + '_Y_test.npy')
+
+    dict_mat = {}
+    dict_mat['nb_of_shp'] = 1
+    dict_mat['gml'] = False
+
+    if kwargs.get('cross_validation'):
+        if save :
+            base = kwargs.get('base')
+            base += 'cross_validation.csv'
+        else:
+            base = None
+        X_samp = kwargs.get('X_samp')
+        Y_samp = kwargs.get('Y_samp')
+        FID_samp = kwargs.get('FID_samp')
+        temp_result = train_test_split(X_samp, Y_samp, FID_samp,
+                                       train_size=train_size)
+        (X_train, X_test, Y_train, Y_test,
+             FID_train, FID_test) = temp_result
+        scaler = StandardScaler().fit(X_train)
+        cross_validation_dict = kwargs.get('cross_validation_dict')
+        clf = cla.select_best_svm_classifier_sample(scaler.transform(X_train),
+                                                    Y_train,
+                                                    output_result_name=base,
+                                                    **cross_validation_dict)
+        estimator = clf.best_estimator_
+        if save:
+            base = kwargs.get('base')
+
+            X_train_name = base + '_X_train.npy'
+            X_test_name = base + '_X_test.npy'
+            Y_train_name = base + '_Y_train.npy'
+            Y_test_name = base + '_Y_test.npy'
+            FID_train_name = base + '_FID_train.npy'
+            FID_test_name = base + '_FID_test.npy'
+
+            np.save(X_train_name, X_train)
+            np.save(X_test_name, X_test)
+            np.save(Y_train_name, Y_train)
+            np.save(Y_test_name, Y_test)
+            np.save(FID_train_name, FID_train)
+            np.save(FID_test_name, FID_test)
+    else:
+        X_train = kwargs.get('X_train')
+        Y_train = kwargs.get('Y_train')
+        X_test = kwargs.get('X_test')
+        Y_test = kwargs.get('Y_test')
+        scaler = StandardScaler().fit(X_train)
+        estimator = kwargs.get('estimator', None)
+
+    dict_mat['X'], dict_mat['Y'], dict_mat['FID'] = scaler.transform(X), Y, FID
+    dict_mat['X_train'], dict_mat['X_test'] = (scaler.transform(X_train),
+                                               scaler.transform(X_test))
+    dict_mat['Y_train'], dict_mat['Y_test'] = Y_train, Y_test
+
+    temp_result = cla.classifier(scaler.transform(X_train), Y_train,
+                                 scaler.transform(X_test),
+                                 Y_test, scaler.transform(X), dict_mat,
+                                 estimator=estimator, n_jobs=n_jobs)
+    Y_predict, classif, dict_mat, classif_parameters = temp_result
+
+    print 'Control : classifier done'
+    print classif_parameters
+
+    # Get some quality indices
+    cm, report, accuracy = cla.pred_error_metrics(Y_predict, Y_test,
+                                                  target_names=target_names)
+
+    kappa = cohen_kappa_score(Y_predict, Y_test)
+
+    dict_output = {'cm': cm,
+                   'kappa': kappa,
+                   'accuracy': accuracy,
+                   'report': report,
+                   'Y_predict': Y_predict,
+                   'Y_true': Y_test,
+                   'classif': classif,
+                   'FID': FID,
+                   'estimator': str(classif_parameters)}
+
+    # dict_output.update(dict_mat)
+    dict_output = dict(dict_output, **dict_mat)
+
+    if save:
+        base = kwargs.get('base')
+        try:
+            fichier = open(base + '_classifier.txt', 'w')
+            fichier.write(dict_output['estimator'])
+            fichier.close()
+        except:
+            print 'classifier not saved'
+        save_output(dict_output, base + '_classif')
+
+    return dict_output
+
+
 def my_recipe(segmentation, roi_file, raster,
               **kwargs):
     """
@@ -591,7 +717,6 @@ def classification_from_A_to_Z(segmentation, roi_file, raster,
                                    dict_mat['FID_{}'.format(i+1)],
                                    field_pred=field_pred, field_id=field_id)
         print 'Control : vector updated'
-
     # Get some quality indices
     cm, report, accuracy = cla.pred_error_metrics(Y_predict, Y_test,
                                                   target_names=target_names)
-- 
GitLab