brier.hpp 46.93 KiB
// Copyright (c) 2023, INRAE.
// Distributed under the terms of the GPL-3 Licence.
// The full licence is in the file LICENCE, distributed with this software.
#ifndef EVALHYD_PROBABILIST_BRIER_HPP
#define EVALHYD_PROBABILIST_BRIER_HPP
#include <limits>
#include <xtensor/xtensor.hpp>
#include <xtensor/xview.hpp>
#include <xtensor/xindex_view.hpp>
#include <xtensor/xmasked_view.hpp>
#include <xtensor/xsort.hpp>
#include <xtensor/xmath.hpp>
// NOTE ------------------------------------------------------------------------
// All equations in metrics below are following notations from
// "Wilks, D. S. (2011). Statistical methods in the atmospheric sciences.
// Amsterdam; Boston: Elsevier Academic Press. ISBN: 9780123850225".
// In particular, pp. 302-303, 332-333.
// -----------------------------------------------------------------------------
namespace evalhyd
    namespace probabilist
        namespace elements
            /// Determine observed realisation of threshold(s) exceedance.
            ///
            /// \param q_obs
            ///     Streamflow observations.
            ///     shape: (sites, time)
            /// \param q_thr
            ///     Streamflow exceedance threshold(s).
            ///     shape: (sites, thresholds)
            /// \param is_high_flow_event
            ///     Whether events correspond to being above the threshold(s).
            /// \return
            ///     Event observed outcome.
            ///     shape: (sites, thresholds, time)
            template<class XD2a, class XD2b>
            inline xt::xtensor<double, 3> calc_o_k(
                    const XD2a& q_obs,
                    const XD2b& q_thr,
                    bool is_high_flow_event
                if (is_high_flow_event)
                    // observations above threshold(s)
                    return xt::view(q_obs, xt::all(), xt::newaxis(), xt::all())
                           >= xt::view(q_thr, xt::all(), xt::all(), xt::newaxis());
                else
                    // observations below threshold(s)
                    return xt::view(q_obs, xt::all(), xt::newaxis(), xt::all())
                           <= xt::view(q_thr, xt::all(), xt::all(), xt::newaxis());
            /// Determine mean observed realisation of threshold(s) exceedance.
            ///
            /// \param o_k
            ///     Event observed outcome.
            ///     shape: (sites, thresholds, time)
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
/// \param t_msk /// Temporal subsets of the whole record. /// shape: (sites, lead times, subsets, time) /// \param b_exp /// Boostrap samples. /// shape: (samples, time slice) /// \param n_sit /// Number of sites. /// \param n_ldt /// Number of lead times. /// \param n_thr /// Number of thresholds. /// \param n_msk /// Number of temporal subsets. /// \param n_exp /// Number of bootstrap samples. /// \return /// Mean event observed outcome. /// shape: (sites, lead times, subsets, samples, thresholds) inline xt::xtensor<double, 5> calc_bar_o( const xt::xtensor<double, 3>& o_k, const xt::xtensor<bool, 4>& t_msk, const std::vector<xt::xkeep_slice<int>>& b_exp, std::size_t n_sit, std::size_t n_ldt, std::size_t n_thr, std::size_t n_msk, std::size_t n_exp ) { // initialise output variable xt::xtensor<double, 5> bar_o = xt::zeros<double>({n_sit, n_ldt, n_msk, n_exp, n_thr}); // apply the mask // (using NaN workaround until reducers work on masked_view) auto o_k_masked = xt::where( xt::view(t_msk, xt::all(), xt::all(), xt::all(), xt::newaxis(), xt::all()), xt::view(o_k, xt::all(), xt::newaxis(), xt::newaxis(), xt::all(), xt::all()), NAN ); // compute variable one sample at a time for (std::size_t e = 0; e < n_exp; e++) { // apply the bootstrap sampling auto o_k_masked_sampled = xt::view(o_k_masked, xt::all(), xt::all(), xt::all(), xt::all(), b_exp[e]); // compute mean "climatology" relative frequency of the event // $\bar{o} = \frac{1}{n} \sum_{k=1}^{n} o_k$ xt::view(bar_o, xt::all(), xt::all(), xt::all(), e, xt::all()) = xt::nanmean(o_k_masked_sampled, -1); } return bar_o; } /// Determine number of forecast members exceeding threshold(s) /// /// \param q_prd /// Streamflow predictions. /// shape: (sites, lead times, members, time) /// \param q_thr /// Streamflow exceedance threshold(s). /// shape: (sites, thresholds) /// \param is_high_flow_event
141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
/// Whether events correspond to being above the threshold(s). /// \return /// Number of forecast members exceeding threshold(s). /// shape: (sites, lead times, thresholds, time) template<class XD4, class XD2> inline xt::xtensor<double, 4> calc_sum_f_k( const XD4& q_prd, const XD2& q_thr, bool is_high_flow_event ) { if (is_high_flow_event) { // determine if members are above threshold(s) auto f_k = xt::view(q_prd, xt::all(), xt::all(), xt::newaxis(), xt::all(), xt::all()) >= xt::view(q_thr, xt::all(), xt::newaxis(), xt::all(), xt::newaxis(), xt::newaxis()); // calculate how many members are above threshold(s) return xt::sum(f_k, 3); } else { // determine if members are below threshold(s) auto f_k = xt::view(q_prd, xt::all(), xt::all(), xt::newaxis(), xt::all(), xt::all()) <= xt::view(q_thr, xt::all(), xt::newaxis(), xt::all(), xt::newaxis(), xt::newaxis()); // calculate how many members are below threshold(s) return xt::sum(f_k, 3); } } /// Determine forecast probability of threshold(s) exceedance to occur. /// /// \param sum_f_k /// Number of forecast members exceeding threshold(s). /// shape: (sites, lead times, thresholds, time) /// \param n_mbr /// Number of ensemble members. /// \return /// Event probability forecast. /// shape: (sites, lead times, thresholds, time) inline xt::xtensor<double, 4> calc_y_k( const xt::xtensor<double, 4>& sum_f_k, std::size_t n_mbr ) { // determine probability of threshold(s) exceedance // /!\ probability calculation dividing by n (the number of // members), not n+1 (the number of ranks) like in other metrics return sum_f_k / n_mbr; } } namespace intermediate { /// Compute the Brier score for each time step. /// /// \param o_k /// Observed event outcome. /// shape: (sites, thresholds, time) /// \param y_k /// Event probability forecast. /// shape: (sites, lead times, thresholds, time) /// \return /// Brier score for each threshold for each time step. /// shape: (sites, lead times, thresholds, time)
211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
inline xt::xtensor<double, 4> calc_bs( const xt::xtensor<double, 3>& o_k, const xt::xtensor<double, 4>& y_k ) { // return computed Brier score(s) // $bs = (o_k - y_k)^2$ return xt::square( xt::view(o_k, xt::all(), xt::newaxis(), xt::all(), xt::all()) - y_k ); } } namespace metrics { /// Compute the Brier score (BS). /// /// \param bs /// Brier score for each threshold for each time step. /// shape: (sites, lead times, thresholds, time) /// \param q_thr /// Streamflow exceedance threshold(s). /// shape: (sites, thresholds) /// \param t_msk /// Temporal subsets of the whole record. /// shape: (sites, lead times, subsets, time) /// \param b_exp /// Boostrap samples. /// shape: (samples, time slice) /// \param n_sit /// Number of sites. /// \param n_ldt /// Number of lead times. /// \param n_thr /// Number of thresholds. /// \param n_msk /// Number of temporal subsets. /// \param n_exp /// Number of bootstrap samples. /// \return /// Brier score for each subset and for each threshold. /// shape: (sites, lead times, subsets, samples, thresholds) template <class XD2> inline xt::xtensor<double, 5> calc_BS( const xt::xtensor<double, 4>& bs, const XD2& q_thr, const xt::xtensor<bool, 4>& t_msk, const std::vector<xt::xkeep_slice<int>>& b_exp, std::size_t n_sit, std::size_t n_ldt, std::size_t n_thr, std::size_t n_msk, std::size_t n_exp ) { // initialise output variable xt::xtensor<double, 5> BS = xt::zeros<double>({n_sit, n_ldt, n_msk, n_exp, n_thr}); // compute variable one mask at a time to minimise memory imprint for (std::size_t m = 0; m < n_msk; m++) { // apply the mask // (using NaN workaround until reducers work on masked_view) auto bs_masked = xt::where( xt::view(t_msk, xt::all(), xt::all(), m, xt::newaxis(), xt::all()), bs,
281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
NAN ); // compute variable one sample at a time for (std::size_t e = 0; e < n_exp; e++) { // apply the bootstrap sampling auto bs_masked_sampled = xt::view(bs_masked, xt::all(), xt::all(), xt::all(), b_exp[e]); // calculate the mean over the time steps // $BS = \frac{1}{n} \sum_{k=1}^{n} (o_k - y_k)^2$ xt::view(BS, xt::all(), xt::all(), m, e, xt::all()) = xt::nanmean(bs_masked_sampled, -1); } } // assign NaN where thresholds were not provided (i.e. set as NaN) xt::masked_view( BS, xt::isnan(xt::view(q_thr, xt::all(), xt::newaxis(), xt::newaxis(), xt::newaxis(), xt::all())) ) = NAN; return BS; } /// Compute the X and Y axes of the reliability diagram /// (`y_i`, the forecast probability; `bar_o_i`, the observed frequency;) /// as well as the frequencies of the sampling histogram /// (`N_i`, the number of forecasts of given probability `y_i`)'. /// /// \param q_thr /// Streamflow exceedance threshold(s). /// shape: (sites, thresholds) /// \param o_k /// Observed event outcome. /// shape: (sites, thresholds, time) /// \param y_k /// Event probability forecast. /// shape: (sites, lead times, thresholds, time) /// \param t_msk /// Temporal subsets of the whole record. /// shape: (sites, lead times, subsets, time) /// \param b_exp /// Boostrap samples. /// shape: (samples, time slice) /// \param n_sit /// Number of sites. /// \param n_ldt /// Number of lead times. /// \param n_thr /// Number of thresholds. /// \param n_mbr /// Number of ensemble members. /// \param n_msk /// Number of temporal subsets. /// \param n_exp /// Number of bootstrap samples. /// \return /// X and Y axes of the reliability diagram, and ordinates /// (i.e. frequencies) of the sampling histogram, in this order. /// shape: (sites, lead times, subsets, samples, thresholds, bins, axes) template <class XD2> inline xt::xtensor<double, 7> calc_REL_DIAG( const XD2& q_thr, const xt::xtensor<double, 3>& o_k, const xt::xtensor<double, 4>& y_k,