From 07c5d714013d4083b541bf2a62ac1bd58cedcb23 Mon Sep 17 00:00:00 2001
From: Remi Cresson <remi.cresson@inrae.fr>
Date: Tue, 16 May 2023 10:06:21 +0200
Subject: [PATCH] TEST: compare no-data inference with baseline

---
 test/nodata_test.py | 46 ++++++++++++++++++++++++++++++++++++---------
 1 file changed, 37 insertions(+), 9 deletions(-)

diff --git a/test/nodata_test.py b/test/nodata_test.py
index 93e165b..c389215 100644
--- a/test/nodata_test.py
+++ b/test/nodata_test.py
@@ -1,15 +1,30 @@
 #!/usr/bin/env python3
 # -*- coding: utf-8 -*-
+import otbApplication
 import pytest
+import tensorflow as tf
 import unittest
+
 import otbtf
-import otbApplication
-import tensorflow as tf
-from test_utils import resolve_paths
+from test_utils import resolve_paths, compare
+
 
 class NodataInferenceTest(unittest.TestCase):
 
     def test_infersimple(self):
+        """
+        In this test, we create a synthetic image:
+            f(x, y) = x * y if x > y else 0
+
+        Then we use an input no-data value (`source1.nodata 0`) and a
+        background value for the output (`output.bv 1024`).
+
+        We use the l2_norm SavedModel, forcing otbtf to use a tiling scheme
+        of 4x4. If the test succeeds, the output pixels in 4x4 areas where
+        there is at least one no-data pixel (i.e. 0), should be filled with
+        the `bv` value (i.e. 1024).
+
+        """
         sm_dir = resolve_paths("$TMPDIR/l2_norm_savedmodel")
 
         # Create model
@@ -18,11 +33,11 @@ class NodataInferenceTest(unittest.TestCase):
         model = tf.keras.Model(inputs={"x": x}, outputs={"y": y})
         model.save(sm_dir)
 
-        # OTB pipeline
+        # Input image: f(x, y) = x * y if x > y else 0
         bmx = otbApplication.Registry.CreateApplication("BandMathX")
         bmx.SetParameterString("exp", "{idxX>idxY?idxX*idxY:0}")
         bmx.SetParameterStringList(
-            "il", [resolve_paths("$DATADIR/fake_spot6.jp2")]
+            "il", [resolve_paths("$DATADIR/xs_subset.tif")]
         )
         bmx.Execute()
 
@@ -36,13 +51,26 @@ class NodataInferenceTest(unittest.TestCase):
         )
         infer.SetParameterFloat("source1.nodata", 0.0)
         for param in [
-            "source1.rfieldx", "source1.rfieldy", "output.efieldx", "output.efieldy"
+            "source1.rfieldx",
+            "source1.rfieldy",
+            "output.efieldx",
+            "output.efieldy",
+            "optim.tilesizex",
+            "optim.tilesizey",
         ]:
-            infer.SetParameterInt(param, 16)
+            infer.SetParameterInt(param, 4)
+
+        infer.SetParameterFloat("output.bv", 1024)
         infer.SetParameterString("out", resolve_paths("$TMPDIR/nd_out.tif"))
         infer.ExecuteAndWriteOutput()
 
+        self.assertTrue(
+            compare(
+                raster1=resolve_paths("$TMPDIR/nd_out.tif"),
+                raster2=resolve_paths("$DATADIR/nd_out.tif"),
+            )
+        )
+
 
 if __name__ == '__main__':
-    NodataInferenceTest().test_infersimple()
-    #unittest.main()
+    unittest.main()
-- 
GitLab