evaluator.hpp 16.58 KiB
// Copyright (c) 2023, INRAE.
// Distributed under the terms of the GPL-3 Licence.
// The full licence is in the file LICENCE, distributed with this software.
#ifndef EVALHYD_PROBABILIST_EVALUATOR_HPP
#define EVALHYD_PROBABILIST_EVALUATOR_HPP
#include <stdexcept>
#include <vector>
#include <xtl/xoptional.hpp>
#include <xtensor/xtensor.hpp>
#include <xtensor/xview.hpp>
#include "brier.hpp"
#include "quantiles.hpp"
#include "contingency.hpp"
namespace evalhyd
    namespace probabilist
        template <class XD2, class XD4, class XB4>
        class Evaluator
        private:
            // members for input data
            const XD2& q_obs;
            const XD4& q_prd;
            // members for optional input data
            const XD2& _q_thr;
            xtl::xoptional<const std::string, bool> _events;
            XB4 t_msk;
            const std::vector<xt::xkeep_slice<int>>& b_exp;
            // members for dimensions
            std::size_t n_sit;
            std::size_t n_ldt;
            std::size_t n_tim;
            std::size_t n_msk;
            std::size_t n_mbr;
            std::size_t n_thr;
            std::size_t n_exp;
            // members for computational elements
            // > Brier-based
            xtl::xoptional<xt::xtensor<double, 3>, bool> o_k;
            xtl::xoptional<xt::xtensor<double, 5>, bool> bar_o;
            xtl::xoptional<xt::xtensor<double, 4>, bool> sum_f_k;
            xtl::xoptional<xt::xtensor<double, 4>, bool> y_k;
            // > Quantiles-based
            xtl::xoptional<xt::xtensor<double, 4>, bool> q_qnt;
            // > Contingency table-based
            xtl::xoptional<xt::xtensor<double, 5>, bool> a_k;
            xtl::xoptional<xt::xtensor<double, 5>, bool> ct_a;
            xtl::xoptional<xt::xtensor<double, 5>, bool> ct_b;
            xtl::xoptional<xt::xtensor<double, 5>, bool> ct_c;
            xtl::xoptional<xt::xtensor<double, 5>, bool> ct_d;
            // members for intermediate evaluation metrics
            // (i.e. before the reduction along the temporal axis)
            // > Brier-based
            xtl::xoptional<xt::xtensor<double, 4>, bool> bs;
            // > Quantiles-based
            xtl::xoptional<xt::xtensor<double, 4>, bool> qs;
            xtl::xoptional<xt::xtensor<double, 3>, bool> crps;
            // > Contingency table-based
            xtl::xoptional<xt::xtensor<double, 5>, bool> pod;
            xtl::xoptional<xt::xtensor<double, 5>, bool> pofd;
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
xtl::xoptional<xt::xtensor<double, 5>, bool> far; xtl::xoptional<xt::xtensor<double, 5>, bool> csi; // members for evaluation metrics // > Brier-based xtl::xoptional<xt::xtensor<double, 5>, bool> BS; xtl::xoptional<xt::xtensor<double, 6>, bool> BS_CRD; xtl::xoptional<xt::xtensor<double, 6>, bool> BS_LBD; xtl::xoptional<xt::xtensor<double, 5>, bool> BSS; // > Quantiles-based xtl::xoptional<xt::xtensor<double, 5>, bool> QS; xtl::xoptional<xt::xtensor<double, 4>, bool> CRPS; // > Contingency table-based xtl::xoptional<xt::xtensor<double, 6>, bool> POD; xtl::xoptional<xt::xtensor<double, 6>, bool> POFD; xtl::xoptional<xt::xtensor<double, 6>, bool> FAR; xtl::xoptional<xt::xtensor<double, 6>, bool> CSI; xtl::xoptional<xt::xtensor<double, 5>, bool> ROCSS; // methods to get optional parameters auto get_q_thr() { if (_q_thr.size() < 1) { throw std::runtime_error( "threshold-based metric requested, " "but *q_thr* not provided" ); } else{ return _q_thr; } } bool is_high_flow_event() { if (_events.has_value()) { if (_events.value() == "high") { return true; } else if (_events.value() == "low") { return false; } else { throw std::runtime_error( "invalid value for *events*: " + _events.value() ); } } else { throw std::runtime_error( "threshold-based metric requested, " "but *events* not provided" ); } } // methods to compute elements xt::xtensor<double, 3> get_o_k() { if (!o_k.has_value()) { o_k = elements::calc_o_k( q_obs, get_q_thr(), is_high_flow_event() );
141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
} return o_k.value(); }; xt::xtensor<double, 5> get_bar_o() { if (!bar_o.has_value()) { bar_o = elements::calc_bar_o( get_o_k(), t_msk, b_exp, n_sit, n_ldt, n_thr, n_msk, n_exp ); } return bar_o.value(); }; xt::xtensor<double, 4> get_sum_f_k() { if (!sum_f_k.has_value()) { sum_f_k = elements::calc_sum_f_k( q_prd, get_q_thr(), is_high_flow_event() ); } return sum_f_k.value(); }; xt::xtensor<double, 4> get_y_k() { if (!y_k.has_value()) { y_k = elements::calc_y_k( get_sum_f_k(), n_mbr ); } return y_k.value(); }; xt::xtensor<double, 4> get_q_qnt() { if (!q_qnt.has_value()) { q_qnt = elements::calc_q_qnt( q_prd ); } return q_qnt.value(); }; xt::xtensor<double, 5> get_a_k() { if (!a_k.has_value()) { a_k = elements::calc_a_k( get_sum_f_k(), n_mbr ); } return a_k.value(); }; xt::xtensor<double, 5> get_ct_a() { if (!ct_a.has_value()) { ct_a = elements::calc_ct_a( get_o_k(), get_a_k() ); } return ct_a.value(); };
211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
xt::xtensor<double, 5> get_ct_b() { if (!ct_b.has_value()) { ct_b = elements::calc_ct_b( get_o_k(), get_a_k() ); } return ct_b.value(); }; xt::xtensor<double, 5> get_ct_c() { if (!ct_c.has_value()) { ct_c = elements::calc_ct_c( get_o_k(), get_a_k() ); } return ct_c.value(); }; xt::xtensor<double, 5> get_ct_d() { if (!ct_d.has_value()) { ct_d = elements::calc_ct_d( get_o_k(), get_a_k() ); } return ct_d.value(); }; // methods to compute intermediate metrics xt::xtensor<double, 4> get_bs() { if (!bs.has_value()) { bs = intermediate::calc_bs( get_o_k(), get_y_k() ); } return bs.value(); }; xt::xtensor<double, 4> get_qs() { if (!qs.has_value()) { qs = intermediate::calc_qs( q_obs, get_q_qnt(), n_mbr ); } return qs.value(); };; xt::xtensor<double, 3> get_crps() { if (!crps.has_value()) { crps = intermediate::calc_crps( get_qs(), n_mbr ); } return crps.value(); }; xt::xtensor<double, 5> get_pod() {
281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
if (!pod.has_value()) { pod = intermediate::calc_pod( get_ct_a(), get_ct_c() ); } return pod.value(); }; xt::xtensor<double, 5> get_pofd() { if (!pofd.has_value()) { pofd = intermediate::calc_pofd( get_ct_b(), get_ct_d() ); } return pofd.value(); }; xt::xtensor<double, 5> get_far() { if (!far.has_value()) { far = intermediate::calc_far( get_ct_a(), get_ct_b() ); } return far.value(); }; xt::xtensor<double, 5> get_csi() { if (!csi.has_value()) { csi = intermediate::calc_csi( get_ct_a(), get_ct_b(), get_ct_c() ); } return csi.value(); }; public: // constructor method Evaluator(const XD2& obs, const XD4& prd, const XD2& thr, xtl::xoptional<const std::string&, bool> events, const XB4& msk, const std::vector<xt::xkeep_slice<int>>& exp) : q_obs{obs}, q_prd{prd}, _q_thr{thr}, _events{events}, t_msk(msk), b_exp(exp) { // initialise a mask if none provided // (corresponding to no temporal subset) if (msk.size() < 1) { t_msk = xt::ones<bool>( {q_prd.shape(0), q_prd.shape(1), std::size_t {1}, q_prd.shape(3)} ); } // determine size for recurring dimensions n_sit = q_prd.shape(0); n_ldt = q_prd.shape(1); n_mbr = q_prd.shape(2); n_tim = q_prd.shape(3); n_msk = t_msk.shape(2); n_thr = _q_thr.shape(1);
351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
n_exp = b_exp.size(); // drop time steps where observations and/or predictions are NaN for (std::size_t s = 0; s < n_sit; s++) { for (std::size_t l = 0; l < n_ldt; l++) { auto obs_nan = xt::isnan(xt::view(q_obs, s)); auto prd_nan = xt::isnan(xt::view(q_prd, s, l)); auto sum_nan = xt::eval(xt::sum(prd_nan, -1)); if (xt::amin(sum_nan) != xt::amax(sum_nan)) { throw std::runtime_error( "predictions across members feature " "non-equal lengths" ); } auto msk_nan = xt::where(obs_nan || xt::row(prd_nan, 0))[0]; xt::view(t_msk, s, l, xt::all(), xt::keep(msk_nan)) = false; } } }; // methods to compute metrics xt::xtensor<double, 5> get_BS() { if (!BS.has_value()) { BS = metrics::calc_BS( get_bs(), get_q_thr(), t_msk, b_exp, n_sit, n_ldt, n_thr, n_msk, n_exp ); } return BS.value(); }; xt::xtensor<double, 6> get_BS_CRD() { if (!BS_CRD.has_value()) { BS_CRD = metrics::calc_BS_CRD( get_q_thr(), get_o_k(), get_y_k(), get_bar_o(), t_msk, b_exp, n_sit, n_ldt, n_thr, n_mbr, n_msk, n_exp ); } return BS_CRD.value(); }; xt::xtensor<double, 6> get_BS_LBD() { if (!BS_LBD.has_value()) { BS_LBD = metrics::calc_BS_LBD( get_q_thr(), get_o_k(), get_y_k(), t_msk, b_exp, n_sit, n_ldt, n_thr, n_msk, n_exp ); } return BS_LBD.value(); }; xt::xtensor<double, 5> get_BSS()
421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490
{ if (!BSS.has_value()) { BSS = metrics::calc_BSS( get_bs(), get_q_thr(), get_o_k(), get_bar_o(), t_msk, b_exp, n_sit, n_ldt, n_thr, n_msk, n_exp ); } return BSS.value(); }; xt::xtensor<double, 5> get_QS() { if (!QS.has_value()) { QS = metrics::calc_QS( get_qs(), t_msk, b_exp, n_sit, n_ldt, n_mbr, n_msk, n_exp ); } return QS.value(); }; xt::xtensor<double, 4> get_CRPS() { if (!CRPS.has_value()) { CRPS = metrics::calc_CRPS( get_crps(), t_msk, b_exp, n_sit, n_ldt, n_msk, n_exp ); } return CRPS.value(); }; xt::xtensor<double, 6> get_POD() { if (!POD.has_value()) { POD = metrics::calc_POD( get_pod(), get_q_thr(), t_msk, b_exp, n_sit, n_ldt, n_thr, n_mbr, n_msk, n_exp ); } return POD.value(); }; xt::xtensor<double, 6> get_POFD() { if (!POFD.has_value()) { POFD = metrics::calc_POFD( get_pofd(), get_q_thr(), t_msk, b_exp, n_sit, n_ldt, n_thr, n_mbr, n_msk, n_exp ); } return POFD.value(); }; xt::xtensor<double, 6> get_FAR() { if (!FAR.has_value()) { FAR = metrics::calc_FAR( get_far(), get_q_thr(), t_msk, b_exp, n_sit, n_ldt, n_thr, n_mbr, n_msk, n_exp ); } return FAR.value(); };
491492493494495496497498499500501502503504505506507508509510511512513514515516517518519
xt::xtensor<double, 6> get_CSI() { if (!CSI.has_value()) { CSI = metrics::calc_CSI( get_csi(), get_q_thr(), t_msk, b_exp, n_sit, n_ldt, n_thr, n_mbr, n_msk, n_exp ); } return CSI.value(); }; xt::xtensor<double, 5> get_ROCSS() { if (!ROCSS.has_value()) { ROCSS = metrics::calc_ROCSS( get_POD(), get_POFD(), get_q_thr() ); } return ROCSS.value(); }; }; } } #endif //EVALHYD_PROBABILIST_EVALUATOR_HPP