diff --git a/include/evalhyd/detail/probabilist/brier.hpp b/include/evalhyd/detail/probabilist/brier.hpp index 2beef726ad61be34c3e54cc2eda8eeece2cd490f..febdd626cba84dc011b0c731e46c7a816ae6cf1a 100644 --- a/include/evalhyd/detail/probabilist/brier.hpp +++ b/include/evalhyd/detail/probabilist/brier.hpp @@ -38,10 +38,10 @@ namespace evalhyd inline xt::xtensor<double, 2> calc_o_k( const XV1D2& q_obs, const XV1D2& q_thr, - bool high_flow_events + bool is_high_flow_event ) { - if (high_flow_events) + if (is_high_flow_event) { // observations above threshold(s) return q_obs >= xt::view(q_thr, xt::all(), xt::newaxis()); @@ -123,10 +123,10 @@ namespace evalhyd inline xt::xtensor<double, 2> calc_sum_f_k( const XV2D4& q_prd, const XV1D2& q_thr, - bool high_flow_events + bool is_high_flow_event ) { - if (high_flow_events) + if (is_high_flow_event) { // determine if members are above threshold(s) auto f_k = q_prd >= diff --git a/include/evalhyd/detail/probabilist/evaluator.hpp b/include/evalhyd/detail/probabilist/evaluator.hpp index 90c65d037838a86f9fda98c842dbed90c88155c4..2c0b3a421c4a575008cd420aa5d5e66c1a686607 100644 --- a/include/evalhyd/detail/probabilist/evaluator.hpp +++ b/include/evalhyd/detail/probabilist/evaluator.hpp @@ -56,8 +56,9 @@ namespace evalhyd // members for input data const view1d_xtensor2d_double_type& q_obs; const view2d_xtensor4d_double_type& q_prd; - const view1d_xtensor2d_double_type& q_thr; - const bool high_flow_events; + // members for optional input data + const view1d_xtensor2d_double_type& _q_thr; + xtl::xoptional<const std::string, bool> _events; xt::xtensor<bool, 2> t_msk; const std::vector<xt::xkeep_slice<int>>& b_exp; @@ -106,6 +107,48 @@ namespace evalhyd xtl::xoptional<xt::xtensor<double, 4>, bool> CSI; xtl::xoptional<xt::xtensor<double, 3>, bool> ROCSS; + // methods to get optional parameters + auto get_q_thr() + { + if (_q_thr.size() < 1) + { + throw std::runtime_error( + "threshold-based metric requested, " + "but *q_thr* not provided" + ); + } + else{ + return _q_thr; + } + } + + bool is_high_flow_event() + { + if (_events.has_value()) + { + if (_events.value() == "high") + { + return true; + } + else if (_events.value() == "low") + { + return false; + } + else + { + throw std::runtime_error( + "invalid value for *events*: " + _events.value() + ); + } + } + else + { + throw std::runtime_error( + "threshold-based metric requested, " + "but *events* not provided" + ); + } + } // methods to compute elements xt::xtensor<double, 2> get_o_k() @@ -113,7 +156,7 @@ namespace evalhyd if (!o_k.has_value()) { o_k = elements::calc_o_k( - q_obs, q_thr, high_flow_events + q_obs, get_q_thr(), is_high_flow_event() ); } return o_k.value(); @@ -135,7 +178,7 @@ namespace evalhyd if (!sum_f_k.has_value()) { sum_f_k = elements::calc_sum_f_k( - q_prd, q_thr, high_flow_events + q_prd, get_q_thr(), is_high_flow_event() ); } return sum_f_k.value(); @@ -301,12 +344,11 @@ namespace evalhyd Evaluator(const view1d_xtensor2d_double_type& obs, const view2d_xtensor4d_double_type& prd, const view1d_xtensor2d_double_type& thr, - const bool high_flow_events, + xtl::xoptional<const std::string, bool> events, const view2d_xtensor4d_bool_type& msk, const std::vector<xt::xkeep_slice<int>>& exp) : - q_obs{obs}, q_prd{prd}, q_thr{thr}, - high_flow_events{high_flow_events}, - t_msk(msk), b_exp(exp) + q_obs{obs}, q_prd{prd}, + _q_thr{thr}, _events{events}, t_msk(msk), b_exp(exp) { // initialise a mask if none provided // (corresponding to no temporal subset) @@ -319,7 +361,7 @@ namespace evalhyd n = q_obs.size(); n_msk = t_msk.shape(0); n_mbr = q_prd.shape(0); - n_thr = q_thr.size(); + n_thr = _q_thr.size(); n_exp = b_exp.size(); // drop time steps where observations and/or predictions are NaN @@ -345,7 +387,8 @@ namespace evalhyd if (!BS.has_value()) { BS = metrics::calc_BS( - get_bs(), q_thr, t_msk, b_exp, n_thr, n_msk, n_exp + get_bs(), get_q_thr(), t_msk, b_exp, + n_thr, n_msk, n_exp ); } return BS.value(); @@ -356,8 +399,8 @@ namespace evalhyd if (!BS_CRD.has_value()) { BS_CRD = metrics::calc_BS_CRD( - q_thr, get_o_k(), get_y_k(), get_bar_o(), t_msk, - b_exp, n_thr, n_mbr, n_msk, n_exp + get_q_thr(), get_o_k(), get_y_k(), get_bar_o(), + t_msk, b_exp, n_thr, n_mbr, n_msk, n_exp ); } return BS_CRD.value(); @@ -368,8 +411,8 @@ namespace evalhyd if (!BS_LBD.has_value()) { BS_LBD = metrics::calc_BS_LBD( - q_thr, get_o_k(), get_y_k(), t_msk, - b_exp, n_thr, n_msk, n_exp + get_q_thr(), get_o_k(), get_y_k(), + t_msk, b_exp, n_thr, n_msk, n_exp ); } return BS_LBD.value(); @@ -380,7 +423,7 @@ namespace evalhyd if (!BSS.has_value()) { BSS = metrics::calc_BSS( - get_bs(), q_thr, get_o_k(), get_bar_o(), t_msk, + get_bs(), get_q_thr(), get_o_k(), get_bar_o(), t_msk, b_exp, n_thr, n_msk, n_exp ); } @@ -414,7 +457,7 @@ namespace evalhyd if (!POD.has_value()) { POD = metrics::calc_POD( - get_pod(), q_thr, t_msk, b_exp, + get_pod(), get_q_thr(), t_msk, b_exp, n_thr, n_mbr, n_msk, n_exp ); } @@ -426,7 +469,7 @@ namespace evalhyd if (!POFD.has_value()) { POFD = metrics::calc_POFD( - get_pofd(), q_thr, t_msk, b_exp, + get_pofd(), get_q_thr(), t_msk, b_exp, n_thr, n_mbr, n_msk, n_exp ); } @@ -438,7 +481,7 @@ namespace evalhyd if (!FAR.has_value()) { FAR = metrics::calc_FAR( - get_far(), q_thr, t_msk, b_exp, + get_far(), get_q_thr(), t_msk, b_exp, n_thr, n_mbr, n_msk, n_exp ); } @@ -450,7 +493,7 @@ namespace evalhyd if (!CSI.has_value()) { CSI = metrics::calc_CSI( - get_csi(), q_thr, t_msk, b_exp, + get_csi(), get_q_thr(), t_msk, b_exp, n_thr, n_mbr, n_msk, n_exp ); } diff --git a/include/evalhyd/detail/utils.hpp b/include/evalhyd/detail/utils.hpp index 1c55572163fa0200ff91771a0746f17fe356ef6e..620f5b3b9aaf7f030e7f51d6b98424807bd33100 100644 --- a/include/evalhyd/detail/utils.hpp +++ b/include/evalhyd/detail/utils.hpp @@ -105,85 +105,6 @@ namespace evalhyd xt::random::seed(bootstrap.find("seed")->second); } } - - namespace evalp - { - // Procedure to check that optional parameters are provided - // as arguments when required metrics need them. - // - // \param metrics - // Vector of strings for the metric(s) to be computed. - // \param thresholds - // Array of thresholds for metrics based on exceedance events. - // \param events - // Kind of streamflow exceedance events. - // \return - // Whether high flow events are considered. - inline bool check_optionals ( - const std::vector<std::string>& metrics, - const xt::xtensor<double, 2>& thresholds, - xtl::xoptional<const std::string, bool> events - ) - { - std::vector<std::string>threshold_metrics = - {"BS", "BS_CRD", "BS_LBD", "BSS", - "POD", "POFD", "FAR", "CSI", "ROCSS"}; - - bool thresholds_required_and_provided = false; - for (const auto& metric : metrics) - { - if (std::find(threshold_metrics.begin(), threshold_metrics.end(), - metric) != threshold_metrics.end()) - { - // check thresholds - if (thresholds.size() < 1) - { - throw std::runtime_error( - "missing thresholds *q_thr* required to " - "compute " + metric - ); - } - else - { - thresholds_required_and_provided = true; - break; - } - } - } - - // check events - if (thresholds_required_and_provided) - { - if (events.has_value()) - { - if (events.value() == "high") - { - return true; - } - else if (events.value() == "low") - { - return false; - } - else - { - throw std::runtime_error( - "invalid value for streamflow *events*" - ); - } - } - else - { - throw std::runtime_error( - "*q_thr* provided but *events* is missing" - ); - } - } - else - { - return true; - } - } - } } } diff --git a/include/evalhyd/evalp.hpp b/include/evalhyd/evalp.hpp index 5582e1f1090404751a054dc4cf92980b74c9c409..50977f9f7a6cb8ebb4d33d95d8609db7cec0e633 100644 --- a/include/evalhyd/evalp.hpp +++ b/include/evalhyd/evalp.hpp @@ -201,10 +201,7 @@ namespace evalhyd "POD", "POFD", "FAR", "CSI", "ROCSS"} ); - // check that optional parameters are given as arguments - bool high_flow_events = - utils::evalp::check_optionals(metrics, q_thr_, events); - + // check optional parameters if (bootstrap.has_value()) { utils::check_bootstrap(bootstrap.value()); @@ -388,9 +385,7 @@ namespace evalhyd xt::view(t_msk_, s, l, xt::all(), xt::all())); probabilist::Evaluator<XD2, XD4, XB4> evaluator( - q_obs_v, q_prd_v, q_thr_v, - high_flow_events, - t_msk_v, b_exp + q_obs_v, q_prd_v, q_thr_v, events, t_msk_v, b_exp ); // retrieve or compute requested metrics