From d3a83260c8f512aa43acd1d46abde0cca4b7d6ab Mon Sep 17 00:00:00 2001
From: mlang <marc.lang@teledetection.fr>
Date: Tue, 5 Jun 2018 14:28:18 +0200
Subject: [PATCH] in add_to_predict_shp, change prediction field type to float

---
 ymraster/classification.py | 16 +++++++++++++---
 1 file changed, 13 insertions(+), 3 deletions(-)

diff --git a/ymraster/classification.py b/ymraster/classification.py
index 0317611..dd27cbb 100755
--- a/ymraster/classification.py
+++ b/ymraster/classification.py
@@ -389,6 +389,7 @@ def get_stats_from_shape(vector, raster, stats, nodata = None, field_id = 'DN',
 
     #get params
     rst = Raster(raster)
+    nodata = rst.nodata_value
     n_band = rst.count
     del rst
     init = True
@@ -400,7 +401,7 @@ def get_stats_from_shape(vector, raster, stats, nodata = None, field_id = 'DN',
         #get stats with rasterstats module
         my_stats  = zonal_stats(vector,raster, stats = stats,
                                 band = idx_band + 1, geojson_out = True,
-                                nodata = None)
+                                nodata = nodata)
 
         #convert it to an X an Y array
         #-----------------------------
@@ -432,6 +433,7 @@ def get_stats_from_shape(vector, raster, stats, nodata = None, field_id = 'DN',
         t = sp.where(count != 1)
         for elem in t[0]:
             print 'Warning : Label {} appears {} times'.format(labels[elem], count[elem])
+
     return X, Y, FID
 
 
@@ -462,19 +464,27 @@ def add_predict_to_shp(vector, Y_predict, FID, **kwargs):
     field_names = [layerDefinition.GetFieldDefn(i).GetName() for i in range(layerDefinition.GetFieldCount())]
 
     # Add a new field
-    new_field = ogr.FieldDefn(field_pred, ogr.OFTInteger)
+    new_field = ogr.FieldDefn(field_pred, ogr.OFTReal)
     layer.CreateField(new_field)
-
+    count_error = 0
     for feat in layer :
         fid = feat.GetField(field_id)
         t = sp.where(FID == fid)[0]
         pred = Y_predict[t]
+        if isinstance(pred, sp.ndarray):
+            if pred.shape[0] == 0:
+                count_error += 1
+                continue
+            else:
+                pred = pred[0]
         try :
             feat.SetField(field_pred, int(pred))
         except TypeError :
             print 'Warning type error : ', pred
         layer.SetFeature(feat)
     source.Destroy()
+    print('{} prediction could not have been assign to' +
+          ' a shape polygon'.format(count_error))
 
 def extract_sample(X,Y,FID):
     """
-- 
GitLab