diff --git a/changelog.rst b/changelog.rst index b043a7dc45afebdb50205cd62f89174ca1c26d12..c5c7235cb566e0740fc53894857465679ac1b2d6 100644 --- a/changelog.rst +++ b/changelog.rst @@ -12,6 +12,7 @@ Yet to be versioned and released. Only available from *dev* branch until then. * add `"KGENP"` and `"KGENP_D"` as deterministic evaluation metrics since stable sorting is now available in ``xtensor`` (`CPP#5 <https://gitlab.irstea.fr/HYCAR-Hydro/evalhyd/evalhyd-cpp/-/issues/5>`_) +* add `"CONT_TBL"` as probabilistic evaluation metric .. rubric:: Bug fixes diff --git a/include/evalhyd/detail/probabilist/contingency.hpp b/include/evalhyd/detail/probabilist/contingency.hpp index a4685e4aec83615ebad1c612d3fbcdb72b42f97b..0960708455f7a0fa409f41187031fb74c3389f3f 100644 --- a/include/evalhyd/detail/probabilist/contingency.hpp +++ b/include/evalhyd/detail/probabilist/contingency.hpp @@ -297,6 +297,84 @@ namespace evalhyd } } + /// Compute the contingency table (CONT_TBL), i.e. 'hits', + /// 'false alarms', 'misses', 'correct rejections', in this order. + /// + /// \param ct_a + /// Hits. + /// shape: (sites, lead times, levels, thresholds, time) + /// \param ct_b + /// False alarms. + /// shape: (sites, lead times, levels, thresholds, time) + /// \param ct_c + /// Misses. + /// shape: (sites, lead times, levels, thresholds, time) + /// \param ct_d + /// Correct rejections. + /// shape: (sites, lead times, levels, thresholds, time) + /// \param q_thr + /// Streamflow exceedance threshold(s). + /// shape: (sites, thresholds) + /// \param t_msk + /// Temporal subsets of the whole record. + /// shape: (sites, lead times, subsets, time) + /// \param b_exp + /// Boostrap samples. + /// shape: (samples, time slice) + /// \param n_sit + /// Number of sites. + /// \param n_ldt + /// Number of lead times. + /// \param n_thr + /// Number of thresholds. + /// \param n_mbr + /// Number of ensemble members. + /// \param n_msk + /// Number of temporal subsets. + /// \param n_exp + /// Number of bootstrap samples. + /// \return + /// Contingency table. + /// shape: (sites, lead times, subsets, samples, levels, thresholds, cells) + template <class XD2> + inline xt::xtensor<double, 7> calc_CONT_TBL( + const xt::xtensor<double, 5>& ct_a, + const xt::xtensor<double, 5>& ct_b, + const xt::xtensor<double, 5>& ct_c, + const xt::xtensor<double, 5>& ct_d, + const XD2& q_thr, + const xt::xtensor<bool, 4>& t_msk, + const std::vector<xt::xkeep_slice<int>>& b_exp, + std::size_t n_sit, + std::size_t n_ldt, + std::size_t n_thr, + std::size_t n_mbr, + std::size_t n_msk, + std::size_t n_exp + ) + { + // initialise output variable + xt::xtensor<double, 7> CONT_TBL = + xt::zeros<double>({n_sit, n_ldt, n_msk, n_exp, + n_mbr + 1, n_thr, std::size_t {4}}); + + // compute table one cell at a time + std::size_t i = 0; + for (auto cell: {ct_a, ct_b, ct_c, ct_d}) + { + xt::view(CONT_TBL, xt::all(), xt::all(), xt::all(), + xt::all(), xt::all(), xt::all(), i) = + detail::calc_METRIC_from_metric( + cell, q_thr, t_msk, b_exp, + n_sit, n_ldt, n_thr, n_mbr, n_msk, n_exp + ); + + i++; + } + + return CONT_TBL; + } + /// Compute the probability of detection (POD), /// also known as 'hit rate'. /// diff --git a/include/evalhyd/detail/probabilist/evaluator.hpp b/include/evalhyd/detail/probabilist/evaluator.hpp index 976145c58817f529c2786a246665c7eac55c5eb9..3f94fd6d85dcdd94fa5da2224bad5dc914a8397f 100644 --- a/include/evalhyd/detail/probabilist/evaluator.hpp +++ b/include/evalhyd/detail/probabilist/evaluator.hpp @@ -110,6 +110,7 @@ namespace evalhyd xtl::xoptional<xt::xtensor<double, 5>, bool> QS; xtl::xoptional<xt::xtensor<double, 4>, bool> CRPS_FROM_QS; // > Contingency table-based + xtl::xoptional<xt::xtensor<double, 7>, bool> CONT_TBL; xtl::xoptional<xt::xtensor<double, 6>, bool> POD; xtl::xoptional<xt::xtensor<double, 6>, bool> POFD; xtl::xoptional<xt::xtensor<double, 6>, bool> FAR; @@ -651,6 +652,19 @@ namespace evalhyd return CRPS_FROM_QS.value(); }; + xt::xtensor<double, 7> get_CONT_TBL() + { + if (!CONT_TBL.has_value()) + { + CONT_TBL = metrics::calc_CONT_TBL( + get_ct_a(), get_ct_b(), get_ct_c(), get_ct_d(), + get_q_thr(), t_msk, b_exp, + n_sit, n_ldt, n_thr, n_mbr, n_msk, n_exp + ); + } + return CONT_TBL.value(); + }; + xt::xtensor<double, 6> get_POD() { if (!POD.has_value()) diff --git a/include/evalhyd/evalp.hpp b/include/evalhyd/evalp.hpp index 3b5f831d26cb72966f325f9ba5a8468778b680c9..f7f4340f210a2027d49d30d6fcf0a8505c0e4110 100644 --- a/include/evalhyd/evalp.hpp +++ b/include/evalhyd/evalp.hpp @@ -240,7 +240,7 @@ namespace evalhyd {"BS", "BSS", "BS_CRD", "BS_LBD", "REL_DIAG", "CRPS_FROM_BS", "CRPS_FROM_ECDF", "QS", "CRPS_FROM_QS", - "POD", "POFD", "FAR", "CSI", "ROCSS", + "CONT_TBL", "POD", "POFD", "FAR", "CSI", "ROCSS", "RANK_HIST", "DS", "AS", "CR", "AW", "AWN", "WS", "ES"} @@ -483,6 +483,12 @@ namespace evalhyd uncertainty::summarise_p(evaluator.get_CRPS_FROM_QS(), summary) ); } + else if ( metric == "CONT_TBL" ) + { + r.emplace_back( + uncertainty::summarise_p(evaluator.get_CONT_TBL(), summary) + ); + } else if ( metric == "POD" ) { r.emplace_back( diff --git a/tests/expected/evalp/CONT_TBL.csv b/tests/expected/evalp/CONT_TBL.csv new file mode 100644 index 0000000000000000000000000000000000000000..3740d4a3e7accc3423bed038d2c8c022cdf5ba36 --- /dev/null +++ b/tests/expected/evalp/CONT_TBL.csv @@ -0,0 +1,208 @@ +0.446945,0.553055,0.,0. +0.33119,0.66881,0.,0. +0.273312,0.726688,0.,0. +nan,nan,nan,nan +0.385852,0.048232,0.061093,0.504823 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.048232,0.061093,0.504823 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.048232,0.061093,0.504823 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.048232,0.061093,0.504823 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.385852,0.045016,0.061093,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.379421,0.045016,0.067524,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.379421,0.045016,0.067524,0.508039 +0.289389,0.025723,0.041801,0.643087 +0.205788,0.019293,0.067524,0.707395 +nan,nan,nan,nan +0.379421,0.045016,0.067524,0.508039 +0.286174,0.025723,0.045016,0.643087 +0.205788,0.016077,0.067524,0.710611 +nan,nan,nan,nan diff --git a/tests/test_probabilist.cpp b/tests/test_probabilist.cpp index 05061b147f3fd49ea358213818d10c31473031d3..bdcd175333f978d8e58ee2c87a581b2b0916441a 100644 --- a/tests/test_probabilist.cpp +++ b/tests/test_probabilist.cpp @@ -33,7 +33,7 @@ std::vector<std::string> all_metrics_p = { "BS", "BSS", "BS_CRD", "BS_LBD", "REL_DIAG", "CRPS_FROM_BS", "CRPS_FROM_ECDF", "QS", "CRPS_FROM_QS", - "POD", "POFD", "FAR", "CSI", "ROCSS", + "CONT_TBL", "POD", "POFD", "FAR", "CSI", "ROCSS", "RANK_HIST", "DS", "AS", "CR", "AW", "AWN", "WS", "ES" @@ -192,7 +192,7 @@ TEST(ProbabilistTests, TestContingency) // compute scores xt::xtensor<double, 2> thresholds = {{690, 534, 445, NAN}}; - std::vector<std::string> metrics = {"POD", "POFD", "FAR", "CSI", "ROCSS"}; + std::vector<std::string> metrics = {"CONT_TBL", "POD", "POFD", "FAR", "CSI", "ROCSS"}; std::vector<xt::xarray<double>> results = evalhyd::evalp( @@ -208,8 +208,18 @@ TEST(ProbabilistTests, TestContingency) // check results for (std::size_t m = 0; m < metrics.size(); m++) { + if (metrics[m] == "CONT_TBL") + { + // /!\ stacked-up thresholds and cells in CSV file because 7D metric, + // so need to resize array accordingly + expected[metrics[m]].resize( + {std::size_t {1}, std::size_t {1}, std::size_t {1}, std::size_t {1}, + predicted.shape(0) + 1, thresholds.shape(1), std::size_t {4}} + ); + } + EXPECT_TRUE(xt::all(xt::isclose( - results[m], expected[metrics[m]], 1e-05, 1e-08, true + results[m], expected[metrics[m]], 1e-04, 1e-07, true ))) << "Failure for (" << metrics[m] << ")"; } }