diff --git a/Learning/ObjectBased.py b/Learning/ObjectBased.py index ee96958c9244630fd2309536bb1ed841769d810b..72dd82141c94b00c34782ea1d837094552808b6a 100644 --- a/Learning/ObjectBased.py +++ b/Learning/ObjectBased.py @@ -1,4 +1,5 @@ import glob +import pickle import numpy as np import pandas as pd @@ -119,6 +120,11 @@ class ObjectBasedClassifier: self.obia_base.populate_map(t, L, c, output_file, compress) prg.update(1) return + + def save_training_base(self, fn): + with open(fn, 'wb') as f: + pickle.dump(self.training_base, f) + return #TEST CODE def run_test(sample_folder): diff --git a/Workflows/basic.py b/Workflows/basic.py index 9e0205c134e3cb249e7b60ef1107ec82f68506a3..53975525b9e8d06d593baeea7e9ea23a1d17f6ac 100644 --- a/Workflows/basic.py +++ b/Workflows/basic.py @@ -85,6 +85,11 @@ def train_valid_workflow(seg, ts_lst_pkl, d, m_file): ref_class_field=d['ref_db']['fields']) obc.gen_k_folds(5, class_field=d['ref_db']['fields'][-1]) + + if 'export_training_base' in d['training'].keys() and d['training']['export_training_base'] is True: + obc.save_training_base('{}/_side/training_base.pkl'.format(os.path.join(d['output_path'], d['chain_name']))) + print('[MORINGA-INFO] : Training base export completed.') + ok = True for i,cf in enumerate(d['ref_db']['fields']): diff --git a/Workflows/basic_config.json b/Workflows/basic_config.json index e3a2d792effd9757044f574fd02c88b577969a5b..147d8a6c14fd5f69db3bef4b8dca7badf9635779 100644 --- a/Workflows/basic_config.json +++ b/Workflows/basic_config.json @@ -43,6 +43,7 @@ ], "training": { + "export_training_base": false, "classifier": "rf", "parameters": { "n_trees": 400