An error occurred while loading the file. Please try again.
-
Thibault Hallouin authored75793d3d
// Copyright (c) 2023, INRAE.
// Distributed under the terms of the GPL-3 Licence.
// The full licence is in the file LICENCE, distributed with this software.
#ifndef EVALHYD_PROBABILIST_CONTINGENCY_HPP
#define EVALHYD_PROBABILIST_CONTINGENCY_HPP
#include <xtensor/xtensor.hpp>
#include <xtensor/xview.hpp>
#include <xtensor/xmasked_view.hpp>
#include <xtensor/xmath.hpp>
// NOTE ------------------------------------------------------------------------
// All equations in metrics below are following notations from
// "Wilks, D. S. (2011). Statistical methods in the atmospheric sciences.
// Amsterdam; Boston: Elsevier Academic Press. ISBN: 9780123850225".
// In particular, pp. 302-303, 332-333.
// -----------------------------------------------------------------------------
namespace evalhyd
{
namespace probabilist
{
namespace elements
{
// Contingency table:
//
// OBS
// Y N
// +-----+-----+ a: hits
// Y | a | b | b: false alarms
// PRD +-----+-----+ c: misses
// N | c | d | d: correct rejections
// +-----+-----+
//
/// Determine alerts based on forecast.
///
/// \param sum_f_k
/// Number of forecast members exceeding threshold(s).
/// shape: (sites, lead times, thresholds, time)
/// \param n_mbr
/// Number of ensemble members.
/// \return
/// Alerts based on forecast.
/// shape: (sites, lead times, levels, thresholds, time)
inline xt::xtensor<double, 5> calc_a_k(
const xt::xtensor<double, 4>& sum_f_k,
std::size_t n_mbr
)
{
// compute range of alert levels $alert_lvl$
// (i.e. number of members that must forecast event
// for alert to be raised)
auto alert_lvl = xt::arange<double>(double(n_mbr + 1));
// determine whether forecast yield alert
return xt::view(sum_f_k, xt::all(), xt::all(), xt::newaxis(),
xt::all(), xt::all())
>= xt::view(alert_lvl, xt::newaxis(), xt::newaxis(),
xt::all(), xt::newaxis(), xt::newaxis());
}
/// Determine hits ('a' in contingency table).
///
/// \param o_k
/// Observed event outcome.
/// shape: (sites, thresholds, time)
/// \param a_k
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
/// Alerts based on forecast.
/// shape: (sites, lead times, levels, thresholds, time)
/// \return
/// Hits.
/// shape: (sites, lead times, levels, thresholds, time)
inline xt::xtensor<double, 5> calc_ct_a(
const xt::xtensor<double, 3>& o_k,
const xt::xtensor<double, 5>& a_k
)
{
return xt::equal(xt::view(o_k, xt::all(), xt::newaxis(),
xt::newaxis(), xt::all(), xt::all()),
1.)
&& xt::equal(a_k, 1.);
}
/// Determine false alarms ('b' in contingency table).
///
/// \param o_k
/// Observed event outcome.
/// shape: (sites, thresholds, time)
/// \param a_k
/// Alerts based on forecast.
/// shape: (sites, lead times, levels, thresholds, time)
/// \return
/// False alarms.
/// shape: (sites, lead times, levels, thresholds, time)
inline xt::xtensor<double, 5> calc_ct_b(
const xt::xtensor<double, 3>& o_k,
const xt::xtensor<double, 5>& a_k
)
{
return xt::equal(xt::view(o_k, xt::all(), xt::newaxis(),
xt::newaxis(), xt::all(), xt::all()),
0.)
&& xt::equal(a_k, 1.);
}
/// Determine misses ('c' in contingency table).
///
/// \param o_k
/// Observed event outcome.
/// shape: (sites, thresholds, time)
/// \param a_k
/// Alerts based on forecast.
/// shape: (sites, lead times, levels, thresholds, time)
/// \return
/// Misses.
/// shape: (sites, lead times, levels, thresholds, time)
inline xt::xtensor<double, 5> calc_ct_c(
const xt::xtensor<double, 3>& o_k,
const xt::xtensor<double, 5>& a_k
)
{
return xt::equal(xt::view(o_k, xt::all(), xt::newaxis(),
xt::newaxis(), xt::all(), xt::all()),
1.)
&& xt::equal(a_k, 0.);
}
/// Determine correct rejections ('d' in contingency table).
///
/// \param o_k
/// Observed event outcome.
/// shape: (sites, thresholds, time)
/// \param a_k
/// Alerts based on forecast.
/// shape: (sites, lead times, levels, thresholds, time)
/// \return
/// Correct rejections.
141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
/// shape: (sites, lead times, levels, thresholds, time)
inline xt::xtensor<double, 5> calc_ct_d(
const xt::xtensor<double, 3>& o_k,
const xt::xtensor<double, 5>& a_k
)
{
return xt::equal(xt::view(o_k, xt::all(), xt::newaxis(),
xt::newaxis(), xt::all(), xt::all()),
0.)
&& xt::equal(a_k, 0.);
}
}
namespace intermediate
{
/// Compute the probability of detection for each time step.
///
/// \param ct_a
/// Hits.
/// shape: (sites, lead times, levels, thresholds, time)
/// \param ct_c
/// Misses.
/// shape: (sites, lead times, levels, thresholds, time)
/// \return
/// Probability of detection for each time step.
/// shape: (sites, lead times, levels, thresholds, time)
inline xt::xtensor<double, 5> calc_pod(
const xt::xtensor<double, 5>& ct_a,
const xt::xtensor<double, 5>& ct_c
)
{
return ct_a / (ct_a + ct_c);
}
/// Compute the probability of false detection for each time step.
///
/// \param ct_b
/// False alarms.
/// shape: (sites, lead times, levels, thresholds, time)
/// \param ct_d
/// Correct rejections.
/// shape: (sites, lead times, levels, thresholds, time)
/// \return
/// Probability of false detection for each time step.
/// shape: (sites, lead times, levels, thresholds, time)
inline xt::xtensor<double, 5> calc_pofd(
const xt::xtensor<double, 5>& ct_b,
const xt::xtensor<double, 5>& ct_d
)
{
return ct_b / (ct_b + ct_d);
}
/// Compute the false alarm ratio for each time step.
///
/// \param ct_a
/// Hits.
/// shape: (sites, lead times, levels, thresholds, time)
/// \param ct_b
/// False alarms.
/// shape: (sites, lead times, levels, thresholds, time)
/// \return
/// False alarm ratio for each time step.
/// shape: (sites, lead times, levels, thresholds, time)
inline xt::xtensor<double, 5> calc_far(
const xt::xtensor<double, 5>& ct_a,
const xt::xtensor<double, 5>& ct_b
)
{
return ct_b / (ct_a + ct_b);
211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
}
/// Compute the critical success index for each time step.
///
/// \param ct_a
/// Hits.
/// shape: (sites, lead times, levels, thresholds, time)
/// \param ct_b
/// False alarms.
/// shape: (sites, lead times, levels, thresholds, time)
/// \param ct_c
/// Misses.
/// shape: (sites, lead times, levels, thresholds, time)
/// \return
/// Critical success index for each time step.
/// shape: (sites, lead times, levels, thresholds, time)
inline xt::xtensor<double, 5> calc_csi(
const xt::xtensor<double, 5>& ct_a,
const xt::xtensor<double, 5>& ct_b,
const xt::xtensor<double, 5>& ct_c
)
{
return ct_a / (ct_a + ct_b + ct_c);
}
}
namespace metrics
{
namespace detail
{
template <class XD2>
inline xt::xtensor<double, 6> calc_METRIC_from_metric(
const xt::xtensor<double, 5>& metric,
const XD2& q_thr,
const xt::xtensor<bool, 4>& t_msk,
const std::vector<xt::xkeep_slice<int>>& b_exp,
std::size_t n_sit,
std::size_t n_ldt,
std::size_t n_thr,
std::size_t n_mbr,
std::size_t n_msk,
std::size_t n_exp
)
{
// initialise output variable
xt::xtensor<double, 6> METRIC =
xt::zeros<double>({n_sit, n_ldt, n_msk, n_exp,
n_mbr + 1, n_thr});
// compute variable one mask at a time to minimise memory imprint
for (std::size_t m = 0; m < n_msk; m++)
{
// apply the mask
// (using NaN workaround until reducers work on masked_view)
auto metric_masked = xt::where(
xt::view(t_msk, xt::all(), xt::all(), m,
xt::newaxis(), xt::newaxis(),
xt::all()),
metric,
NAN
);
// compute variable one sample at a time
for (std::size_t e = 0; e < n_exp; e++)
{
// apply the bootstrap sampling
auto metric_masked_sampled =
xt::view(metric_masked, xt::all(), xt::all(),
xt::all(), xt::all(), b_exp[e]);
281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
// calculate the mean over the time steps
xt::view(METRIC, xt::all(), xt::all(), m, e,
xt::all(), xt::all()) =
xt::nanmean(metric_masked_sampled, -1);
}
}
// assign NaN where thresholds were not provided (i.e. set as NaN)
xt::masked_view(
METRIC,
xt::isnan(xt::view(q_thr, xt::all(), xt::newaxis(),
xt::newaxis(), xt::newaxis(),
xt::newaxis(), xt::all()))
) = NAN;
return METRIC;
}
}
/// Compute the probability of detection (POD),
/// also known as 'hit rate'.
///
/// \param pod
/// Probability of detection for each time step.
/// shape: (sites, lead times, levels, thresholds, time)
/// \param q_thr
/// Streamflow exceedance threshold(s).
/// shape: (sites, thresholds)
/// \param t_msk
/// Temporal subsets of the whole record.
/// shape: (sites, lead times, subsets, time)
/// \param b_exp
/// Boostrap samples.
/// shape: (samples, time slice)
/// \param n_sit
/// Number of sites.
/// \param n_ldt
/// Number of lead times.
/// \param n_thr
/// Number of thresholds.
/// \param n_mbr
/// Number of ensemble members.
/// \param n_msk
/// Number of temporal subsets.
/// \param n_exp
/// Number of bootstrap samples.
/// \return
/// Probabilities of detection.
/// shape: (sites, lead times, subsets, samples, levels, thresholds)
template <class XD2>
inline xt::xtensor<double, 6> calc_POD(
const xt::xtensor<double, 5>& pod,
const XD2& q_thr,
const xt::xtensor<bool, 4>& t_msk,
const std::vector<xt::xkeep_slice<int>>& b_exp,
std::size_t n_sit,
std::size_t n_ldt,
std::size_t n_thr,
std::size_t n_mbr,
std::size_t n_msk,
std::size_t n_exp
)
{
return detail::calc_METRIC_from_metric(
pod, q_thr, t_msk, b_exp,
n_sit, n_ldt, n_thr, n_mbr, n_msk, n_exp
);
}
/// Compute the probability of detection (POFD),
351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
/// also known as 'false alarm rate'.
///
/// \param pofd
/// Probability of false detection for each time step.
/// shape: (sites, lead times, levels, thresholds, time)
/// \param q_thr
/// Streamflow exceedance threshold(s).
/// shape: (sites, thresholds)
/// \param t_msk
/// Temporal subsets of the whole record.
/// shape: (sites, lead times, subsets, time)
/// \param b_exp
/// Boostrap samples.
/// shape: (samples, time slice)
/// \param n_sit
/// Number of sites.
/// \param n_ldt
/// Number of lead times.
/// \param n_thr
/// Number of thresholds.
/// \param n_mbr
/// Number of ensemble members.
/// \param n_msk
/// Number of temporal subsets.
/// \param n_exp
/// Number of bootstrap samples.
/// \return
/// Probabilities of false detection.
/// shape: (sites, lead times, subsets, samples, levels, thresholds)
template <class XD2>
inline xt::xtensor<double, 6> calc_POFD(
const xt::xtensor<double, 5>& pofd,
const XD2& q_thr,
const xt::xtensor<bool, 4>& t_msk,
const std::vector<xt::xkeep_slice<int>>& b_exp,
std::size_t n_sit,
std::size_t n_ldt,
std::size_t n_thr,
std::size_t n_mbr,
std::size_t n_msk,
std::size_t n_exp
)
{
return detail::calc_METRIC_from_metric(
pofd, q_thr, t_msk, b_exp,
n_sit, n_ldt, n_thr, n_mbr, n_msk, n_exp
);
}
/// Compute the false alarm ratio (FAR).
///
/// \param far
/// False alarm ratio for each time step.
/// shape: (sites, lead times, levels, thresholds, time)
/// \param q_thr
/// Streamflow exceedance threshold(s).
/// shape: (sites, thresholds)
/// \param t_msk
/// Temporal subsets of the whole record.
/// shape: (sites, lead times, subsets, time)
/// \param b_exp
/// Boostrap samples.
/// shape: (samples, time slice)
/// \param n_sit
/// Number of sites.
/// \param n_ldt
/// Number of lead times.
/// \param n_thr
/// Number of thresholds.
/// \param n_mbr
421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490
/// Number of ensemble members.
/// \param n_msk
/// Number of temporal subsets.
/// \param n_exp
/// Number of bootstrap samples.
/// \return
/// False alarm ratios.
/// shape: (sites, lead times, subsets, samples, levels, thresholds)
template <class XD2>
inline xt::xtensor<double, 6> calc_FAR(
const xt::xtensor<double, 5>& far,
const XD2& q_thr,
const xt::xtensor<bool, 4>& t_msk,
const std::vector<xt::xkeep_slice<int>>& b_exp,
std::size_t n_sit,
std::size_t n_ldt,
std::size_t n_thr,
std::size_t n_mbr,
std::size_t n_msk,
std::size_t n_exp
)
{
return detail::calc_METRIC_from_metric(
far, q_thr, t_msk, b_exp,
n_sit, n_ldt, n_thr, n_mbr, n_msk, n_exp
);
}
/// Compute the critical success index (CSI).
///
/// \param csi
/// Critical success index for each time step.
/// shape: (sites, lead times, levels, thresholds, time)
/// \param q_thr
/// Streamflow exceedance threshold(s).
/// shape: (sites, thresholds)
/// \param t_msk
/// Temporal subsets of the whole record.
/// shape: (sites, lead times, subsets, time)
/// \param b_exp
/// Boostrap samples.
/// shape: (samples, time slice)
/// \param n_sit
/// Number of sites.
/// \param n_ldt
/// Number of lead times.
/// \param n_thr
/// Number of thresholds.
/// \param n_mbr
/// Number of ensemble members.
/// \param n_msk
/// Number of temporal subsets.
/// \param n_exp
/// Number of bootstrap samples.
/// \return
/// Critical success indices.
/// shape: (sites, lead times, subsets, samples, levels, thresholds)
template <class XD2>
inline xt::xtensor<double, 6> calc_CSI(
const xt::xtensor<double, 5>& csi,
const XD2& q_thr,
const xt::xtensor<bool, 4>& t_msk,
const std::vector<xt::xkeep_slice<int>>& b_exp,
std::size_t n_sit,
std::size_t n_ldt,
std::size_t n_thr,
std::size_t n_mbr,
std::size_t n_msk,
std::size_t n_exp
)
491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544
{
return detail::calc_METRIC_from_metric(
csi, q_thr, t_msk, b_exp,
n_sit, n_ldt, n_thr, n_mbr, n_msk, n_exp
);
}
/// Compute the relative operating characteristic skill score (ROCSS).
///
/// \param POD
/// Probabilities of detection.
/// shape: (sites, lead times, subsets, samples, levels, thresholds)
/// \param POFD
/// Probabilities of false detection.
/// shape: (sites, lead times, subsets, samples, levels, thresholds)
/// \param q_thr
/// Streamflow exceedance threshold(s).
/// shape: (sites, thresholds)
/// \return
/// ROC skill scores.
/// shape: (sites, lead times, subsets, samples, thresholds)
template <class XD2>
inline xt::xtensor<double, 5> calc_ROCSS(
const xt::xtensor<double, 6>& POD,
const xt::xtensor<double, 6>& POFD,
const XD2& q_thr
)
{
// compute the area under the ROC curve
// xt::trapz(y, x, axis=4)
// (note: taking the opposite of the integration results
// because POD/POFD values are in decreasing order)
auto A = - xt::trapz(POD, POFD, 4);
// compute the ROC skill score
// $SS_{ROC} = \frac{A - A_{random}}{A_{perfect} - A_{random}}$
// $SS_{ROC} = \frac{A - 0.5}{1. - 0.5} = 2A - 1$
auto ROCSS = xt::eval((2. * A) - 1.);
// assign NaN where thresholds were not provided (i.e. set as NaN)
xt::masked_view(
ROCSS,
xt::isnan(xt::view(q_thr, xt::all(), xt::newaxis(),
xt::newaxis(), xt::newaxis(),
xt::all()))
) = NAN;
return ROCSS;
}
}
}
}
#endif //EVALHYD_PROBABILIST_CONTINGENCY_HPP