Commit 6814b207 authored by Gaetano Raffaele's avatar Gaetano Raffaele

ENH: Made classification step parallel (using Nproc)

parent 460c11f9
......@@ -4,6 +4,8 @@ import subprocess
import platform
import numpy as np
from mtdUtils import queuedProcess
def getFeaturesFields(shp,flds_pref):
ds = ogr.Open(shp, 0)
......@@ -120,7 +122,7 @@ def training(shp,code,model_fld,params,feat,feat_mode = 'list'):
return flds
def classify(shp_list,code,stat_file,model_file,out_file,feat,feat_mode = 'list'):
def classify(shp_list,code,stat_file,model_file,out_file,feat,feat_mode = 'list',Nproc=1):
# Platform dependent parameters
if platform.system() == 'Linux':
......@@ -137,6 +139,30 @@ def classify(shp_list,code,stat_file,model_file,out_file,feat,feat_mode = 'list'
else:
sys.exit('ERROR: mode ' + feat_mode + ' not valid.')
if platform.system() == 'Linux':
cmd_list = []
for shp in shp_list:
cmd = ['otbcli_VectorClassifier', '-in', shp, '-instat', stat_file, '-model', model_file, '-out', out_file,
'-cfield', code, '-feat'] + flds
cmd_list.append(cmd)
queuedProcess(cmd_list,Nproc,shell=sh)
elif platform.system() == 'Windows':
for shp in shp_list:
import otbApplication
app = otbApplication.Registry.CreateApplication('VectorClassifier')
app.SetParameterString('in', shp)
app.SetParameterString('instat', stat_file)
app.SetParameterString('model', model_file)
app.SetParameterString('out', out_file)
app.SetParameterString('cfield', code)
app.UpdateParameters()
app.SetParameterStringList('feat', flds)
app.UpdateParameters()
app.ExecuteAndWriteOutput()
else:
sys.exit('Platform not supported!')
'''
for shp in shp_list:
#roughFix(shp, flds)
if platform.system() == 'Linux':
......@@ -156,5 +182,6 @@ def classify(shp_list,code,stat_file,model_file,out_file,feat,feat_mode = 'list'
app.ExecuteAndWriteOutput()
else:
sys.exit('Platform not supported!')
'''
return
......@@ -14,28 +14,10 @@ import csv
from computeFeatures import featureComputation, readConfigFile
from segmentationWorkflow import segmentationWorkflow, generateGTSamples, generateVALSamples
from classificationWorkflow import training, classify
from mtdUtils import checkSRS, getRasterInfo, getFieldNames, keepFields
from mtdUtils import checkSRS, getRasterInfo, getFieldNames, keepFields, queuedProcess
from validationFramework import pixelValidation,surfaceValidation,formatValidationTxt
import time
def queuedProcess(cmd_list,N_processes=4,shell=False,delay=0):
cmd_queue = cmd_list
prc_queue = []
for t in range(N_processes):
prc_queue.append(subprocess.Popen(cmd_queue.pop(0), shell=shell))
time.sleep(delay)
while len(prc_queue) > 0:
for i in range(len(prc_queue)):
if prc_queue[i].poll() is not None:
prc_queue.pop(i)
if len(cmd_queue) > 0:
prc_queue.append(subprocess.Popen(cmd_queue.pop(0), shell=shell))
time.sleep(delay)
break
def main(argv):
try:
opts, args = getopt.getopt(argv, '', ['runlevel=', 'single-step'])
......
......@@ -8,6 +8,7 @@ import string
import subprocess
import sys
import uuid
import time
import gdal
import numpy as np
......@@ -808,3 +809,21 @@ def keepFields(src_shp,out_shp,except_list):
src_ds = None
dst = None
def queuedProcess(cmd_list,N_processes=4,shell=False,delay=0):
cmd_queue = cmd_list
prc_queue = []
for t in range(N_processes):
prc_queue.append(subprocess.Popen(cmd_queue.pop(0), shell=shell))
time.sleep(delay)
while len(prc_queue) > 0:
for i in range(len(prc_queue)):
if prc_queue[i].poll() is not None:
prc_queue.pop(i)
if len(cmd_queue) > 0:
prc_queue.append(subprocess.Popen(cmd_queue.pop(0), shell=shell))
time.sleep(delay)
break
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment