Commit ad3bdc97 authored by Thibault Hallouin's avatar Thibault Hallouin
Browse files

fix dimension mismatches in contingency table-based metrics

1 merge request!3release v0.1.0
Pipeline #43305 passed with stage
in 2 minutes and 40 seconds
Showing with 19 additions and 20 deletions
+19 -20
......@@ -45,7 +45,7 @@ namespace evalhyd
// \return
// Alerts based on forecast.
// shape: (levels, thresholds, time)
inline xt::xtensor<double, 2> calc_a_k(
inline xt::xtensor<double, 3> calc_a_k(
const xt::xtensor<double, 2>& sum_f_k,
std::size_t n_mbr
)
......@@ -73,7 +73,7 @@ namespace evalhyd
// shape: (levels, thresholds, time)
inline xt::xtensor<double, 3> calc_ct_a(
const xt::xtensor<double, 2>& o_k,
const xt::xtensor<double, 2>& a_k
const xt::xtensor<double, 3>& a_k
)
{
return xt::equal(o_k, 1.) && xt::equal(a_k, 1.);
......@@ -92,7 +92,7 @@ namespace evalhyd
// shape: (levels, thresholds, time)
inline xt::xtensor<double, 3> calc_ct_b(
const xt::xtensor<double, 2>& o_k,
const xt::xtensor<double, 2>& a_k
const xt::xtensor<double, 3>& a_k
)
{
return xt::equal(o_k, 0.) && xt::equal(a_k, 1.);
......@@ -111,7 +111,7 @@ namespace evalhyd
// shape: (levels, thresholds, time)
inline xt::xtensor<double, 3> calc_ct_c(
const xt::xtensor<double, 2>& o_k,
const xt::xtensor<double, 2>& a_k
const xt::xtensor<double, 3>& a_k
)
{
return xt::equal(o_k, 1.) && xt::equal(a_k, 0.);
......@@ -130,7 +130,7 @@ namespace evalhyd
// shape: (levels, thresholds, time)
inline xt::xtensor<double, 3> calc_ct_d(
const xt::xtensor<double, 2>& o_k,
const xt::xtensor<double, 2>& a_k
const xt::xtensor<double, 3>& a_k
)
{
return xt::equal(o_k, 0.) && xt::equal(a_k, 0.);
......@@ -241,9 +241,8 @@ namespace evalhyd
)
{
// initialise output variable
// shape: (subsets, thresholds)
xt::xtensor<double, 4> METRIC =
xt::zeros<double>({n_msk, n_exp, n_mbr, n_thr});
xt::zeros<double>({n_msk, n_exp, n_mbr + 1, n_thr});
// compute variable one mask at a time to minimise memory imprint
for (std::size_t m = 0; m < n_msk; m++)
......@@ -475,7 +474,7 @@ namespace evalhyd
// \return
// Critical success indices.
// shape: (subsets, samples, levels, thresholds)
inline xt::xtensor<double, 4> calc_ROCSS(
inline xt::xtensor<double, 3> calc_ROCSS(
const xt::xtensor<double, 4>& POD,
const xt::xtensor<double, 4>& POFD
)
......
......@@ -173,7 +173,7 @@ namespace evalhyd
return a_k.value();
};
xt::xtensor<double, 2> get_ct_a()
xt::xtensor<double, 3> get_ct_a()
{
if (!ct_a.has_value())
{
......@@ -184,7 +184,7 @@ namespace evalhyd
return ct_a.value();
};
xt::xtensor<double, 2> get_ct_b()
xt::xtensor<double, 3> get_ct_b()
{
if (!ct_b.has_value())
{
......@@ -195,7 +195,7 @@ namespace evalhyd
return ct_b.value();
};
xt::xtensor<double, 2> get_ct_c()
xt::xtensor<double, 3> get_ct_c()
{
if (!ct_c.has_value())
{
......@@ -206,7 +206,7 @@ namespace evalhyd
return ct_c.value();
};
xt::xtensor<double, 2> get_ct_d()
xt::xtensor<double, 3> get_ct_d()
{
if (!ct_d.has_value())
{
......@@ -251,7 +251,7 @@ namespace evalhyd
return crps.value();
};
xt::xtensor<double, 4> get_pod()
xt::xtensor<double, 3> get_pod()
{
if (!pod.has_value())
{
......@@ -262,7 +262,7 @@ namespace evalhyd
return pod.value();
};
xt::xtensor<double, 4> get_pofd()
xt::xtensor<double, 3> get_pofd()
{
if (!pofd.has_value())
{
......@@ -273,7 +273,7 @@ namespace evalhyd
return pofd.value();
};
xt::xtensor<double, 4> get_far()
xt::xtensor<double, 3> get_far()
{
if (!far.has_value())
{
......@@ -284,7 +284,7 @@ namespace evalhyd
return far.value();
};
xt::xtensor<double, 4> get_csi()
xt::xtensor<double, 3> get_csi()
{
if (!csi.has_value())
{
......
......@@ -288,10 +288,10 @@ namespace evalhyd
dim["BS_LBD"] = {n_sit, n_ltm, n_msk, n_exp, n_thr, 3};
dim["QS"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr};
dim["CRPS"] = {n_sit, n_ltm, n_msk, n_exp};
dim["POD"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr, n_thr};
dim["POFD"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr, n_thr};
dim["FAR"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr, n_thr};
dim["CSI"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr, n_thr};
dim["POD"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr + 1, n_thr};
dim["POFD"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr + 1, n_thr};
dim["FAR"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr + 1, n_thr};
dim["CSI"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr + 1, n_thr};
dim["ROCSS"] = {n_sit, n_ltm, n_msk, n_exp, n_thr};
// generate masks from conditions if provided
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment