From ad3bdc97af2be084456b318c0da7534626b968a3 Mon Sep 17 00:00:00 2001 From: Thibault Hallouin <thibault.hallouin@inrae.fr> Date: Thu, 12 Jan 2023 11:02:30 +0100 Subject: [PATCH] fix dimension mismatches in contingency table-based metrics --- .../evalhyd/detail/probabilist/contingency.hpp | 15 +++++++-------- include/evalhyd/detail/probabilist/evaluator.hpp | 16 ++++++++-------- include/evalhyd/evalp.hpp | 8 ++++---- 3 files changed, 19 insertions(+), 20 deletions(-) diff --git a/include/evalhyd/detail/probabilist/contingency.hpp b/include/evalhyd/detail/probabilist/contingency.hpp index 7934602..55c17bc 100644 --- a/include/evalhyd/detail/probabilist/contingency.hpp +++ b/include/evalhyd/detail/probabilist/contingency.hpp @@ -45,7 +45,7 @@ namespace evalhyd // \return // Alerts based on forecast. // shape: (levels, thresholds, time) - inline xt::xtensor<double, 2> calc_a_k( + inline xt::xtensor<double, 3> calc_a_k( const xt::xtensor<double, 2>& sum_f_k, std::size_t n_mbr ) @@ -73,7 +73,7 @@ namespace evalhyd // shape: (levels, thresholds, time) inline xt::xtensor<double, 3> calc_ct_a( const xt::xtensor<double, 2>& o_k, - const xt::xtensor<double, 2>& a_k + const xt::xtensor<double, 3>& a_k ) { return xt::equal(o_k, 1.) && xt::equal(a_k, 1.); @@ -92,7 +92,7 @@ namespace evalhyd // shape: (levels, thresholds, time) inline xt::xtensor<double, 3> calc_ct_b( const xt::xtensor<double, 2>& o_k, - const xt::xtensor<double, 2>& a_k + const xt::xtensor<double, 3>& a_k ) { return xt::equal(o_k, 0.) && xt::equal(a_k, 1.); @@ -111,7 +111,7 @@ namespace evalhyd // shape: (levels, thresholds, time) inline xt::xtensor<double, 3> calc_ct_c( const xt::xtensor<double, 2>& o_k, - const xt::xtensor<double, 2>& a_k + const xt::xtensor<double, 3>& a_k ) { return xt::equal(o_k, 1.) && xt::equal(a_k, 0.); @@ -130,7 +130,7 @@ namespace evalhyd // shape: (levels, thresholds, time) inline xt::xtensor<double, 3> calc_ct_d( const xt::xtensor<double, 2>& o_k, - const xt::xtensor<double, 2>& a_k + const xt::xtensor<double, 3>& a_k ) { return xt::equal(o_k, 0.) && xt::equal(a_k, 0.); @@ -241,9 +241,8 @@ namespace evalhyd ) { // initialise output variable - // shape: (subsets, thresholds) xt::xtensor<double, 4> METRIC = - xt::zeros<double>({n_msk, n_exp, n_mbr, n_thr}); + xt::zeros<double>({n_msk, n_exp, n_mbr + 1, n_thr}); // compute variable one mask at a time to minimise memory imprint for (std::size_t m = 0; m < n_msk; m++) @@ -475,7 +474,7 @@ namespace evalhyd // \return // Critical success indices. // shape: (subsets, samples, levels, thresholds) - inline xt::xtensor<double, 4> calc_ROCSS( + inline xt::xtensor<double, 3> calc_ROCSS( const xt::xtensor<double, 4>& POD, const xt::xtensor<double, 4>& POFD ) diff --git a/include/evalhyd/detail/probabilist/evaluator.hpp b/include/evalhyd/detail/probabilist/evaluator.hpp index e858093..d4d4c15 100644 --- a/include/evalhyd/detail/probabilist/evaluator.hpp +++ b/include/evalhyd/detail/probabilist/evaluator.hpp @@ -173,7 +173,7 @@ namespace evalhyd return a_k.value(); }; - xt::xtensor<double, 2> get_ct_a() + xt::xtensor<double, 3> get_ct_a() { if (!ct_a.has_value()) { @@ -184,7 +184,7 @@ namespace evalhyd return ct_a.value(); }; - xt::xtensor<double, 2> get_ct_b() + xt::xtensor<double, 3> get_ct_b() { if (!ct_b.has_value()) { @@ -195,7 +195,7 @@ namespace evalhyd return ct_b.value(); }; - xt::xtensor<double, 2> get_ct_c() + xt::xtensor<double, 3> get_ct_c() { if (!ct_c.has_value()) { @@ -206,7 +206,7 @@ namespace evalhyd return ct_c.value(); }; - xt::xtensor<double, 2> get_ct_d() + xt::xtensor<double, 3> get_ct_d() { if (!ct_d.has_value()) { @@ -251,7 +251,7 @@ namespace evalhyd return crps.value(); }; - xt::xtensor<double, 4> get_pod() + xt::xtensor<double, 3> get_pod() { if (!pod.has_value()) { @@ -262,7 +262,7 @@ namespace evalhyd return pod.value(); }; - xt::xtensor<double, 4> get_pofd() + xt::xtensor<double, 3> get_pofd() { if (!pofd.has_value()) { @@ -273,7 +273,7 @@ namespace evalhyd return pofd.value(); }; - xt::xtensor<double, 4> get_far() + xt::xtensor<double, 3> get_far() { if (!far.has_value()) { @@ -284,7 +284,7 @@ namespace evalhyd return far.value(); }; - xt::xtensor<double, 4> get_csi() + xt::xtensor<double, 3> get_csi() { if (!csi.has_value()) { diff --git a/include/evalhyd/evalp.hpp b/include/evalhyd/evalp.hpp index 97a2264..e56e5b1 100644 --- a/include/evalhyd/evalp.hpp +++ b/include/evalhyd/evalp.hpp @@ -288,10 +288,10 @@ namespace evalhyd dim["BS_LBD"] = {n_sit, n_ltm, n_msk, n_exp, n_thr, 3}; dim["QS"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr}; dim["CRPS"] = {n_sit, n_ltm, n_msk, n_exp}; - dim["POD"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr, n_thr}; - dim["POFD"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr, n_thr}; - dim["FAR"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr, n_thr}; - dim["CSI"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr, n_thr}; + dim["POD"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr + 1, n_thr}; + dim["POFD"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr + 1, n_thr}; + dim["FAR"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr + 1, n_thr}; + dim["CSI"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr + 1, n_thr}; dim["ROCSS"] = {n_sit, n_ltm, n_msk, n_exp, n_thr}; // generate masks from conditions if provided -- GitLab