Commit 7adb66dd authored by Thibault Hallouin's avatar Thibault Hallouin
Browse files

pass multi-sites/multi-leadtimes tensors to probabilist Evaluator

Until now, the probabilist Evaluator was not able to compute multi-sites
and/or multi-leadtimes metrics, because it was given one site and one
lead time at a time. This was done to spare some memory, but it was not
logical to load all sites and all lead times for the input data and then
only process a small chunk at a time.

There is no multi-sites/multi-leadtimes metrics implemented yet, but
the Evaluator is ready for it now. This implementation makes full use
of the broadcasting power of xtensor.
1 merge request!3release v0.1.0
Pipeline #43400 passed with stage
in 2 minutes and 25 seconds
Showing with 751 additions and 578 deletions
+751 -578
This diff is collapsed.
This diff is collapsed.
......@@ -25,87 +25,67 @@ namespace evalhyd
class Evaluator
{
private:
using view1d_xtensor2d_double_type = decltype(
xt::view(
std::declval<const XD2&>(),
std::declval<std::size_t>(),
xt::all()
)
);
using view2d_xtensor4d_double_type = decltype(
xt::view(
std::declval<const XD4&>(),
std::declval<std::size_t>(),
std::declval<std::size_t>(),
xt::all(),
xt::all()
)
);
using view2d_xtensor4d_bool_type = decltype(
xt::view(
std::declval<const XB4&>(),
std::declval<std::size_t>(),
std::declval<std::size_t>(),
xt::all(),
xt::all()
)
);
// members for input data
const view1d_xtensor2d_double_type& q_obs;
const view2d_xtensor4d_double_type& q_prd;
const XD2& q_obs;
const XD4& q_prd;
// members for optional input data
const view1d_xtensor2d_double_type& _q_thr;
const XD2& _q_thr;
xtl::xoptional<const std::string, bool> _events;
xt::xtensor<bool, 2> t_msk;
XB4 t_msk;
const std::vector<xt::xkeep_slice<int>>& b_exp;
// members for dimensions
std::size_t n;
std::size_t n_sit;
std::size_t n_ldt;
std::size_t n_tim;
std::size_t n_msk;
std::size_t n_mbr;
std::size_t n_thr;
std::size_t n_exp;
// members for computational elements
xtl::xoptional<xt::xtensor<double, 2>, bool> o_k;
xtl::xoptional<xt::xtensor<double, 3>, bool> bar_o;
xtl::xoptional<xt::xtensor<double, 2>, bool> sum_f_k;
xtl::xoptional<xt::xtensor<double, 2>, bool> y_k;
xtl::xoptional<xt::xtensor<double, 2>, bool> q_qnt;
xtl::xoptional<xt::xtensor<double, 3>, bool> a_k;
xtl::xoptional<xt::xtensor<double, 3>, bool> ct_a;
xtl::xoptional<xt::xtensor<double, 3>, bool> ct_b;
xtl::xoptional<xt::xtensor<double, 3>, bool> ct_c;
xtl::xoptional<xt::xtensor<double, 3>, bool> ct_d;
// > Brier-based
xtl::xoptional<xt::xtensor<double, 3>, bool> o_k;
xtl::xoptional<xt::xtensor<double, 5>, bool> bar_o;
xtl::xoptional<xt::xtensor<double, 4>, bool> sum_f_k;
xtl::xoptional<xt::xtensor<double, 4>, bool> y_k;
// > Quantiles-based
xtl::xoptional<xt::xtensor<double, 4>, bool> q_qnt;
// > Contingency table-based
xtl::xoptional<xt::xtensor<double, 5>, bool> a_k;
xtl::xoptional<xt::xtensor<double, 5>, bool> ct_a;
xtl::xoptional<xt::xtensor<double, 5>, bool> ct_b;
xtl::xoptional<xt::xtensor<double, 5>, bool> ct_c;
xtl::xoptional<xt::xtensor<double, 5>, bool> ct_d;
// members for intermediate evaluation metrics
// (i.e. before the reduction along the temporal axis)
xtl::xoptional<xt::xtensor<double, 2>, bool> bs;
xtl::xoptional<xt::xtensor<double, 2>, bool> qs;
xtl::xoptional<xt::xtensor<double, 2>, bool> crps;
xtl::xoptional<xt::xtensor<double, 3>, bool> pod;
xtl::xoptional<xt::xtensor<double, 3>, bool> pofd;
xtl::xoptional<xt::xtensor<double, 3>, bool> far;
xtl::xoptional<xt::xtensor<double, 3>, bool> csi;
// > Brier-based
xtl::xoptional<xt::xtensor<double, 4>, bool> bs;
// > Quantiles-based
xtl::xoptional<xt::xtensor<double, 4>, bool> qs;
xtl::xoptional<xt::xtensor<double, 3>, bool> crps;
// > Contingency table-based
xtl::xoptional<xt::xtensor<double, 5>, bool> pod;
xtl::xoptional<xt::xtensor<double, 5>, bool> pofd;
xtl::xoptional<xt::xtensor<double, 5>, bool> far;
xtl::xoptional<xt::xtensor<double, 5>, bool> csi;
// members for evaluation metrics
// > Brier-based
xtl::xoptional<xt::xtensor<double, 3>, bool> BS;
xtl::xoptional<xt::xtensor<double, 4>, bool> BS_CRD;
xtl::xoptional<xt::xtensor<double, 4>, bool> BS_LBD;
xtl::xoptional<xt::xtensor<double, 3>, bool> BSS;
xtl::xoptional<xt::xtensor<double, 5>, bool> BS;
xtl::xoptional<xt::xtensor<double, 6>, bool> BS_CRD;
xtl::xoptional<xt::xtensor<double, 6>, bool> BS_LBD;
xtl::xoptional<xt::xtensor<double, 5>, bool> BSS;
// > Quantiles-based
xtl::xoptional<xt::xtensor<double, 3>, bool> QS;
xtl::xoptional<xt::xtensor<double, 2>, bool> CRPS;
xtl::xoptional<xt::xtensor<double, 5>, bool> QS;
xtl::xoptional<xt::xtensor<double, 4>, bool> CRPS;
// > Contingency table-based
xtl::xoptional<xt::xtensor<double, 4>, bool> POD;
xtl::xoptional<xt::xtensor<double, 4>, bool> POFD;
xtl::xoptional<xt::xtensor<double, 4>, bool> FAR;
xtl::xoptional<xt::xtensor<double, 4>, bool> CSI;
xtl::xoptional<xt::xtensor<double, 3>, bool> ROCSS;
xtl::xoptional<xt::xtensor<double, 6>, bool> POD;
xtl::xoptional<xt::xtensor<double, 6>, bool> POFD;
xtl::xoptional<xt::xtensor<double, 6>, bool> FAR;
xtl::xoptional<xt::xtensor<double, 6>, bool> CSI;
xtl::xoptional<xt::xtensor<double, 5>, bool> ROCSS;
// methods to get optional parameters
auto get_q_thr()
......@@ -151,7 +131,7 @@ namespace evalhyd
}
// methods to compute elements
xt::xtensor<double, 2> get_o_k()
xt::xtensor<double, 3> get_o_k()
{
if (!o_k.has_value())
{
......@@ -162,18 +142,19 @@ namespace evalhyd
return o_k.value();
};
xt::xtensor<double, 3> get_bar_o()
xt::xtensor<double, 5> get_bar_o()
{
if (!bar_o.has_value())
{
bar_o = elements::calc_bar_o(
get_o_k(), t_msk, b_exp, n_thr, n_msk, n_exp
get_o_k(), t_msk, b_exp,
n_sit, n_ldt, n_thr, n_msk, n_exp
);
}
return bar_o.value();
};
xt::xtensor<double, 2> get_sum_f_k()
xt::xtensor<double, 4> get_sum_f_k()
{
if (!sum_f_k.has_value())
{
......@@ -184,7 +165,7 @@ namespace evalhyd
return sum_f_k.value();
};
xt::xtensor<double, 2> get_y_k()
xt::xtensor<double, 4> get_y_k()
{
if (!y_k.has_value())
{
......@@ -195,7 +176,7 @@ namespace evalhyd
return y_k.value();
};
xt::xtensor<double, 2> get_q_qnt()
xt::xtensor<double, 4> get_q_qnt()
{
if (!q_qnt.has_value())
{
......@@ -206,7 +187,7 @@ namespace evalhyd
return q_qnt.value();
};
xt::xtensor<double, 3> get_a_k()
xt::xtensor<double, 5> get_a_k()
{
if (!a_k.has_value())
{
......@@ -217,7 +198,7 @@ namespace evalhyd
return a_k.value();
};
xt::xtensor<double, 3> get_ct_a()
xt::xtensor<double, 5> get_ct_a()
{
if (!ct_a.has_value())
{
......@@ -228,7 +209,7 @@ namespace evalhyd
return ct_a.value();
};
xt::xtensor<double, 3> get_ct_b()
xt::xtensor<double, 5> get_ct_b()
{
if (!ct_b.has_value())
{
......@@ -239,7 +220,7 @@ namespace evalhyd
return ct_b.value();
};
xt::xtensor<double, 3> get_ct_c()
xt::xtensor<double, 5> get_ct_c()
{
if (!ct_c.has_value())
{
......@@ -250,7 +231,7 @@ namespace evalhyd
return ct_c.value();
};
xt::xtensor<double, 3> get_ct_d()
xt::xtensor<double, 5> get_ct_d()
{
if (!ct_d.has_value())
{
......@@ -262,7 +243,7 @@ namespace evalhyd
};
// methods to compute intermediate metrics
xt::xtensor<double, 2> get_bs()
xt::xtensor<double, 4> get_bs()
{
if (!bs.has_value())
{
......@@ -273,7 +254,7 @@ namespace evalhyd
return bs.value();
};
xt::xtensor<double, 2> get_qs()
xt::xtensor<double, 4> get_qs()
{
if (!qs.has_value())
{
......@@ -284,7 +265,7 @@ namespace evalhyd
return qs.value();
};;
xt::xtensor<double, 2> get_crps()
xt::xtensor<double, 3> get_crps()
{
if (!crps.has_value())
{
......@@ -295,7 +276,7 @@ namespace evalhyd
return crps.value();
};
xt::xtensor<double, 3> get_pod()
xt::xtensor<double, 5> get_pod()
{
if (!pod.has_value())
{
......@@ -306,7 +287,7 @@ namespace evalhyd
return pod.value();
};
xt::xtensor<double, 3> get_pofd()
xt::xtensor<double, 5> get_pofd()
{
if (!pofd.has_value())
{
......@@ -317,7 +298,7 @@ namespace evalhyd
return pofd.value();
};
xt::xtensor<double, 3> get_far()
xt::xtensor<double, 5> get_far()
{
if (!far.has_value())
{
......@@ -328,7 +309,7 @@ namespace evalhyd
return far.value();
};
xt::xtensor<double, 3> get_csi()
xt::xtensor<double, 5> get_csi()
{
if (!csi.has_value())
{
......@@ -341,11 +322,11 @@ namespace evalhyd
public:
// constructor method
Evaluator(const view1d_xtensor2d_double_type& obs,
const view2d_xtensor4d_double_type& prd,
const view1d_xtensor2d_double_type& thr,
xtl::xoptional<const std::string, bool> events,
const view2d_xtensor4d_bool_type& msk,
Evaluator(const XD2& obs,
const XD4& prd,
const XD2& thr,
xtl::xoptional<const std::string&, bool> events,
const XB4& msk,
const std::vector<xt::xkeep_slice<int>>& exp) :
q_obs{obs}, q_prd{prd},
_q_thr{thr}, _events{events}, t_msk(msk), b_exp(exp)
......@@ -354,158 +335,178 @@ namespace evalhyd
// (corresponding to no temporal subset)
if (msk.size() < 1)
{
t_msk = xt::ones<bool>({std::size_t {1}, q_obs.size()});
t_msk = xt::ones<bool>(
{q_prd.shape(0), q_prd.shape(1),
std::size_t {1}, q_prd.shape(3)}
);
}
// determine size for recurring dimensions
n = q_obs.size();
n_msk = t_msk.shape(0);
n_mbr = q_prd.shape(0);
n_thr = _q_thr.size();
n_sit = q_prd.shape(0);
n_ldt = q_prd.shape(1);
n_mbr = q_prd.shape(2);
n_tim = q_prd.shape(3);
n_msk = t_msk.shape(2);
n_thr = _q_thr.shape(1);
n_exp = b_exp.size();
// drop time steps where observations and/or predictions are NaN
auto obs_nan = xt::isnan(q_obs);
auto prd_nan = xt::isnan(q_prd);
auto sum_nan = xt::eval(xt::sum(prd_nan, -1));
if (xt::amin(sum_nan) != xt::amax(sum_nan))
for (std::size_t s = 0; s < n_sit; s++)
{
throw std::runtime_error(
"predictions across members feature non-equal lengths"
);
for (std::size_t l = 0; l < n_ldt; l++)
{
auto obs_nan =
xt::isnan(xt::view(q_obs, s));
auto prd_nan =
xt::isnan(xt::view(q_prd, s, l));
auto sum_nan =
xt::eval(xt::sum(prd_nan, -1));
if (xt::amin(sum_nan) != xt::amax(sum_nan))
{
throw std::runtime_error(
"predictions across members feature "
"non-equal lengths"
);
}
auto msk_nan =
xt::where(obs_nan || xt::row(prd_nan, 0))[0];
xt::view(t_msk, s, l, xt::all(), xt::keep(msk_nan)) =
false;
}
}
auto msk_nan = xt::where(obs_nan || xt::row(prd_nan, 0))[0];
xt::view(t_msk, xt::all(), xt::keep(msk_nan)) = false;
};
// methods to compute metrics
xt::xtensor<double, 3> get_BS()
xt::xtensor<double, 5> get_BS()
{
if (!BS.has_value())
{
BS = metrics::calc_BS(
get_bs(), get_q_thr(), t_msk, b_exp,
n_thr, n_msk, n_exp
n_sit, n_ldt, n_thr, n_msk, n_exp
);
}
return BS.value();
};
xt::xtensor<double, 4> get_BS_CRD()
xt::xtensor<double, 6> get_BS_CRD()
{
if (!BS_CRD.has_value())
{
BS_CRD = metrics::calc_BS_CRD(
get_q_thr(), get_o_k(), get_y_k(), get_bar_o(),
t_msk, b_exp, n_thr, n_mbr, n_msk, n_exp
t_msk, b_exp,
n_sit, n_ldt, n_thr, n_mbr, n_msk, n_exp
);
}
return BS_CRD.value();
};
xt::xtensor<double, 4> get_BS_LBD()
xt::xtensor<double, 6> get_BS_LBD()
{
if (!BS_LBD.has_value())
{
BS_LBD = metrics::calc_BS_LBD(
get_q_thr(), get_o_k(), get_y_k(),
t_msk, b_exp, n_thr, n_msk, n_exp
t_msk, b_exp, n_sit, n_ldt, n_thr, n_msk, n_exp
);
}
return BS_LBD.value();
};
xt::xtensor<double, 3> get_BSS()
xt::xtensor<double, 5> get_BSS()
{
if (!BSS.has_value())
{
BSS = metrics::calc_BSS(
get_bs(), get_q_thr(), get_o_k(), get_bar_o(), t_msk,
b_exp, n_thr, n_msk, n_exp
b_exp, n_sit, n_ldt, n_thr, n_msk, n_exp
);
}
return BSS.value();
};
xt::xtensor<double, 3> get_QS()
xt::xtensor<double, 5> get_QS()
{
if (!QS.has_value())
{
QS = metrics::calc_QS(
get_qs(), t_msk, b_exp, n_mbr, n_msk, n_exp
get_qs(), t_msk, b_exp,
n_sit, n_ldt, n_mbr, n_msk, n_exp
);
}
return QS.value();
};
xt::xtensor<double, 2> get_CRPS()
xt::xtensor<double, 4> get_CRPS()
{
if (!CRPS.has_value())
{
CRPS = metrics::calc_CRPS(
get_crps(), t_msk, b_exp, n_msk, n_exp
get_crps(), t_msk, b_exp,
n_sit, n_ldt, n_msk, n_exp
);
}
return CRPS.value();
};
xt::xtensor<double, 4> get_POD()
xt::xtensor<double, 6> get_POD()
{
if (!POD.has_value())
{
POD = metrics::calc_POD(
get_pod(), get_q_thr(), t_msk, b_exp,
n_thr, n_mbr, n_msk, n_exp
n_sit, n_ldt, n_thr, n_mbr, n_msk, n_exp
);
}
return POD.value();
};
xt::xtensor<double, 4> get_POFD()
xt::xtensor<double, 6> get_POFD()
{
if (!POFD.has_value())
{
POFD = metrics::calc_POFD(
get_pofd(), get_q_thr(), t_msk, b_exp,
n_thr, n_mbr, n_msk, n_exp
n_sit, n_ldt, n_thr, n_mbr, n_msk, n_exp
);
}
return POFD.value();
};
xt::xtensor<double, 4> get_FAR()
xt::xtensor<double, 6> get_FAR()
{
if (!FAR.has_value())
{
FAR = metrics::calc_FAR(
get_far(), get_q_thr(), t_msk, b_exp,
n_thr, n_mbr, n_msk, n_exp
n_sit, n_ldt, n_thr, n_mbr, n_msk, n_exp
);
}
return FAR.value();
};
xt::xtensor<double, 4> get_CSI()
xt::xtensor<double, 6> get_CSI()
{
if (!CSI.has_value())
{
CSI = metrics::calc_CSI(
get_csi(), get_q_thr(), t_msk, b_exp,
n_thr, n_mbr, n_msk, n_exp
n_sit, n_ldt, n_thr, n_mbr, n_msk, n_exp
);
}
return CSI.value();
};
xt::xtensor<double, 3> get_ROCSS()
xt::xtensor<double, 5> get_ROCSS()
{
if (!ROCSS.has_value())
{
ROCSS = metrics::calc_ROCSS(
get_POD(), get_POFD()
get_POD(), get_POFD(), get_q_thr()
);
}
return ROCSS.value();
......
......@@ -27,16 +27,16 @@ namespace evalhyd
//
// \param q_prd
// Streamflow predictions.
// shape: (members, time)
// shape: (sites, lead times, members, time)
// \return
// Streamflow forecast quantiles.
// shape: (quantiles, time)
template <class XV2D4>
inline xt::xtensor<double, 2> calc_q_qnt(
const XV2D4& q_prd
// shape: (sites, lead times, quantiles, time)
template <class XD4>
inline xt::xtensor<double, 4> calc_q_qnt(
const XD4& q_prd
)
{
return xt::sort(q_prd, 0);
return xt::sort(q_prd, 2);
}
}
......@@ -46,17 +46,17 @@ namespace evalhyd
//
// \param q_obs
// Streamflow observations.
// shape: (time,)
// shape: (sites, time)
// \param q_qnt
// Streamflow quantiles.
// shape: (quantiles, time)
// shape: (sites, lead times, quantiles, time)
// \return
// Quantile scores for each time step.
// shape: (quantiles, time)
template <class XV1D2>
inline xt::xtensor<double, 2> calc_qs(
const XV1D2 &q_obs,
const xt::xtensor<double, 2>& q_qnt,
// shape: (sites, lead times, quantiles, time)
template <class XD2>
inline xt::xtensor<double, 4> calc_qs(
const XD2 &q_obs,
const xt::xtensor<double, 4>& q_qnt,
std::size_t n_mbr
)
{
......@@ -66,13 +66,17 @@ namespace evalhyd
/ double(n_mbr + 1);
// calculate the difference
xt::xtensor<double, 2> diff = q_qnt - q_obs;
xt::xtensor<double, 4> diff =
q_qnt - xt::view(q_obs, xt::all(), xt::newaxis(),
xt::newaxis(), xt::all());
// calculate the quantile scores
xt::xtensor<double, 2> qs = xt::where(
xt::xtensor<double, 4> qs = xt::where(
diff > 0,
2 * (1 - xt::view(alpha, xt::all(), xt::newaxis())) * diff,
- 2 * xt::view(alpha, xt::all(), xt::newaxis()) * diff
2 * (1 - xt::view(alpha, xt::newaxis(), xt::newaxis(),
xt::all(), xt::newaxis())) * diff,
- 2 * xt::view(alpha, xt::newaxis(), xt::newaxis(),
xt::all(), xt::newaxis()) * diff
);
return qs;
......@@ -88,21 +92,18 @@ namespace evalhyd
//
// \param qs
// Quantile scores for each time step.
// shape: (quantiles, time)
// shape: (sites, lead times, quantiles, time)
// \return
// CRPS for each time step.
// shape: (1, time)
inline xt::xtensor<double, 2> calc_crps(
const xt::xtensor<double, 2>& qs,
// shape: (sites, lead times, time)
inline xt::xtensor<double, 3> calc_crps(
const xt::xtensor<double, 4>& qs,
std::size_t n_mbr
)
{
// integrate with trapezoidal rule
return xt::view(
// xt::trapz(y, dx=1/(n+1), axis=0)
xt::trapz(qs, 1./(double(n_mbr) + 1.), 0),
xt::newaxis(), xt::all()
);
// xt::trapz(y, dx=1/(n+1), axis=2)
return xt::trapz(qs, 1./(double(n_mbr) + 1.), 2);
}
}
......@@ -112,13 +113,17 @@ namespace evalhyd
//
// \param qs
// Quantile scores for each time step.
// shape: (quantiles, time)
// shape: (sites, lead times, quantiles, time)
// \param t_msk
// Temporal subsets of the whole record.
// shape: (subsets, time)
// shape: (sites, lead times, subsets, time)
// \param b_exp
// Boostrap samples.
// shape: (samples, time slice)
// \param n_sit
// Number of sites.
// \param n_ldt
// Number of lead times.
// \param n_mbr
// Number of ensemble members.
// \param n_msk
......@@ -127,38 +132,45 @@ namespace evalhyd
// Number of bootstrap samples.
// \return
// Quantile scores.
// shape: (subsets, samples, quantiles)
inline xt::xtensor<double, 3> calc_QS(
const xt::xtensor<double, 2>& qs,
const xt::xtensor<bool, 2>& t_msk,
// shape: (sites, lead times, subsets, samples, quantiles)
inline xt::xtensor<double, 5> calc_QS(
const xt::xtensor<double, 4>& qs,
const xt::xtensor<bool, 4>& t_msk,
const std::vector<xt::xkeep_slice<int>>& b_exp,
std::size_t n_sit,
std::size_t n_ldt,
std::size_t n_mbr,
std::size_t n_msk,
std::size_t n_exp
)
{
// initialise output variable
// shape: (subsets, quantiles)
xt::xtensor<double, 3> QS =
xt::zeros<double>({n_msk, n_exp, n_mbr});
xt::xtensor<double, 5> QS =
xt::zeros<double>({n_sit, n_ldt, n_msk, n_exp, n_mbr});
// compute variable one mask at a time to minimise memory imprint
for (std::size_t m = 0; m < n_msk; m++)
{
// apply the mask
// (using NaN workaround until reducers work on masked_view)
auto qs_masked = xt::where(xt::row(t_msk, m), qs, NAN);
auto qs_masked = xt::where(
xt::view(t_msk, xt::all(), xt::all(), m,
xt::newaxis(), xt::all()),
qs,
NAN
);
// compute variable one sample at a time
for (std::size_t e = 0; e < n_exp; e++)
{
// apply the bootstrap sampling
auto qs_masked_sampled =
xt::view(qs_masked, xt::all(), b_exp[e]);
xt::view(qs_masked, xt::all(), xt::all(),
xt::all(), b_exp[e]);
// calculate the mean over the time steps
// $QS = \frac{1}{n} \sum_{k=1}^{n} qs$
xt::view(QS, m, e, xt::all()) =
xt::view(QS, xt::all(), xt::all(), m, e, xt::all()) =
xt::nanmean(qs_masked_sampled, -1);
}
}
......@@ -171,49 +183,60 @@ namespace evalhyd
//
// \param crps
// CRPS for each time step.
// shape: (1, time)
// shape: (sites, lead times, time)
// \param t_msk
// Temporal subsets of the whole record.
// shape: (subsets, time)
// shape: (sites, lead times, subsets, time)
// \param b_exp
// Boostrap samples.
// shape: (samples, time slice)
// \param n_sit
// Number of sites.
// \param n_ldt
// Number of lead times.
// \param n_msk
// Number of temporal subsets.
// \param n_exp
// Number of bootstrap samples.
// \return
// CRPS.
// shape: (subsets, samples)
inline xt::xtensor<double, 2> calc_CRPS(
const xt::xtensor<double, 2>& crps,
const xt::xtensor<bool, 2>& t_msk,
// shape: (sites, lead times, subsets, samples)
inline xt::xtensor<double, 4> calc_CRPS(
const xt::xtensor<double, 3>& crps,
const xt::xtensor<bool, 4>& t_msk,
const std::vector<xt::xkeep_slice<int>>& b_exp,
std::size_t n_sit,
std::size_t n_ldt,
std::size_t n_msk,
std::size_t n_exp
)
{
// initialise output variable
// shape: (subsets,)
xt::xtensor<double, 2> CRPS = xt::zeros<double>({n_msk, n_exp});
xt::xtensor<double, 4> CRPS =
xt::zeros<double>({n_sit, n_ldt, n_msk, n_exp});
// compute variable one mask at a time to minimise memory imprint
for (std::size_t m = 0; m < n_msk; m++)
{
// apply the mask
// (using NaN workaround until reducers work on masked_view)
auto crps_masked = xt::where(xt::row(t_msk, m), crps, NAN);
auto crps_masked = xt::where(
xt::view(t_msk, xt::all(), xt::all(), m, xt::all()),
crps,
NAN
);
// compute variable one sample at a time
for (std::size_t e = 0; e < n_exp; e++)
{
// apply the bootstrap sampling
auto crps_masked_sampled =
xt::view(crps_masked, xt::all(), b_exp[e]);
xt::view(crps_masked, xt::all(), xt::all(),
b_exp[e]);
// calculate the mean over the time steps
// $CRPS = \frac{1}{n} \sum_{k=1}^{n} crps$
xt::view(CRPS, m, e) =
xt::view(CRPS, xt::all(), xt::all(), m, e) =
xt::squeeze(xt::nanmean(crps_masked_sampled, -1));
}
}
......
......@@ -282,28 +282,9 @@ namespace evalhyd
// retrieve dimensions
std::size_t n_sit = q_prd_.shape(0);
std::size_t n_ltm = q_prd_.shape(1);
std::size_t n_mbr = q_prd_.shape(2);
std::size_t n_tim = q_prd_.shape(3);
std::size_t n_thr = q_thr_.shape(1);
std::size_t n_msk = t_msk_.size() > 0 ? t_msk_.shape(2) :
(m_cdt.size() > 0 ? m_cdt.shape(1) : 1);
std::size_t n_exp = !bootstrap.has_value() ? 1:
bootstrap.value().find("n_samples")->second;
// register metrics number of dimensions
std::unordered_map<std::string, std::vector<std::size_t>> dim;
dim["BS"] = {n_sit, n_ltm, n_msk, n_exp, n_thr};
dim["BSS"] = {n_sit, n_ltm, n_msk, n_exp, n_thr};
dim["BS_CRD"] = {n_sit, n_ltm, n_msk, n_exp, n_thr, 3};
dim["BS_LBD"] = {n_sit, n_ltm, n_msk, n_exp, n_thr, 3};
dim["QS"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr};
dim["CRPS"] = {n_sit, n_ltm, n_msk, n_exp};
dim["POD"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr + 1, n_thr};
dim["POFD"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr + 1, n_thr};
dim["FAR"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr + 1, n_thr};
dim["CSI"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr + 1, n_thr};
dim["ROCSS"] = {n_sit, n_ltm, n_msk, n_exp, n_thr};
// generate masks from conditions if provided
auto gen_msk = [&]()
......@@ -360,108 +341,86 @@ namespace evalhyd
b_exp.push_back(xt::keep(all));
}
// instantiate determinist evaluator
probabilist::Evaluator<XD2, XD4, XB4> evaluator(
q_obs_, q_prd_, q_thr_, events,
t_msk_.size() > 0 ? t_msk_: (m_cdt.size() > 0 ? c_msk : t_msk_),
b_exp
);
// initialise data structure for outputs
std::vector<xt::xarray<double>> r;
for (const auto& metric : metrics)
{
r.emplace_back(xt::zeros<double>(dim[metric]));
}
// compute variables one site at a time to minimise memory imprint
for (std::size_t s = 0; s < n_sit; s++)
for ( const auto& metric : metrics )
{
// compute variables one lead time at a time to minimise memory imprint
for (std::size_t l = 0; l < n_ltm; l++)
if ( metric == "BS" )
{
// instantiate probabilist evaluator
const auto q_obs_v = xt::view(q_obs_, s, xt::all());
const auto q_prd_v = xt::view(q_prd_, s, l, xt::all(), xt::all());
const auto q_thr_v = xt::view(q_thr_, s, xt::all());
const auto t_msk_v =
t_msk_.size() > 0 ?
xt::view(t_msk_, s, l, xt::all(), xt::all()) :
(m_cdt.size() > 0 ?
xt::view(c_msk, s, l, xt::all(), xt::all()) :
xt::view(t_msk_, s, l, xt::all(), xt::all()));
probabilist::Evaluator<XD2, XD4, XB4> evaluator(
q_obs_v, q_prd_v, q_thr_v, events, t_msk_v, b_exp
r.emplace_back(
uncertainty::summarise(evaluator.get_BS(), summary)
);
}
else if ( metric == "BSS" )
{
r.emplace_back(
uncertainty::summarise(evaluator.get_BSS(), summary)
);
}
else if ( metric == "BS_CRD" )
{
r.emplace_back(
uncertainty::summarise(evaluator.get_BS_CRD(), summary)
);
}
else if ( metric == "BS_LBD" )
{
r.emplace_back(
uncertainty::summarise(evaluator.get_BS_LBD(), summary)
);
}
else if ( metric == "QS" )
{
r.emplace_back(
uncertainty::summarise(evaluator.get_QS(), summary)
);
}
else if ( metric == "CRPS" )
{
r.emplace_back(
uncertainty::summarise(evaluator.get_CRPS(), summary)
);
}
else if ( metric == "POD" )
{
r.emplace_back(
uncertainty::summarise(evaluator.get_POD(), summary)
);
}
else if ( metric == "POFD" )
{
r.emplace_back(
uncertainty::summarise(evaluator.get_POFD(), summary)
);
}
else if ( metric == "FAR" )
{
r.emplace_back(
uncertainty::summarise(evaluator.get_FAR(), summary)
);
}
else if ( metric == "CSI" )
{
r.emplace_back(
uncertainty::summarise(evaluator.get_CSI(), summary)
);
}
else if ( metric == "ROCSS" )
{
r.emplace_back(
uncertainty::summarise(evaluator.get_ROCSS(), summary)
);
// retrieve or compute requested metrics
for (std::size_t m = 0; m < metrics.size(); m++)
{
const auto& metric = metrics[m];
if ( metric == "BS" )
{
// (sites, lead times, subsets, samples, thresholds)
xt::view(r[m], s, l, xt::all(), xt::all(), xt::all()) =
uncertainty::summarise(evaluator.get_BS(), summary);
}
else if ( metric == "BSS" )
{
// (sites, lead times, subsets, samples, thresholds)
xt::view(r[m], s, l, xt::all(), xt::all(), xt::all()) =
uncertainty::summarise(evaluator.get_BSS(), summary);
}
else if ( metric == "BS_CRD" )
{
// (sites, lead times, subsets, samples, thresholds, components)
xt::view(r[m], s, l, xt::all(), xt::all(), xt::all(), xt::all()) =
uncertainty::summarise(evaluator.get_BS_CRD(), summary);
}
else if ( metric == "BS_LBD" )
{
// (sites, lead times, subsets, samples, thresholds, components)
xt::view(r[m], s, l, xt::all(), xt::all(), xt::all(), xt::all()) =
uncertainty::summarise(evaluator.get_BS_LBD(), summary);
}
else if ( metric == "QS" )
{
// (sites, lead times, subsets, samples, quantiles)
xt::view(r[m], s, l, xt::all(), xt::all(), xt::all()) =
uncertainty::summarise(evaluator.get_QS(), summary);
}
else if ( metric == "CRPS" )
{
// (sites, lead times, subsets, samples)
xt::view(r[m], s, l, xt::all(), xt::all()) =
uncertainty::summarise(evaluator.get_CRPS(), summary);
}
else if ( metric == "POD" )
{
// (sites, lead times, subsets, samples, levels, thresholds)
xt::view(r[m], s, l, xt::all(), xt::all()) =
uncertainty::summarise(evaluator.get_POD(), summary);
}
else if ( metric == "POFD" )
{
// (sites, lead times, subsets, samples, levels, thresholds)
xt::view(r[m], s, l, xt::all(), xt::all()) =
uncertainty::summarise(evaluator.get_POFD(), summary);
}
else if ( metric == "FAR" )
{
// (sites, lead times, subsets, samples, levels, thresholds)
xt::view(r[m], s, l, xt::all(), xt::all()) =
uncertainty::summarise(evaluator.get_FAR(), summary);
}
else if ( metric == "CSI" )
{
// (sites, lead times, subsets, samples, levels, thresholds)
xt::view(r[m], s, l, xt::all(), xt::all()) =
uncertainty::summarise(evaluator.get_CSI(), summary);
}
else if ( metric == "ROCSS" )
{
// (sites, lead times, subsets, samples, thresholds)
xt::view(r[m], s, l, xt::all(), xt::all()) =
uncertainty::summarise(evaluator.get_ROCSS(), summary);
}
}
}
}
return r;
}
}
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment