From 6f17c3e0f32f935c0a4f34dc1d3c6a9ee780497b Mon Sep 17 00:00:00 2001
From: Thibault Hallouin <thibault.hallouin@inrae.fr>
Date: Tue, 26 Dec 2023 17:55:06 +0100
Subject: [PATCH] add contingency table as probabilistic evaluation metric

---
 changelog.rst                                 |   1 +
 .../detail/probabilist/contingency.hpp        |  78 +++++++
 .../evalhyd/detail/probabilist/evaluator.hpp  |  14 ++
 include/evalhyd/evalp.hpp                     |   8 +-
 tests/expected/evalp/CONT_TBL.csv             | 208 ++++++++++++++++++
 tests/test_probabilist.cpp                    |  16 +-
 6 files changed, 321 insertions(+), 4 deletions(-)
 create mode 100644 tests/expected/evalp/CONT_TBL.csv

diff --git a/changelog.rst b/changelog.rst
index b043a7d..c5c7235 100644
--- a/changelog.rst
+++ b/changelog.rst
@@ -12,6 +12,7 @@ Yet to be versioned and released. Only available from *dev* branch until then.
 * add `"KGENP"` and `"KGENP_D"` as deterministic evaluation metrics since
   stable sorting is now available in ``xtensor``
   (`CPP#5 <https://gitlab.irstea.fr/HYCAR-Hydro/evalhyd/evalhyd-cpp/-/issues/5>`_)
+* add `"CONT_TBL"` as probabilistic evaluation metric
 
 .. rubric:: Bug fixes
 
diff --git a/include/evalhyd/detail/probabilist/contingency.hpp b/include/evalhyd/detail/probabilist/contingency.hpp
index a4685e4..0960708 100644
--- a/include/evalhyd/detail/probabilist/contingency.hpp
+++ b/include/evalhyd/detail/probabilist/contingency.hpp
@@ -297,6 +297,84 @@ namespace evalhyd
                 }
             }
 
+            /// Compute the contingency table (CONT_TBL), i.e. 'hits',
+            /// 'false alarms', 'misses', 'correct rejections', in this order.
+            ///
+            /// \param ct_a
+            ///     Hits.
+            ///     shape: (sites, lead times, levels, thresholds, time)
+            /// \param ct_b
+            ///     False alarms.
+            ///     shape: (sites, lead times, levels, thresholds, time)
+            /// \param ct_c
+            ///     Misses.
+            ///     shape: (sites, lead times, levels, thresholds, time)
+            /// \param ct_d
+            ///     Correct rejections.
+            ///     shape: (sites, lead times, levels, thresholds, time)
+            /// \param q_thr
+            ///     Streamflow exceedance threshold(s).
+            ///     shape: (sites, thresholds)
+            /// \param t_msk
+            ///     Temporal subsets of the whole record.
+            ///     shape: (sites, lead times, subsets, time)
+            /// \param b_exp
+            ///     Boostrap samples.
+            ///     shape: (samples, time slice)
+            /// \param n_sit
+            ///     Number of sites.
+            /// \param n_ldt
+            ///     Number of lead times.
+            /// \param n_thr
+            ///     Number of thresholds.
+            /// \param n_mbr
+            ///     Number of ensemble members.
+            /// \param n_msk
+            ///     Number of temporal subsets.
+            /// \param n_exp
+            ///     Number of bootstrap samples.
+            /// \return
+            ///     Contingency table.
+            ///     shape: (sites, lead times, subsets, samples, levels, thresholds, cells)
+            template <class XD2>
+            inline xt::xtensor<double, 7> calc_CONT_TBL(
+                    const xt::xtensor<double, 5>& ct_a,
+                    const xt::xtensor<double, 5>& ct_b,
+                    const xt::xtensor<double, 5>& ct_c,
+                    const xt::xtensor<double, 5>& ct_d,
+                    const XD2& q_thr,
+                    const xt::xtensor<bool, 4>& t_msk,
+                    const std::vector<xt::xkeep_slice<int>>& b_exp,
+                    std::size_t n_sit,
+                    std::size_t n_ldt,
+                    std::size_t n_thr,
+                    std::size_t n_mbr,
+                    std::size_t n_msk,
+                    std::size_t n_exp
+            )
+            {
+                // initialise output variable
+                xt::xtensor<double, 7> CONT_TBL =
+                        xt::zeros<double>({n_sit, n_ldt, n_msk, n_exp,
+                                           n_mbr + 1, n_thr, std::size_t {4}});
+
+                // compute table one cell at a time
+                std::size_t i = 0;
+                for (auto cell: {ct_a, ct_b, ct_c, ct_d})
+                {
+                    xt::view(CONT_TBL, xt::all(), xt::all(), xt::all(),
+                             xt::all(), xt::all(), xt::all(), i) =
+                        detail::calc_METRIC_from_metric(
+                                cell, q_thr, t_msk, b_exp,
+                                n_sit, n_ldt, n_thr, n_mbr, n_msk, n_exp
+                        );
+
+                    i++;
+                }
+
+                return CONT_TBL;
+            }
+
             /// Compute the probability of detection (POD),
             /// also known as 'hit rate'.
             ///
diff --git a/include/evalhyd/detail/probabilist/evaluator.hpp b/include/evalhyd/detail/probabilist/evaluator.hpp
index 976145c..3f94fd6 100644
--- a/include/evalhyd/detail/probabilist/evaluator.hpp
+++ b/include/evalhyd/detail/probabilist/evaluator.hpp
@@ -110,6 +110,7 @@ namespace evalhyd
             xtl::xoptional<xt::xtensor<double, 5>, bool> QS;
             xtl::xoptional<xt::xtensor<double, 4>, bool> CRPS_FROM_QS;
             // > Contingency table-based
