From c3272e93a6128333b9d4f9a4dae8dc12a92f8f0d Mon Sep 17 00:00:00 2001
From: Raffaele Gaetano <raffaele.gaetano@cirad.fr>
Date: Thu, 14 Sep 2023 12:15:44 +0200
Subject: [PATCH] ENH: option to save training base in pkl file.

---
 Learning/ObjectBased.py     | 6 ++++++
 Workflows/basic.py          | 5 +++++
 Workflows/basic_config.json | 1 +
 3 files changed, 12 insertions(+)

diff --git a/Learning/ObjectBased.py b/Learning/ObjectBased.py
index ee96958..72dd821 100644
--- a/Learning/ObjectBased.py
+++ b/Learning/ObjectBased.py
@@ -1,4 +1,5 @@
 import glob
+import pickle
 
 import numpy as np
 import pandas as pd
@@ -119,6 +120,11 @@ class ObjectBasedClassifier:
                     self.obia_base.populate_map(t, L, c, output_file, compress)
                 prg.update(1)
         return
+    
+    def save_training_base(self, fn):
+        with open(fn, 'wb') as f:
+            pickle.dump(self.training_base, f)
+        return
 
 #TEST CODE
 def run_test(sample_folder):
diff --git a/Workflows/basic.py b/Workflows/basic.py
index 9e0205c..5397552 100644
--- a/Workflows/basic.py
+++ b/Workflows/basic.py
@@ -85,6 +85,11 @@ def train_valid_workflow(seg, ts_lst_pkl, d, m_file):
                                 ref_class_field=d['ref_db']['fields'])
 
     obc.gen_k_folds(5, class_field=d['ref_db']['fields'][-1])
+
+    if 'export_training_base' in d['training'].keys() and d['training']['export_training_base'] is True:
+        obc.save_training_base('{}/_side/training_base.pkl'.format(os.path.join(d['output_path'], d['chain_name'])))
+        print('[MORINGA-INFO] : Training base export completed.')
+
     ok = True
 
     for i,cf in enumerate(d['ref_db']['fields']):
diff --git a/Workflows/basic_config.json b/Workflows/basic_config.json
index e3a2d79..147d8a6 100644
--- a/Workflows/basic_config.json
+++ b/Workflows/basic_config.json
@@ -43,6 +43,7 @@
 	],
 	
 	"training": {
+		"export_training_base": false,
 		"classifier": "rf",
 		"parameters": {
 			"n_trees": 400
-- 
GitLab