Commit e56fd9f3 authored by Gaetano Raffaele's avatar Gaetano Raffaele
Browse files

ENH: check return value step-wise

parent b13ca93b
No related merge requests found
Showing with 18 additions and 10 deletions
+18 -10
......@@ -2,6 +2,7 @@ import os
import json
import pickle
import glob
import sys
from Workflows.operations import *
from Learning.ObjectBased import ObjectBasedClassifier
from Postprocessing import Report, MapFormatting
......@@ -12,6 +13,11 @@ def unroll_file_list(lst):
out_lst.extend(sorted(glob.glob(f)))
return out_lst
def check_step(step, val):
if not val:
print("[MORINGA-ERR] : ***** ERROR ON STEP {} *****".format(step))
sys.exit(step)
def process_timeseries(oroot, d, ts_lst_pkl):
ts_lst = []
for ts in d['timeseries']:
......@@ -49,7 +55,7 @@ def process_timeseries(oroot, d, ts_lst_pkl):
raise ValueError('TimeSeries type not yet supported.')
with open(ts_lst_pkl, 'wb') as ts_save:
pickle.dump(ts_lst, ts_save)
return
return os.path.exists(ts_lst_pkl)
def perform_segmentation(ofn, d):
print('[MORINGA-INFO] : Performing segmentation')
......@@ -64,7 +70,7 @@ def perform_segmentation(ofn, d):
roi=d['roi'],
n_proc=d['segmentation']['n_proc'],
light=d['segmentation']['light_mode'])
return
return os.path.exists(ofn)
def train_valid_workflow(seg, ts_lst_pkl, d, m_file):
assert (os.path.exists(seg))
......@@ -79,6 +85,7 @@ def train_valid_workflow(seg, ts_lst_pkl, d, m_file):
ref_class_field=d['ref_db']['fields'])
obc.gen_k_folds(5, class_field=d['ref_db']['fields'][-1])
ok = True
for i,cf in enumerate(d['ref_db']['fields']):
if d['training']['classifier'] == 'rf':
......@@ -88,7 +95,8 @@ def train_valid_workflow(seg, ts_lst_pkl, d, m_file):
os.makedirs(os.path.dirname(m_file[i]), exist_ok=True)
with open(m_file[i], 'wb') as mf:
pickle.dump(m_dict, mf)
return
ok = ok and os.path.exists(m_file[i])
return ok
def classify(seg, ts_lst_pkl, m_files, d, map_files):
assert (os.path.exists(seg))
......@@ -108,7 +116,7 @@ def classify(seg, ts_lst_pkl, m_files, d, map_files):
models.append(m_dict['model'])
perc = [m_dict['perc2'], m_dict['perc98']]
obc.classify(models, perc=perc, output_files=map_files)
return
return all([os.path.exists(x) for x in map_files])
def report(map_files, m_files, d, report_files):
print('[MORINGA-INFO] : Generating report(s)')
......@@ -126,7 +134,7 @@ def report(map_files, m_files, d, report_files):
os.path.splitext(report_file)[0]+'_figures',
d['chain_name'])
Report.generate_pdf(of, report_file, d['chain_name'])
return
return all([os.path.exists(x) for x in report_files])
def basic(cfg, runlevel=1, single_step=False):
os.environ['OTB_LOGGER_LEVEL'] = 'CRITICAL'
......@@ -143,7 +151,7 @@ def basic(cfg, runlevel=1, single_step=False):
ts_lst_pkl = os.path.join(oside, 'time_series_list.pkl')
if step == 1:
print("[MORINGA-INFO] : ***** BEGIN STEP {} *****".format(step))
process_timeseries(oroot, d, ts_lst_pkl)
check_step(step, process_timeseries(oroot, d, ts_lst_pkl))
step += 1
if single_step:
return
......@@ -152,7 +160,7 @@ def basic(cfg, runlevel=1, single_step=False):
seg = os.path.join(oroot, 'segmentation/{}_obj_layer.tif'.format(d['chain_name']))
if step == 2:
print("[MORINGA-INFO] : ***** BEGIN STEP {} *****".format(step))
perform_segmentation(seg, d)
check_step(step, perform_segmentation(seg, d))
step += 1
if single_step:
return
......@@ -163,7 +171,7 @@ def basic(cfg, runlevel=1, single_step=False):
m_files.append(os.path.join(oroot, 'model/model_{}.pkl'.format(cf)))
if step == 3:
print("[MORINGA-INFO] : ***** BEGIN STEP {} *****".format(step))
train_valid_workflow(seg, ts_lst_pkl, d, m_files)
check_step(step, train_valid_workflow(seg, ts_lst_pkl, d, m_files))
step += 1
if single_step:
return
......@@ -174,7 +182,7 @@ def basic(cfg, runlevel=1, single_step=False):
map_files.append(os.path.join(oroot, 'maps/{}_map_{}.tif'.format(d['chain_name'],cf)))
if step == 4:
print("[MORINGA-INFO] : ***** BEGIN STEP {} *****".format(step))
classify(seg, ts_lst_pkl, m_files, d, map_files)
check_step(step, classify(seg, ts_lst_pkl, m_files, d, map_files))
for m,p in zip(map_files, d['map_output']['palette_files']):
MapFormatting.create_qgs_style(m,p)
step += 1
......@@ -187,7 +195,7 @@ def basic(cfg, runlevel=1, single_step=False):
report_fn.append(os.path.join(oroot, 'reports/{}_report_{}.pdf'.format(d['chain_name'],cf)))
if step == 5:
print("[MORINGA-INFO] : ***** BEGIN STEP {} *****".format(step))
report(map_files, m_files, d, report_fn)
check_step(step, report(map_files, m_files, d, report_fn))
print("[MORINGA-INFO] : ***** PROCESS FINISHED *****".format(step))
return
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment