Commit f6ca9c92 authored by Thibault Hallouin's avatar Thibault Hallouin
Browse files

move tests on optional q_thr/events inside Evaluator

1 merge request!3release v0.1.0
Pipeline #43392 passed with stage
in 2 minutes and 15 seconds
Showing with 68 additions and 109 deletions
+68 -109
......@@ -38,10 +38,10 @@ namespace evalhyd
inline xt::xtensor<double, 2> calc_o_k(
const XV1D2& q_obs,
const XV1D2& q_thr,
bool high_flow_events
bool is_high_flow_event
)
{
if (high_flow_events)
if (is_high_flow_event)
{
// observations above threshold(s)
return q_obs >= xt::view(q_thr, xt::all(), xt::newaxis());
......@@ -123,10 +123,10 @@ namespace evalhyd
inline xt::xtensor<double, 2> calc_sum_f_k(
const XV2D4& q_prd,
const XV1D2& q_thr,
bool high_flow_events
bool is_high_flow_event
)
{
if (high_flow_events)
if (is_high_flow_event)
{
// determine if members are above threshold(s)
auto f_k = q_prd >=
......
......@@ -56,8 +56,9 @@ namespace evalhyd
// members for input data
const view1d_xtensor2d_double_type& q_obs;
const view2d_xtensor4d_double_type& q_prd;
const view1d_xtensor2d_double_type& q_thr;
const bool high_flow_events;
// members for optional input data
const view1d_xtensor2d_double_type& _q_thr;
xtl::xoptional<const std::string, bool> _events;
xt::xtensor<bool, 2> t_msk;
const std::vector<xt::xkeep_slice<int>>& b_exp;
......@@ -106,6 +107,48 @@ namespace evalhyd
xtl::xoptional<xt::xtensor<double, 4>, bool> CSI;
xtl::xoptional<xt::xtensor<double, 3>, bool> ROCSS;
// methods to get optional parameters
auto get_q_thr()
{
if (_q_thr.size() < 1)
{
throw std::runtime_error(
"threshold-based metric requested, "
"but *q_thr* not provided"
);
}
else{
return _q_thr;
}
}
bool is_high_flow_event()
{
if (_events.has_value())
{
if (_events.value() == "high")
{
return true;
}
else if (_events.value() == "low")
{
return false;
}
else
{
throw std::runtime_error(
"invalid value for *events*: " + _events.value()
);
}
}
else
{
throw std::runtime_error(
"threshold-based metric requested, "
"but *events* not provided"
);
}
}
// methods to compute elements
xt::xtensor<double, 2> get_o_k()
......@@ -113,7 +156,7 @@ namespace evalhyd
if (!o_k.has_value())
{
o_k = elements::calc_o_k(
q_obs, q_thr, high_flow_events
q_obs, get_q_thr(), is_high_flow_event()
);
}
return o_k.value();
......@@ -135,7 +178,7 @@ namespace evalhyd
if (!sum_f_k.has_value())
{
sum_f_k = elements::calc_sum_f_k(
q_prd, q_thr, high_flow_events
q_prd, get_q_thr(), is_high_flow_event()
);
}
return sum_f_k.value();
......@@ -301,12 +344,11 @@ namespace evalhyd
Evaluator(const view1d_xtensor2d_double_type& obs,
const view2d_xtensor4d_double_type& prd,
const view1d_xtensor2d_double_type& thr,
const bool high_flow_events,
xtl::xoptional<const std::string, bool> events,
const view2d_xtensor4d_bool_type& msk,
const std::vector<xt::xkeep_slice<int>>& exp) :
q_obs{obs}, q_prd{prd}, q_thr{thr},
high_flow_events{high_flow_events},
t_msk(msk), b_exp(exp)
q_obs{obs}, q_prd{prd},
_q_thr{thr}, _events{events}, t_msk(msk), b_exp(exp)
{
// initialise a mask if none provided
// (corresponding to no temporal subset)
......@@ -319,7 +361,7 @@ namespace evalhyd
n = q_obs.size();
n_msk = t_msk.shape(0);
n_mbr = q_prd.shape(0);
n_thr = q_thr.size();
n_thr = _q_thr.size();
n_exp = b_exp.size();
// drop time steps where observations and/or predictions are NaN
......@@ -345,7 +387,8 @@ namespace evalhyd
if (!BS.has_value())
{
BS = metrics::calc_BS(
get_bs(), q_thr, t_msk, b_exp, n_thr, n_msk, n_exp
get_bs(), get_q_thr(), t_msk, b_exp,
n_thr, n_msk, n_exp
);
}
return BS.value();
......@@ -356,8 +399,8 @@ namespace evalhyd
if (!BS_CRD.has_value())
{
BS_CRD = metrics::calc_BS_CRD(
q_thr, get_o_k(), get_y_k(), get_bar_o(), t_msk,
b_exp, n_thr, n_mbr, n_msk, n_exp
get_q_thr(), get_o_k(), get_y_k(), get_bar_o(),
t_msk, b_exp, n_thr, n_mbr, n_msk, n_exp
);
}
return BS_CRD.value();
......@@ -368,8 +411,8 @@ namespace evalhyd
if (!BS_LBD.has_value())
{
BS_LBD = metrics::calc_BS_LBD(
q_thr, get_o_k(), get_y_k(), t_msk,
b_exp, n_thr, n_msk, n_exp
get_q_thr(), get_o_k(), get_y_k(),
t_msk, b_exp, n_thr, n_msk, n_exp
);
}
return BS_LBD.value();
......@@ -380,7 +423,7 @@ namespace evalhyd
if (!BSS.has_value())
{
BSS = metrics::calc_BSS(
get_bs(), q_thr, get_o_k(), get_bar_o(), t_msk,
get_bs(), get_q_thr(), get_o_k(), get_bar_o(), t_msk,
b_exp, n_thr, n_msk, n_exp
);
}
......@@ -414,7 +457,7 @@ namespace evalhyd
if (!POD.has_value())
{
POD = metrics::calc_POD(
get_pod(), q_thr, t_msk, b_exp,
get_pod(), get_q_thr(), t_msk, b_exp,
n_thr, n_mbr, n_msk, n_exp
);
}
......@@ -426,7 +469,7 @@ namespace evalhyd
if (!POFD.has_value())
{
POFD = metrics::calc_POFD(
get_pofd(), q_thr, t_msk, b_exp,
get_pofd(), get_q_thr(), t_msk, b_exp,
n_thr, n_mbr, n_msk, n_exp
);
}
......@@ -438,7 +481,7 @@ namespace evalhyd
if (!FAR.has_value())
{
FAR = metrics::calc_FAR(
get_far(), q_thr, t_msk, b_exp,
get_far(), get_q_thr(), t_msk, b_exp,
n_thr, n_mbr, n_msk, n_exp
);
}
......@@ -450,7 +493,7 @@ namespace evalhyd
if (!CSI.has_value())
{
CSI = metrics::calc_CSI(
get_csi(), q_thr, t_msk, b_exp,
get_csi(), get_q_thr(), t_msk, b_exp,
n_thr, n_mbr, n_msk, n_exp
);
}
......
......@@ -105,85 +105,6 @@ namespace evalhyd
xt::random::seed(bootstrap.find("seed")->second);
}
}
namespace evalp
{
// Procedure to check that optional parameters are provided
// as arguments when required metrics need them.
//
// \param metrics
// Vector of strings for the metric(s) to be computed.
// \param thresholds
// Array of thresholds for metrics based on exceedance events.
// \param events
// Kind of streamflow exceedance events.
// \return
// Whether high flow events are considered.
inline bool check_optionals (
const std::vector<std::string>& metrics,
const xt::xtensor<double, 2>& thresholds,
xtl::xoptional<const std::string, bool> events
)
{
std::vector<std::string>threshold_metrics =
{"BS", "BS_CRD", "BS_LBD", "BSS",
"POD", "POFD", "FAR", "CSI", "ROCSS"};
bool thresholds_required_and_provided = false;
for (const auto& metric : metrics)
{
if (std::find(threshold_metrics.begin(), threshold_metrics.end(),
metric) != threshold_metrics.end())
{
// check thresholds
if (thresholds.size() < 1)
{
throw std::runtime_error(
"missing thresholds *q_thr* required to "
"compute " + metric
);
}
else
{
thresholds_required_and_provided = true;
break;
}
}
}
// check events
if (thresholds_required_and_provided)
{
if (events.has_value())
{
if (events.value() == "high")
{
return true;
}
else if (events.value() == "low")
{
return false;
}
else
{
throw std::runtime_error(
"invalid value for streamflow *events*"
);
}
}
else
{
throw std::runtime_error(
"*q_thr* provided but *events* is missing"
);
}
}
else
{
return true;
}
}
}
}
}
......
......@@ -201,10 +201,7 @@ namespace evalhyd
"POD", "POFD", "FAR", "CSI", "ROCSS"}
);
// check that optional parameters are given as arguments
bool high_flow_events =
utils::evalp::check_optionals(metrics, q_thr_, events);
// check optional parameters
if (bootstrap.has_value())
{
utils::check_bootstrap(bootstrap.value());
......@@ -388,9 +385,7 @@ namespace evalhyd
xt::view(t_msk_, s, l, xt::all(), xt::all()));
probabilist::Evaluator<XD2, XD4, XB4> evaluator(
q_obs_v, q_prd_v, q_thr_v,
high_flow_events,
t_msk_v, b_exp
q_obs_v, q_prd_v, q_thr_v, events, t_msk_v, b_exp
);
// retrieve or compute requested metrics
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment