An error occurred while loading the file. Please try again.
-
dkuhlman authorede45e59ee
// Copyright (c) 2023, INRAE.
// Distributed under the terms of the GPL-3 Licence.
// The full licence is in the file LICENCE, distributed with this software.
#ifndef EVALHYD_DETERMINIST_EVALUATOR_HPP
#define EVALHYD_DETERMINIST_EVALUATOR_HPP
#include <vector>
#include <xtl/xoptional.hpp>
#include <xtensor/xexpression.hpp>
#include <xtensor/xtensor.hpp>
#include "diagnostics.hpp"
#include "errors.hpp"
#include "efficiencies.hpp"
#include "events.hpp"
namespace evalhyd
{
namespace determinist
{
template <class XD2, class XB3>
class Evaluator
{
private:
// members for input data
const XD2& q_obs;
const XD2& q_prd;
// members for optional input data
const XD2& _q_thr;
xtl::xoptional<const std::string, bool> _events;
xt::xtensor<bool, 3> t_msk;
const std::vector<xt::xkeep_slice<int>>& b_exp;
// members for dimensions
std::size_t n_tim;
std::size_t n_msk;
std::size_t n_srs;
std::size_t n_thr;
std::size_t n_exp;
// members for computational elements
xtl::xoptional<xt::xtensor<double, 3>, bool> t_counts;
xtl::xoptional<xt::xtensor<double, 4>, bool> mean_obs;
xtl::xoptional<xt::xtensor<double, 4>, bool> mean_prd;
xtl::xoptional<xt::xtensor<double, 2>, bool> err;
xtl::xoptional<xt::xtensor<double, 2>, bool> abs_err;
xtl::xoptional<xt::xtensor<double, 2>, bool> quad_err;
xtl::xoptional<xt::xtensor<double, 4>, bool> err_obs;
xtl::xoptional<xt::xtensor<double, 4>, bool> quad_err_obs;
xtl::xoptional<xt::xtensor<double, 4>, bool> err_prd;
xtl::xoptional<xt::xtensor<double, 4>, bool> quad_err_prd;
xtl::xoptional<xt::xtensor<double, 3>, bool> r_pearson;
xtl::xoptional<xt::xtensor<double, 3>, bool> r_spearman;
xtl::xoptional<xt::xtensor<double, 3>, bool> alpha;
xtl::xoptional<xt::xtensor<double, 3>, bool> gamma;
xtl::xoptional<xt::xtensor<double, 3>, bool> alpha_np;
xtl::xoptional<xt::xtensor<double, 3>, bool> bias;
xtl::xoptional<xt::xtensor<double, 3>, bool> obs_event;
xtl::xoptional<xt::xtensor<double, 3>, bool> prd_event;
xtl::xoptional<xt::xtensor<double, 3>, bool> ct_a;
xtl::xoptional<xt::xtensor<double, 3>, bool> ct_b;
xtl::xoptional<xt::xtensor<double, 3>, bool> ct_c;
xtl::xoptional<xt::xtensor<double, 3>, bool> ct_d;
// members for evaluation metrics
xtl::xoptional<xt::xtensor<double, 3>, bool> MAE;
xtl::xoptional<xt::xtensor<double, 3>, bool> MARE;
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
xtl::xoptional<xt::xtensor<double, 3>, bool> MSE;
xtl::xoptional<xt::xtensor<double, 3>, bool> RMSE;
xtl::xoptional<xt::xtensor<double, 3>, bool> NSE;
xtl::xoptional<xt::xtensor<double, 3>, bool> KGE;
xtl::xoptional<xt::xtensor<double, 4>, bool> KGE_D;
xtl::xoptional<xt::xtensor<double, 3>, bool> KGEPRIME;
xtl::xoptional<xt::xtensor<double, 4>, bool> KGEPRIME_D;
xtl::xoptional<xt::xtensor<double, 3>, bool> KGENP;
xtl::xoptional<xt::xtensor<double, 4>, bool> KGENP_D;
xtl::xoptional<xt::xtensor<double, 5>, bool> CONT_TBL;
// methods to get optional parameters
auto get_q_thr()
{
if (_q_thr.size() < 1)
{
throw std::runtime_error(
"threshold-based metric requested, "
"but *q_thr* not provided"
);
}
else{
return _q_thr;
}
}
bool is_high_flow_event()
{
if (_events.has_value())
{
if (_events.value() == "high")
{
return true;
}
else if (_events.value() == "low")
{
return false;
}
else
{
throw std::runtime_error(
"invalid value for *events*: " + _events.value()
);
}
}
else
{
throw std::runtime_error(
"threshold-based metric requested, "
"but *events* not provided"
);
}
}
// methods to compute elements
xt::xtensor<double, 3> get_t_counts()
{
if (!t_counts.has_value())
{
t_counts = elements::calc_t_counts(
t_msk, b_exp, n_srs, n_msk, n_exp
);
}
return t_counts.value();
};
xt::xtensor<double, 4> get_mean_obs()
{
if (!mean_obs.has_value())
{
141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
mean_obs = elements::calc_mean_obs(
q_obs, t_msk, b_exp, n_srs, n_msk, n_exp
);
}
return mean_obs.value();
};
xt::xtensor<double, 4> get_mean_prd()
{
if (!mean_prd.has_value())
{
mean_prd = elements::calc_mean_prd(
q_prd, t_msk, b_exp, n_srs, n_msk, n_exp
);
}
return mean_prd.value();
};
xt::xtensor<double, 2> get_err()
{
if (!err.has_value())
{
err = elements::calc_err(
q_obs, q_prd
);
}
return err.value();
};
xt::xtensor<double, 2> get_abs_err()
{
if (!abs_err.has_value())
{
abs_err = elements::calc_abs_err(
get_err()
);
}
return abs_err.value();
};
xt::xtensor<double, 2> get_quad_err()
{
if (!quad_err.has_value())
{
quad_err = elements::calc_quad_err(
get_err()
);
}
return quad_err.value();
};
xt::xtensor<double, 4> get_err_obs()
{
if (!err_obs.has_value())
{
err_obs = elements::calc_err_obs(
q_obs, get_mean_obs(), t_msk,
n_srs, n_tim, n_msk, n_exp
);
}
return err_obs.value();
};
xt::xtensor<double, 4> get_quad_err_obs()
{
if (!quad_err_obs.has_value())
{
quad_err_obs = elements::calc_quad_err_obs(
get_err_obs()
);
211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
}
return quad_err_obs.value();
};
xt::xtensor<double, 4> get_err_prd()
{
if (!err_prd.has_value())
{
err_prd = elements::calc_err_prd(
q_prd, get_mean_prd(), t_msk,
n_srs, n_tim, n_msk, n_exp
);
}
return err_prd.value();
};
xt::xtensor<double, 4> get_quad_err_prd()
{
if (!quad_err_prd.has_value())
{
quad_err_prd = elements::calc_quad_err_prd(
get_err_prd()
);
}
return quad_err_prd.value();
};
xt::xtensor<double, 3> get_r_pearson()
{
if (!r_pearson.has_value())
{
r_pearson = elements::calc_r_pearson(
get_err_obs(), get_err_prd(),
get_quad_err_obs(), get_quad_err_prd(),
t_msk, b_exp,
n_srs, n_msk, n_exp
);
}
return r_pearson.value();
};
xt::xtensor<double, 3> get_r_spearman()
{
if (!r_spearman.has_value())
{
r_spearman = elements::calc_r_spearman(
q_obs, q_prd, t_msk, b_exp,
n_srs, n_msk, n_exp
);
}
return r_spearman.value();
};
xt::xtensor<double, 3> get_alpha()
{
if (!alpha.has_value())
{
alpha = elements::calc_alpha(
q_obs, q_prd, get_mean_obs(), get_mean_prd(),
t_msk, b_exp, n_srs, n_msk, n_exp
);
}
return alpha.value();
};
xt::xtensor<double, 3> get_gamma()
{
if (!gamma.has_value())
{
gamma = elements::calc_gamma(
281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
get_mean_obs(), get_mean_prd(), get_alpha()
);
}
return gamma.value();
};
xt::xtensor<double, 3> get_alpha_np()
{
if (!alpha_np.has_value())
{
alpha_np = elements::calc_alpha_np(
q_obs, q_prd, get_mean_obs(), get_mean_prd(),
t_msk, b_exp, n_srs, n_msk, n_exp
);
}
return alpha_np.value();
};
xt::xtensor<double, 3> get_bias()
{
if (!bias.has_value())
{
bias = elements::calc_bias(
q_obs, q_prd, t_msk, b_exp, n_srs, n_msk, n_exp
);
}
return bias.value();
};
xt::xtensor<double, 3> get_obs_event()
{
if (!obs_event.has_value())
{
obs_event = elements::calc_obs_event(
q_obs, get_q_thr(), is_high_flow_event()
);
}
return obs_event.value();
};
xt::xtensor<double, 3> get_prd_event()
{
if (!prd_event.has_value())
{
prd_event = elements::calc_prd_event(
q_prd, get_q_thr(), is_high_flow_event()
);
}
return prd_event.value();
};
xt::xtensor<double, 3> get_ct_a()
{
if (!ct_a.has_value())
{
ct_a = elements::calc_ct_a(
get_obs_event(), get_prd_event()
);
}
return ct_a.value();
};
xt::xtensor<double, 3> get_ct_b()
{
if (!ct_b.has_value())
{
ct_b = elements::calc_ct_b(
get_obs_event(), get_prd_event()
);
}
351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
return ct_b.value();
};
xt::xtensor<double, 3> get_ct_c()
{
if (!ct_c.has_value())
{
ct_c = elements::calc_ct_c(
get_obs_event(), get_prd_event()
);
}
return ct_c.value();
};
xt::xtensor<double, 3> get_ct_d()
{
if (!ct_d.has_value())
{
ct_d = elements::calc_ct_d(
get_obs_event(), get_prd_event()
);
}
return ct_d.value();
};
public:
// constructor method
Evaluator(const XD2& obs,
const XD2& prd,
const XD2& thr,
xtl::xoptional<const std::string&, bool> events,
const XB3& msk,
const std::vector<xt::xkeep_slice<int>>& exp) :
q_obs{obs}, q_prd{prd},
_q_thr{thr}, _events{events},
t_msk{msk}, b_exp{exp}
{
// initialise a mask if none provided
// (corresponding to no temporal subset)
if (msk.size() < 1)
{
t_msk = xt::ones<bool>(
{q_prd.shape(0), std::size_t {1}, q_prd.shape(1)}
);
}
// determine size for recurring dimensions
n_srs = q_prd.shape(0);
n_tim = q_prd.shape(1);
n_msk = t_msk.shape(1);
n_thr = _q_thr.shape(1);
n_exp = b_exp.size();
// drop time steps where observations or predictions are NaN
for (std::size_t s = 0; s < n_srs; s++)
{
auto obs_nan = xt::isnan(xt::view(q_obs, 0));
auto prd_nan = xt::isnan(xt::view(q_prd, s));
auto msk_nan = xt::where(obs_nan || prd_nan)[0];
xt::view(t_msk, s, xt::all(), xt::keep(msk_nan)) = false;
}
};
// methods to compute metrics
xt::xtensor<double, 3> get_MAE()
{
if (!MAE.has_value())
{
421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490
MAE = metrics::calc_MAE(
get_abs_err(), t_msk, b_exp, n_srs, n_msk, n_exp
);
}
return MAE.value();
};
xt::xtensor<double, 3> get_MARE()
{
if (!MARE.has_value())
{
MARE = metrics::calc_MARE(
get_MAE(), get_mean_obs(), n_srs, n_msk, n_exp
);
}
return MARE.value();
};
xt::xtensor<double, 3> get_MSE()
{
if (!MSE.has_value())
{
MSE = metrics::calc_MSE(
get_quad_err(), t_msk, b_exp, n_srs, n_msk, n_exp
);
}
return MSE.value();
};
xt::xtensor<double, 3> get_RMSE()
{
if (!RMSE.has_value())
{
RMSE = metrics::calc_RMSE(
get_MSE()
);
}
return RMSE.value();
};
xt::xtensor<double, 3> get_NSE()
{
if (!NSE.has_value())
{
NSE = metrics::calc_NSE(
get_quad_err(), get_quad_err_obs(), t_msk, b_exp,
n_srs, n_msk, n_exp
);
}
return NSE.value();
};
xt::xtensor<double, 3> get_KGE()
{
if (!KGE.has_value())
{
KGE = metrics::calc_KGE(
get_r_pearson(), get_alpha(), get_bias(),
n_srs, n_msk, n_exp
);
}
return KGE.value();
};
xt::xtensor<double, 4> get_KGE_D()
{
if (!KGE_D.has_value())
{
KGE_D = metrics::calc_KGE_D(
get_r_pearson(), get_alpha(), get_bias(),
491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560
n_srs, n_msk, n_exp
);
}
return KGE_D.value();
};
xt::xtensor<double, 3> get_KGEPRIME()
{
if (!KGEPRIME.has_value())
{
KGEPRIME = metrics::calc_KGEPRIME(
get_r_pearson(), get_gamma(), get_bias(),
n_srs, n_msk, n_exp
);
}
return KGEPRIME.value();
};
xt::xtensor<double, 4> get_KGEPRIME_D()
{
if (!KGEPRIME_D.has_value())
{
KGEPRIME_D = metrics::calc_KGEPRIME_D(
get_r_pearson(), get_gamma(), get_bias(),
n_srs, n_msk, n_exp
);
}
return KGEPRIME_D.value();
};
xt::xtensor<double, 3> get_KGENP()
{
if (!KGENP.has_value())
{
KGENP = metrics::calc_KGENP(
get_r_spearman(), get_alpha_np(), get_bias(),
n_srs, n_msk, n_exp
);
}
return KGENP.value();
};
xt::xtensor<double, 4> get_KGENP_D()
{
if (!KGENP_D.has_value())
{
KGENP_D = metrics::calc_KGENP_D(
get_r_spearman(), get_alpha_np(), get_bias(),
n_srs, n_msk, n_exp
);
}
return KGENP_D.value();
};
xt::xtensor<double, 5> get_CONT_TBL()
{
if (!CONT_TBL.has_value())
{
CONT_TBL = metrics::calc_CONT_TBL(
get_q_thr(), get_ct_a(), get_ct_b(), get_ct_c(),
get_ct_d(), t_msk, b_exp,
n_srs, n_thr, n_msk, n_exp
);
}
return CONT_TBL.value();
};
// methods to compute diagnostics
xt::xtensor<double, 3> get_completeness()
{
561562563564565566567568
return get_t_counts();
};
};
}
}
#endif //EVALHYD_DETERMINIST_EVALUATOR_HPP