Commit a5c3aad4 authored by Thibault Hallouin's avatar Thibault Hallouin
Browse files

add multi-variate metric ES

the energy score (ES) is a multi-variate (multi-site)
generalisation of the CRPS
1 merge request!3release v0.1.0
Pipeline #46202 failed with stage
in 5 minutes and 24 seconds
Showing with 249 additions and 2 deletions
+249 -2
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "contingency.hpp" #include "contingency.hpp"
#include "ranks.hpp" #include "ranks.hpp"
#include "intervals.hpp" #include "intervals.hpp"
#include "multivariate.hpp"
namespace evalhyd namespace evalhyd
...@@ -93,6 +94,8 @@ namespace evalhyd ...@@ -93,6 +94,8 @@ namespace evalhyd
xtl::xoptional<xt::xtensor<double, 5>, bool> csi; xtl::xoptional<xt::xtensor<double, 5>, bool> csi;
// > Intervals-based // > Intervals-based
xtl::xoptional<xt::xtensor<double, 4>, bool> ws; xtl::xoptional<xt::xtensor<double, 4>, bool> ws;
// > Multi-variate
xtl::xoptional<xt::xtensor<double, 2>, bool> es;
// members for evaluation metrics // members for evaluation metrics
// > Brier-based // > Brier-based
...@@ -124,6 +127,8 @@ namespace evalhyd ...@@ -124,6 +127,8 @@ namespace evalhyd
xtl::xoptional<xt::xtensor<double, 5>, bool> AWI; xtl::xoptional<xt::xtensor<double, 5>, bool> AWI;
xtl::xoptional<xt::xtensor<double, 5>, bool> WS; xtl::xoptional<xt::xtensor<double, 5>, bool> WS;
xtl::xoptional<xt::xtensor<double, 5>, bool> WSS; xtl::xoptional<xt::xtensor<double, 5>, bool> WSS;
// > Multi-variate
xtl::xoptional<xt::xtensor<double, 4>, bool> ES;
// methods to get optional parameters // methods to get optional parameters
auto get_q_thr() auto get_q_thr()
...@@ -475,6 +480,17 @@ namespace evalhyd ...@@ -475,6 +480,17 @@ namespace evalhyd
return ws.value(); return ws.value();
}; };
xt::xtensor<double, 2> get_es()
{
if (!es.has_value())
{
es = intermediate::calc_es(
q_obs, q_prd, n_ldt, n_mbr, n_tim
);
}
return es.value();
};
public: public:
// constructor method // constructor method
Evaluator(const XD2& obs, Evaluator(const XD2& obs,
...@@ -817,6 +833,17 @@ namespace evalhyd ...@@ -817,6 +833,17 @@ namespace evalhyd
return WSS.value(); return WSS.value();
}; };
xt::xtensor<double, 4> get_ES()
{
if (!ES.has_value())
{
ES = metrics::calc_ES(
get_es(), t_msk, b_exp, n_ldt, n_msk, n_exp
);
}
return ES.value();
};
// methods to compute diagnostics // methods to compute diagnostics
xt::xtensor<double, 4> get_completeness() xt::xtensor<double, 4> get_completeness()
{ {
......
// Copyright (c) 2023, INRAE.
// Distributed under the terms of the GPL-3 Licence.
// The full licence is in the file LICENCE, distributed with this software.
#ifndef EVALHYD_PROBABILIST_MULTIVARIATE_HPP
#define EVALHYD_PROBABILIST_MULTIVARIATE_HPP
#include <xtensor/xtensor.hpp>
#include <xtensor/xview.hpp>
#include <xtensor/xmath.hpp>
namespace evalhyd
{
namespace probabilist
{
namespace intermediate
{
/// Compute the energy score for each time step computed using its
/// formulation based on expectancies where the ensemble is used as
/// the random variable.
///
/// \param q_obs
/// Streamflow observations.
/// shape: (sites, time)
/// \param q_prd
/// Streamflow predictions.
/// shape: (sites, lead times, members, time)
/// \param n_ldt
/// Number of lead times.
/// \param n_mbr
/// Number of ensemble members.
/// \param n_tim
/// Number of time steps.
/// \return
/// CRPS for each time step.
/// shape: (lead times, time)
template <class XD2, class XD4>
inline xt::xtensor<double, 2> calc_es(
const XD2& q_obs,
const XD4& q_prd,
std::size_t n_ldt,
std::size_t n_mbr,
std::size_t n_tim
)
{
// notations below follow Gneiting et al. (2008)
// https://doi.org/10.1007/s11749-008-0114-x
// initialise internal variable
xt::xtensor<double, 2> es_xj_x =
xt::zeros<double>({n_ldt, n_tim});
xt::xtensor<double, 2> es_xi_xj =
xt::zeros<double>({n_ldt, n_tim});
for (std::size_t j = 0; j < n_mbr; j++)
{
// $\sum_{j=1}^{m} || x_j - x ||$
es_xj_x += xt::sqrt(
xt::sum(
xt::square(
// x_j is the jth member of q_prd
xt::view(q_prd, xt::all(), xt::all(),
j, xt::all())
// x is q_obs
- xt::view(q_obs, xt::all(),
xt::newaxis(), xt::all())
),
0
)
);
for (std::size_t i = 0; i < n_mbr; i++)
{
// $\sum_{i=1}^{m} \sum_{j=1}^{m} || x_i - x_j ||$
es_xi_xj += xt::sqrt(
xt::sum(
xt::square(
// x_i is the ith member of q_prd
xt::view(q_prd, xt::all(),
xt::all(), i, xt::all())
// x_j is the jth member of q_prd
- xt::view(q_prd, xt::all(),
xt::all(), j, xt::all())
),
0
)
);
}
}
auto es = (
(1. / n_mbr * es_xj_x)
- (1. / (2 * n_mbr * n_mbr) * es_xi_xj)
);
return es;
}
}
namespace metrics
{
/// Compute the energy score (ES), a multi-site generalisation
/// of the continuous rank probability score.
///
/// \param es
/// ES for each time step.
/// shape: (lead times, time)
/// \param t_msk
/// Temporal subsets of the whole record.
/// shape: (sites, lead times, subsets, time)
/// \param b_exp
/// Boostrap samples.
/// shape: (samples, time slice)
/// \param n_ldt
/// Number of lead times.
/// \param n_tim
/// Number of time steps.
/// \param n_msk
/// Number of temporal subsets.
/// \param n_exp
/// Number of bootstrap samples.
/// \return
/// ES.
/// shape: (lead times, subsets, samples)
inline xt::xtensor<double, 4> calc_ES(
const xt::xtensor<double, 2>& es,
const xt::xtensor<bool, 4>& t_msk,
const std::vector<xt::xkeep_slice<int>>& b_exp,
std::size_t n_ldt,
std::size_t n_msk,
std::size_t n_exp
)
{
// initialise output variable
xt::xtensor<double, 4> ES =
xt::zeros<double>({std::size_t {1}, n_ldt, n_msk, n_exp});
// compute variable one mask at a time to minimise memory imprint
for (std::size_t m = 0; m < n_msk; m++)
{
// determine the multi-site mask (i.e. only retain time
// steps where no site is masked)
auto msk = xt::prod(
xt::view(t_msk, xt::all(), xt::all(), m, xt::all()),
0
);
// apply the mask
// (using NaN workaround until reducers work on masked_view)
auto es_masked = xt::where(msk, es, NAN);
// compute variable one sample at a time
for (std::size_t e = 0; e < n_exp; e++)
{
// apply the bootstrap sampling
auto es_masked_sampled = xt::view(
es_masked, xt::all(), b_exp[e]
);
// calculate the mean over the time steps
xt::view(ES, 0, xt::all(), m, e) =
xt::nanmean(es_masked_sampled, -1);
}
}
return ES;
}
}
}
}
#endif //EVALHYD_PROBABILIST_MULTIVARIATE_HPP
...@@ -229,7 +229,8 @@ namespace evalhyd ...@@ -229,7 +229,8 @@ namespace evalhyd
"QS", "CRPS_FROM_QS", "QS", "CRPS_FROM_QS",
"POD", "POFD", "FAR", "CSI", "ROCSS", "POD", "POFD", "FAR", "CSI", "ROCSS",
"RANK_HIST", "DS", "AS", "RANK_HIST", "DS", "AS",
"CR", "AW", "AWN", "AWI", "WS", "WSS"} "CR", "AW", "AWN", "AWI", "WS", "WSS",
"ES"}
); );
if ( diagnostics.has_value() ) if ( diagnostics.has_value() )
...@@ -553,6 +554,12 @@ namespace evalhyd ...@@ -553,6 +554,12 @@ namespace evalhyd
uncertainty::summarise_p(evaluator.get_WSS(), summary) uncertainty::summarise_p(evaluator.get_WSS(), summary)
); );
} }
else if ( metric == "ES" )
{
r.emplace_back(
uncertainty::summarise_p(evaluator.get_ES(), summary)
);
}
} }
if ( diagnostics.has_value() ) if ( diagnostics.has_value() )
......
587.2254970444062
...@@ -35,7 +35,8 @@ std::vector<std::string> all_metrics_p = { ...@@ -35,7 +35,8 @@ std::vector<std::string> all_metrics_p = {
"QS", "CRPS_FROM_QS", "QS", "CRPS_FROM_QS",
"POD", "POFD", "FAR", "CSI", "ROCSS", "POD", "POFD", "FAR", "CSI", "ROCSS",
"RANK_HIST", "DS", "AS", "RANK_HIST", "DS", "AS",
"CR", "AW", "AWN", "AWI", "WS", "WSS" "CR", "AW", "AWN", "AWI", "WS", "WSS",
"ES"
}; };
std::tuple<xt::xtensor<double, 1>, xt::xtensor<double, 2>> load_data_p() std::tuple<xt::xtensor<double, 1>, xt::xtensor<double, 2>> load_data_p()
...@@ -285,6 +286,44 @@ TEST(ProbabilistTests, TestIntervals) ...@@ -285,6 +286,44 @@ TEST(ProbabilistTests, TestIntervals)
} }
} }
TEST(ProbabilistTests, TestMultiVariate)
{
// read in data
xt::xtensor<double, 1> observed;
xt::xtensor<double, 2> predicted;
std::tie(observed, predicted) = load_data_p();
// read in expected results
auto expected = load_expected_p();
// compute scores
std::vector<std::string> metrics = {"ES"};
xt::xtensor<double, 2> obs = xt::repeat(
xt::view(observed, xt::newaxis(), xt::all()), 5, 0
);
xt::xtensor<double, 4> prd = xt::repeat(
xt::view(predicted, xt::newaxis(), xt::newaxis(), xt::all(), xt::all()), 5, 0
);
std::vector<xt::xarray<double>> results =
evalhyd::evalp(
// shape: (sites [5], time [t])
obs,
// shape: (sites [5], lead times [1], members [m], time [t])
prd,
metrics
);
// check results
for (std::size_t m = 0; m < metrics.size(); m++)
{
EXPECT_TRUE(xt::all(xt::isclose(
results[m], expected[metrics[m]], 1e-05, 1e-08, true
))) << "Failure for (" << metrics[m] << ")";
}
}
TEST(ProbabilistTests, TestMasks) TEST(ProbabilistTests, TestMasks)
{ {
// read in data // read in data
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment