diff --git a/classificationWorkflow.py b/classificationWorkflow.py index f7983a61905e90a232f899fd02f409da1f31fd24..461558cdb87fe15ea6cfdda86836ff6ae45579d7 100644 --- a/classificationWorkflow.py +++ b/classificationWorkflow.py @@ -66,7 +66,6 @@ def baseTrainingCmd(shp,stat_file,code,flds,params,model_file,confmat_file): sh = True else: sys.exit("Platform not supported!") - if type(shp) == str: shp = [shp] @@ -110,6 +109,7 @@ def baseClassifyCmd(shp,stat_file,model_file,code,flds,out_file,compute_confiden cmd += ['-out', out_file] if compute_confidence: cmd += ['-confmap', 1] + subprocess.call(cmd,sh) elif platform.system() == 'Windows': import otbApplication app = otbApplication.Registry.CreateApplication('VectorClassifier') @@ -270,6 +270,40 @@ def classify(shp_list,code,stat_file,model_file,out_fld,out_ext,feat,feat_mode = else: sys.exit('Platform not supported!') +def addField(filein, nameField, valueField, valueType=None, + driver_name="ESRI Shapefile", fWidth=None): + + + driver = ogr.GetDriverByName(driver_name) + source = driver.Open(filein, 1) + layer = source.GetLayer() + layer_name = layer.GetName() + layer_defn = layer.GetLayerDefn() + field_names = [layer_defn.GetFieldDefn(i).GetName() for i in range(layer_defn.GetFieldCount())] + if not valueType: + try : + int(valueField) + new_field1 = ogr.FieldDefn(nameField, ogr.OFTInteger) + except : + new_field1 = ogr.FieldDefn(nameField, ogr.OFTString) + elif valueType == str: + new_field1 = ogr.FieldDefn(nameField, ogr.OFTString) + sqlite_type = 'varchar' + elif valueType == int: + new_field1 = ogr.FieldDefn(nameField, ogr.OFTInteger) + sqlite_type = 'int' + elif valueType == float: + new_field1 = ogr.FieldDefn(nameField, ogr.OFTFLOAT) + sqlite_type = 'float' + if fWidth: + new_field1.SetWidth(fWidth) + + layer.CreateField(new_field1) + for feat in layer: + layer.SetFeature(feat) + feat.SetField(nameField, valueField) + layer.SetFeature(feat) + def splitShapefileByClasses(shp,code): ds = ogr.Open(shp) ly = ds.GetLayer(0) @@ -385,6 +419,7 @@ def Hclassify(shp_list,stat_file,h_model_fld,feat,out_fld,out_ext,feat_mode = 'l if not os.path.exists(h_model_fld): sys.exit('Folder ' + h_model_fld + ' not exists!') + out_file_list = [] for shp in shp_list: with open(h_model_fld + '/h-model.csv', mode='rb') as h_model_file: toProcess = [] @@ -394,14 +429,22 @@ def Hclassify(shp_list,stat_file,h_model_fld,feat,out_fld,out_ext,feat_mode = 'l for row in rdr: if row[0] == 'ROOT': in_shp = shp - elif len(toProcess) > 0 and toProcess[-1].endswith('_p' + row[0] + '_' + row[1] + '.shp'): - in_shp = toProcess.pop() + elif len(toProcess) > 0: + for i,process in enumerate(toProcess): + if process.endswith('_p' + row[0] + '_' + row[1] + '.shp') : + in_shp = toProcess.pop(i) else: continue if os.path.exists(in_shp): out_shp = in_shp.replace('.shp','_ROOT.shp') if row[0] == 'ROOT' else None split_shp = out_shp if row[0] == 'ROOT' else in_shp - baseClassifyCmd(in_shp,stat_file,row[2],'p'+row[3],flds,out_shp) + with open(row[2]) as model: + lines = model.readlines() + to_classify = int(lines[1].split(' ')[0]) != 1 + if to_classify : + baseClassifyCmd(in_shp,stat_file,row[2],'p'+row[3],flds,out_shp) + else : + addField(in_shp,'p'+row[3],int(lines[1].split(' ')[1])) if out_shp is not None: toDelete.append(out_shp) if row[4] == 'True': @@ -410,11 +453,16 @@ def Hclassify(shp_list,stat_file,h_model_fld,feat,out_fld,out_ext,feat_mode = 'l toDelete.extend(ds_dict.values()) elif row[4] == 'False': toMerge.append(in_shp) - - out_file = out_fld + '/' + os.path.basename(shp).replace('.shp', out_ext + '.shp') - mergeShapefiles(toMerge,out_file) + if len(toMerge) > 1 : + out_file = out_fld + '/' + os.path.basename(shp).replace('.shp', out_ext + '.shp') + mergeShapefiles(toMerge,out_file) + out_file_list.append(out_file) + elif len(toMerge) == 1 : + out_file = out_fld + '/' + os.path.basename(shp).replace('.shp', out_ext + '.shp') + os.rename(toMerge[0],out_file) + out_file_list.append(out_file) drv = ogr.GetDriverByName('ESRI Shapefile') for fn in toDelete: - drv.DeleteDataSource(fn) + drv.DeleteDataSource(fn) - return \ No newline at end of file + return out_file_list \ No newline at end of file diff --git a/launchChain.py b/launchChain.py index 3c1add5f1ddc9f6fa59228378924ff421321bf56..bf6da9cafa5a88a5b69d97f827c9bb616606ef80 100644 --- a/launchChain.py +++ b/launchChain.py @@ -670,7 +670,7 @@ def main(argv): rfimpnodesize = int(config.get('TRAINING CONFIGURATION', 'rfimpnodesize')) rfimpmaxfeat = int(config.get('TRAINING CONFIGURATION', 'rfimpmaxfeat')) rfimpnruns = int(config.get('TRAINING CONFIGURATION', 'rfimpnruns')) - + hierarchical_classif = config.get('TRAINING CONFIGURATION','hierarchicalclassif') model_fld = output_fld + '/MODEL_' + setup_name if not os.path.exists(model_fld): @@ -684,16 +684,65 @@ def main(argv): var_list.append(row[0]) if input_runlevel < 6: + if hierarchical_classif is False: + if ch_mode < 0: + + nFolds = -ch_mode + dgt = int(floor(log10(nFolds)) + 1) + + train_folds, test_folds = genKFolds(reference, cfieldlist[-1], nFolds) + kfold_train_samples, kfold_test_samples = kFoldRefToSamples(samples_fld + '/GT_samples.shp', val_fld + '/VAL_samples.shp', train_folds, test_folds) + + for cfield in cfieldlist: + + accuracies = [] + kappas = [] + fscores = {} + + for i in range(nFolds): - if ch_mode < 0: + mfld = model_fld + '/fold_' + str(i+1).zfill(dgt) - nFolds = -ch_mode - dgt = int(floor(log10(nFolds)) + 1) + if not os.path.exists(mfld): + os.mkdir(mfld) - train_folds, test_folds = genKFolds(reference, cfieldlist[-1], nFolds) - kfold_train_samples, kfold_test_samples = kFoldRefToSamples(samples_fld + '/GT_samples.shp', val_fld + '/VAL_samples.shp', train_folds, test_folds) + sfn, mfn = training(kfold_train_samples[i], cfield, mfld, params, var_list) + val_list = [kfold_test_samples[i]] + val_out = mfld + '/' + os.path.basename(kfold_test_samples[i]).replace('.shp','_' + cfield + '.shp') + val_out_check = mfld + '/' + os.path.basename(kfold_test_samples[i]).replace('.shp','_' + cfield + '_check.shp') + val_out_tmp_list = classify(val_list, 'p' + cfield, sfn, mfn, mfld, + '_' + cfield + '_tmp', var_list) + keepFields(val_out_tmp_list[0], val_out, ['Segment_ID', cfield, 'p' + cfield, 'confidence']) + shpd.DeleteDataSource(val_out_tmp_list[0]) + txt_out = mfld + '/' + os.path.basename(kfold_test_samples[i]).replace('.shp','_' + cfield + '_report.txt') + classes, cm, acc, kappa, prf = surfaceValidation(test_folds[i], val_out, val_out_check, cfield) + formatValidationTxt(classes, cm, acc, kappa, prf, txt_out) + + accuracies.append(acc) + kappas.append(kappa) + for c,fs in zip(classes,prf[2]): + if c not in fscores.keys(): + fscores[c] = [] + fscores[c].append(fs) + + kFoldReport(fscores,accuracies,kappas,model_fld + '/kFold_report_' + cfield + '.txt') for cfield in cfieldlist: + training(samples_fld + '/GT_samples.shp',cfield,model_fld,params,var_list) + + if rfimp == True: + for cfield in cfieldlist: + getVariableImportance(samples_fld + '/GT_samples.shp',var_list,cfield,samples_fld + '/var_importance_' + cfield + '.csv', nbtrees=rfimpntrees, nodesize=rfimpnodesize, mtry=rfimpmaxfeat, nruns=rfimpnruns, field_names=samples_fld + '/field_names.csv') + else : + from classificationWorkflow import Htraining, Hclassify + if ch_mode < 0: + + nFolds = -ch_mode + dgt = int(floor(log10(nFolds)) + 1) + train_folds, test_folds = genKFolds(reference, cfieldlist[-1], nFolds) + kfold_train_samples, kfold_test_samples = kFoldRefToSamples(samples_fld + '/GT_samples.shp', + val_fld + '/VAL_samples.shp', train_folds, + test_folds) accuracies = [] kappas = [] @@ -701,38 +750,47 @@ def main(argv): for i in range(nFolds): - mfld = model_fld + '/fold_' + str(i+1).zfill(dgt) + mfld = model_fld + '/fold_' + str(i + 1).zfill(dgt) if not os.path.exists(mfld): os.mkdir(mfld) - - sfn, mfn = training(kfold_train_samples[i], cfield, mfld, params, var_list) + sfn, mfn = Htraining(kfold_train_samples[i], cfieldlist, mfld, params, var_list) val_list = [kfold_test_samples[i]] - val_out = mfld + '/' + os.path.basename(kfold_test_samples[i]).replace('.shp','_' + cfield + '.shp') - val_out_check = mfld + '/' + os.path.basename(kfold_test_samples[i]).replace('.shp','_' + cfield + '_check.shp') - val_out_tmp_list = classify(val_list, 'p' + cfield, sfn, mfn, mfld, - '_' + cfield + '_tmp', var_list) - keepFields(val_out_tmp_list[0], val_out, ['Segment_ID', cfield, 'p' + cfield, 'confidence']) - shpd.DeleteDataSource(val_out_tmp_list[0]) - txt_out = mfld + '/' + os.path.basename(kfold_test_samples[i]).replace('.shp','_' + cfield + '_report.txt') - classes, cm, acc, kappa, prf = surfaceValidation(test_folds[i], val_out, val_out_check, cfield) - formatValidationTxt(classes, cm, acc, kappa, prf, txt_out) - - accuracies.append(acc) - kappas.append(kappa) - for c,fs in zip(classes,prf[2]): - if c not in fscores.keys(): - fscores[c] = [] - fscores[c].append(fs) - - kFoldReport(fscores,accuracies,kappas,model_fld + '/kFold_report_' + cfield + '.txt') - - for cfield in cfieldlist: - training(samples_fld + '/GT_samples.shp',cfield,model_fld,params,var_list) - - if rfimp == True: - for cfield in cfieldlist: - getVariableImportance(samples_fld + '/GT_samples.shp',var_list,cfield,samples_fld + '/var_importance_' + cfield + '.csv', nbtrees=rfimpntrees, nodesize=rfimpnodesize, mtry=rfimpmaxfeat, nruns=rfimpnruns, field_names=samples_fld + '/field_names.csv') + val_out = mfld + '/' + os.path.basename('_'.join(cfieldlist)) + val_out_check = mfld + '/' + os.path.basename(kfold_test_samples[i]).replace('.shp', + '_' + '_'.join(cfieldlist) + '_check.shp') + val_out_tmp_list = Hclassify(val_list, sfn, mfn, var_list, + mfld,'_h_classif') + for cfield in cfieldlist: + val_out = mfld + '/' + os.path.basename(kfold_test_samples[i]).replace('.shp', + '_' + cfield +'.shp') + val_out_check = mfld + '/' + os.path.basename(kfold_test_samples[i]).replace('.shp', + '_' + cfield +'_check.shp') + + keepFields(val_out_tmp_list[0], val_out, ['Segment_ID', cfield, 'p' + cfield, 'confidence']) + txt_out = mfld + '/' + os.path.basename(kfold_test_samples[i]).replace('.shp', + '_' + cfield + '_report.txt') + + classes, cm, acc, kappa, prf = surfaceValidation(test_folds[i], val_out, val_out_check, cfield) + formatValidationTxt(classes, cm, acc, kappa, prf, txt_out) + + accuracies.append(acc) + kappas.append(kappa) + for c, fs in zip(classes, prf[2]): + if c not in fscores.keys(): + fscores[c] = [] + fscores[c].append(fs) + + kFoldReport(fscores, accuracies, kappas, model_fld + '/kFold_report_' + cfield + '.txt') + + Htraining(samples_fld + '/GT_samples.shp', cfieldlist, model_fld, params, var_list) + + if rfimp == True: + for cfield in cfieldlist: + getVariableImportance(samples_fld + '/GT_samples.shp', var_list, cfield, + samples_fld + '/var_importance_' + cfield + '.csv', nbtrees=rfimpntrees, + nodesize=rfimpnodesize, mtry=rfimpmaxfeat, nruns=rfimpnruns, + field_names=samples_fld + '/field_names.csv') if single_step: sys.exit("Single step mode. Exiting.") @@ -759,63 +817,120 @@ def main(argv): shpd = ogr.GetDriverByName('ESRI Shapefile') if input_runlevel < 7: - - for cfield in cfieldlist: - model_file = model_fld + '/' + classifier + '_' + cfield + '.model' - if not os.path.exists(model_file): - warnings.warn('Error: Model file ' + model_file + ' not found. Skipping.') - continue + if hierarchical_classif is False : + for cfield in cfieldlist: + model_file = model_fld + '/' + classifier + '_' + cfield + '.model' + if not os.path.exists(model_file): + warnings.warn('Error: Model file ' + model_file + ' not found. Skipping.') + continue + + if ch_mode == 0 or ch_mode == 2: + ref_shp = config.get('GENERAL CONFIGURATION', 'validation') + val_mode = int(config.get('GENERAL CONFIGURATION', 'validmode')) + val_list = [val_fld + '/VAL_samples.shp'] + val_out = final_fld + '/VAL_samples_' + cfield + '.shp' + val_out_check = final_fld + '/VAL_samples_' + cfield + '_check.shp' + val_out_tmp_list = classify(val_list, 'p'+cfield, stat_file, model_file, final_fld, '_' + cfield + '_tmp', var_list, compute_confidence=comp_conf) + keepFields(val_out_tmp_list[0],val_out,['Segment_ID',cfield,'p'+cfield,'confidence']) + shpd.DeleteDataSource(val_out_tmp_list[0]) + txt_out = final_fld + '/VAL_samples.' + cfield + '.report.txt' + if val_mode == 0: + pixelValidation(ref_shp, val_out, seg_output, txt_out, cfield) + elif val_mode == 1: + classes,cm,acc,kappa,prf = surfaceValidation(ref_shp, val_out, val_out_check, cfield) + formatValidationTxt(classes, cm, acc, kappa, prf, txt_out) + + shp_list = glob.glob(test_fld + '/segmentation_*.shp') + if ch_mode > 0 or (ch_mode < 0 and len(shp_list) > 0): + if not os.path.exists(final_fld + '/MAPS'): + os.mkdir(final_fld + '/MAPS') + if not os.path.exists(final_fld + '/MAPS/VECTOR_' + cfield): + os.mkdir(final_fld + '/MAPS/VECTOR_' + cfield) + map_list = [] + ref_list = [] + + for cshp in shp_list: + ref_list.append(cshp.replace('.shp', '.tif')) + + map_tmp_list = classify(shp_list,'p'+cfield,stat_file,model_file,final_fld + '/MAPS/VECTOR_' + cfield + '/','_' + cfield + '_tmp',var_list,Nproc=N_proc,compute_confidence=comp_conf) + + for cshp in map_tmp_list: + map_out = cshp.replace('_tmp.shp','.shp') + keepFields(cshp,map_out,['Segment_ID','p'+cfield,'confidence']) + shpd.DeleteDataSource(cshp) + map_list.append(map_out) + + if rasout == 'VRT': + if not os.path.exists(final_fld + '/MAPS/RASTER_' + cfield): + os.mkdir(final_fld + '/MAPS/RASTER_' + cfield) + ras_list = [] + cmd_list = [] + for map,ref in zip(map_list,ref_list): + ras_list.append(final_fld + '/MAPS/RASTER_' + cfield + '/' + os.path.basename(map).replace('.shp', '.tif')) + cmd = ['otbcli_Rasterization', '-in', map, '-im', ref, '-mode', 'attribute', '-mode.attribute.field', 'p'+cfield, '-out', ras_list[-1]] + cmd_list.append(cmd) + queuedProcess(cmd_list,N_processes=N_proc,shell=sh) + + cmd = ['gdalbuildvrt', '-srcnodata', '0', '-vrtnodata', '0', final_fld + '/MAPS/RASTER_' + cfield + '/Classif_' + cfield + '.vrt'] + ras_list + subprocess.call(cmd, shell=sh) + + else : + from classificationWorkflow import Hclassify + h_model_fld = os.path.join(model_fld,'H-MODEL_' + '_'.join(cfieldlist)) + if len(glob.glob(os.path.join(h_model_fld,'*.model'))) == 0 : + sys.exit('Error: There is no model file in ' + h_model_fld + ' folder. Skipping.') if ch_mode == 0 or ch_mode == 2: ref_shp = config.get('GENERAL CONFIGURATION', 'validation') val_mode = int(config.get('GENERAL CONFIGURATION', 'validmode')) val_list = [val_fld + '/VAL_samples.shp'] - val_out = final_fld + '/VAL_samples_' + cfield + '.shp' - val_out_check = final_fld + '/VAL_samples_' + cfield + '_check.shp' - val_out_tmp_list = classify(val_list, 'p'+cfield, stat_file, model_file, final_fld, '_' + cfield + '_tmp', var_list, compute_confidence=comp_conf) - keepFields(val_out_tmp_list[0],val_out,['Segment_ID',cfield,'p'+cfield,'confidence']) + val_out = final_fld + '/VAL_samples_' + '_'.join(cfieldlist) + '.shp' + val_out_check = final_fld + '/VAL_samples_' + '_'.join(cfieldlist) + '_check.shp' + val_out_tmp_list = Hclassify(val_list,stat_file,h_model_fld,var_list,final_fld,'_h_classif') + keepFields(val_out_tmp_list[0],val_out,['Segment_ID']+cfieldlist+['p'+c for c in cfieldlist]+['confidence']) shpd.DeleteDataSource(val_out_tmp_list[0]) - txt_out = final_fld + '/VAL_samples.' + cfield + '.report.txt' - if val_mode == 0: - pixelValidation(ref_shp, val_out, seg_output, txt_out, cfield) - elif val_mode == 1: - classes,cm,acc,kappa,prf = surfaceValidation(ref_shp, val_out, val_out_check, cfield) - formatValidationTxt(classes, cm, acc, kappa, prf, txt_out) + for cfield in cfieldlist: + txt_out = final_fld + '/VAL_samples.' + cfield + '.report.txt' + if val_mode == 0: + pixelValidation(ref_shp, val_out, seg_output, txt_out, cfield) + elif val_mode == 1: + classes,cm,acc,kappa,prf = surfaceValidation(ref_shp, val_out, val_out_check, cfield) + formatValidationTxt(classes, cm, acc, kappa, prf, txt_out) shp_list = glob.glob(test_fld + '/segmentation_*.shp') if ch_mode > 0 or (ch_mode < 0 and len(shp_list) > 0): if not os.path.exists(final_fld + '/MAPS'): os.mkdir(final_fld + '/MAPS') - if not os.path.exists(final_fld + '/MAPS/VECTOR_' + cfield): - os.mkdir(final_fld + '/MAPS/VECTOR_' + cfield) + if not os.path.exists(final_fld + '/MAPS/VECTOR'): + os.mkdir(final_fld + '/MAPS/VECTOR') map_list = [] ref_list = [] - for cshp in shp_list: ref_list.append(cshp.replace('.shp', '.tif')) - map_tmp_list = classify(shp_list,'p'+cfield,stat_file,model_file,final_fld + '/MAPS/VECTOR_' + cfield + '/','_' + cfield + '_tmp',var_list,Nproc=N_proc,compute_confidence=comp_conf) - + map_tmp_list = Hclassify(shp_list,stat_file,h_model_fld,var_list,os.path.join(final_fld, 'MAPS', 'VECTOR'),'_hclassif_tmp') + class_field = [] + for cfield in cfieldlist: + class_field.append('p'+cfield) for cshp in map_tmp_list: map_out = cshp.replace('_tmp.shp','.shp') - keepFields(cshp,map_out,['Segment_ID','p'+cfield,'confidence']) + keepFields(cshp,map_out,['Segment_ID']+class_field+['confidence']) shpd.DeleteDataSource(cshp) map_list.append(map_out) if rasout == 'VRT': - if not os.path.exists(final_fld + '/MAPS/RASTER_' + cfield): - os.mkdir(final_fld + '/MAPS/RASTER_' + cfield) - ras_list = [] - cmd_list = [] - for map,ref in zip(map_list,ref_list): - ras_list.append(final_fld + '/MAPS/RASTER_' + cfield + '/' + os.path.basename(map).replace('.shp', '.tif')) - cmd = ['otbcli_Rasterization', '-in', map, '-im', ref, '-mode', 'attribute', '-mode.attribute.field', 'p'+cfield, '-out', ras_list[-1]] - cmd_list.append(cmd) - queuedProcess(cmd_list,N_processes=N_proc,shell=sh) - - cmd = ['gdalbuildvrt', '-srcnodata', '0', '-vrtnodata', '0', final_fld + '/MAPS/RASTER_' + cfield + '/Classif_' + cfield + '.vrt'] + ras_list - subprocess.call(cmd, shell=sh) - + for cfield in cfieldlist : + if not os.path.exists(final_fld + '/MAPS/RASTER_' + cfield): + os.mkdir(final_fld + '/MAPS/RASTER_' + cfield) + ras_list = [] + cmd_list = [] + for map,ref in zip(map_list,ref_list): + ras_list.append(final_fld + '/MAPS/RASTER_' + cfield + '/' + os.path.basename(map).replace('.shp', '.tif')) + cmd = ['otbcli_Rasterization', '-in', map, '-im', ref, '-mode', 'attribute', '-mode.attribute.field', 'p'+cfield, '-out', ras_list[-1]] + cmd_list.append(cmd) + queuedProcess(cmd_list,N_processes=N_proc,shell=sh) + cmd = ['gdalbuildvrt', '-srcnodata', '0', '-vrtnodata', '0', final_fld + '/MAPS/RASTER_' + cfield + '/Classif_' + cfield + '.vrt'] + ras_list + subprocess.call(cmd, shell=sh) if single_step: sys.exit("Single step mode. Exiting.") diff --git a/validationFramework.py b/validationFramework.py index 55fbf56956af810f39ba600db3993f94960057ed..890de2a5c54a3c53b63dadb93e15a815f421e1f5 100644 --- a/validationFramework.py +++ b/validationFramework.py @@ -100,7 +100,7 @@ def surfaceValidation(ref_shp,val_shp,out,cfield,pfield=None): acc = accuracy_score(y_true, y_pred, sample_weight=y_wght) kappa = cohen_kappa_score(y_true, y_pred, sample_weight=y_wght) prf = precision_recall_fscore_support(y_true, y_pred, sample_weight=y_wght) - classes = sorted(np.unique(y_true)) + classes = sorted(np.unique(y_true+y_pred)) return classes,cm,acc,kappa,prf