evalp.cpp 11.83 KiB
#include <utility>
#include <unordered_map>
#include <vector>
#include <array>
#include <stdexcept>
#include <xtensor/xtensor.hpp>
#include <xtensor/xarray.hpp>
#include <xtensor/xview.hpp>
#include "evalhyd/evalp.hpp"
#include "utils.hpp"
#include "masks.hpp"
#include "maths.hpp"
#include "uncertainty.hpp"
#include "probabilist/evaluator.hpp"
namespace eh = evalhyd;
namespace evalhyd
    std::vector<xt::xarray<double>> evalp(
            const xt::xtensor<double, 2>& q_obs,
            const xt::xtensor<double, 4>& q_prd,
            const std::vector<std::string>& metrics,
            const xt::xtensor<double, 2>& q_thr,
            const xt::xtensor<bool, 4>& t_msk,
            const xt::xtensor<std::array<char, 32>, 2>& m_cdt,
            const std::unordered_map<std::string, int>& bootstrap,
            const std::vector<std::string>& dts
        // check that the metrics to be computed are valid
        eh::utils::check_metrics(
                metrics,
                {"BS", "BSS", "BS_CRD", "BS_LBD", "QS", "CRPS"}
        // check that optional parameters are given as arguments
        eh::utils::evalp::check_optionals(metrics, q_thr);
        eh::utils::check_bootstrap(bootstrap);
        // check that data dimensions are compatible
        // > time
        if (q_obs.shape(1) != q_prd.shape(3))
            throw std::runtime_error(
                    "observations and predictions feature different "
                    "temporal lengths"
        if (t_msk.size() > 0)
            if (q_obs.shape(1) != t_msk.shape(3))
                throw std::runtime_error(
                        "observations and masks feature different "
                        "temporal lengths"
        if (!dts.empty())
            if (q_obs.shape(1) != dts.size())
                throw std::runtime_error(
                        "observations and datetimes feature different "
                        "temporal lengths"
        // > leadtimes
        if (t_msk.size() > 0)
            if (q_prd.shape(1) != t_msk.shape(1))
                throw std::runtime_error(
                        "predictions and temporal masks feature different "
                        "numbers of lead times"
        // > sites
        if (q_obs.shape(0) != q_prd.shape(0))
            throw std::runtime_error(
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
"observations and predictions feature different " "numbers of sites" ); if (t_msk.size() > 0) if (q_obs.shape(0) != t_msk.shape(0)) throw std::runtime_error( "observations and temporal masks feature different " "numbers of sites" ); if (m_cdt.size() > 0) if (q_obs.shape(0) != m_cdt.shape(0)) throw std::runtime_error( "observations and masking conditions feature different " "numbers of sites" ); // retrieve dimensions std::size_t n_sit = q_prd.shape(0); std::size_t n_ltm = q_prd.shape(1); std::size_t n_mbr = q_prd.shape(2); std::size_t n_tim = q_prd.shape(3); std::size_t n_thr = q_thr.shape(1); std::size_t n_msk = t_msk.size() > 0 ? t_msk.shape(2) : (m_cdt.size() > 0 ? m_cdt.shape(1) : 1); std::size_t n_exp = bootstrap.find("n_samples")->second == -9 ? 1: bootstrap.find("n_samples")->second; // register metrics number of dimensions std::unordered_map<std::string, std::vector<std::size_t>> dim; dim["BS"] = {n_sit, n_ltm, n_msk, n_exp, n_thr}; dim["BSS"] = {n_sit, n_ltm, n_msk, n_exp, n_thr}; dim["BS_CRD"] = {n_sit, n_ltm, n_msk, n_exp, n_thr, 3}; dim["BS_LBD"] = {n_sit, n_ltm, n_msk, n_exp, n_thr, 3}; dim["QS"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr}; dim["CRPS"] = {n_sit, n_ltm, n_msk, n_exp}; // declare maps for memoisation purposes std::unordered_map<std::string, std::vector<std::string>> elt; std::unordered_map<std::string, std::vector<std::string>> dep; // register potentially recurring computation elements across metrics elt["bs"] = {"o_k", "y_k"}; elt["BSS"] = {"o_k", "bar_o"}; elt["BS_CRD"] = {"o_k", "bar_o", "y_k"}; elt["BS_LBD"] = {"o_k", "y_k"}; elt["qs"] = {"q_qnt"}; // register nested metrics (i.e. metric dependent on another metric) dep["BS"] = {"bs"}; dep["BSS"] = {"bs"}; dep["QS"] = {"qs"}; dep["CRPS"] = {"qs", "crps"}; // determine required elt/dep to be pre-computed std::vector<std::string> req_elt; std::vector<std::string> req_dep; eh::utils::find_requirements(metrics, elt, dep, req_elt, req_dep); // generate masks from conditions if provided auto gen_msk = [&]() { xt::xtensor<bool, 4> c_msk = xt::zeros<bool>({n_sit, n_ltm, n_msk, n_tim}); if (m_cdt.size() > 0) for (int s = 0; s < n_sit; s++) for (int l = 0; l < n_ltm; l++) for (int m = 0; m < n_msk; m++) xt::view(c_msk, s, l, m) = eh::masks::generate_mask_from_conditions( xt::view(m_cdt, s, m),
141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
xt::view(q_obs, s), xt::view(q_prd, s, l) ); return c_msk; }; const xt::xtensor<bool, 4> c_msk = gen_msk(); // generate bootstrap experiment if requested std::vector<xt::xkeep_slice<int>> b_exp; auto n_samples = bootstrap.find("n_samples")->second; auto len_sample = bootstrap.find("len_sample")->second; if ((n_samples != -9) && (len_sample != -9)) { if (dts.empty()) throw std::runtime_error( "bootstrap requested but datetimes not provided" ); b_exp = eh::uncertainty::bootstrap( dts, n_samples, len_sample ); } else { // if no bootstrap requested, generate one sample // containing all the time indices once xt::xtensor<int, 1> all = xt::arange(n_tim); b_exp.push_back(xt::keep(all)); } // initialise data structure for outputs std::vector<xt::xarray<double>> r; for (const auto& metric : metrics) r.emplace_back(xt::zeros<double>(dim[metric])); auto summary = bootstrap.find("summary")->second; // compute variables one site at a time to minimise memory imprint for (int s = 0; s < n_sit; s++) // compute variables one lead time at a time to minimise memory imprint for (int l = 0; l < n_ltm; l++) { // instantiate probabilist evaluator const auto q_obs_v = xt::view(q_obs, s, xt::all()); const auto q_prd_v = xt::view(q_prd, s, l, xt::all(), xt::all()); const auto q_thr_v = xt::view(q_thr, s, xt::all()); const auto t_msk_v = t_msk.size() > 0 ? xt::view(t_msk, s, l, xt::all(), xt::all()) : (m_cdt.size() > 0 ? xt::view(c_msk, s, l, xt::all(), xt::all()) : xt::view(t_msk, s, l, xt::all(), xt::all())); eh::probabilist::Evaluator evaluator( q_obs_v, q_prd_v, q_thr_v, t_msk_v, b_exp ); // pre-compute required elt for (const auto& element : req_elt) { if ( element == "o_k" ) evaluator.calc_o_k(); else if ( element == "bar_o" ) evaluator.calc_bar_o(); else if ( element == "y_k" ) evaluator.calc_y_k(); else if ( element == "q_qnt" ) evaluator.calc_q_qnt(); }
211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
// pre-compute required dep for (const auto& dependency : req_dep) { if ( dependency == "bs" ) evaluator.calc_bs(); else if ( dependency == "qs" ) evaluator.calc_qs(); else if ( dependency == "crps" ) evaluator.calc_crps(); } // retrieve or compute requested metrics for (int m = 0; m < metrics.size(); m++) { const auto& metric = metrics[m]; if ( metric == "BS" ) { if (std::find(req_dep.begin(), req_dep.end(), metric) == req_dep.end()) evaluator.calc_BS(); // (sites, lead times, subsets, samples, thresholds) xt::view(r[m], s, l, xt::all(), xt::all(), xt::all()) = eh::uncertainty::summarise(evaluator.BS, summary); } else if ( metric == "BSS" ) { if (std::find(req_dep.begin(), req_dep.end(), metric) == req_dep.end()) evaluator.calc_BSS(); // (sites, lead times, subsets, samples, thresholds) xt::view(r[m], s, l, xt::all(), xt::all(), xt::all()) = eh::uncertainty::summarise(evaluator.BSS, summary); } else if ( metric == "BS_CRD" ) { if (std::find(req_dep.begin(), req_dep.end(), metric) == req_dep.end()) evaluator.calc_BS_CRD(); // (sites, lead times, subsets, samples, thresholds, components) xt::view(r[m], s, l, xt::all(), xt::all(), xt::all(), xt::all()) = eh::uncertainty::summarise(evaluator.BS_CRD, summary); } else if ( metric == "BS_LBD" ) { if (std::find(req_dep.begin(), req_dep.end(), metric) == req_dep.end()) evaluator.calc_BS_LBD(); // (sites, lead times, subsets, samples, thresholds, components) xt::view(r[m], s, l, xt::all(), xt::all(), xt::all(), xt::all()) = eh::uncertainty::summarise(evaluator.BS_LBD, summary); } else if ( metric == "QS" ) { if (std::find(req_dep.begin(), req_dep.end(), metric) == req_dep.end()) evaluator.calc_QS(); // (sites, lead times, subsets, samples, quantiles) xt::view(r[m], s, l, xt::all(), xt::all(), xt::all()) = eh::uncertainty::summarise(evaluator.QS, summary); } else if ( metric == "CRPS" ) { if (std::find(req_dep.begin(), req_dep.end(), metric) == req_dep.end()) evaluator.calc_CRPS(); // (sites, lead times, subsets, samples) xt::view(r[m], s, l, xt::all(), xt::all()) = eh::uncertainty::summarise(evaluator.CRPS, summary); }
281282283284285286287
} } return r; } }