Commit eec971ed authored by Gaetano Raffaele's avatar Gaetano Raffaele

ENH: Made classification step parallel (using Nproc)

parent 647dad88
......@@ -122,7 +122,7 @@ def training(shp,code,model_fld,params,feat,feat_mode = 'list'):
return flds
def classify(shp_list,code,stat_file,model_file,out_file,feat,feat_mode = 'list',Nproc=1):
def classify(shp_list,code,stat_file,model_file,out_fld,out_ext,feat,feat_mode = 'list',Nproc=1):
# Platform dependent parameters
if platform.system() == 'Linux':
......@@ -139,15 +139,45 @@ def classify(shp_list,code,stat_file,model_file,out_file,feat,feat_mode = 'list'
else:
sys.exit('ERROR: mode ' + feat_mode + ' not valid.')
'''
for shp in shp_list:
#roughFix(shp, flds)
if platform.system() == 'Linux':
cmd = ['otbcli_VectorClassifier','-in',shp,'-instat',stat_file,'-model',model_file,'-out',out_file,'-cfield',code,'-feat'] + flds
subprocess.call(cmd,shell=sh)
elif platform.system() == 'Windows':
import otbApplication
app = otbApplication.Registry.CreateApplication('VectorClassifier')
app.SetParameterString('in',shp)
app.SetParameterString('instat', stat_file)
app.SetParameterString('model', model_file)
app.SetParameterString('out', out_file)
app.SetParameterString('cfield', code)
app.UpdateParameters()
app.SetParameterStringList('feat',flds)
app.UpdateParameters()
app.ExecuteAndWriteOutput()
else:
sys.exit('Platform not supported!')
return
'''
if platform.system() == 'Linux':
cmd_list = []
out_file_list = []
for shp in shp_list:
out_file = out_fld + '/' + os.path.basename(shp).replace('.shp', out_ext + '.shp')
cmd = ['otbcli_VectorClassifier', '-in', shp, '-instat', stat_file, '-model', model_file, '-out', out_file,
'-cfield', code, '-feat'] + flds
cmd_list.append(cmd)
out_file_list.append(out_file)
queuedProcess(cmd_list,Nproc,shell=sh)
return out_file_list
elif platform.system() == 'Windows':
out_file_list = []
for shp in shp_list:
out_file = out_fld + '/' + os.path.basename(shp).replace('.shp', out_ext + '.shp')
import otbApplication
app = otbApplication.Registry.CreateApplication('VectorClassifier')
app.SetParameterString('in', shp)
......@@ -159,29 +189,8 @@ def classify(shp_list,code,stat_file,model_file,out_file,feat,feat_mode = 'list'
app.SetParameterStringList('feat', flds)
app.UpdateParameters()
app.ExecuteAndWriteOutput()
out_file_list.append(out_file)
return out_file_list
else:
sys.exit('Platform not supported!')
'''
for shp in shp_list:
#roughFix(shp, flds)
if platform.system() == 'Linux':
cmd = ['otbcli_VectorClassifier','-in',shp,'-instat',stat_file,'-model',model_file,'-out',out_file,'-cfield',code,'-feat'] + flds
subprocess.call(cmd,shell=sh)
elif platform.system() == 'Windows':
import otbApplication
app = otbApplication.Registry.CreateApplication('VectorClassifier')
app.SetParameterString('in',shp)
app.SetParameterString('instat', stat_file)
app.SetParameterString('model', model_file)
app.SetParameterString('out', out_file)
app.SetParameterString('cfield', code)
app.UpdateParameters()
app.SetParameterStringList('feat',flds)
app.UpdateParameters()
app.ExecuteAndWriteOutput()
else:
sys.exit('Platform not supported!')
'''
return
......@@ -648,12 +648,11 @@ def main(argv):
ref_shp = config.get('GENERAL CONFIGURATION', 'validation')
val_mode = int(config.get('GENERAL CONFIGURATION', 'validmode'))
val_list = [val_fld + '/VAL_samples.shp']
val_out_tmp = final_fld + '/VAL_samples_' + cfield + '_tmp.shp'
val_out = final_fld + '/VAL_samples_' + cfield + '.shp'
val_out_check = final_fld + '/VAL_samples_' + cfield + '_check.shp'
classify(val_list, 'p'+cfield, stat_file, model_file, val_out_tmp, var_list)
keepFields(val_out_tmp,val_out,['Segment_ID',cfield,'p'+cfield])
shpd.DeleteDataSource(val_out_tmp)
val_out_tmp_list = classify(val_list, 'p'+cfield, stat_file, model_file, final_fld, '_' + cfield + '_tmp', var_list)
keepFields(val_out_tmp_list[0],val_out,['Segment_ID',cfield,'p'+cfield])
shpd.DeleteDataSource(val_out_tmp_list[0])
txt_out = final_fld + '/VAL_samples.' + cfield + '.report.txt'
if val_mode == 0:
pixelValidation(ref_shp, val_out, seg_output, txt_out, cfield)
......@@ -669,23 +668,28 @@ def main(argv):
shp_list = glob.glob(test_fld + '/segmentation_*.shp')
map_list = []
ref_list = []
for cshp in shp_list:
map_out = final_fld + '/MAPS/VECTOR_' + cfield + '/' + os.path.basename(cshp).replace('.shp','_' + cfield + '.shp')
map_out_tmp = final_fld + '/MAPS/VECTOR_' + cfield + '/' + os.path.basename(cshp).replace('.shp','_' + cfield + '_tmp.shp')
classify([cshp],'p'+cfield,stat_file,model_file,map_out_tmp,var_list,Nproc=N_proc)
keepFields(map_out_tmp,map_out,['Segment_ID','p'+cfield])
shpd.DeleteDataSource(map_out_tmp)
ref_list.append(cshp.replace('.shp', '.tif'))
map_tmp_list = classify(shp_list,'p'+cfield,stat_file,model_file,final_fld + '/MAPS/VECTOR_' + cfield + '/','_' + cfield + '_tmp',var_list,Nproc=N_proc)
for cshp in map_tmp_list:
map_out = cshp.replace('_tmp.shp','.shp')
keepFields(cshp,map_out,['Segment_ID','p'+cfield])
shpd.DeleteDataSource(cshp)
map_list.append(map_out)
ref_list.append(cshp.replace('.shp','.tif'))
if rasout == 'VRT':
if not os.path.exists(final_fld + '/MAPS/RASTER_' + cfield):
os.mkdir(final_fld + '/MAPS/RASTER_' + cfield)
ras_list = []
cmd_list = []
for map,ref in zip(map_list,ref_list):
ras_list.append(final_fld + '/MAPS/RASTER_' + cfield + '/' + os.path.basename(map).replace('.shp', '.tif'))
cmd = ['otbcli_Rasterization', '-in', map, '-im', ref, '-mode', 'attribute', '-mode.attribute.field', 'p'+cfield, '-out', ras_list[-1]]
subprocess.call(cmd,shell=sh)
cmd_list.append(cmd)
queuedProcess(cmd_list,N_processes=N_proc,shell=sh)
cmd = ['gdalbuildvrt', '-srcnodata', '0', '-vrtnodata', '0', final_fld + '/MAPS/RASTER_' + cfield + '/Classif_' + cfield + '.vrt'] + ras_list
subprocess.call(cmd, shell=sh)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment