Commit d3a83260 authored by mlang's avatar mlang
Browse files

in add_to_predict_shp, change prediction field type to float

parent 73e1faad
No related merge requests found
Showing with 13 additions and 3 deletions
+13 -3
...@@ -389,6 +389,7 @@ def get_stats_from_shape(vector, raster, stats, nodata = None, field_id = 'DN', ...@@ -389,6 +389,7 @@ def get_stats_from_shape(vector, raster, stats, nodata = None, field_id = 'DN',
#get params #get params
rst = Raster(raster) rst = Raster(raster)
nodata = rst.nodata_value
n_band = rst.count n_band = rst.count
del rst del rst
init = True init = True
...@@ -400,7 +401,7 @@ def get_stats_from_shape(vector, raster, stats, nodata = None, field_id = 'DN', ...@@ -400,7 +401,7 @@ def get_stats_from_shape(vector, raster, stats, nodata = None, field_id = 'DN',
#get stats with rasterstats module #get stats with rasterstats module
my_stats = zonal_stats(vector,raster, stats = stats, my_stats = zonal_stats(vector,raster, stats = stats,
band = idx_band + 1, geojson_out = True, band = idx_band + 1, geojson_out = True,
nodata = None) nodata = nodata)
#convert it to an X an Y array #convert it to an X an Y array
#----------------------------- #-----------------------------
...@@ -432,6 +433,7 @@ def get_stats_from_shape(vector, raster, stats, nodata = None, field_id = 'DN', ...@@ -432,6 +433,7 @@ def get_stats_from_shape(vector, raster, stats, nodata = None, field_id = 'DN',
t = sp.where(count != 1) t = sp.where(count != 1)
for elem in t[0]: for elem in t[0]:
print 'Warning : Label {} appears {} times'.format(labels[elem], count[elem]) print 'Warning : Label {} appears {} times'.format(labels[elem], count[elem])
return X, Y, FID return X, Y, FID
...@@ -462,19 +464,27 @@ def add_predict_to_shp(vector, Y_predict, FID, **kwargs): ...@@ -462,19 +464,27 @@ def add_predict_to_shp(vector, Y_predict, FID, **kwargs):
field_names = [layerDefinition.GetFieldDefn(i).GetName() for i in range(layerDefinition.GetFieldCount())] field_names = [layerDefinition.GetFieldDefn(i).GetName() for i in range(layerDefinition.GetFieldCount())]
# Add a new field # Add a new field
new_field = ogr.FieldDefn(field_pred, ogr.OFTInteger) new_field = ogr.FieldDefn(field_pred, ogr.OFTReal)
layer.CreateField(new_field) layer.CreateField(new_field)
count_error = 0
for feat in layer : for feat in layer :
fid = feat.GetField(field_id) fid = feat.GetField(field_id)
t = sp.where(FID == fid)[0] t = sp.where(FID == fid)[0]
pred = Y_predict[t] pred = Y_predict[t]
if isinstance(pred, sp.ndarray):
if pred.shape[0] == 0:
count_error += 1
continue
else:
pred = pred[0]
try : try :
feat.SetField(field_pred, int(pred)) feat.SetField(field_pred, int(pred))
except TypeError : except TypeError :
print 'Warning type error : ', pred print 'Warning type error : ', pred
layer.SetFeature(feat) layer.SetFeature(feat)
source.Destroy() source.Destroy()
print('{} prediction could not have been assign to' +
' a shape polygon'.format(count_error))
def extract_sample(X,Y,FID): def extract_sample(X,Y,FID):
""" """
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment