Commit bf38621e authored by SPeillet's avatar SPeillet
Browse files

DOC: add documentation

Showing with 64 additions and 2 deletions
+64 -2
......@@ -348,7 +348,31 @@ def retrieveClassHierarchy(shp,code_list):
return h_dict,ds_dict.values()
def Htraining(shp,code_list,model_fld,params,feat,feat_mode = 'list'):
'''
Hierarchical classification
---------------------------
Input
-----
shp: str
path of the training shapefile
code_list: list
list of the fields names for the different classification levels
model_fld: str
path to the folder where model files will be saved
params: list
list of parameter model to use with otbVectorTraining application
feat: list or str
features to use for classification
feat_mode: list/prefix
selection mode of the features
----------------------------------------------
Output
------
stat_file: str
path to a stat_file to use to unskew model
h_model_fld: str
path to the folder where model files are
'''
# Platform dependent parameters
if platform.system() == 'Linux':
sh = False
......@@ -373,6 +397,7 @@ def Htraining(shp,code_list,model_fld,params,feat,feat_mode = 'list'):
else:
sys.exit('ERROR: mode ' + feat_mode + ' not valid.')
# Statistics unskew
if platform.system() == 'Linux':
cmd = ['otbcli_ComputeVectorFeaturesStatistics','-io.vd',shp,'-io.stats',stat_file,'-feat'] + flds
subprocess.call(cmd,shell=sh)
......@@ -387,7 +412,9 @@ def Htraining(shp,code_list,model_fld,params,feat,feat_mode = 'list'):
else:
sys.exit("Platform not supported!")
# Computes classes hierarchy in a dictionnary
h_dict, ds_list = retrieveClassHierarchy(shp,code_list)
# Create models for each level
with open(h_model_fld + '/h-model.csv',mode='wb') as h_model_file:
writer = csv.writer(h_model_file)
level = 'ROOT'
......@@ -408,7 +435,31 @@ def Htraining(shp,code_list,model_fld,params,feat,feat_mode = 'list'):
return stat_file, h_model_fld
def Hclassify(shp_list,stat_file,h_model_fld,feat,out_fld,out_ext,feat_mode = 'list'):
'''
Hierarchical classification
---------------------------
Input
-----
shp_list: list
list of shapefile path to classify
stat_file: str
path to a stat_file to use to unskew model
h_model_fld: str
path to the folder where model files are
feat: list or str
features to use for classification
out_fld: str
path to the output folder
out_ext: str
output suffix
feat_mode: list/prefix
selection mode of the features
----------------------------------------------
Output
------
out_file_list : list
list of the output shapefile paths
'''
if feat_mode == 'prefix':
flds = getFeaturesFields(shp_list[0], feat)
elif feat_mode == 'list':
......@@ -438,6 +489,8 @@ def Hclassify(shp_list,stat_file,h_model_fld,feat,out_fld,out_ext,feat_mode = 'l
if os.path.exists(in_shp):
out_shp = in_shp.replace('.shp','_ROOT.shp') if row[0] == 'ROOT' else None
split_shp = out_shp if row[0] == 'ROOT' else in_shp
# Classification attempt with only one class raised an error,
# if so, apply the value class in consequence
with open(row[2]) as model:
lines = model.readlines()
to_classify = int(lines[1].split(' ')[0]) != 1
......@@ -447,12 +500,16 @@ def Hclassify(shp_list,stat_file,h_model_fld,feat,out_fld,out_ext,feat_mode = 'l
addField(in_shp,'p'+row[3],int(lines[1].split(' ')[1]))
if out_shp is not None:
toDelete.append(out_shp)
#if there is a more detailed level than the current one, split the current
#classification by classes to use as bases for next level
if row[4] == 'True':
ds_dict = splitShapefileByClasses(split_shp,'p'+row[3])
[toProcess.insert(0,x) for x in ds_dict.values()]
toDelete.extend(ds_dict.values())
#last level of hierarchy classes, to merge
elif row[4] == 'False':
toMerge.append(in_shp)
# Merge all resulting shapefiles, or rename if there is only one
if len(toMerge) > 1 :
out_file = out_fld + '/' + os.path.basename(shp).replace('.shp', out_ext + '.shp')
mergeShapefiles(toMerge,out_file)
......@@ -462,6 +519,7 @@ def Hclassify(shp_list,stat_file,h_model_fld,feat,out_fld,out_ext,feat_mode = 'l
os.rename(toMerge[0],out_file)
out_file_list.append(out_file)
drv = ogr.GetDriverByName('ESRI Shapefile')
# Drop tmp files
for fn in toDelete:
drv.DeleteDataSource(fn)
......
......@@ -824,6 +824,7 @@ def main(argv):
warnings.warn('Error: Model file ' + model_file + ' not found. Skipping.')
continue
# Validation step
if ch_mode == 0 or ch_mode == 2:
ref_shp = config.get('GENERAL CONFIGURATION', 'validation')
val_mode = int(config.get('GENERAL CONFIGURATION', 'validmode'))
......@@ -840,6 +841,7 @@ def main(argv):
classes,cm,acc,kappa,prf = surfaceValidation(ref_shp, val_out, val_out_check, cfield)
formatValidationTxt(classes, cm, acc, kappa, prf, txt_out)
# Classification and map production steps
shp_list = glob.glob(test_fld + '/segmentation_*.shp')
if ch_mode > 0 or (ch_mode < 0 and len(shp_list) > 0):
if not os.path.exists(final_fld + '/MAPS'):
......@@ -880,6 +882,7 @@ def main(argv):
if len(glob.glob(os.path.join(h_model_fld,'*.model'))) == 0 :
sys.exit('Error: There is no model file in ' + h_model_fld + ' folder. Skipping.')
# Validation step
if ch_mode == 0 or ch_mode == 2:
ref_shp = config.get('GENERAL CONFIGURATION', 'validation')
val_mode = int(config.get('GENERAL CONFIGURATION', 'validmode'))
......@@ -897,6 +900,7 @@ def main(argv):
classes,cm,acc,kappa,prf = surfaceValidation(ref_shp, val_out, val_out_check, cfield)
formatValidationTxt(classes, cm, acc, kappa, prf, txt_out)
# Classification and map production steps
shp_list = glob.glob(test_fld + '/segmentation_*.shp')
if ch_mode > 0 or (ch_mode < 0 and len(shp_list) > 0):
if not os.path.exists(final_fld + '/MAPS'):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment