Commit 0cf08f50 authored by Thibault Hallouin's avatar Thibault Hallouin
Browse files

expose evalhyd conditional masking functionality

No related merge requests found
Pipeline #38794 passed with stage
in 2 minutes and 1 second
Showing with 35 additions and 9 deletions
+35 -9
Subproject commit 5edca4b512d4258dd44414f3aa5c7a81db7b6cef Subproject commit bb555aee25de67b568390f4d76ae9b9a1678a065
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <array>
#define STRINGIFY(x) #x #define STRINGIFY(x) #x
#define MACRO_STRINGIFY(x) STRINGIFY(x) #define MACRO_STRINGIFY(x) STRINGIFY(x)
...@@ -47,7 +48,7 @@ PYBIND11_MODULE(evalhyd, m) ...@@ -47,7 +48,7 @@ PYBIND11_MODULE(evalhyd, m)
metrics: `List[str]` metrics: `List[str]`
The sequence of evaluation metrics to be computed. The sequence of evaluation metrics to be computed.
transform: `str`, optional transform: `str`, optional
The transformation to apply to both streamflow observations The transformation to apply to both streamflow observations
and predictions prior to the calculation of the *metrics*. and predictions prior to the calculation of the *metrics*.
...@@ -94,9 +95,10 @@ PYBIND11_MODULE(evalhyd, m) ...@@ -94,9 +95,10 @@ PYBIND11_MODULE(evalhyd, m)
shape: [(components,)+] shape: [(components,)+]
)pbdoc", )pbdoc",
py::arg("q_obs"), py::arg("q_prd"), py::arg("metrics"), py::arg("q_obs"), py::arg("q_prd"), py::arg("metrics"),
py::arg("transform")="none", py::arg("exponent")=1, py::arg("transform") = "none", py::arg("exponent") = 1,
py::arg("epsilon")=-9, py::arg("epsilon") = -9,
py::arg("t_msk") = xt::pytensor<bool, 1>({}) py::arg("t_msk") = xt::pytensor<bool, 1>({}),
py::arg("m_cdt") = xt::pytensor<std::array<char, 32>, 1>({})
); );
m.def( m.def(
"evald", evalhyd::evald<2>, "evald", evalhyd::evald<2>,
...@@ -170,9 +172,10 @@ PYBIND11_MODULE(evalhyd, m) ...@@ -170,9 +172,10 @@ PYBIND11_MODULE(evalhyd, m)
shape: [(1+, components), ...] shape: [(1+, components), ...]
)pbdoc", )pbdoc",
py::arg("q_obs"), py::arg("q_prd"), py::arg("metrics"), py::arg("q_obs"), py::arg("q_prd"), py::arg("metrics"),
py::arg("transform")="none", py::arg("exponent")=1, py::arg("transform") = "none", py::arg("exponent") = 1,
py::arg("epsilon")=-9, py::arg("epsilon") = -9,
py::arg("t_msk") = xt::pytensor<bool, 2>({0}) py::arg("t_msk") = xt::pytensor<bool, 2>({0}),
py::arg("m_cdt") = xt::pytensor<std::array<char, 32>, 2>({0})
); );
// probabilistic evaluation // probabilistic evaluation
...@@ -227,7 +230,8 @@ PYBIND11_MODULE(evalhyd, m) ...@@ -227,7 +230,8 @@ PYBIND11_MODULE(evalhyd, m)
)pbdoc", )pbdoc",
py::arg("q_obs"), py::arg("q_prd"), py::arg("metrics"), py::arg("q_obs"), py::arg("q_prd"), py::arg("metrics"),
py::arg("q_thr") = xt::pytensor<double, 2>({0}), py::arg("q_thr") = xt::pytensor<double, 2>({0}),
py::arg("t_msk") = xt::pytensor<bool, 3>({0}) py::arg("t_msk") = xt::pytensor<bool, 3>({0}),
py::arg("m_cdt") = xt::pytensor<std::array<char, 32>, 2>({0})
); );
#ifdef VERSION_INFO #ifdef VERSION_INFO
......
...@@ -95,6 +95,17 @@ class TestMasking(unittest.TestCase): ...@@ -95,6 +95,17 @@ class TestMasking(unittest.TestCase):
evalhyd.evald(_obs[..., 99:], _prd[..., 99:], ["NSE"])[0] evalhyd.evald(_obs[..., 99:], _prd[..., 99:], ["NSE"])[0]
) )
def test_conditions(self):
cdt = numpy.array([["q{<2000,>3000}"]], dtype='|S32')
obs = _obs[..., (_obs[0] < 2000) | (_obs[0] > 3000)]
prd = _prd[..., (_obs[0] < 2000) | (_obs[0] > 3000)]
numpy.testing.assert_almost_equal(
evalhyd.evald(_obs, _prd, ["NSE"], m_cdt=cdt)[0],
evalhyd.evald(obs, prd, ["NSE"])[0]
)
class TestMissingData(unittest.TestCase): class TestMissingData(unittest.TestCase):
......
...@@ -86,6 +86,17 @@ class TestMasking(unittest.TestCase): ...@@ -86,6 +86,17 @@ class TestMasking(unittest.TestCase):
evalhyd.evalp(_obs[..., 99:], _prd[..., 99:], ["QS"])[0] evalhyd.evalp(_obs[..., 99:], _prd[..., 99:], ["QS"])[0]
) )
def test_conditions(self):
cdt = numpy.array([["q{<2000,>3000}"]], dtype='|S32')
obs = _obs[..., (_obs[0] < 2000) | (_obs[0] > 3000)]
prd = _prd[..., (_obs[0] < 2000) | (_obs[0] > 3000)]
numpy.testing.assert_almost_equal(
evalhyd.evalp(_obs, _prd, ["QS"], m_cdt=cdt)[0],
evalhyd.evalp(obs, prd, ["QS"])[0]
)
class TestMissingData(unittest.TestCase): class TestMissingData(unittest.TestCase):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment