From d3b5095d9946a7d2fdb9fec157564b7c698b9f0e Mon Sep 17 00:00:00 2001
From: Thibault Hallouin <thibault.hallouin@inrae.fr>
Date: Thu, 13 Apr 2023 14:25:35 +0200
Subject: [PATCH] add multi-variate metric ES

---
 deps/evalhyd                |  2 +-
 tests/expected/evalp/ES.csv |  1 +
 tests/test_probabilist.py   | 40 +++++++++++++++++++++++++++++++------
 3 files changed, 36 insertions(+), 7 deletions(-)
 create mode 100644 tests/expected/evalp/ES.csv

diff --git a/deps/evalhyd b/deps/evalhyd
index f58170a..31cf8da 160000
--- a/deps/evalhyd
+++ b/deps/evalhyd
@@ -1 +1 @@
-Subproject commit f58170a5fcaeae1143bd1a626453722306609c00
+Subproject commit 31cf8da51a415301a7a1842b5dbc0ecdd3fff1b1
diff --git a/tests/expected/evalp/ES.csv b/tests/expected/evalp/ES.csv
new file mode 100644
index 0000000..2cb138c
--- /dev/null
+++ b/tests/expected/evalp/ES.csv
@@ -0,0 +1 @@
+587.2254970444062
diff --git a/tests/test_probabilist.py b/tests/test_probabilist.py
index 72502eb..9c9b945 100644
--- a/tests/test_probabilist.py
+++ b/tests/test_probabilist.py
@@ -24,7 +24,9 @@ _all_metrics = (
     # ranks-based
     'RANK_HIST', 'DS', 'AS',
     # intervals
-    'CR', 'AW', 'AWN', 'AWI', 'WS', 'WSS'
+    'CR', 'AW', 'AWN', 'AWI', 'WS', 'WSS',
+    # multivariate
+    'ES'
 )
 
 # list all available deterministic diagnostics
@@ -83,6 +85,13 @@ class TestMetrics(unittest.TestCase):
         ) for metric in ('CR', 'AW', 'AWN', 'AWI', 'WS', 'WSS')
     }
 
+    expected_mvr = {
+        metric: (
+            numpy.genfromtxt(f"./expected/evalp/{metric}.csv", delimiter=',')
+            [numpy.newaxis, numpy.newaxis, numpy.newaxis, numpy.newaxis, ...]
+        ) for metric in ('ES',)
+    }
+
     def test_thresholds_metrics(self):
         thr = numpy.array([[690, 534, 445, numpy.nan]])
         for metric in self.expected_thr.keys():
@@ -139,6 +148,19 @@ class TestMetrics(unittest.TestCase):
                     self.expected_itv[metric]
                 )
 
+    def test_multivariate_metrics(self):
+        n_sit = 5
+
+        multi_obs = numpy.repeat(_obs, repeats=n_sit, axis=0)
+        multi_prd = numpy.repeat(_prd, repeats=n_sit, axis=0)
+
+        for metric in self.expected_mvr.keys():
+            with self.subTest(metric=metric):
+                numpy.testing.assert_almost_equal(
+                    evalhyd.evalp(multi_obs, multi_prd, [metric], seed=7)[0],
+                    self.expected_mvr[metric]
+                )
+
 
 class TestDecomposition(unittest.TestCase):
 
@@ -325,10 +347,13 @@ class TestMultiDimensional(unittest.TestCase):
         multi_prd = numpy.repeat(_prd, repeats=n_sit, axis=0)
         multi_thr = numpy.repeat(self.thr, repeats=n_sit, axis=0)
 
+        # skip multisite metrics because their result is not the sum of sites
+        metrics = [m for m in self.metrics if m not in ("ES",)]
+
         multi = evalhyd.evalp(
             multi_obs,
             multi_prd,
-            self.metrics,
+            metrics,
             q_thr=multi_thr,
             events=self.events,
             c_lvl=self.lvl,
@@ -338,14 +363,14 @@ class TestMultiDimensional(unittest.TestCase):
         mono = evalhyd.evalp(
             _obs,
             _prd,
-            self.metrics,
+            metrics,
             q_thr=self.thr,
             events=self.events,
             c_lvl=self.lvl,
             seed=self.seed
         )
 
-        for m, metric in enumerate(self.metrics):
+        for m, metric in enumerate(metrics):
             for site in range(n_sit):
                 with self.subTest(metric=metric, site=site):
                     numpy.testing.assert_almost_equal(
@@ -400,17 +425,20 @@ class TestMultiDimensional(unittest.TestCase):
 
         multi_thr = numpy.repeat(self.thr, repeats=n_sit, axis=0)
 
+        # skip multisite metrics because their result is not the sum of sites
+        metrics = [m for m in self.metrics if m not in ("ES",)]
+
         multi = evalhyd.evalp(
             multi_obs,
             multi_prd,
-            self.metrics,
+            metrics,
             q_thr=multi_thr,
             events=self.events,
             c_lvl=self.lvl,
             seed=self.seed
         )
 
-        for m, metric in enumerate(self.metrics):
+        for m, metric in enumerate(metrics):
             for sit in range(n_sit):
                 for ldt in range(n_ldt):
 
-- 
GitLab