An error occurred while loading the file. Please try again.
-
Thibault Hallouin authored
Until now, the probabilist Evaluator was not able to compute multi-sites and/or multi-leadtimes metrics, because it was given one site and one lead time at a time. This was done to spare some memory, but it was not logical to load all sites and all lead times for the input data and then only process a small chunk at a time. There is no multi-sites/multi-leadtimes metrics implemented yet, but the Evaluator is ready for it now. This implementation makes full use of the broadcasting power of xtensor.
7adb66dd
// Copyright (c) 2023, INRAE.
// Distributed under the terms of the GPL-3 Licence.
// The full licence is in the file LICENCE, distributed with this software.
#ifndef EVALHYD_PROBABILIST_EVALUATOR_HPP
#define EVALHYD_PROBABILIST_EVALUATOR_HPP
#include <stdexcept>
#include <vector>
#include <xtl/xoptional.hpp>
#include <xtensor/xtensor.hpp>
#include <xtensor/xview.hpp>
#include "brier.hpp"
#include "quantiles.hpp"
#include "contingency.hpp"
namespace evalhyd
{
namespace probabilist
{
template <class XD2, class XD4, class XB4>
class Evaluator
{
private:
// members for input data
const XD2& q_obs;
const XD4& q_prd;
// members for optional input data
const XD2& _q_thr;
xtl::xoptional<const std::string, bool> _events;
XB4 t_msk;
const std::vector<xt::xkeep_slice<int>>& b_exp;
// members for dimensions
std::size_t n_sit;
std::size_t n_ldt;
std::size_t n_tim;
std::size_t n_msk;
std::size_t n_mbr;
std::size_t n_thr;
std::size_t n_exp;
// members for computational elements
// > Brier-based
xtl::xoptional<xt::xtensor<double, 3>, bool> o_k;
xtl::xoptional<xt::xtensor<double, 5>, bool> bar_o;
xtl::xoptional<xt::xtensor<double, 4>, bool> sum_f_k;
xtl::xoptional<xt::xtensor<double, 4>, bool> y_k;
// > Quantiles-based
xtl::xoptional<xt::xtensor<double, 4>, bool> q_qnt;
// > Contingency table-based
xtl::xoptional<xt::xtensor<double, 5>, bool> a_k;
xtl::xoptional<xt::xtensor<double, 5>, bool> ct_a;
xtl::xoptional<xt::xtensor<double, 5>, bool> ct_b;
xtl::xoptional<xt::xtensor<double, 5>, bool> ct_c;
xtl::xoptional<xt::xtensor<double, 5>, bool> ct_d;
// members for intermediate evaluation metrics
// (i.e. before the reduction along the temporal axis)
// > Brier-based
xtl::xoptional<xt::xtensor<double, 4>, bool> bs;
// > Quantiles-based
xtl::xoptional<xt::xtensor<double, 4>, bool> qs;
xtl::xoptional<xt::xtensor<double, 3>, bool> crps;
// > Contingency table-based
xtl::xoptional<xt::xtensor<double, 5>, bool> pod;
xtl::xoptional<xt::xtensor<double, 5>, bool> pofd;
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
xtl::xoptional<xt::xtensor<double, 5>, bool> far;
xtl::xoptional<xt::xtensor<double, 5>, bool> csi;
// members for evaluation metrics
// > Brier-based
xtl::xoptional<xt::xtensor<double, 5>, bool> BS;
xtl::xoptional<xt::xtensor<double, 6>, bool> BS_CRD;
xtl::xoptional<xt::xtensor<double, 6>, bool> BS_LBD;
xtl::xoptional<xt::xtensor<double, 5>, bool> BSS;
// > Quantiles-based
xtl::xoptional<xt::xtensor<double, 5>, bool> QS;
xtl::xoptional<xt::xtensor<double, 4>, bool> CRPS;
// > Contingency table-based
xtl::xoptional<xt::xtensor<double, 6>, bool> POD;
xtl::xoptional<xt::xtensor<double, 6>, bool> POFD;
xtl::xoptional<xt::xtensor<double, 6>, bool> FAR;
xtl::xoptional<xt::xtensor<double, 6>, bool> CSI;
xtl::xoptional<xt::xtensor<double, 5>, bool> ROCSS;
// methods to get optional parameters
auto get_q_thr()
{
if (_q_thr.size() < 1)
{
throw std::runtime_error(
"threshold-based metric requested, "
"but *q_thr* not provided"
);
}
else{
return _q_thr;
}
}
bool is_high_flow_event()
{
if (_events.has_value())
{
if (_events.value() == "high")
{
return true;
}
else if (_events.value() == "low")
{
return false;
}
else
{
throw std::runtime_error(
"invalid value for *events*: " + _events.value()
);
}
}
else
{
throw std::runtime_error(
"threshold-based metric requested, "
"but *events* not provided"
);
}
}
// methods to compute elements
xt::xtensor<double, 3> get_o_k()
{
if (!o_k.has_value())
{
o_k = elements::calc_o_k(
q_obs, get_q_thr(), is_high_flow_event()
);
141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
}
return o_k.value();
};
xt::xtensor<double, 5> get_bar_o()
{
if (!bar_o.has_value())
{
bar_o = elements::calc_bar_o(
get_o_k(), t_msk, b_exp,
n_sit, n_ldt, n_thr, n_msk, n_exp
);
}
return bar_o.value();
};
xt::xtensor<double, 4> get_sum_f_k()
{
if (!sum_f_k.has_value())
{
sum_f_k = elements::calc_sum_f_k(
q_prd, get_q_thr(), is_high_flow_event()
);
}
return sum_f_k.value();
};
xt::xtensor<double, 4> get_y_k()
{
if (!y_k.has_value())
{
y_k = elements::calc_y_k(
get_sum_f_k(), n_mbr
);
}
return y_k.value();
};
xt::xtensor<double, 4> get_q_qnt()
{
if (!q_qnt.has_value())
{
q_qnt = elements::calc_q_qnt(
q_prd
);
}
return q_qnt.value();
};
xt::xtensor<double, 5> get_a_k()
{
if (!a_k.has_value())
{
a_k = elements::calc_a_k(
get_sum_f_k(), n_mbr
);
}
return a_k.value();
};
xt::xtensor<double, 5> get_ct_a()
{
if (!ct_a.has_value())
{
ct_a = elements::calc_ct_a(
get_o_k(), get_a_k()
);
}
return ct_a.value();
};
211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
xt::xtensor<double, 5> get_ct_b()
{
if (!ct_b.has_value())
{
ct_b = elements::calc_ct_b(
get_o_k(), get_a_k()
);
}
return ct_b.value();
};
xt::xtensor<double, 5> get_ct_c()
{
if (!ct_c.has_value())
{
ct_c = elements::calc_ct_c(
get_o_k(), get_a_k()
);
}
return ct_c.value();
};
xt::xtensor<double, 5> get_ct_d()
{
if (!ct_d.has_value())
{
ct_d = elements::calc_ct_d(
get_o_k(), get_a_k()
);
}
return ct_d.value();
};
// methods to compute intermediate metrics
xt::xtensor<double, 4> get_bs()
{
if (!bs.has_value())
{
bs = intermediate::calc_bs(
get_o_k(), get_y_k()
);
}
return bs.value();
};
xt::xtensor<double, 4> get_qs()
{
if (!qs.has_value())
{
qs = intermediate::calc_qs(
q_obs, get_q_qnt(), n_mbr
);
}
return qs.value();
};;
xt::xtensor<double, 3> get_crps()
{
if (!crps.has_value())
{
crps = intermediate::calc_crps(
get_qs(), n_mbr
);
}
return crps.value();
};
xt::xtensor<double, 5> get_pod()
{
281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
if (!pod.has_value())
{
pod = intermediate::calc_pod(
get_ct_a(), get_ct_c()
);
}
return pod.value();
};
xt::xtensor<double, 5> get_pofd()
{
if (!pofd.has_value())
{
pofd = intermediate::calc_pofd(
get_ct_b(), get_ct_d()
);
}
return pofd.value();
};
xt::xtensor<double, 5> get_far()
{
if (!far.has_value())
{
far = intermediate::calc_far(
get_ct_a(), get_ct_b()
);
}
return far.value();
};
xt::xtensor<double, 5> get_csi()
{
if (!csi.has_value())
{
csi = intermediate::calc_csi(
get_ct_a(), get_ct_b(), get_ct_c()
);
}
return csi.value();
};
public:
// constructor method
Evaluator(const XD2& obs,
const XD4& prd,
const XD2& thr,
xtl::xoptional<const std::string&, bool> events,
const XB4& msk,
const std::vector<xt::xkeep_slice<int>>& exp) :
q_obs{obs}, q_prd{prd},
_q_thr{thr}, _events{events}, t_msk(msk), b_exp(exp)
{
// initialise a mask if none provided
// (corresponding to no temporal subset)
if (msk.size() < 1)
{
t_msk = xt::ones<bool>(
{q_prd.shape(0), q_prd.shape(1),
std::size_t {1}, q_prd.shape(3)}
);
}
// determine size for recurring dimensions
n_sit = q_prd.shape(0);
n_ldt = q_prd.shape(1);
n_mbr = q_prd.shape(2);
n_tim = q_prd.shape(3);
n_msk = t_msk.shape(2);
n_thr = _q_thr.shape(1);
351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
n_exp = b_exp.size();
// drop time steps where observations and/or predictions are NaN
for (std::size_t s = 0; s < n_sit; s++)
{
for (std::size_t l = 0; l < n_ldt; l++)
{
auto obs_nan =
xt::isnan(xt::view(q_obs, s));
auto prd_nan =
xt::isnan(xt::view(q_prd, s, l));
auto sum_nan =
xt::eval(xt::sum(prd_nan, -1));
if (xt::amin(sum_nan) != xt::amax(sum_nan))
{
throw std::runtime_error(
"predictions across members feature "
"non-equal lengths"
);
}
auto msk_nan =
xt::where(obs_nan || xt::row(prd_nan, 0))[0];
xt::view(t_msk, s, l, xt::all(), xt::keep(msk_nan)) =
false;
}
}
};
// methods to compute metrics
xt::xtensor<double, 5> get_BS()
{
if (!BS.has_value())
{
BS = metrics::calc_BS(
get_bs(), get_q_thr(), t_msk, b_exp,
n_sit, n_ldt, n_thr, n_msk, n_exp
);
}
return BS.value();
};
xt::xtensor<double, 6> get_BS_CRD()
{
if (!BS_CRD.has_value())
{
BS_CRD = metrics::calc_BS_CRD(
get_q_thr(), get_o_k(), get_y_k(), get_bar_o(),
t_msk, b_exp,
n_sit, n_ldt, n_thr, n_mbr, n_msk, n_exp
);
}
return BS_CRD.value();
};
xt::xtensor<double, 6> get_BS_LBD()
{
if (!BS_LBD.has_value())
{
BS_LBD = metrics::calc_BS_LBD(
get_q_thr(), get_o_k(), get_y_k(),
t_msk, b_exp, n_sit, n_ldt, n_thr, n_msk, n_exp
);
}
return BS_LBD.value();
};
xt::xtensor<double, 5> get_BSS()
421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490
{
if (!BSS.has_value())
{
BSS = metrics::calc_BSS(
get_bs(), get_q_thr(), get_o_k(), get_bar_o(), t_msk,
b_exp, n_sit, n_ldt, n_thr, n_msk, n_exp
);
}
return BSS.value();
};
xt::xtensor<double, 5> get_QS()
{
if (!QS.has_value())
{
QS = metrics::calc_QS(
get_qs(), t_msk, b_exp,
n_sit, n_ldt, n_mbr, n_msk, n_exp
);
}
return QS.value();
};
xt::xtensor<double, 4> get_CRPS()
{
if (!CRPS.has_value())
{
CRPS = metrics::calc_CRPS(
get_crps(), t_msk, b_exp,
n_sit, n_ldt, n_msk, n_exp
);
}
return CRPS.value();
};
xt::xtensor<double, 6> get_POD()
{
if (!POD.has_value())
{
POD = metrics::calc_POD(
get_pod(), get_q_thr(), t_msk, b_exp,
n_sit, n_ldt, n_thr, n_mbr, n_msk, n_exp
);
}
return POD.value();
};
xt::xtensor<double, 6> get_POFD()
{
if (!POFD.has_value())
{
POFD = metrics::calc_POFD(
get_pofd(), get_q_thr(), t_msk, b_exp,
n_sit, n_ldt, n_thr, n_mbr, n_msk, n_exp
);
}
return POFD.value();
};
xt::xtensor<double, 6> get_FAR()
{
if (!FAR.has_value())
{
FAR = metrics::calc_FAR(
get_far(), get_q_thr(), t_msk, b_exp,
n_sit, n_ldt, n_thr, n_mbr, n_msk, n_exp
);
}
return FAR.value();
};
491492493494495496497498499500501502503504505506507508509510511512513514515516517518519
xt::xtensor<double, 6> get_CSI()
{
if (!CSI.has_value())
{
CSI = metrics::calc_CSI(
get_csi(), get_q_thr(), t_msk, b_exp,
n_sit, n_ldt, n_thr, n_mbr, n_msk, n_exp
);
}
return CSI.value();
};
xt::xtensor<double, 5> get_ROCSS()
{
if (!ROCSS.has_value())
{
ROCSS = metrics::calc_ROCSS(
get_POD(), get_POFD(), get_q_thr()
);
}
return ROCSS.value();
};
};
}
}
#endif //EVALHYD_PROBABILIST_EVALUATOR_HPP