Commit f6ca9c92 authored by Thibault Hallouin's avatar Thibault Hallouin
Browse files

move tests on optional q_thr/events inside Evaluator

1 merge request!3release v0.1.0
Pipeline #43392 passed with stage
in 2 minutes and 15 seconds
Showing with 68 additions and 109 deletions
+68 -109
...@@ -38,10 +38,10 @@ namespace evalhyd ...@@ -38,10 +38,10 @@ namespace evalhyd
inline xt::xtensor<double, 2> calc_o_k( inline xt::xtensor<double, 2> calc_o_k(
const XV1D2& q_obs, const XV1D2& q_obs,
const XV1D2& q_thr, const XV1D2& q_thr,
bool high_flow_events bool is_high_flow_event
) )
{ {
if (high_flow_events) if (is_high_flow_event)
{ {
// observations above threshold(s) // observations above threshold(s)
return q_obs >= xt::view(q_thr, xt::all(), xt::newaxis()); return q_obs >= xt::view(q_thr, xt::all(), xt::newaxis());
...@@ -123,10 +123,10 @@ namespace evalhyd ...@@ -123,10 +123,10 @@ namespace evalhyd
inline xt::xtensor<double, 2> calc_sum_f_k( inline xt::xtensor<double, 2> calc_sum_f_k(
const XV2D4& q_prd, const XV2D4& q_prd,
const XV1D2& q_thr, const XV1D2& q_thr,
bool high_flow_events bool is_high_flow_event
) )
{ {
if (high_flow_events) if (is_high_flow_event)
{ {
// determine if members are above threshold(s) // determine if members are above threshold(s)
auto f_k = q_prd >= auto f_k = q_prd >=
......
...@@ -56,8 +56,9 @@ namespace evalhyd ...@@ -56,8 +56,9 @@ namespace evalhyd
// members for input data // members for input data
const view1d_xtensor2d_double_type& q_obs; const view1d_xtensor2d_double_type& q_obs;
const view2d_xtensor4d_double_type& q_prd; const view2d_xtensor4d_double_type& q_prd;
const view1d_xtensor2d_double_type& q_thr; // members for optional input data
const bool high_flow_events; const view1d_xtensor2d_double_type& _q_thr;
xtl::xoptional<const std::string, bool> _events;
xt::xtensor<bool, 2> t_msk; xt::xtensor<bool, 2> t_msk;
const std::vector<xt::xkeep_slice<int>>& b_exp; const std::vector<xt::xkeep_slice<int>>& b_exp;
...@@ -106,6 +107,48 @@ namespace evalhyd ...@@ -106,6 +107,48 @@ namespace evalhyd
xtl::xoptional<xt::xtensor<double, 4>, bool> CSI; xtl::xoptional<xt::xtensor<double, 4>, bool> CSI;
xtl::xoptional<xt::xtensor<double, 3>, bool> ROCSS; xtl::xoptional<xt::xtensor<double, 3>, bool> ROCSS;
// methods to get optional parameters
auto get_q_thr()
{
if (_q_thr.size() < 1)
{
throw std::runtime_error(
"threshold-based metric requested, "
"but *q_thr* not provided"
);
}
else{
return _q_thr;
}
}
bool is_high_flow_event()
{
if (_events.has_value())
{
if (_events.value() == "high")
{
return true;
}
else if (_events.value() == "low")
{
return false;
}
else
{
throw std::runtime_error(
"invalid value for *events*: " + _events.value()
);
}
}
else
{
throw std::runtime_error(
"threshold-based metric requested, "
"but *events* not provided"
);
}
}
// methods to compute elements // methods to compute elements
xt::xtensor<double, 2> get_o_k() xt::xtensor<double, 2> get_o_k()
...@@ -113,7 +156,7 @@ namespace evalhyd ...@@ -113,7 +156,7 @@ namespace evalhyd
if (!o_k.has_value()) if (!o_k.has_value())
{ {
o_k = elements::calc_o_k( o_k = elements::calc_o_k(
q_obs, q_thr, high_flow_events q_obs, get_q_thr(), is_high_flow_event()
); );
} }
return o_k.value(); return o_k.value();
...@@ -135,7 +178,7 @@ namespace evalhyd ...@@ -135,7 +178,7 @@ namespace evalhyd
if (!sum_f_k.has_value()) if (!sum_f_k.has_value())
{ {
sum_f_k = elements::calc_sum_f_k( sum_f_k = elements::calc_sum_f_k(
q_prd, q_thr, high_flow_events q_prd, get_q_thr(), is_high_flow_event()
); );
} }
return sum_f_k.value(); return sum_f_k.value();
...@@ -301,12 +344,11 @@ namespace evalhyd ...@@ -301,12 +344,11 @@ namespace evalhyd
Evaluator(const view1d_xtensor2d_double_type& obs, Evaluator(const view1d_xtensor2d_double_type& obs,
const view2d_xtensor4d_double_type& prd, const view2d_xtensor4d_double_type& prd,
const view1d_xtensor2d_double_type& thr, const view1d_xtensor2d_double_type& thr,
const bool high_flow_events, xtl::xoptional<const std::string, bool> events,
const view2d_xtensor4d_bool_type& msk, const view2d_xtensor4d_bool_type& msk,
const std::vector<xt::xkeep_slice<int>>& exp) : const std::vector<xt::xkeep_slice<int>>& exp) :
q_obs{obs}, q_prd{prd}, q_thr{thr}, q_obs{obs}, q_prd{prd},
high_flow_events{high_flow_events}, _q_thr{thr}, _events{events}, t_msk(msk), b_exp(exp)
t_msk(msk), b_exp(exp)
{ {
// initialise a mask if none provided // initialise a mask if none provided
// (corresponding to no temporal subset) // (corresponding to no temporal subset)
...@@ -319,7 +361,7 @@ namespace evalhyd ...@@ -319,7 +361,7 @@ namespace evalhyd
n = q_obs.size(); n = q_obs.size();
n_msk = t_msk.shape(0); n_msk = t_msk.shape(0);
n_mbr = q_prd.shape(0); n_mbr = q_prd.shape(0);
n_thr = q_thr.size(); n_thr = _q_thr.size();
n_exp = b_exp.size(); n_exp = b_exp.size();
// drop time steps where observations and/or predictions are NaN // drop time steps where observations and/or predictions are NaN
...@@ -345,7 +387,8 @@ namespace evalhyd ...@@ -345,7 +387,8 @@ namespace evalhyd
if (!BS.has_value()) if (!BS.has_value())
{ {
BS = metrics::calc_BS( BS = metrics::calc_BS(
get_bs(), q_thr, t_msk, b_exp, n_thr, n_msk, n_exp get_bs(), get_q_thr(), t_msk, b_exp,
n_thr, n_msk, n_exp
); );
} }
return BS.value(); return BS.value();
...@@ -356,8 +399,8 @@ namespace evalhyd ...@@ -356,8 +399,8 @@ namespace evalhyd
if (!BS_CRD.has_value()) if (!BS_CRD.has_value())
{ {
BS_CRD = metrics::calc_BS_CRD( BS_CRD = metrics::calc_BS_CRD(
q_thr, get_o_k(), get_y_k(), get_bar_o(), t_msk, get_q_thr(), get_o_k(), get_y_k(), get_bar_o(),
b_exp, n_thr, n_mbr, n_msk, n_exp t_msk, b_exp, n_thr, n_mbr, n_msk, n_exp
); );
} }
return BS_CRD.value(); return BS_CRD.value();
...@@ -368,8 +411,8 @@ namespace evalhyd ...@@ -368,8 +411,8 @@ namespace evalhyd
if (!BS_LBD.has_value()) if (!BS_LBD.has_value())
{ {
BS_LBD = metrics::calc_BS_LBD( BS_LBD = metrics::calc_BS_LBD(
q_thr, get_o_k(), get_y_k(), t_msk, get_q_thr(), get_o_k(), get_y_k(),
b_exp, n_thr, n_msk, n_exp t_msk, b_exp, n_thr, n_msk, n_exp
); );
} }
return BS_LBD.value(); return BS_LBD.value();
...@@ -380,7 +423,7 @@ namespace evalhyd ...@@ -380,7 +423,7 @@ namespace evalhyd
if (!BSS.has_value()) if (!BSS.has_value())
{ {
BSS = metrics::calc_BSS( BSS = metrics::calc_BSS(
get_bs(), q_thr, get_o_k(), get_bar_o(), t_msk, get_bs(), get_q_thr(), get_o_k(), get_bar_o(), t_msk,
b_exp, n_thr, n_msk, n_exp b_exp, n_thr, n_msk, n_exp
); );
} }
...@@ -414,7 +457,7 @@ namespace evalhyd ...@@ -414,7 +457,7 @@ namespace evalhyd
if (!POD.has_value()) if (!POD.has_value())
{ {
POD = metrics::calc_POD( POD = metrics::calc_POD(
get_pod(), q_thr, t_msk, b_exp, get_pod(), get_q_thr(), t_msk, b_exp,
n_thr, n_mbr, n_msk, n_exp n_thr, n_mbr, n_msk, n_exp
); );
} }
...@@ -426,7 +469,7 @@ namespace evalhyd ...@@ -426,7 +469,7 @@ namespace evalhyd
if (!POFD.has_value()) if (!POFD.has_value())
{ {
POFD = metrics::calc_POFD( POFD = metrics::calc_POFD(
get_pofd(), q_thr, t_msk, b_exp, get_pofd(), get_q_thr(), t_msk, b_exp,
n_thr, n_mbr, n_msk, n_exp n_thr, n_mbr, n_msk, n_exp
); );
} }
...@@ -438,7 +481,7 @@ namespace evalhyd ...@@ -438,7 +481,7 @@ namespace evalhyd
if (!FAR.has_value()) if (!FAR.has_value())
{ {
FAR = metrics::calc_FAR( FAR = metrics::calc_FAR(
get_far(), q_thr, t_msk, b_exp, get_far(), get_q_thr(), t_msk, b_exp,
n_thr, n_mbr, n_msk, n_exp n_thr, n_mbr, n_msk, n_exp
); );
} }
...@@ -450,7 +493,7 @@ namespace evalhyd ...@@ -450,7 +493,7 @@ namespace evalhyd
if (!CSI.has_value()) if (!CSI.has_value())
{ {
CSI = metrics::calc_CSI( CSI = metrics::calc_CSI(
get_csi(), q_thr, t_msk, b_exp, get_csi(), get_q_thr(), t_msk, b_exp,
n_thr, n_mbr, n_msk, n_exp n_thr, n_mbr, n_msk, n_exp
); );
} }
......
...@@ -105,85 +105,6 @@ namespace evalhyd ...@@ -105,85 +105,6 @@ namespace evalhyd
xt::random::seed(bootstrap.find("seed")->second); xt::random::seed(bootstrap.find("seed")->second);
} }
} }
namespace evalp
{
// Procedure to check that optional parameters are provided
// as arguments when required metrics need them.
//
// \param metrics
// Vector of strings for the metric(s) to be computed.
// \param thresholds
// Array of thresholds for metrics based on exceedance events.
// \param events
// Kind of streamflow exceedance events.
// \return
// Whether high flow events are considered.
inline bool check_optionals (
const std::vector<std::string>& metrics,
const xt::xtensor<double, 2>& thresholds,
xtl::xoptional<const std::string, bool> events
)
{
std::vector<std::string>threshold_metrics =
{"BS", "BS_CRD", "BS_LBD", "BSS",
"POD", "POFD", "FAR", "CSI", "ROCSS"};
bool thresholds_required_and_provided = false;
for (const auto& metric : metrics)
{
if (std::find(threshold_metrics.begin(), threshold_metrics.end(),
metric) != threshold_metrics.end())
{
// check thresholds
if (thresholds.size() < 1)
{
throw std::runtime_error(
"missing thresholds *q_thr* required to "
"compute " + metric
);
}
else
{
thresholds_required_and_provided = true;
break;
}
}
}
// check events
if (thresholds_required_and_provided)
{
if (events.has_value())
{
if (events.value() == "high")
{
return true;
}
else if (events.value() == "low")
{
return false;
}
else
{
throw std::runtime_error(
"invalid value for streamflow *events*"
);
}
}
else
{
throw std::runtime_error(
"*q_thr* provided but *events* is missing"
);
}
}
else
{
return true;
}
}
}
} }
} }
......
...@@ -201,10 +201,7 @@ namespace evalhyd ...@@ -201,10 +201,7 @@ namespace evalhyd
"POD", "POFD", "FAR", "CSI", "ROCSS"} "POD", "POFD", "FAR", "CSI", "ROCSS"}
); );
// check that optional parameters are given as arguments // check optional parameters
bool high_flow_events =
utils::evalp::check_optionals(metrics, q_thr_, events);
if (bootstrap.has_value()) if (bootstrap.has_value())
{ {
utils::check_bootstrap(bootstrap.value()); utils::check_bootstrap(bootstrap.value());
...@@ -388,9 +385,7 @@ namespace evalhyd ...@@ -388,9 +385,7 @@ namespace evalhyd
xt::view(t_msk_, s, l, xt::all(), xt::all())); xt::view(t_msk_, s, l, xt::all(), xt::all()));
probabilist::Evaluator<XD2, XD4, XB4> evaluator( probabilist::Evaluator<XD2, XD4, XB4> evaluator(
q_obs_v, q_prd_v, q_thr_v, q_obs_v, q_prd_v, q_thr_v, events, t_msk_v, b_exp
high_flow_events,
t_msk_v, b_exp
); );
// retrieve or compute requested metrics // retrieve or compute requested metrics
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment