From c3272e93a6128333b9d4f9a4dae8dc12a92f8f0d Mon Sep 17 00:00:00 2001 From: Raffaele Gaetano <raffaele.gaetano@cirad.fr> Date: Thu, 14 Sep 2023 12:15:44 +0200 Subject: [PATCH] ENH: option to save training base in pkl file. --- Learning/ObjectBased.py | 6 ++++++ Workflows/basic.py | 5 +++++ Workflows/basic_config.json | 1 + 3 files changed, 12 insertions(+) diff --git a/Learning/ObjectBased.py b/Learning/ObjectBased.py index ee96958..72dd821 100644 --- a/Learning/ObjectBased.py +++ b/Learning/ObjectBased.py @@ -1,4 +1,5 @@ import glob +import pickle import numpy as np import pandas as pd @@ -119,6 +120,11 @@ class ObjectBasedClassifier: self.obia_base.populate_map(t, L, c, output_file, compress) prg.update(1) return + + def save_training_base(self, fn): + with open(fn, 'wb') as f: + pickle.dump(self.training_base, f) + return #TEST CODE def run_test(sample_folder): diff --git a/Workflows/basic.py b/Workflows/basic.py index 9e0205c..5397552 100644 --- a/Workflows/basic.py +++ b/Workflows/basic.py @@ -85,6 +85,11 @@ def train_valid_workflow(seg, ts_lst_pkl, d, m_file): ref_class_field=d['ref_db']['fields']) obc.gen_k_folds(5, class_field=d['ref_db']['fields'][-1]) + + if 'export_training_base' in d['training'].keys() and d['training']['export_training_base'] is True: + obc.save_training_base('{}/_side/training_base.pkl'.format(os.path.join(d['output_path'], d['chain_name']))) + print('[MORINGA-INFO] : Training base export completed.') + ok = True for i,cf in enumerate(d['ref_db']['fields']): diff --git a/Workflows/basic_config.json b/Workflows/basic_config.json index e3a2d79..147d8a6 100644 --- a/Workflows/basic_config.json +++ b/Workflows/basic_config.json @@ -43,6 +43,7 @@ ], "training": { + "export_training_base": false, "classifier": "rf", "parameters": { "n_trees": 400 -- GitLab