+            xtl::xoptional<xt::xtensor<double, 7>, bool> CONT_TBL;
             xtl::xoptional<xt::xtensor<double, 6>, bool> POD;
             xtl::xoptional<xt::xtensor<double, 6>, bool> POFD;
             xtl::xoptional<xt::xtensor<double, 6>, bool> FAR;
@@ -651,6 +652,19 @@ namespace evalhyd
                 return CRPS_FROM_QS.value();
             };
 
+            xt::xtensor<double, 7> get_CONT_TBL()
+            {
+                if (!CONT_TBL.has_value())
+                {
+                    CONT_TBL = metrics::calc_CONT_TBL(
+                            get_ct_a(), get_ct_b(), get_ct_c(), get_ct_d(),
+                            get_q_thr(), t_msk, b_exp,
+                            n_sit, n_ldt, n_thr, n_mbr, n_msk, n_exp
+                    );
+                }
+                return CONT_TBL.value();
+            };
+
             xt::xtensor<double, 6> get_POD()
             {
                 if (!POD.has_value())
diff --git a/include/evalhyd/evalp.hpp b/include/evalhyd/evalp.hpp
index 3b5f831..f7f4340 100644
--- a/include/evalhyd/evalp.hpp
+++ b/include/evalhyd/evalp.hpp
@@ -240,7 +240,7 @@ namespace evalhyd
                 {"BS", "BSS", "BS_CRD", "BS_LBD", "REL_DIAG", "CRPS_FROM_BS",
                  "CRPS_FROM_ECDF",
                  "QS", "CRPS_FROM_QS",
-                 "POD", "POFD", "FAR", "CSI", "ROCSS",
+                 "CONT_TBL", "POD", "POFD", "FAR", "CSI", "ROCSS",
                  "RANK_HIST", "DS", "AS",
                  "CR", "AW", "AWN", "WS",
                  "ES"}
@@ -483,6 +483,12 @@ namespace evalhyd
                         uncertainty::summarise_p(evaluator.get_CRPS_FROM_QS(), summary)
                 );
             }
+            else if ( metric == "CONT_TBL" )
+            {
+                r.emplace_back(
+                        uncertainty::summarise_p(evaluator.get_CONT_TBL(), summary)
+                );
+            }
             else if ( metric == "POD" )
             {
                 r.emplace_back(
diff --git a/tests/expected/evalp/CONT_TBL.csv b/tests/expected/evalp/CONT_TBL.csv
new file mode 100644
index 0000000..3740d4a
--- /dev/null
+++ b/tests/expected/evalp/CONT_TBL.csv
@@ -0,0 +1,208 @@
+0.446945,0.553055,0.,0.
+0.33119,0.66881,0.,0.
+0.273312,0.726688,0.,0.
+nan,nan,nan,nan
+0.385852,0.048232,0.061093,0.504823
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.048232,0.061093,0.504823
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.048232,0.061093,0.504823
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.048232,0.061093,0.504823
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.385852,0.045016,0.061093,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.379421,0.045016,0.067524,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.379421,0.045016,0.067524,0.508039
+0.289389,0.025723,0.041801,0.643087
+0.205788,0.019293,0.067524,0.707395
+nan,nan,nan,nan
+0.379421,0.045016,0.067524,0.508039
+0.286174,0.025723,0.045016,0.643087
+0.205788,0.016077,0.067524,0.710611
+nan,nan,nan,nan
diff --git a/tests/test_probabilist.cpp b/tests/test_probabilist.cpp
index 05061b1..bdcd175 100644
--- a/tests/test_probabilist.cpp
+++ b/tests/test_probabilist.cpp
@@ -33,7 +33,7 @@ std::vector<std::string> all_metrics_p = {
         "BS", "BSS", "BS_CRD", "BS_LBD", "REL_DIAG", "CRPS_FROM_BS",
         "CRPS_FROM_ECDF",
         "QS", "CRPS_FROM_QS",
-        "POD", "POFD", "FAR", "CSI", "ROCSS",
+        "CONT_TBL", "POD", "POFD", "FAR", "CSI", "ROCSS",
         "RANK_HIST", "DS", "AS",
         "CR", "AW", "AWN", "WS",
         "ES"
@@ -192,7 +192,7 @@ TEST(ProbabilistTests, TestContingency)
 
     // compute scores
     xt::xtensor<double, 2> thresholds = {{690, 534, 445, NAN}};
-    std::vector<std::string> metrics = {"POD", "POFD", "FAR", "CSI", "ROCSS"};
+    std::vector<std::string> metrics = {"CONT_TBL", "POD", "POFD", "FAR", "CSI", "ROCSS"};
 
     std::vector<xt::xarray<double>> results =
             evalhyd::evalp(
@@ -208,8 +208,18 @@ TEST(ProbabilistTests, TestContingency)
     // check results
     for (std::size_t m = 0; m < metrics.size(); m++)
     {
+        if (metrics[m] == "CONT_TBL")
+        {
+            // /!\ stacked-up thresholds and cells in CSV file because 7D metric,
+            //     so need to resize array accordingly
+            expected[metrics[m]].resize(
+                {std::size_t {1}, std::size_t {1}, std::size_t {1}, std::size_t {1},
+                 predicted.shape(0) + 1, thresholds.shape(1), std::size_t {4}}
+            );
+        }
+
         EXPECT_TRUE(xt::all(xt::isclose(
-                results[m], expected[metrics[m]], 1e-05, 1e-08, true
+                results[m], expected[metrics[m]], 1e-04, 1e-07, true
         ))) << "Failure for (" << metrics[m] << ")";
     }
 }
-- 
GitLab