basic.py 6.3 KB
Newer Older
import os
import json
import pickle
Gaetano Raffaele's avatar
Gaetano Raffaele committed
from Workflows.operations import preprocess_s2, run_segmentation
from Learning.ObjectBased import ObjectBasedClassifier
from Postprocessing import Report, MapFormatting

def process_timeseries(oroot, d, ts_lst_pkl):
    ts_lst = []
    for ts in d['timeseries']:
        print('[MORINGA-INFO] : Preprocessing {} from {}'.format(ts['type'], ts['provider']))
        if ts['type'] == 's2':
            ots = os.path.join(oroot, 'timeseries/' + ts['type'] + ts['provider'])
            os.makedirs(ots, exist_ok=True)
            ts_lst.append(preprocess_s2(ts['path'],
                                        ots,
                                        roi=d['roi'],
                                        output_dates_file=ts['output_dates_file'],
                                        provider=ts['provider']))
        else:
            raise ValueError('TimeSeries type not yet supported.')
    with open(ts_lst_pkl, 'wb') as ts_save:
        pickle.dump(ts_lst, ts_save)
    return

def perform_segmentation(ofn, d):
    print('[MORINGA-INFO] : Performing segmentation')
    os.makedirs(os.path.dirname(ofn), exist_ok=True)
    run_segmentation(d['segmentation']['src'],
                     d['segmentation']['th'],
                     d['segmentation']['cw'],
                     d['segmentation']['sw'],
                     ofn,
                     n_first_iter=d['segmentation']['n_first_iter'],
                     margin=d['segmentation']['margin'],
                     roi=d['roi'],
                     n_proc=d['segmentation']['n_proc'],
                     light=d['segmentation']['lightmode'])
    return

def train_valid_workflow(seg, ts_lst_pkl, d, m_file):
    assert (os.path.exists(seg))
    assert (os.path.exists(ts_lst_pkl))
    print('[MORINGA-INFO] : Running Training/Validation Workflow')
    with open(ts_lst_pkl, 'rb') as ts_save:
        ts_lst = pickle.load(ts_save)
    obc = ObjectBasedClassifier(seg,
                                ts_lst,
                                d['userfeat'],
                                reference_data=d['ref_db']['path'],
                                ref_class_field=d['ref_db']['fields'])

    obc.gen_k_folds(5, class_field=d['ref_db']['fields'][-1])

    for i,cf in enumerate(d['ref_db']['fields']):
        if d['training']['classifier'] == 'rf':
            m, s, r = obc.train_RF(d['training']['parameters']['n_trees'], class_field=cf, return_true_vs_pred=True)
            m_dict = {'model': m, 'results': r, 'summary': s,
                      'perc2':obc.training_base['perc2'], 'perc98':obc.training_base['perc98']}
            os.makedirs(os.path.dirname(m_file[i]), exist_ok=True)
            with open(m_file[i], 'wb') as mf:
                pickle.dump(m_dict, mf)
    return

def classify(seg, ts_lst_pkl, m_files, d, map_files):
    assert (os.path.exists(seg))
    assert (os.path.exists(ts_lst_pkl))
    for m_file in m_files:
        assert (os.path.exists(m_file))
    print('[MORINGA-INFO] : Performing classification')
    with open(ts_lst_pkl, 'rb') as ts_save:
        ts_lst = pickle.load(ts_save)
    obc = ObjectBasedClassifier(seg,
                                ts_lst,
                                d['userfeat'])
    for m_file, map_file in zip(m_files, map_files):
        with open(m_file, 'rb') as mf:
            m_dict = pickle.load(mf)
        obc.classify(m_dict['model'], perc=[m_dict['perc2'], m_dict['perc98']], output_file=map_file)
    return

def report(map_files, m_files, d, report_files):
    print('[MORINGA-INFO] : Generating report(s)')
    for map_file, palette_fn, m_file, report_file in zip(map_files, d['map_output']['palette_files'], m_files, report_files):
        assert os.path.exists(map_file)
        assert os.path.exists(m_file)
        os.makedirs(os.path.splitext(report_file)[0]+'_figures', exist_ok=True)
        with open(m_file, 'rb') as mf:
            m_dict = pickle.load(mf)
        of = Report.generate_report_figures(
            map_file,
            palette_fn,
            m_dict['results'],
            m_dict['summary'],
            os.path.splitext(report_file)[0]+'_figures',
            d['chain_name'])
        Report.generate_pdf(of, report_file, d['chain_name'])
    return

def basic(cfg, runlevel=1, single_step=False):
    os.environ['OTB_LOGGER_LEVEL'] = 'CRITICAL'
    with open(cfg,'r') as f:
        d = json.load(f)

    oroot = os.path.join(d['output_path'], d['chain_name'])
    oside = os.path.join(oroot, '_side')
    os.makedirs(oside, exist_ok=True)

    step = runlevel

    # Preprocess timeseries
    ts_lst_pkl = os.path.join(oside, 'time_series_list.pkl')
        print("[MORINGA-INFO] : ***** BEGIN STEP {} *****".format(step))
        process_timeseries(oroot, d, ts_lst_pkl)
        step += 1
        if single_step:
            return

    # Segmentation
    seg = os.path.join(oroot, 'segmentation/{}_obj_layer.tif'.format(d['chain_name']))
    if step == 2:
        print("[MORINGA-INFO] : ***** BEGIN STEP {} *****".format(step))
        perform_segmentation(seg, d)
        step += 1
        if single_step:
            return

    # Training/Validation Workflow
    for cf in d['ref_db']['fields']:
        m_files.append(os.path.join(oroot, 'model/model_{}.pkl'.format(cf)))
        print("[MORINGA-INFO] : ***** BEGIN STEP {} *****".format(step))
        train_valid_workflow(seg, ts_lst_pkl, d, m_files)
        step += 1
        if single_step:
            return

    # Classification
    map_files = []
    for cf in d['ref_db']['fields']:
        map_files.append(os.path.join(oroot, 'maps/{}_map_{}.tif'.format(d['chain_name'],cf)))
    if step == 4:
        print("[MORINGA-INFO] : ***** BEGIN STEP {} *****".format(step))
        classify(seg, ts_lst_pkl, m_files, d, map_files)
        for m,p in zip(map_files, d['map_output']['palette_files']):
            MapFormatting.create_qgs_style(m,p)
    # Report
    report_fn = []
    for cf in d['ref_db']['fields']:
        report_fn.append(os.path.join(oroot, 'reports/{}_report_{}.pdf'.format(d['chain_name'],cf)))
    if step == 5:
        print("[MORINGA-INFO] : ***** BEGIN STEP {} *****".format(step))
        report(map_files, m_files, d, report_fn)

    print("[MORINGA-INFO] : ***** PROCESS FINISHED *****".format(step))