Commit 07c5d714 authored by Remi Cresson's avatar Remi Cresson
Browse files

TEST: compare no-data inference with baseline

Showing with 37 additions and 9 deletions
+37 -9
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import otbApplication
import pytest import pytest
import tensorflow as tf
import unittest import unittest
import otbtf import otbtf
import otbApplication from test_utils import resolve_paths, compare
import tensorflow as tf
from test_utils import resolve_paths
class NodataInferenceTest(unittest.TestCase): class NodataInferenceTest(unittest.TestCase):
def test_infersimple(self): def test_infersimple(self):
"""
In this test, we create a synthetic image:
f(x, y) = x * y if x > y else 0
Then we use an input no-data value (`source1.nodata 0`) and a
background value for the output (`output.bv 1024`).
We use the l2_norm SavedModel, forcing otbtf to use a tiling scheme
of 4x4. If the test succeeds, the output pixels in 4x4 areas where
there is at least one no-data pixel (i.e. 0), should be filled with
the `bv` value (i.e. 1024).
"""
sm_dir = resolve_paths("$TMPDIR/l2_norm_savedmodel") sm_dir = resolve_paths("$TMPDIR/l2_norm_savedmodel")
# Create model # Create model
...@@ -18,11 +33,11 @@ class NodataInferenceTest(unittest.TestCase): ...@@ -18,11 +33,11 @@ class NodataInferenceTest(unittest.TestCase):
model = tf.keras.Model(inputs={"x": x}, outputs={"y": y}) model = tf.keras.Model(inputs={"x": x}, outputs={"y": y})
model.save(sm_dir) model.save(sm_dir)
# OTB pipeline # Input image: f(x, y) = x * y if x > y else 0
bmx = otbApplication.Registry.CreateApplication("BandMathX") bmx = otbApplication.Registry.CreateApplication("BandMathX")
bmx.SetParameterString("exp", "{idxX>idxY?idxX*idxY:0}") bmx.SetParameterString("exp", "{idxX>idxY?idxX*idxY:0}")
bmx.SetParameterStringList( bmx.SetParameterStringList(
"il", [resolve_paths("$DATADIR/fake_spot6.jp2")] "il", [resolve_paths("$DATADIR/xs_subset.tif")]
) )
bmx.Execute() bmx.Execute()
...@@ -36,13 +51,26 @@ class NodataInferenceTest(unittest.TestCase): ...@@ -36,13 +51,26 @@ class NodataInferenceTest(unittest.TestCase):
) )
infer.SetParameterFloat("source1.nodata", 0.0) infer.SetParameterFloat("source1.nodata", 0.0)
for param in [ for param in [
"source1.rfieldx", "source1.rfieldy", "output.efieldx", "output.efieldy" "source1.rfieldx",
"source1.rfieldy",
"output.efieldx",
"output.efieldy",
"optim.tilesizex",
"optim.tilesizey",
]: ]:
infer.SetParameterInt(param, 16) infer.SetParameterInt(param, 4)
infer.SetParameterFloat("output.bv", 1024)
infer.SetParameterString("out", resolve_paths("$TMPDIR/nd_out.tif")) infer.SetParameterString("out", resolve_paths("$TMPDIR/nd_out.tif"))
infer.ExecuteAndWriteOutput() infer.ExecuteAndWriteOutput()
self.assertTrue(
compare(
raster1=resolve_paths("$TMPDIR/nd_out.tif"),
raster2=resolve_paths("$DATADIR/nd_out.tif"),
)
)
if __name__ == '__main__': if __name__ == '__main__':
NodataInferenceTest().test_infersimple() unittest.main()
#unittest.main()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment