From d604263cb3b726d2f41424c4ed428e056b3c7cc7 Mon Sep 17 00:00:00 2001
From: Remi Cresson <remi.cresson@irstea.fr>
Date: Thu, 2 Jun 2022 17:08:18 +0200
Subject: [PATCH] TEST: fix tests

---
 tests/decloud_unittest.py              |  5 ++++-
 tests/train_from_tfrecords_unittest.py | 22 ++++++++++++----------
 2 files changed, 16 insertions(+), 11 deletions(-)

diff --git a/tests/decloud_unittest.py b/tests/decloud_unittest.py
index 657ea08..c973340 100644
--- a/tests/decloud_unittest.py
+++ b/tests/decloud_unittest.py
@@ -15,7 +15,10 @@ class DecloudTest(ABC, unittest.TestCase):
     DECLOUD_DATA_DIR = get_env_var("DECLOUD_DATA_DIR")
 
     def get_path(self, path):
-        return os.path.join(self.DECLOUD_DATA_DIR, path)
+        pth = os.path.join(self.DECLOUD_DATA_DIR, path)
+        if not os.path.exists(pth):
+            raise FileNotFoundError(f"Directory {pth} not found!")
+        return pth
 
     def compare_images(self, image, reference, mae_threshold=0.01):
 
diff --git a/tests/train_from_tfrecords_unittest.py b/tests/train_from_tfrecords_unittest.py
index 12e2c72..c3752d7 100644
--- a/tests/train_from_tfrecords_unittest.py
+++ b/tests/train_from_tfrecords_unittest.py
@@ -8,6 +8,7 @@ from .decloud_unittest import DecloudTest
 
 SAVEDMODEL_FILENAME = "saved_model.pb"
 
+
 def is_savedmodel_written(args_list):
     out_savedmodel = "/tmp/savedmodel"
     base_args = ["--logdir", "/tmp/logdir",
@@ -25,8 +26,9 @@ def is_savedmodel_written(args_list):
 
 
 OS2_TFREC_PTH = "baseline/TFRecord/CRGA"
-OS2_ALL_BANDS_TFREC_PTH = "/baseline/TFRecord/CRGA_all_bands"
-MERANER_ALL_BANDS_TFREC_PTH = "/baseline/TFRecord/CRGA_all_bands"
+OS2_ALL_BANDS_TFREC_PTH = "baseline/TFRecord/CRGA_all_bands"
+MERANER_ALL_BANDS_TFREC_PTH = "baseline/TFRecord/CRGA_all_bands"
+ERRMSG = f"File {SAVEDMODEL_FILENAME} not found !"
 
 
 class TrainFromTFRecordsTest(DecloudTest):
@@ -34,42 +36,42 @@ class TrainFromTFRecordsTest(DecloudTest):
     def test_trainFromTFRecords_os1_unet(self):
         self.assertTrue(is_savedmodel_written(["--training_record", self.get_path(OS2_TFREC_PTH),
                                                "--model", "crga_os1_unet"]),
-                        "File {} not found !".format(SAVEDMODEL_FILENAME))
+                        ERRMSG)
 
     def test_trainFromTFRecords_os2_david(self):
         self.assertTrue(is_savedmodel_written(["--training_record", self.get_path(OS2_TFREC_PTH),
                                                "--model", "crga_os2_david"]),
-                        "File {} not found !".format(SAVEDMODEL_FILENAME))
+                        ERRMSG)
 
     def test_trainFromTFRecords_os2_unet(self):
         self.assertTrue(is_savedmodel_written(["--training_record", self.get_path(OS2_TFREC_PTH),
                                                "--model", "crga_os2_unet"]),
-                        "File {} not found !".format(SAVEDMODEL_FILENAME))
+                        ERRMSG)
 
     def test_trainFromTFRecords_os1_unet_all_bands(self):
         self.assertTrue(is_savedmodel_written(["--training_record", self.get_path(OS2_ALL_BANDS_TFREC_PTH),
                                                "--model", "crga_os1_unet_all_bands"]),
-                        "File {} not found !".format(SAVEDMODEL_FILENAME))
+                        ERRMSG)
 
     def test_trainFromTFRecords_os2_david_all_bands(self):
         self.assertTrue(is_savedmodel_written(["--training_record", self.get_path(OS2_ALL_BANDS_TFREC_PTH),
                                                "--model", "crga_os2_david_all_bands"]),
-                        "File {} not found !".format(SAVEDMODEL_FILENAME))
+                        ERRMSG)
 
     def test_trainFromTFRecords_os2_unet_all_bands(self):
         self.assertTrue(is_savedmodel_written(["--training_record", self.get_path(OS2_ALL_BANDS_TFREC_PTH),
                                                "--model", "crga_os2_unet_all_bands"]),
-                        "File {} not found !".format(SAVEDMODEL_FILENAME))
+                        ERRMSG)
 
     def test_trainFromTFRecords_meraner_unet(self):
         self.assertTrue(is_savedmodel_written(["--training_record", self.get_path(MERANER_ALL_BANDS_TFREC_PTH),
                                                "--model", "meraner_unet"]),
-                        "File {} not found !".format(SAVEDMODEL_FILENAME))
+                        ERRMSG)
 
     def test_trainFromTFRecords_meraner_unet_all_bands(self):
         self.assertTrue(is_savedmodel_written(["--training_record", self.get_path(MERANER_ALL_BANDS_TFREC_PTH),
                                                "--model", "meraner_unet_all_bands"]),
-                        "File {} not found !".format(SAVEDMODEL_FILENAME))
+                        ERRMSG)
 
 
 if __name__ == '__main__':
-- 
GitLab