Commit d3b5095d authored by Thibault Hallouin's avatar Thibault Hallouin
Browse files

add multi-variate metric ES

1 merge request!1release v0.1.0.0
Pipeline #46226 passed with stage
in 5 minutes
Showing with 36 additions and 7 deletions
+36 -7
Subproject commit f58170a5fcaeae1143bd1a626453722306609c00 Subproject commit 31cf8da51a415301a7a1842b5dbc0ecdd3fff1b1
587.2254970444062
...@@ -24,7 +24,9 @@ _all_metrics = ( ...@@ -24,7 +24,9 @@ _all_metrics = (
# ranks-based # ranks-based
'RANK_HIST', 'DS', 'AS', 'RANK_HIST', 'DS', 'AS',
# intervals # intervals
'CR', 'AW', 'AWN', 'AWI', 'WS', 'WSS' 'CR', 'AW', 'AWN', 'AWI', 'WS', 'WSS',
# multivariate
'ES'
) )
# list all available deterministic diagnostics # list all available deterministic diagnostics
...@@ -83,6 +85,13 @@ class TestMetrics(unittest.TestCase): ...@@ -83,6 +85,13 @@ class TestMetrics(unittest.TestCase):
) for metric in ('CR', 'AW', 'AWN', 'AWI', 'WS', 'WSS') ) for metric in ('CR', 'AW', 'AWN', 'AWI', 'WS', 'WSS')
} }
expected_mvr = {
metric: (
numpy.genfromtxt(f"./expected/evalp/{metric}.csv", delimiter=',')
[numpy.newaxis, numpy.newaxis, numpy.newaxis, numpy.newaxis, ...]
) for metric in ('ES',)
}
def test_thresholds_metrics(self): def test_thresholds_metrics(self):
thr = numpy.array([[690, 534, 445, numpy.nan]]) thr = numpy.array([[690, 534, 445, numpy.nan]])
for metric in self.expected_thr.keys(): for metric in self.expected_thr.keys():
...@@ -139,6 +148,19 @@ class TestMetrics(unittest.TestCase): ...@@ -139,6 +148,19 @@ class TestMetrics(unittest.TestCase):
self.expected_itv[metric] self.expected_itv[metric]
) )
def test_multivariate_metrics(self):
n_sit = 5
multi_obs = numpy.repeat(_obs, repeats=n_sit, axis=0)
multi_prd = numpy.repeat(_prd, repeats=n_sit, axis=0)
for metric in self.expected_mvr.keys():
with self.subTest(metric=metric):
numpy.testing.assert_almost_equal(
evalhyd.evalp(multi_obs, multi_prd, [metric], seed=7)[0],
self.expected_mvr[metric]
)
class TestDecomposition(unittest.TestCase): class TestDecomposition(unittest.TestCase):
...@@ -325,10 +347,13 @@ class TestMultiDimensional(unittest.TestCase): ...@@ -325,10 +347,13 @@ class TestMultiDimensional(unittest.TestCase):
multi_prd = numpy.repeat(_prd, repeats=n_sit, axis=0) multi_prd = numpy.repeat(_prd, repeats=n_sit, axis=0)
multi_thr = numpy.repeat(self.thr, repeats=n_sit, axis=0) multi_thr = numpy.repeat(self.thr, repeats=n_sit, axis=0)
# skip multisite metrics because their result is not the sum of sites
metrics = [m for m in self.metrics if m not in ("ES",)]
multi = evalhyd.evalp( multi = evalhyd.evalp(
multi_obs, multi_obs,
multi_prd, multi_prd,
self.metrics, metrics,
q_thr=multi_thr, q_thr=multi_thr,
events=self.events, events=self.events,
c_lvl=self.lvl, c_lvl=self.lvl,
...@@ -338,14 +363,14 @@ class TestMultiDimensional(unittest.TestCase): ...@@ -338,14 +363,14 @@ class TestMultiDimensional(unittest.TestCase):
mono = evalhyd.evalp( mono = evalhyd.evalp(
_obs, _obs,
_prd, _prd,
self.metrics, metrics,
q_thr=self.thr, q_thr=self.thr,
events=self.events, events=self.events,
c_lvl=self.lvl, c_lvl=self.lvl,
seed=self.seed seed=self.seed
) )
for m, metric in enumerate(self.metrics): for m, metric in enumerate(metrics):
for site in range(n_sit): for site in range(n_sit):
with self.subTest(metric=metric, site=site): with self.subTest(metric=metric, site=site):
numpy.testing.assert_almost_equal( numpy.testing.assert_almost_equal(
...@@ -400,17 +425,20 @@ class TestMultiDimensional(unittest.TestCase): ...@@ -400,17 +425,20 @@ class TestMultiDimensional(unittest.TestCase):
multi_thr = numpy.repeat(self.thr, repeats=n_sit, axis=0) multi_thr = numpy.repeat(self.thr, repeats=n_sit, axis=0)
# skip multisite metrics because their result is not the sum of sites
metrics = [m for m in self.metrics if m not in ("ES",)]
multi = evalhyd.evalp( multi = evalhyd.evalp(
multi_obs, multi_obs,
multi_prd, multi_prd,
self.metrics, metrics,
q_thr=multi_thr, q_thr=multi_thr,
events=self.events, events=self.events,
c_lvl=self.lvl, c_lvl=self.lvl,
seed=self.seed seed=self.seed
) )
for m, metric in enumerate(self.metrics): for m, metric in enumerate(metrics):
for sit in range(n_sit): for sit in range(n_sit):
for ldt in range(n_ldt): for ldt in range(n_ldt):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment