From d3a83260c8f512aa43acd1d46abde0cca4b7d6ab Mon Sep 17 00:00:00 2001 From: mlang <marc.lang@teledetection.fr> Date: Tue, 5 Jun 2018 14:28:18 +0200 Subject: [PATCH] in add_to_predict_shp, change prediction field type to float --- ymraster/classification.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/ymraster/classification.py b/ymraster/classification.py index 0317611..dd27cbb 100755 --- a/ymraster/classification.py +++ b/ymraster/classification.py @@ -389,6 +389,7 @@ def get_stats_from_shape(vector, raster, stats, nodata = None, field_id = 'DN', #get params rst = Raster(raster) + nodata = rst.nodata_value n_band = rst.count del rst init = True @@ -400,7 +401,7 @@ def get_stats_from_shape(vector, raster, stats, nodata = None, field_id = 'DN', #get stats with rasterstats module my_stats = zonal_stats(vector,raster, stats = stats, band = idx_band + 1, geojson_out = True, - nodata = None) + nodata = nodata) #convert it to an X an Y array #----------------------------- @@ -432,6 +433,7 @@ def get_stats_from_shape(vector, raster, stats, nodata = None, field_id = 'DN', t = sp.where(count != 1) for elem in t[0]: print 'Warning : Label {} appears {} times'.format(labels[elem], count[elem]) + return X, Y, FID @@ -462,19 +464,27 @@ def add_predict_to_shp(vector, Y_predict, FID, **kwargs): field_names = [layerDefinition.GetFieldDefn(i).GetName() for i in range(layerDefinition.GetFieldCount())] # Add a new field - new_field = ogr.FieldDefn(field_pred, ogr.OFTInteger) + new_field = ogr.FieldDefn(field_pred, ogr.OFTReal) layer.CreateField(new_field) - + count_error = 0 for feat in layer : fid = feat.GetField(field_id) t = sp.where(FID == fid)[0] pred = Y_predict[t] + if isinstance(pred, sp.ndarray): + if pred.shape[0] == 0: + count_error += 1 + continue + else: + pred = pred[0] try : feat.SetField(field_pred, int(pred)) except TypeError : print 'Warning type error : ', pred layer.SetFeature(feat) source.Destroy() + print('{} prediction could not have been assign to' + + ' a shape polygon'.format(count_error)) def extract_sample(X,Y,FID): """ -- GitLab