An error occurred while loading the file. Please try again.
-
Thibault Hallouin authoredb556eeea
// Copyright (c) 2023, INRAE.
// Distributed under the terms of the GPL-3 Licence.
// The full licence is in the file LICENCE, distributed with this software.
#ifndef EVALHYD_PROBABILIST_BRIER_HPP
#define EVALHYD_PROBABILIST_BRIER_HPP
#include <limits>
#include <xtensor/xtensor.hpp>
#include <xtensor/xview.hpp>
#include <xtensor/xindex_view.hpp>
#include <xtensor/xmasked_view.hpp>
#include <xtensor/xsort.hpp>
#include <xtensor/xmath.hpp>
// NOTE ------------------------------------------------------------------------
// All equations in metrics below are following notations from
// "Wilks, D. S. (2011). Statistical methods in the atmospheric sciences.
// Amsterdam; Boston: Elsevier Academic Press. ISBN: 9780123850225".
// In particular, pp. 302-303, 332-333.
// -----------------------------------------------------------------------------
namespace evalhyd
{
namespace probabilist
{
namespace elements
{
/// Determine observed realisation of threshold(s) exceedance.
///
/// \param q_obs
/// Streamflow observations.
/// shape: (sites, time)
/// \param q_thr
/// Streamflow exceedance threshold(s).
/// shape: (sites, thresholds)
/// \param is_high_flow_event
/// Whether events correspond to being above the threshold(s).
/// \return
/// Event observed outcome.
/// shape: (sites, thresholds, time)
template<class XD2a, class XD2b>
inline xt::xtensor<double, 3> calc_o_k(
const XD2a& q_obs,
const XD2b& q_thr,
bool is_high_flow_event
)
{
if (is_high_flow_event)
{
// observations above threshold(s)
return xt::view(q_obs, xt::all(), xt::newaxis(), xt::all())
>= xt::view(q_thr, xt::all(), xt::all(), xt::newaxis());
}
else
{
// observations below threshold(s)
return xt::view(q_obs, xt::all(), xt::newaxis(), xt::all())
<= xt::view(q_thr, xt::all(), xt::all(), xt::newaxis());
}
}
/// Determine mean observed realisation of threshold(s) exceedance.
///
/// \param o_k
/// Event observed outcome.
/// shape: (sites, thresholds, time)
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
/// \param t_msk
/// Temporal subsets of the whole record.
/// shape: (sites, lead times, subsets, time)
/// \param b_exp
/// Boostrap samples.
/// shape: (samples, time slice)
/// \param n_sit
/// Number of sites.
/// \param n_ldt
/// Number of lead times.
/// \param n_thr
/// Number of thresholds.
/// \param n_msk
/// Number of temporal subsets.
/// \param n_exp
/// Number of bootstrap samples.
/// \return
/// Mean event observed outcome.
/// shape: (sites, lead times, subsets, samples, thresholds)
inline xt::xtensor<double, 5> calc_bar_o(
const xt::xtensor<double, 3>& o_k,
const xt::xtensor<bool, 4>& t_msk,
const std::vector<xt::xkeep_slice<int>>& b_exp,
std::size_t n_sit,
std::size_t n_ldt,
std::size_t n_thr,
std::size_t n_msk,
std::size_t n_exp
)
{
// initialise output variable
xt::xtensor<double, 5> bar_o =
xt::zeros<double>({n_sit, n_ldt, n_msk, n_exp, n_thr});
// apply the mask
// (using NaN workaround until reducers work on masked_view)
auto o_k_masked = xt::where(
xt::view(t_msk, xt::all(), xt::all(),
xt::all(), xt::newaxis(), xt::all()),
xt::view(o_k, xt::all(), xt::newaxis(),
xt::newaxis(), xt::all(), xt::all()),
NAN
);
// compute variable one sample at a time
for (std::size_t e = 0; e < n_exp; e++)
{
// apply the bootstrap sampling
auto o_k_masked_sampled =
xt::view(o_k_masked, xt::all(), xt::all(),
xt::all(), xt::all(), b_exp[e]);
// compute mean "climatology" relative frequency of the event
// $\bar{o} = \frac{1}{n} \sum_{k=1}^{n} o_k$
xt::view(bar_o, xt::all(), xt::all(), xt::all(), e, xt::all()) =
xt::nanmean(o_k_masked_sampled, -1);
}
return bar_o;
}
/// Determine number of forecast members exceeding threshold(s)
///
/// \param q_prd
/// Streamflow predictions.
/// shape: (sites, lead times, members, time)
/// \param q_thr
/// Streamflow exceedance threshold(s).
/// shape: (sites, thresholds)
/// \param is_high_flow_event
141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
/// Whether events correspond to being above the threshold(s).
/// \return
/// Number of forecast members exceeding threshold(s).
/// shape: (sites, lead times, thresholds, time)
template<class XD4, class XD2>
inline xt::xtensor<double, 4> calc_sum_f_k(
const XD4& q_prd,
const XD2& q_thr,
bool is_high_flow_event
)
{
if (is_high_flow_event)
{
// determine if members are above threshold(s)
auto f_k = xt::view(q_prd, xt::all(), xt::all(),
xt::newaxis(), xt::all(), xt::all())
>= xt::view(q_thr, xt::all(), xt::newaxis(),
xt::all(), xt::newaxis(), xt::newaxis());
// calculate how many members are above threshold(s)
return xt::sum(f_k, 3);
}
else
{
// determine if members are below threshold(s)
auto f_k = xt::view(q_prd, xt::all(), xt::all(),
xt::newaxis(), xt::all(), xt::all())
<= xt::view(q_thr, xt::all(), xt::newaxis(),
xt::all(), xt::newaxis(), xt::newaxis());
// calculate how many members are below threshold(s)
return xt::sum(f_k, 3);
}
}
/// Determine forecast probability of threshold(s) exceedance to occur.
///
/// \param sum_f_k
/// Number of forecast members exceeding threshold(s).
/// shape: (sites, lead times, thresholds, time)
/// \param n_mbr
/// Number of ensemble members.
/// \return
/// Event probability forecast.
/// shape: (sites, lead times, thresholds, time)
inline xt::xtensor<double, 4> calc_y_k(
const xt::xtensor<double, 4>& sum_f_k,
std::size_t n_mbr
)
{
// determine probability of threshold(s) exceedance
// /!\ probability calculation dividing by n (the number of
// members), not n+1 (the number of ranks) like in other metrics
return sum_f_k / n_mbr;
}
}
namespace intermediate
{
/// Compute the Brier score for each time step.
///
/// \param o_k
/// Observed event outcome.
/// shape: (sites, thresholds, time)
/// \param y_k
/// Event probability forecast.
/// shape: (sites, lead times, thresholds, time)
/// \return
/// Brier score for each threshold for each time step.
/// shape: (sites, lead times, thresholds, time)
211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
inline xt::xtensor<double, 4> calc_bs(
const xt::xtensor<double, 3>& o_k,
const xt::xtensor<double, 4>& y_k
)
{
// return computed Brier score(s)
// $bs = (o_k - y_k)^2$
return xt::square(
xt::view(o_k, xt::all(), xt::newaxis(),
xt::all(), xt::all())
- y_k
);
}
}
namespace metrics
{
/// Compute the Brier score (BS).
///
/// \param bs
/// Brier score for each threshold for each time step.
/// shape: (sites, lead times, thresholds, time)
/// \param q_thr
/// Streamflow exceedance threshold(s).
/// shape: (sites, thresholds)
/// \param t_msk
/// Temporal subsets of the whole record.
/// shape: (sites, lead times, subsets, time)
/// \param b_exp
/// Boostrap samples.
/// shape: (samples, time slice)
/// \param n_sit
/// Number of sites.
/// \param n_ldt
/// Number of lead times.
/// \param n_thr
/// Number of thresholds.
/// \param n_msk
/// Number of temporal subsets.
/// \param n_exp
/// Number of bootstrap samples.
/// \return
/// Brier score for each subset and for each threshold.
/// shape: (sites, lead times, subsets, samples, thresholds)
template <class XD2>
inline xt::xtensor<double, 5> calc_BS(
const xt::xtensor<double, 4>& bs,
const XD2& q_thr,
const xt::xtensor<bool, 4>& t_msk,
const std::vector<xt::xkeep_slice<int>>& b_exp,
std::size_t n_sit,
std::size_t n_ldt,
std::size_t n_thr,
std::size_t n_msk,
std::size_t n_exp
)
{
// initialise output variable
xt::xtensor<double, 5> BS =
xt::zeros<double>({n_sit, n_ldt, n_msk, n_exp, n_thr});
// compute variable one mask at a time to minimise memory imprint
for (std::size_t m = 0; m < n_msk; m++)
{
// apply the mask
// (using NaN workaround until reducers work on masked_view)
auto bs_masked = xt::where(
xt::view(t_msk, xt::all(), xt::all(), m,
xt::newaxis(), xt::all()),
bs,
281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
NAN
);
// compute variable one sample at a time
for (std::size_t e = 0; e < n_exp; e++)
{
// apply the bootstrap sampling
auto bs_masked_sampled =
xt::view(bs_masked, xt::all(), xt::all(),
xt::all(), b_exp[e]);
// calculate the mean over the time steps
// $BS = \frac{1}{n} \sum_{k=1}^{n} (o_k - y_k)^2$
xt::view(BS, xt::all(), xt::all(), m, e, xt::all()) =
xt::nanmean(bs_masked_sampled, -1);
}
}
// assign NaN where thresholds were not provided (i.e. set as NaN)
xt::masked_view(
BS,
xt::isnan(xt::view(q_thr, xt::all(), xt::newaxis(),
xt::newaxis(), xt::newaxis(),
xt::all()))
) = NAN;
return BS;
}
/// Compute the X and Y axes of the reliability diagram
/// (`y_i`, the forecast probability; `bar_o_i`, the observed frequency;)
/// as well as the frequencies of the sampling histogram
/// (`N_i`, the number of forecasts of given probability `y_i`)'.
///
/// \param q_thr
/// Streamflow exceedance threshold(s).
/// shape: (sites, thresholds)
/// \param o_k
/// Observed event outcome.
/// shape: (sites, thresholds, time)
/// \param y_k
/// Event probability forecast.
/// shape: (sites, lead times, thresholds, time)
/// \param t_msk
/// Temporal subsets of the whole record.
/// shape: (sites, lead times, subsets, time)
/// \param b_exp
/// Boostrap samples.
/// shape: (samples, time slice)
/// \param n_sit
/// Number of sites.
/// \param n_ldt
/// Number of lead times.
/// \param n_thr
/// Number of thresholds.
/// \param n_mbr
/// Number of ensemble members.
/// \param n_msk
/// Number of temporal subsets.
/// \param n_exp
/// Number of bootstrap samples.
/// \return
/// X and Y axes of the reliability diagram, and ordinates
/// (i.e. frequencies) of the sampling histogram, in this order.
/// shape: (sites, lead times, subsets, samples, thresholds, bins, axes)
template <class XD2>
inline xt::xtensor<double, 7> calc_REL_DIAG(
const XD2& q_thr,
const xt::xtensor<double, 3>& o_k,
const xt::xtensor<double, 4>& y_k,