From f7b3523e4d2c16a703029032f5cf70d51d463400 Mon Sep 17 00:00:00 2001
From: Thibault Hallouin <thibault.hallouin@inrae.fr>
Date: Wed, 1 Feb 2023 18:55:56 +0100
Subject: [PATCH] add unit test for conditional masking on time indices

---
 tests/test_determinist.py | 17 +++++++++++++++++
 tests/test_probabilist.py | 13 +++++++++++++
 2 files changed, 30 insertions(+)

diff --git a/tests/test_determinist.py b/tests/test_determinist.py
index 0b4118f..fed9d87 100644
--- a/tests/test_determinist.py
+++ b/tests/test_determinist.py
@@ -140,6 +140,23 @@ class TestMasking(unittest.TestCase):
                 evalhyd.evald(obs, prd, ["NSE"])[0]
             )
 
+        with self.subTest(conditions="time indices"):
+            cdt = numpy.array([["t{20:311}"],
+                               ["t{20:100,100:311}"],
+                               ["t{20,21,22,23,24:311}"],
+                               ["t{20,21,22,23:309,309,310}"],
+                               ["t{20:80,80,81,82,83:311}"]],
+                              dtype='|S32')
+
+            # TODO: figure out why passing views would not work
+            obs = _obs[..., 20:].copy()
+            prd = _prd[..., 20:].copy()
+
+            numpy.testing.assert_almost_equal(
+                evalhyd.evald(_obs, _prd, ["NSE"], m_cdt=cdt)[0],
+                evalhyd.evald(obs, prd, ["NSE"])[0]
+            )
+
 
 class TestMissingData(unittest.TestCase):
 
diff --git a/tests/test_probabilist.py b/tests/test_probabilist.py
index 7f74f1f..38f4752 100644
--- a/tests/test_probabilist.py
+++ b/tests/test_probabilist.py
@@ -205,6 +205,19 @@ class TestMasking(unittest.TestCase):
                 evalhyd.evalp(obs, prd, ["QS"])[0]
             )
 
+        with self.subTest(conditions="time indices"):
+            cdt = numpy.array([["t{20:80,80,81,82,83:311}"]],
+                              dtype='|S32')
+
+            # TODO: figure out why passing views would not work
+            obs = _obs[..., 20:].copy()
+            prd = _prd[..., 20:].copy()
+
+            numpy.testing.assert_almost_equal(
+                evalhyd.evalp(_obs, _prd, ["QS"], m_cdt=cdt)[0],
+                evalhyd.evalp(obs, prd, ["QS"])[0]
+            )
+
 
 class TestMissingData(unittest.TestCase):
 
-- 
GitLab