diff --git a/include/evalhyd/detail/probabilist/brier.hpp b/include/evalhyd/detail/probabilist/brier.hpp index fec0f092903bd9004366f11e1205b42456532628..338c08bfc8184208c1d698da967c73da9eac60f8 100644 --- a/include/evalhyd/detail/probabilist/brier.hpp +++ b/include/evalhyd/detail/probabilist/brier.hpp @@ -87,7 +87,7 @@ namespace evalhyd return bar_o; } - // Determine forecast probability of threshold(s) exceedance to occur. + // Determine number of forecast members exceeding threshold(s) // // \param q_prd // Streamflow predictions. @@ -95,30 +95,42 @@ namespace evalhyd // \param q_thr // Streamflow exceedance threshold(s). // shape: (thresholds,) - // \param n_mbr - // Number of ensemble members. // \return - // Event probability forecast. + // Number of forecast members exceeding threshold(s). // shape: (thresholds, time) template<class XV2D4, class XV1D2> - inline xt::xtensor<double, 2> calc_y_k( + inline xt::xtensor<double, 2> calc_sum_f_k( const XV2D4& q_prd, - const XV1D2& q_thr, - std::size_t n_mbr + const XV1D2& q_thr ) { // determine if members have exceeded threshold(s) - auto e_frc = - q_prd >= xt::view(q_thr, xt::all(), - xt::newaxis(), xt::newaxis()); + auto f_k = q_prd >= + xt::view(q_thr, xt::all(), xt::newaxis(), xt::newaxis()); // calculate how many members have exceeded threshold(s) - auto n_frc = xt::sum(e_frc, 1); + return xt::sum(f_k, 1); + } + // Determine forecast probability of threshold(s) exceedance to occur. + // + // \param sum_f_k + // Number of forecast members exceeding threshold(s). + // shape: (thresholds,) + // \param n_mbr + // Number of ensemble members. + // \return + // Event probability forecast. + // shape: (thresholds, time) + inline xt::xtensor<double, 2> calc_y_k( + const xt::xtensor<double, 2>& sum_f_k, + std::size_t n_mbr + ) + { // determine probability of threshold(s) exceedance // /!\ probability calculation dividing by n (the number of // members), not n+1 (the number of ranks) like in other metrics - return xt::cast<double>(n_frc) / n_mbr; + return xt::cast<double>(sum_f_k) / n_mbr; } } diff --git a/include/evalhyd/detail/probabilist/contingency.hpp b/include/evalhyd/detail/probabilist/contingency.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6ef68ff7f9fd0e2d40415e604c27162dd7073347 --- /dev/null +++ b/include/evalhyd/detail/probabilist/contingency.hpp @@ -0,0 +1,492 @@ +#ifndef EVALHYD_PROBABILIST_CONTINGENCY_HPP +#define EVALHYD_PROBABILIST_CONTINGENCY_HPP + +#include <xtensor/xtensor.hpp> +#include <xtensor/xview.hpp> +#include <xtensor/xmasked_view.hpp> +#include <xtensor/xmath.hpp> + +// NOTE ------------------------------------------------------------------------ +// All equations in metrics below are following notations from +// "Wilks, D. S. (2011). Statistical methods in the atmospheric sciences. +// Amsterdam; Boston: Elsevier Academic Press. ISBN: 9780123850225". +// In particular, pp. 302-303, 332-333. +// ----------------------------------------------------------------------------- + + +namespace evalhyd +{ + namespace probabilist + { + namespace elements + { + // Contingency table: + // + // OBS + // Y N + // +-----+-----+ a: hits + // Y | a | b | b: false alarms + // PRD +-----+-----+ c: misses + // N | c | d | d: correct rejections + // +-----+-----+ + // + + // Determine alerts based on forecast. + // + // \param sum_f_k + // Observed event outcome. + // shape: (thresholds, time) + // \param n_mbr + // Number of ensemble members. + // \return + // Alerts based on forecast. + // shape: (levels, thresholds, time) + inline xt::xtensor<double, 2> calc_a_k( + const xt::xtensor<double, 2>& sum_f_k, + std::size_t n_mbr + ) + { + // compute range of alert levels $alert_lvl$ + // (i.e. number of members that must forecast event + // for alert to be raised) + auto alert_lvl = xt::arange<double>(double(n_mbr + 1)); + + // determine whether forecast yield alert + return sum_f_k >= + xt::view(alert_lvl, xt::all(), xt::newaxis(), xt::newaxis()); + } + + // Determine hits ('a' in contingency table). + // + // \param o_k + // Observed event outcome. + // shape: (thresholds, time) + // \param a_k + // Alerts based on forecast. + // shape: (levels, thresholds, time) + // \return + // Hits. + // shape: (levels, thresholds, time) + inline xt::xtensor<double, 3> calc_ct_a( + const xt::xtensor<double, 2>& o_k, + const xt::xtensor<double, 2>& a_k + ) + { + return xt::equal(o_k, 1.) && xt::equal(a_k, 1.); + } + + // Determine false alarms ('b' in contingency table). + // + // \param o_k + // Observed event outcome. + // shape: (thresholds, time) + // \param y_k + // Event probability forecast. + // shape: (thresholds, time) + // \return + // False alarms. + // shape: (levels, thresholds, time) + inline xt::xtensor<double, 3> calc_ct_b( + const xt::xtensor<double, 2>& o_k, + const xt::xtensor<double, 2>& a_k + ) + { + return xt::equal(o_k, 0.) && xt::equal(a_k, 1.); + } + + // Determine misses ('c' in contingency table). + // + // \param o_k + // Observed event outcome. + // shape: (thresholds, time) + // \param y_k + // Event probability forecast. + // shape: (thresholds, time) + // \return + // Misses. + // shape: (levels, thresholds, time) + inline xt::xtensor<double, 3> calc_ct_c( + const xt::xtensor<double, 2>& o_k, + const xt::xtensor<double, 2>& a_k + ) + { + return xt::equal(o_k, 1.) && xt::equal(a_k, 0.); + } + + // Determine correct rejections ('d' in contingency table). + // + // \param o_k + // Observed event outcome. + // shape: (thresholds, time) + // \param y_k + // Event probability forecast. + // shape: (thresholds, time) + // \return + // Correct rejections. + // shape: (levels, thresholds, time) + inline xt::xtensor<double, 3> calc_ct_d( + const xt::xtensor<double, 2>& o_k, + const xt::xtensor<double, 2>& a_k + ) + { + return xt::equal(o_k, 0.) && xt::equal(a_k, 0.); + } + } + + namespace intermediate + { + // Compute the probability of detection for each time step. + // + // \param ct_a + // Hits. + // shape: (levels, thresholds, time) + // \param ct_c + // Misses. + // shape: (levels, thresholds, time) + // \return + // Probability of detection for each time step. + // shape: (levels, thresholds, time) + inline xt::xtensor<double, 3> calc_pod( + const xt::xtensor<double, 3>& ct_a, + const xt::xtensor<double, 3>& ct_c + ) + { + return ct_a / (ct_a + ct_c); + } + + // Compute the probability of false detection for each time step. + // + // \param ct_b + // False alarms. + // shape: (levels, thresholds, time) + // \param ct_d + // Correct rejections. + // shape: (levels, thresholds, time) + // \return + // Probability of false detection for each time step. + // shape: (levels, thresholds, time) + inline xt::xtensor<double, 3> calc_pofd( + const xt::xtensor<double, 3>& ct_b, + const xt::xtensor<double, 3>& ct_d + ) + { + return ct_b / (ct_b + ct_d); + } + + // Compute the false alarm ratio for each time step. + // + // \param ct_a + // Hits. + // shape: (levels, thresholds, time) + // \param ct_b + // False alarms. + // shape: (levels, thresholds, time) + // \return + // False alarm ratio for each time step. + // shape: (levels, thresholds, time) + inline xt::xtensor<double, 3> calc_far( + const xt::xtensor<double, 3>& ct_a, + const xt::xtensor<double, 3>& ct_b + ) + { + return ct_b / (ct_a + ct_b); + } + + // Compute the critical success index for each time step. + // + // \param ct_a + // Hits. + // shape: (levels, thresholds, time) + // \param ct_b + // False alarms. + // shape: (levels, thresholds, time) + // \param ct_c + // Misses. + // shape: (levels, thresholds, time) + // \return + // Critical success index for each time step. + // shape: (levels, thresholds, time) + inline xt::xtensor<double, 3> calc_csi( + const xt::xtensor<double, 3>& ct_a, + const xt::xtensor<double, 3>& ct_b, + const xt::xtensor<double, 3>& ct_c + ) + { + return ct_b / (ct_a + ct_b + ct_c); + } + } + + namespace metrics + { + // ----------------------------------------------------------------- + // Accuracy + // ----------------------------------------------------------------- + + namespace detail + { + template <class XV1D2> + inline xt::xtensor<double, 4> calc_METRIC_from_metric( + const xt::xtensor<double, 3>& metric, + const XV1D2& q_thr, + const xt::xtensor<bool, 2>& t_msk, + const std::vector<xt::xkeep_slice<int>>& b_exp, + std::size_t n_thr, + std::size_t n_mbr, + std::size_t n_msk, + std::size_t n_exp + ) + { + // initialise output variable + // shape: (subsets, thresholds) + xt::xtensor<double, 4> METRIC = + xt::zeros<double>({n_msk, n_exp, n_mbr, n_thr}); + + // compute variable one mask at a time to minimise memory imprint + for (std::size_t m = 0; m < n_msk; m++) + { + // apply the mask + // (using NaN workaround until reducers work on masked_view) + auto metric_masked = xt::where(xt::row(t_msk, m), metric, NAN); + + // compute variable one sample at a time + for (std::size_t e = 0; e < n_exp; e++) + { + // apply the bootstrap sampling + auto metric_masked_sampled = + xt::view(metric_masked, xt::all(), xt::all(), b_exp[e]); + + // calculate the mean over the time steps + xt::view(METRIC, m, e, xt::all(), xt::all()) = + xt::nanmean(metric_masked_sampled, -1); + } + } + + // assign NaN where thresholds were not provided (i.e. set as NaN) + xt::masked_view( + METRIC, + xt::isnan(xt::view(q_thr, + xt::newaxis(), xt::newaxis(), + xt::newaxis(), xt::all())) + ) = NAN; + + return METRIC; + } + } + + // Compute the probability of detection (POD), + // also known as 'hit rate'. + // + // \param pod + // Probability of detection for each time step. + // shape: (levels, thresholds, time) + // \param q_thr + // Streamflow exceedance threshold(s). + // shape: (thresholds,) + // \param t_msk + // Temporal subsets of the whole record. + // shape: (subsets, time) + // \param b_exp + // Boostrap samples. + // shape: (samples, time slice) + // \param n_thr + // Number of thresholds. + // \param n_mbr + // Number of ensemble members. + // \param n_msk + // Number of temporal subsets. + // \param n_exp + // Number of bootstrap samples. + // \return + // Probabilities of detection. + // shape: (subsets, samples, levels, thresholds) + template <class XV1D2> + inline xt::xtensor<double, 4> calc_POD( + const xt::xtensor<double, 3>& pod, + const XV1D2& q_thr, + const xt::xtensor<bool, 2>& t_msk, + const std::vector<xt::xkeep_slice<int>>& b_exp, + std::size_t n_thr, + std::size_t n_mbr, + std::size_t n_msk, + std::size_t n_exp + ) + { + return detail::calc_METRIC_from_metric( + pod, q_thr, t_msk, b_exp, n_thr, n_mbr, n_msk, n_exp + ); + } + + // Compute the probability of detection (POFD), + // also known as 'false alarm rate'. + // + // \param pofd + // Probability of false detection for each time step. + // shape: (levels, thresholds, time) + // \param q_thr + // Streamflow exceedance threshold(s). + // shape: (thresholds,) + // \param t_msk + // Temporal subsets of the whole record. + // shape: (subsets, time) + // \param b_exp + // Boostrap samples. + // shape: (samples, time slice) + // \param n_thr + // Number of thresholds. + // \param n_mbr + // Number of ensemble members. + // \param n_msk + // Number of temporal subsets. + // \param n_exp + // Number of bootstrap samples. + // \return + // Probabilities of false detection. + // shape: (subsets, samples, levels, thresholds) + template <class XV1D2> + inline xt::xtensor<double, 4> calc_POFD( + const xt::xtensor<double, 3>& pofd, + const XV1D2& q_thr, + const xt::xtensor<bool, 2>& t_msk, + const std::vector<xt::xkeep_slice<int>>& b_exp, + std::size_t n_thr, + std::size_t n_mbr, + std::size_t n_msk, + std::size_t n_exp + ) + { + return detail::calc_METRIC_from_metric( + pofd, q_thr, t_msk, b_exp, n_thr, n_mbr, n_msk, n_exp + ); + } + + // Compute the false alarm ratio (FAR). + // + // \param far + // False alarm ratio for each time step. + // shape: (levels, thresholds, time) + // \param q_thr + // Streamflow exceedance threshold(s). + // shape: (thresholds,) + // \param t_msk + // Temporal subsets of the whole record. + // shape: (subsets, time) + // \param b_exp + // Boostrap samples. + // shape: (samples, time slice) + // \param n_thr + // Number of thresholds. + // \param n_mbr + // Number of ensemble members. + // \param n_msk + // Number of temporal subsets. + // \param n_exp + // Number of bootstrap samples. + // \return + // False alarm ratios. + // shape: (subsets, samples, levels, thresholds) + template <class XV1D2> + inline xt::xtensor<double, 4> calc_FAR( + const xt::xtensor<double, 3>& far, + const XV1D2& q_thr, + const xt::xtensor<bool, 2>& t_msk, + const std::vector<xt::xkeep_slice<int>>& b_exp, + std::size_t n_thr, + std::size_t n_mbr, + std::size_t n_msk, + std::size_t n_exp + ) + { + return detail::calc_METRIC_from_metric( + far, q_thr, t_msk, b_exp, n_thr, n_mbr, n_msk, n_exp + ); + } + + // Compute the critical success index (CSI). + // + // \param csi + // Critical success index for each time step. + // shape: (levels, thresholds, time) + // \param q_thr + // Streamflow exceedance threshold(s). + // shape: (thresholds,) + // \param t_msk + // Temporal subsets of the whole record. + // shape: (subsets, time) + // \param b_exp + // Boostrap samples. + // shape: (samples, time slice) + // \param n_thr + // Number of thresholds. + // \param n_mbr + // Number of ensemble members. + // \param n_msk + // Number of temporal subsets. + // \param n_exp + // Number of bootstrap samples. + // \return + // Critical success indices. + // shape: (subsets, samples, levels, thresholds) + template <class XV1D2> + inline xt::xtensor<double, 4> calc_CSI( + const xt::xtensor<double, 3>& csi, + const XV1D2& q_thr, + const xt::xtensor<bool, 2>& t_msk, + const std::vector<xt::xkeep_slice<int>>& b_exp, + std::size_t n_thr, + std::size_t n_mbr, + std::size_t n_msk, + std::size_t n_exp + ) + { + return detail::calc_METRIC_from_metric( + csi, q_thr, t_msk, b_exp, n_thr, n_mbr, n_msk, n_exp + ); + } + + // Compute the relative operating characteristic skill score (ROCSS). + // + // \param pod + // Probabilities of detection. + // shape: (subsets, samples, levels, thresholds) + // \param pofd + // Probabilities of false detection. + // shape: (subsets, samples, levels, thresholds) + // \param q_thr + // Streamflow exceedance threshold(s). + // shape: (thresholds,) + // \param t_msk + // Temporal subsets of the whole record. + // shape: (subsets, time) + // \param b_exp + // Boostrap samples. + // shape: (samples, time slice) + // \param n_thr + // Number of thresholds. + // \param n_mbr + // Number of ensemble members. + // \param n_msk + // Number of temporal subsets. + // \param n_exp + // Number of bootstrap samples. + // \return + // Critical success indices. + // shape: (subsets, samples, levels, thresholds) + inline xt::xtensor<double, 4> calc_ROCSS( + const xt::xtensor<double, 4>& POD, + const xt::xtensor<double, 4>& POFD + ) + { + // compute the area under the ROC curve + // xt::trapz(y, x, axis=2) + auto A = xt::trapz(POD, POFD, 2); + + // compute the ROC skill score + // $SS_{ROC} = \frac{A - A_{random}}{A_{perfect} - A_{random}}$ + // $SS_{ROC} = \frac{A - 0.5}{1. - 0.5} = 2A - 1$ + return (2. * A) - 1.; + } + } + } +} + +#endif //EVALHYD_PROBABILIST_CONTINGENCY_HPP \ No newline at end of file diff --git a/include/evalhyd/detail/probabilist/evaluator.hpp b/include/evalhyd/detail/probabilist/evaluator.hpp index e58fb3e77e89ed84766482f7bbde9c80215c2a6c..ba047af80b48919706189b1e36f7c85c4c39b2cc 100644 --- a/include/evalhyd/detail/probabilist/evaluator.hpp +++ b/include/evalhyd/detail/probabilist/evaluator.hpp @@ -10,6 +10,7 @@ #include "brier.hpp" #include "quantiles.hpp" +#include "contingency.hpp" namespace evalhyd @@ -65,22 +66,41 @@ namespace evalhyd // members for computational elements xtl::xoptional<xt::xtensor<double, 2>, bool> o_k; xtl::xoptional<xt::xtensor<double, 3>, bool> bar_o; + xtl::xoptional<xt::xtensor<double, 2>, bool> sum_f_k; xtl::xoptional<xt::xtensor<double, 2>, bool> y_k; xtl::xoptional<xt::xtensor<double, 2>, bool> q_qnt; + xtl::xoptional<xt::xtensor<double, 3>, bool> a_k; + xtl::xoptional<xt::xtensor<double, 3>, bool> ct_a; + xtl::xoptional<xt::xtensor<double, 3>, bool> ct_b; + xtl::xoptional<xt::xtensor<double, 3>, bool> ct_c; + xtl::xoptional<xt::xtensor<double, 3>, bool> ct_d; // members for intermediate evaluation metrics // (i.e. before the reduction along the temporal axis) xtl::xoptional<xt::xtensor<double, 2>, bool> bs; xtl::xoptional<xt::xtensor<double, 2>, bool> qs; xtl::xoptional<xt::xtensor<double, 2>, bool> crps; + xtl::xoptional<xt::xtensor<double, 3>, bool> pod; + xtl::xoptional<xt::xtensor<double, 3>, bool> pofd; + xtl::xoptional<xt::xtensor<double, 3>, bool> far; + xtl::xoptional<xt::xtensor<double, 3>, bool> csi; // members for evaluation metrics + // > Brier-based xtl::xoptional<xt::xtensor<double, 3>, bool> BS; xtl::xoptional<xt::xtensor<double, 4>, bool> BS_CRD; xtl::xoptional<xt::xtensor<double, 4>, bool> BS_LBD; xtl::xoptional<xt::xtensor<double, 3>, bool> BSS; + // > Quantiles-based xtl::xoptional<xt::xtensor<double, 3>, bool> QS; xtl::xoptional<xt::xtensor<double, 2>, bool> CRPS; + // > Contingency table-based + xtl::xoptional<xt::xtensor<double, 4>, bool> POD; + xtl::xoptional<xt::xtensor<double, 4>, bool> POFD; + xtl::xoptional<xt::xtensor<double, 4>, bool> FAR; + xtl::xoptional<xt::xtensor<double, 4>, bool> CSI; + xtl::xoptional<xt::xtensor<double, 3>, bool> ROCSS; + // methods to compute elements xt::xtensor<double, 2> get_o_k() @@ -105,13 +125,24 @@ namespace evalhyd return bar_o.value(); }; + xt::xtensor<double, 2> get_sum_f_k() + { + if (!sum_f_k.has_value()) + { + sum_f_k = elements::calc_sum_f_k<view2d_xtensor4d_double_type, + view1d_xtensor2d_double_type>( + q_prd, q_thr + ); + } + return sum_f_k.value(); + }; + xt::xtensor<double, 2> get_y_k() { if (!y_k.has_value()) { - y_k = elements::calc_y_k<view2d_xtensor4d_double_type, - view1d_xtensor2d_double_type>( - q_prd, q_thr, n_mbr + y_k = elements::calc_y_k( + get_sum_f_k(), n_mbr ); } return y_k.value(); @@ -128,6 +159,61 @@ namespace evalhyd return q_qnt.value(); }; + xt::xtensor<double, 3> get_a_k() + { + if (!a_k.has_value()) + { + a_k = elements::calc_a_k( + get_sum_f_k(), n_mbr + ); + } + return a_k.value(); + }; + + xt::xtensor<double, 2> get_ct_a() + { + if (!ct_a.has_value()) + { + ct_a = elements::calc_ct_a( + get_o_k(), get_a_k() + ); + } + return ct_a.value(); + }; + + xt::xtensor<double, 2> get_ct_b() + { + if (!ct_b.has_value()) + { + ct_b = elements::calc_ct_b( + get_o_k(), get_a_k() + ); + } + return ct_b.value(); + }; + + xt::xtensor<double, 2> get_ct_c() + { + if (!ct_c.has_value()) + { + ct_c = elements::calc_ct_c( + get_o_k(), get_a_k() + ); + } + return ct_c.value(); + }; + + xt::xtensor<double, 2> get_ct_d() + { + if (!ct_d.has_value()) + { + ct_d = elements::calc_ct_d( + get_o_k(), get_a_k() + ); + } + return ct_d.value(); + }; + // methods to compute intermediate metrics xt::xtensor<double, 2> get_bs() { @@ -162,6 +248,50 @@ namespace evalhyd return crps.value(); }; + xt::xtensor<double, 4> get_pod() + { + if (!pod.has_value()) + { + pod = intermediate::calc_pod( + get_ct_a(), get_ct_c() + ); + } + return pod.value(); + }; + + xt::xtensor<double, 4> get_pofd() + { + if (!pofd.has_value()) + { + pofd = intermediate::calc_pofd( + get_ct_b(), get_ct_d() + ); + } + return pofd.value(); + }; + + xt::xtensor<double, 4> get_far() + { + if (!far.has_value()) + { + far = intermediate::calc_far( + get_ct_a(), get_ct_b() + ); + } + return far.value(); + }; + + xt::xtensor<double, 4> get_csi() + { + if (!csi.has_value()) + { + csi = intermediate::calc_csi( + get_ct_a(), get_ct_b(), get_ct_c() + ); + } + return csi.value(); + }; + public: // constructor method Evaluator(const view1d_xtensor2d_double_type& obs, @@ -271,6 +401,65 @@ namespace evalhyd } return CRPS.value(); }; + + xt::xtensor<double, 4> get_POD() + { + if (!POD.has_value()) + { + POD = metrics::calc_POD( + get_pod(), q_thr, t_msk, b_exp, + n_thr, n_mbr, n_msk, n_exp + ); + } + return POD.value(); + }; + + xt::xtensor<double, 4> get_POFD() + { + if (!POFD.has_value()) + { + POFD = metrics::calc_POFD( + get_pofd(), q_thr, t_msk, b_exp, + n_thr, n_mbr, n_msk, n_exp + ); + } + return POFD.value(); + }; + + xt::xtensor<double, 4> get_FAR() + { + if (!FAR.has_value()) + { + FAR = metrics::calc_FAR( + get_far(), q_thr, t_msk, b_exp, + n_thr, n_mbr, n_msk, n_exp + ); + } + return FAR.value(); + }; + + xt::xtensor<double, 4> get_CSI() + { + if (!CSI.has_value()) + { + CSI = metrics::calc_CSI( + get_csi(), q_thr, t_msk, b_exp, + n_thr, n_mbr, n_msk, n_exp + ); + } + return CSI.value(); + }; + + xt::xtensor<double, 3> get_ROCSS() + { + if (!ROCSS.has_value()) + { + ROCSS = metrics::calc_ROCSS( + get_POD(), get_POFD() + ); + } + return ROCSS.value(); + }; }; } } diff --git a/include/evalhyd/evalp.hpp b/include/evalhyd/evalp.hpp index 1a9e14bb7ac4e7c7364638883336e06deef48bb7..900c272f06474dc0aba047177dd1145871191872 100644 --- a/include/evalhyd/evalp.hpp +++ b/include/evalhyd/evalp.hpp @@ -180,7 +180,9 @@ namespace evalhyd // check that the metrics to be computed are valid utils::check_metrics( metrics, - {"BS", "BSS", "BS_CRD", "BS_LBD", "QS", "CRPS"} + {"BS", "BSS", "BS_CRD", "BS_LBD", + "QS", "CRPS", + "POD", "POFD", "FAR", "CSI", "ROCSS"} ); // check that optional parameters are given as arguments @@ -282,6 +284,11 @@ namespace evalhyd dim["BS_LBD"] = {n_sit, n_ltm, n_msk, n_exp, n_thr, 3}; dim["QS"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr}; dim["CRPS"] = {n_sit, n_ltm, n_msk, n_exp}; + dim["POD"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr, n_thr}; + dim["POFD"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr, n_thr}; + dim["FAR"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr, n_thr}; + dim["CSI"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr, n_thr}; + dim["ROCSS"] = {n_sit, n_ltm, n_msk, n_exp, n_thr}; // generate masks from conditions if provided auto gen_msk = [&]() @@ -407,6 +414,36 @@ namespace evalhyd xt::view(r[m], s, l, xt::all(), xt::all()) = uncertainty::summarise(evaluator.get_CRPS(), summary); } + else if ( metric == "POD" ) + { + // (sites, lead times, subsets, samples, levels, thresholds) + xt::view(r[m], s, l, xt::all(), xt::all()) = + uncertainty::summarise(evaluator.get_POD(), summary); + } + else if ( metric == "POFD" ) + { + // (sites, lead times, subsets, samples, levels, thresholds) + xt::view(r[m], s, l, xt::all(), xt::all()) = + uncertainty::summarise(evaluator.get_POFD(), summary); + } + else if ( metric == "FAR" ) + { + // (sites, lead times, subsets, samples, levels, thresholds) + xt::view(r[m], s, l, xt::all(), xt::all()) = + uncertainty::summarise(evaluator.get_FAR(), summary); + } + else if ( metric == "CSI" ) + { + // (sites, lead times, subsets, samples, levels, thresholds) + xt::view(r[m], s, l, xt::all(), xt::all()) = + uncertainty::summarise(evaluator.get_CSI(), summary); + } + else if ( metric == "ROCSS" ) + { + // (sites, lead times, subsets, samples, thresholds) + xt::view(r[m], s, l, xt::all(), xt::all()) = + uncertainty::summarise(evaluator.get_ROCSS(), summary); + } } } }