Commit e48c4b5b authored by Gaetano Raffaele's avatar Gaetano Raffaele

ENH: parallel classification now possible

parent 89f5003c
......@@ -14,7 +14,7 @@ import csv
from computeFeatures import featureComputation, readConfigFile
from segmentationWorkflow import segmentationWorkflow, generateGTSamples, generateVALSamples
from classificationWorkflow import training, classify
from mtdUtils import checkSRS, getRasterInfo, getFieldNames
from mtdUtils import checkSRS, getRasterInfo, getFieldNames, keepFields
from validationFramework import pixelValidation,surfaceValidation,formatValidationTxt
import time
......@@ -652,6 +652,8 @@ def main(argv):
classifier = 'libsvm'
stat_file = model_fld + '/GT_stats.xml'
shpd = ogr.GetDriverByName('ESRI Shapefile')
if input_runlevel < 7:
for cfield in cfieldlist:
......@@ -664,9 +666,12 @@ def main(argv):
ref_shp = config.get('GENERAL CONFIGURATION', 'validation')
val_mode = int(config.get('GENERAL CONFIGURATION', 'validmode'))
val_list = [val_fld + '/VAL_samples.shp']
val_out_tmp = final_fld + '/VAL_samples_' + cfield + '_tmp.shp'
val_out = final_fld + '/VAL_samples_' + cfield + '.shp'
val_out_check = final_fld + '/VAL_samples_' + cfield + '_check.shp'
classify(val_list, 'p'+cfield, stat_file, model_file, val_out, var_list)
classify(val_list, 'p'+cfield, stat_file, model_file, val_out_tmp, var_list)
keepFields(val_out_tmp,val_out,['Segment_ID','p'+cfield])
shpd.DeleteDataSource(val_out_tmp)
txt_out = final_fld + '/VAL_samples.' + cfield + '.report.txt'
if val_mode == 0:
pixelValidation(ref_shp, val_out, seg_output, txt_out, cfield)
......@@ -684,7 +689,10 @@ def main(argv):
ref_list = []
for cshp in shp_list:
map_out = final_fld + '/MAPS/VECTOR_' + cfield + '/' + os.path.basename(cshp).replace('.shp','_' + cfield + '.shp')
classify([cshp],cfield,stat_file,model_file,map_out,var_list)
map_out_tmp = final_fld + '/MAPS/VECTOR_' + cfield + '/' + os.path.basename(cshp).replace('.shp','_' + cfield + '_tmp.shp')
classify([cshp],cfield,stat_file,model_file,map_out_tmp,var_list)
keepFields(map_out_tmp,map_out,['Segment_ID','p'+cfield])
shpd.DeleteDataSource(map_out_tmp)
map_list.append(map_out)
ref_list.append(cshp.replace('.shp','.tif'))
......
......@@ -782,3 +782,29 @@ def getFieldNames(shp):
schema.append(fdefn.name)
return schema
def keepFields(src_shp,out_shp,except_list):
shpd = ogr.GetDriverByName('ESRI Shapefile')
dst = shpd.CreateDataSource(out_shp)
src_ds = ogr.Open(src_shp,0)
ly = src_ds.GetLayer()
dst_ly = dst.CreateLayer(os.path.splitext(os.path.basename(out_shp))[0],
srs=ly.GetSpatialRef(),
geom_type=ly.GetLayerDefn().GetGeomType())
ldef = ly.GetLayerDefn()
toAdd = []
for i in range(ldef.GetFieldCount()):
if ldef.GetFieldDefn(i).name in except_list:
toAdd.append(i)
dst_ly.CreateField(ldef.GetFieldDefn(i))
for f in ly:
dstf = ogr.Feature(dst_ly.GetLayerDefn())
dstf.SetGeometry(f.GetGeometryRef())
for i in range(len(toAdd)):
dstf.SetField(i,f.GetField(toAdd[i]))
dst_ly.CreateFeature(dstf)
src_ds = None
dst = None
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment