From f31664ddca06d65a8e0e9e64e91492fda6b94b6a Mon Sep 17 00:00:00 2001
From: Thibault Hallouin <thibault.hallouin@inrae.fr>
Date: Tue, 13 Sep 2022 16:42:48 +0200
Subject: [PATCH] allow masking conditions to be specified on predictions

An earlier implementation of the masking conditions assumed that the
conditions on streamflow would only be on the observations, but this is
not always the case. For example, reliability scores cannot be done on
the observed streamflow and need to be performed on the predicted
streamflow. So this is now possible as the condition syntax is changed
and now *q_obs*/*q_prd_median*/*q_prd_mean* in place of *q*.
---
 include/evalhyd/evald.hpp   | 14 +++---
 include/evalhyd/evalp.hpp   | 53 ++++++++++++---------
 src/masks.hpp               | 94 +++++++++++++++++++++++++++----------
 src/probabilist/evaluator.h | 23 ++++-----
 tests/test_determinist.cpp  |  2 +-
 tests/test_probabilist.cpp  | 12 ++---
 6 files changed, 126 insertions(+), 72 deletions(-)

diff --git a/include/evalhyd/evald.hpp b/include/evalhyd/evald.hpp
index 121e259..f4097aa 100644
--- a/include/evalhyd/evald.hpp
+++ b/include/evalhyd/evald.hpp
@@ -104,7 +104,7 @@ namespace evalhyd
     ///    m_cdt: ``xt::xtensor<std::array<char, 32>, N>``, optional
     ///       Masking conditions to use to generate temporal subsets. Each
     ///       condition consists in a string and can be specified on observed
-    ///       streamflow values or on time indices. If provided in combination
+    ///       streamflow values, or on time indices. If provided in combination
     ///       with *t_msk*, the latter takes precedence. If not provided and
     ///       neither is *t_msk*, no subset is performed. If provided, there
     ///       must be as many conditions as there are time series of
@@ -179,8 +179,8 @@ namespace evalhyd
                 // flatten arrays to bypass n-dim considerations
                 // (possible because shapes are constrained to be the same)
                 if (m_cdt.shape(m_cdt.dimension() - 1) != 1)
-                    throw std::runtime_error("length of last axis in masking conditions "
-                                             "must be equal to one");
+                    throw std::runtime_error("length of last axis in masking "
+                                             "conditions must be equal to one");
                 for (int a = 0; a < m_cdt.dimension() - 1; a++)
                     if (q_obs.shape(a) != m_cdt.shape(a))
                         throw std::runtime_error("masking conditions and observations "
@@ -216,15 +216,15 @@ namespace evalhyd
         if (q_obs.dimension() != q_prd.dimension())
         {
             throw std::runtime_error(
-                    "observations and predictions feature different numbers "
-                    "of dimensions"
+                    "observations and predictions feature "
+                    "different numbers of dimensions"
             );
         }
         if (q_obs.dimension() != msk.dimension())
         {
             throw std::runtime_error(
-                    "observations and masks feature different numbers "
-                    "of dimensions"
+                    "observations and masks feature "
+                    "different numbers of dimensions"
             );
         }
 
diff --git a/include/evalhyd/evalp.hpp b/include/evalhyd/evalp.hpp
index aea53d2..3231866 100644
--- a/include/evalhyd/evalp.hpp
+++ b/include/evalhyd/evalp.hpp
@@ -5,6 +5,7 @@
 #include <unordered_map>
 #include <vector>
 #include <array>
+#include <stdexcept>
 #include <xtensor/xtensor.hpp>
 #include <xtensor/xarray.hpp>
 #include <xtensor/xview.hpp>
@@ -44,26 +45,26 @@ namespace evalhyd
     ///       Streamflow exceedance threshold(s).
     ///       shape: (sites, thresholds)
     ///
-    ///    t_msk: ``xt::xtensor<bool, 3>``, optional
+    ///    t_msk: ``xt::xtensor<bool, 4>``, optional
     ///       Mask(s) used to generate temporal subsets of the whole streamflow
     ///       time series (where True/False is used for the time steps to
     ///       include/discard in a given subset). If not provided and neither
     ///       is *m_cdt*, no subset is performed and only one set of metrics is
     ///       returned corresponding to the whole time series. If provided, as
     ///       many sets of metrics are returned as they are masks provided.
-    ///       shape: (sites, subsets, time)
+    ///       shape: (sites, lead times, subsets, time)
     ///
     ///       .. seealso:: :doc:`../../functionalities/temporal-masking`
     ///
     ///    m_cdt: ``xt::xtensor<std::array<char, 32>, 2>``, optional
     ///       Masking conditions to use to generate temporal subsets. Each
     ///       condition consists in a string and can be specified on observed
-    ///       streamflow values or on time indices. If provided in combination
-    ///       with *t_msk*, the latter takes precedence. If not provided and
-    ///       neither is *t_msk*, no subset is performed and only one set of
-    ///       metrics is returned corresponding to the whole time series. If
-    ///       provided, as many sets of metrics are returned as they are
-    ///       conditions provided.
+    ///       or predicted streamflow values, or on time indices. If provided
+    ///       in combination with *t_msk*, the latter takes precedence. If not
+    ///       provided and neither is *t_msk*, no subset is performed and only
+    ///       one set of metrics is returned corresponding to the whole time
+    ///       series. If provided, as many sets of metrics are returned as they
+    ///       are conditions provided.
     ///       shape: (sites, subsets)
     ///
     ///       .. seealso:: :doc:`../../functionalities/conditional-masking`
@@ -107,7 +108,7 @@ namespace evalhyd
             const xt::xtensor<double, 4>& q_prd,
             const std::vector<std::string>& metrics,
             const xt::xtensor<double, 2>& q_thr = {},
-            const xt::xtensor<bool, 3>& t_msk = {},
+            const xt::xtensor<bool, 4>& t_msk = {},
             const xt::xtensor<std::array<char, 32>, 2>& m_cdt = {}
     )
     {
@@ -128,11 +129,18 @@ namespace evalhyd
                     "temporal lengths"
             );
         if (t_msk.size() > 0)
-            if (q_obs.shape(1) != t_msk.shape(2))
+            if (q_obs.shape(1) != t_msk.shape(3))
                 throw std::runtime_error(
                         "observations and masks feature different "
                         "temporal lengths"
                 );
+        // > leadtimes
+        if (t_msk.size() > 0)
+            if (q_prd.shape(1) != t_msk.shape(1))
+                throw std::runtime_error(
+                        "predictions and temporal masks feature different "
+                        "numbers of lead times"
+                );
         // > sites
         if (q_obs.shape(0) != q_prd.shape(0))
             throw std::runtime_error(
@@ -158,7 +166,7 @@ namespace evalhyd
         std::size_t n_mbr = q_prd.shape(2);
         std::size_t n_tim = q_prd.shape(3);
         std::size_t n_thr = q_thr.shape(1);
-        std::size_t n_msk = t_msk.size() > 0 ? t_msk.shape(1) :
+        std::size_t n_msk = t_msk.size() > 0 ? t_msk.shape(2) :
                 (m_cdt.size() > 0 ? m_cdt.shape(1) : 1);
 
         // register metrics number of dimensions
@@ -196,17 +204,20 @@ namespace evalhyd
 
         // generate masks from conditions if provided
         auto gen_msk = [&]() {
-            xt::xtensor<bool, 3> c_msk = xt::zeros<bool>({n_sit, n_msk, n_tim});
+            xt::xtensor<bool, 4> c_msk = xt::zeros<bool>({n_sit, n_ltm, n_msk, n_tim});
             if (m_cdt.size() > 0)
                 for (int s = 0; s < n_sit; s++)
-                    for (int m = 0; m < n_msk; m++)
-                        xt::view(c_msk, s, m) =
-                                eh::masks::generate_mask_from_conditions(
-                                        xt::view(m_cdt, s, m), xt::view(q_obs, s)
-                                );
+                    for (int l = 0; l < n_ltm; l++)
+                        for (int m = 0; m < n_msk; m++)
+                            xt::view(c_msk, s, l, m) =
+                                    eh::masks::generate_mask_from_conditions(
+                                            xt::view(m_cdt, s, m),
+                                            xt::view(q_obs, s),
+                                            xt::view(q_prd, s, l)
+                                    );
             return c_msk;
         };
-        const xt::xtensor<bool, 3> c_msk = gen_msk();
+        const xt::xtensor<bool, 4> c_msk = gen_msk();
 
         // initialise data structure for outputs
         std::vector<xt::xarray<double>> r;
@@ -224,10 +235,10 @@ namespace evalhyd
                 const auto q_thr_v = xt::view(q_thr, s, xt::all());
                 const auto t_msk_v =
                         t_msk.size() > 0 ?
-                        xt::view(t_msk, s, xt::all(), xt::all()) :
+                        xt::view(t_msk, s, l, xt::all(), xt::all()) :
                         (m_cdt.size() > 0 ?
-                         xt::view(c_msk, s, xt::all(), xt::all()) :
-                         xt::view(t_msk, s, xt::all(), xt::all()));
+                         xt::view(c_msk, s, l, xt::all(), xt::all()) :
+                         xt::view(t_msk, s, l, xt::all(), xt::all()));
 
                 eh::probabilist::Evaluator evaluator(
                         q_obs_v, q_prd_v, q_thr_v, t_msk_v
diff --git a/src/masks.hpp b/src/masks.hpp
index 3991790..c61bb0f 100644
--- a/src/masks.hpp
+++ b/src/masks.hpp
@@ -7,10 +7,12 @@
 #include <array>
 #include <string>
 #include <regex>
+#include <stdexcept>
 
 #include <xtensor/xexpression.hpp>
 #include <xtensor/xtensor.hpp>
 #include <xtensor/xview.hpp>
+#include <xtensor/xsort.hpp>
 #include <xtensor/xindex_view.hpp>
 
 typedef std::map<std::string, std::vector<std::vector<std::string>>> msk_tree;
@@ -24,9 +26,12 @@ namespace evalhyd
         {
             msk_tree subset;
 
-            // pattern supported to specify conditions to generate masks on streamflow
+            // pattern supported to specify conditions to generate masks on
+            // observed or predicted (median or mean for probabilist) streamflow
             // e.g. q{>9.} q{<9} q{>=99.0} q{<=99} q{>9,<99} q{==9} q{!=9}
-            std::regex exp_q (R"(([q])\{((([><!=]?=?[0-9]+\.?[0-9]*),*)+)\})");
+            std::regex exp_q (
+                    R"((q_obs|q_prd_median|q_prd_mean)\{((([><!=]?=?[0-9]+\.?[0-9]*),*)+)\})"
+            );
 
             for (std::sregex_iterator i =
                     std::sregex_iterator(msk_str.begin(), msk_str.end(), exp_q);
@@ -132,13 +137,21 @@ namespace evalhyd
 
         /// Function to generate temporal mask based on masking conditions
         inline xt::xtensor<bool, 1> generate_mask_from_conditions(
-                const std::array<char, 32>& msk_char_arr, const xt::xtensor<double, 1>& q_obs
+                const std::array<char, 32>& msk_char_arr,
+                const xt::xtensor<double, 1>& q_obs,
+                const xt::xtensor<double, 2>& q_prd = {}
         )
         {
             // parse string to identify masking conditions
             std::string msk_str(msk_char_arr.begin(), msk_char_arr.end());
             msk_tree subset = parse_masking_conditions(msk_str);
 
+            // check if conditions were found in parsing
+            if (subset.empty())
+                throw std::runtime_error(
+                        "no valid condition found to generate mask(s)"
+                );
+
             // initialise a boolean expression for the masks
             xt::xtensor<bool, 1> t_msk = xt::zeros<bool>(q_obs.shape());
 
@@ -149,8 +162,37 @@ namespace evalhyd
                 auto cond = var_cond.second;
 
                 // condition on streamflow
-                if (var == "q")
+                if ((var == "q_obs") or (var == "q_prd_median")
+                    or (var == "q_prd_mean"))
                 {
+                    // preprocess streamflow depending on kind
+                    auto get_q = [&]() {
+                        if (var == "q_obs") {
+                            return q_obs;
+                        }
+                        else if (var == "q_prd_median") {
+                            if (q_prd.size() < 1)
+                                throw std::runtime_error(
+                                        "condition on streamflow predictions "
+                                        "not allowed for generating masks"
+                                );
+                            xt::xtensor<double, 1> q_prd_median =
+                                    xt::median(q_prd, 0);
+                            return q_prd_median;
+                        }
+                        else {  // i.e. (var == "q_prd_mean")
+                            if (q_prd.size() < 1)
+                                throw std::runtime_error(
+                                        "condition on streamflow predictions "
+                                        "not allowed for generating masks"
+                                );
+                            xt::xtensor<double, 1> q_prd_mean =
+                                    xt::mean(q_prd, 0);
+                            return q_prd_mean;
+                        }
+                    };
+                    auto q = get_q();
+
                     // preprocess conditions to identify special cases
                     // within/without
                     bool within = false;
@@ -191,57 +233,57 @@ namespace evalhyd
                     if (within)
                     {
                         if ((opr1 == "<") and (opr2 == ">"))
-                            t_msk = xt::where((q_obs < val1) & (q_obs > val2),
+                            t_msk = xt::where((q < val1) & (q > val2),
                                               1, t_msk);
                         else if ((opr1 == "<=") and (opr2 == ">"))
-                            t_msk = xt::where((q_obs <= val1) & (q_obs > val2),
+                            t_msk = xt::where((q <= val1) & (q > val2),
                                               1, t_msk);
                         else if ((opr1 == "<") and (opr2 == ">="))
-                            t_msk = xt::where((q_obs < val1) & (q_obs >= val2),
+                            t_msk = xt::where((q < val1) & (q >= val2),
                                               1, t_msk);
                         else if ((opr1 == "<=") and (opr2 == ">="))
-                            t_msk = xt::where((q_obs <= val1) & (q_obs >= val2),
+                            t_msk = xt::where((q <= val1) & (q >= val2),
                                               1, t_msk);
 
                         if ((opr2 == "<") and (opr1 == ">"))
-                            t_msk = xt::where((q_obs < val2) & (q_obs > val1),
+                            t_msk = xt::where((q < val2) & (q > val1),
                                               1, t_msk);
                         else if ((opr2 == "<=") and (opr1 == ">"))
-                            t_msk = xt::where((q_obs <= val2) & (q_obs > val1),
+                            t_msk = xt::where((q <= val2) & (q > val1),
                                               1, t_msk);
                         else if ((opr2 == "<") and (opr1 == ">="))
-                            t_msk = xt::where((q_obs < val2) & (q_obs >= val1),
+                            t_msk = xt::where((q < val2) & (q >= val1),
                                               1, t_msk);
                         else if ((opr2 == "<=") and (opr1 == ">="))
-                            t_msk = xt::where((q_obs <= val2) & (q_obs >= val1),
+                            t_msk = xt::where((q <= val2) & (q >= val1),
                                               1, t_msk);
                     }
                     else if (without)
                     {
                         if ((opr1 == "<") and (opr2 == ">"))
-                            t_msk = xt::where((q_obs < val1) | (q_obs > val2),
+                            t_msk = xt::where((q < val1) | (q > val2),
                                               1, t_msk);
                         else if ((opr1 == "<=") and (opr2 == ">"))
-                            t_msk = xt::where((q_obs <= val1) | (q_obs > val2),
+                            t_msk = xt::where((q <= val1) | (q > val2),
                                               1, t_msk);
                         else if ((opr1 == "<") and (opr2 == ">="))
-                            t_msk = xt::where((q_obs < val1) | (q_obs >= val2),
+                            t_msk = xt::where((q < val1) | (q >= val2),
                                               1, t_msk);
                         else if ((opr1 == "<=") and (opr2 == ">="))
-                            t_msk = xt::where((q_obs <= val1) & (q_obs >= val2),
+                            t_msk = xt::where((q <= val1) & (q >= val2),
                                               1, t_msk);
 
                         if ((opr2 == "<") and (opr1 == ">"))
-                            t_msk = xt::where((q_obs < val2) | (q_obs > val1),
+                            t_msk = xt::where((q < val2) | (q > val1),
                                               1, t_msk);
                         else if ((opr2 == "<=") and (opr1 == ">"))
-                            t_msk = xt::where((q_obs <= val2) | (q_obs > val1),
+                            t_msk = xt::where((q <= val2) | (q > val1),
                                               1, t_msk);
                         else if ((opr2 == "<") and (opr1 == ">="))
-                            t_msk = xt::where((q_obs < val2) | (q_obs >= val1),
+                            t_msk = xt::where((q < val2) | (q >= val1),
                                               1, t_msk);
                         else if ((opr2 == "<=") and (opr1 == ">="))
-                            t_msk = xt::where((q_obs <= val2) | (q_obs >= val1),
+                            t_msk = xt::where((q <= val2) | (q >= val1),
                                               1, t_msk);
                     }
                     else
@@ -256,27 +298,27 @@ namespace evalhyd
                             // apply masking condition to given subset
                             if (opr == "<")
                                 t_msk = xt::where(
-                                        q_obs < val, 1, t_msk
+                                        q < val, 1, t_msk
                                 );
                             else if (opr == ">")
                                 t_msk = xt::where(
-                                        q_obs > val, 1, t_msk
+                                        q > val, 1, t_msk
                                 );
                             else if (opr == "<=")
                                 t_msk = xt::where(
-                                        q_obs <= val, 1, t_msk
+                                        q <= val, 1, t_msk
                                 );
                             else if (opr == ">=")
                                 t_msk = xt::where(
-                                        q_obs >= val, 1, t_msk
+                                        q >= val, 1, t_msk
                                 );
                             else if (opr == "==")
                                 t_msk = xt::where(
-                                        xt::equal(q_obs, val), 1, t_msk
+                                        xt::equal(q, val), 1, t_msk
                                 );
                             else if (opr == "!=")
                                 t_msk = xt::where(
-                                        xt::not_equal(q_obs, val), 1, t_msk
+                                        xt::not_equal(q, val), 1, t_msk
                                 );
                         }
                     }
diff --git a/src/probabilist/evaluator.h b/src/probabilist/evaluator.h
index c37d074..86d35bf 100644
--- a/src/probabilist/evaluator.h
+++ b/src/probabilist/evaluator.h
@@ -6,7 +6,7 @@
 #include <xtensor/xview.hpp>
 #include <xtensor/xslice.hpp>
 
-using view1d_xtensor2d_type = decltype(
+using view1d_xtensor2d_double_type = decltype(
     xt::view(
             std::declval<const xt::xtensor<double, 2>&>(),
             std::declval<int>(),
@@ -14,7 +14,7 @@ using view1d_xtensor2d_type = decltype(
     )
 );
 
-using view2d_xtensor4d_type = decltype(
+using view2d_xtensor4d_double_type = decltype(
     xt::view(
             std::declval<const xt::xtensor<double, 4>&>(),
             std::declval<int>(),
@@ -24,9 +24,10 @@ using view2d_xtensor4d_type = decltype(
     )
 );
 
-using view2d_xtensor3d_type = decltype(
+using view2d_xtensor4d_bool_type = decltype(
     xt::view(
-            std::declval<const xt::xtensor<bool, 3>&>(),
+            std::declval<const xt::xtensor<bool, 4>&>(),
+            std::declval<int>(),
             std::declval<int>(),
             xt::all(),
             xt::all()
@@ -41,9 +42,9 @@ namespace evalhyd
         {
         private:
             // members for input data
-            const view1d_xtensor2d_type& q_obs;
-            const view2d_xtensor4d_type& q_prd;
-            const view1d_xtensor2d_type& q_thr;
+            const view1d_xtensor2d_double_type& q_obs;
+            const view2d_xtensor4d_double_type& q_prd;
+            const view1d_xtensor2d_double_type& q_thr;
             xt::xtensor<bool, 2> t_msk;
 
             // members for dimensions
@@ -60,10 +61,10 @@ namespace evalhyd
 
         public:
             // constructor method
-            Evaluator(const view1d_xtensor2d_type& obs,
-                      const view2d_xtensor4d_type& prd,
-                      const view1d_xtensor2d_type& thr,
-                      const view2d_xtensor3d_type& msk) :
+            Evaluator(const view1d_xtensor2d_double_type& obs,
+                      const view2d_xtensor4d_double_type& prd,
+                      const view1d_xtensor2d_double_type& thr,
+                      const view2d_xtensor4d_bool_type& msk) :
                     q_obs{obs}, q_prd{prd}, q_thr{thr}, t_msk(msk)
             {
                 // initialise a mask if none provided
diff --git a/tests/test_determinist.cpp b/tests/test_determinist.cpp
index bf753fe..703dd7a 100644
--- a/tests/test_determinist.cpp
+++ b/tests/test_determinist.cpp
@@ -221,7 +221,7 @@ TEST(DeterministTests, TestMaskingConditions)
 
     // compute scores using masking conditions on streamflow to subset whole record
     xt::xtensor<std::array<char, 32>, 2> q_conditions ={
-            {{"q{<2000,>3000}"}}
+            {{"q_obs{<2000,>3000}"}}
     };
 
     std::vector<xt::xarray<double>> metrics_q_conditioned =
diff --git a/tests/test_probabilist.cpp b/tests/test_probabilist.cpp
index 5906cbc..c2c0497 100644
--- a/tests/test_probabilist.cpp
+++ b/tests/test_probabilist.cpp
@@ -132,10 +132,10 @@ TEST(ProbabilistTests, TestMasks)
     ifs.close();
 
     // generate temporal subset by dropping 20 first time steps
-    xt::xtensor<double, 3> masks =
-            xt::ones<bool>({std::size_t {1}, std::size_t {1},
+    xt::xtensor<double, 4> masks =
+            xt::ones<bool>({std::size_t {1}, std::size_t {1}, std::size_t {1},
                             std::size_t {observed.size()}});
-    xt::view(masks, 0, 0, xt::range(0, 20)) = 0;
+    xt::view(masks, 0, xt::all(), 0, xt::range(0, 20)) = 0;
 
     // compute scores using masks to subset whole record
     xt::xtensor<double, 2> thresholds = {{690, 534, 445}};
@@ -150,7 +150,7 @@ TEST(ProbabilistTests, TestMasks)
                     xt::view(predicted, xt::newaxis(), xt::newaxis(), xt::all(), xt::all()),
                     metrics,
                     thresholds,
-                    // shape: (subsets [1], time [t])
+                    // shape: (sites [1], lead times [1], subsets [1], time [t])
                     masks
             );
 
@@ -191,13 +191,13 @@ TEST(ProbabilistTests, TestMaskingConditions)
     ifs.close();
 
     // generate dummy empty masks required to access next optional argument
-    xt::xtensor<bool, 3> masks;
+    xt::xtensor<bool, 4> masks;
 
     // conditions on streamflow ________________________________________________
 
     // compute scores using masking conditions on streamflow to subset whole record
     xt::xtensor<std::array<char, 32>, 2> q_conditions = {
-            {{"q{<2000,>3000}"}}
+            {{"q_obs{<2000,>3000}"}}
     };
 
     std::vector<xt::xarray<double>> metrics_q_conditioned =
-- 
GitLab