Commit ad3bdc97 authored by Thibault Hallouin's avatar Thibault Hallouin
Browse files

fix dimension mismatches in contingency table-based metrics

1 merge request!3release v0.1.0
Pipeline #43305 passed with stage
in 2 minutes and 40 seconds
Showing with 19 additions and 20 deletions
+19 -20
...@@ -45,7 +45,7 @@ namespace evalhyd ...@@ -45,7 +45,7 @@ namespace evalhyd
// \return // \return
// Alerts based on forecast. // Alerts based on forecast.
// shape: (levels, thresholds, time) // shape: (levels, thresholds, time)
inline xt::xtensor<double, 2> calc_a_k( inline xt::xtensor<double, 3> calc_a_k(
const xt::xtensor<double, 2>& sum_f_k, const xt::xtensor<double, 2>& sum_f_k,
std::size_t n_mbr std::size_t n_mbr
) )
...@@ -73,7 +73,7 @@ namespace evalhyd ...@@ -73,7 +73,7 @@ namespace evalhyd
// shape: (levels, thresholds, time) // shape: (levels, thresholds, time)
inline xt::xtensor<double, 3> calc_ct_a( inline xt::xtensor<double, 3> calc_ct_a(
const xt::xtensor<double, 2>& o_k, const xt::xtensor<double, 2>& o_k,
const xt::xtensor<double, 2>& a_k const xt::xtensor<double, 3>& a_k
) )
{ {
return xt::equal(o_k, 1.) && xt::equal(a_k, 1.); return xt::equal(o_k, 1.) && xt::equal(a_k, 1.);
...@@ -92,7 +92,7 @@ namespace evalhyd ...@@ -92,7 +92,7 @@ namespace evalhyd
// shape: (levels, thresholds, time) // shape: (levels, thresholds, time)
inline xt::xtensor<double, 3> calc_ct_b( inline xt::xtensor<double, 3> calc_ct_b(
const xt::xtensor<double, 2>& o_k, const xt::xtensor<double, 2>& o_k,
const xt::xtensor<double, 2>& a_k const xt::xtensor<double, 3>& a_k
) )
{ {
return xt::equal(o_k, 0.) && xt::equal(a_k, 1.); return xt::equal(o_k, 0.) && xt::equal(a_k, 1.);
...@@ -111,7 +111,7 @@ namespace evalhyd ...@@ -111,7 +111,7 @@ namespace evalhyd
// shape: (levels, thresholds, time) // shape: (levels, thresholds, time)
inline xt::xtensor<double, 3> calc_ct_c( inline xt::xtensor<double, 3> calc_ct_c(
const xt::xtensor<double, 2>& o_k, const xt::xtensor<double, 2>& o_k,
const xt::xtensor<double, 2>& a_k const xt::xtensor<double, 3>& a_k
) )
{ {
return xt::equal(o_k, 1.) && xt::equal(a_k, 0.); return xt::equal(o_k, 1.) && xt::equal(a_k, 0.);
...@@ -130,7 +130,7 @@ namespace evalhyd ...@@ -130,7 +130,7 @@ namespace evalhyd
// shape: (levels, thresholds, time) // shape: (levels, thresholds, time)
inline xt::xtensor<double, 3> calc_ct_d( inline xt::xtensor<double, 3> calc_ct_d(
const xt::xtensor<double, 2>& o_k, const xt::xtensor<double, 2>& o_k,
const xt::xtensor<double, 2>& a_k const xt::xtensor<double, 3>& a_k
) )
{ {
return xt::equal(o_k, 0.) && xt::equal(a_k, 0.); return xt::equal(o_k, 0.) && xt::equal(a_k, 0.);
...@@ -241,9 +241,8 @@ namespace evalhyd ...@@ -241,9 +241,8 @@ namespace evalhyd
) )
{ {
// initialise output variable // initialise output variable
// shape: (subsets, thresholds)
xt::xtensor<double, 4> METRIC = xt::xtensor<double, 4> METRIC =
xt::zeros<double>({n_msk, n_exp, n_mbr, n_thr}); xt::zeros<double>({n_msk, n_exp, n_mbr + 1, n_thr});
// compute variable one mask at a time to minimise memory imprint // compute variable one mask at a time to minimise memory imprint
for (std::size_t m = 0; m < n_msk; m++) for (std::size_t m = 0; m < n_msk; m++)
...@@ -475,7 +474,7 @@ namespace evalhyd ...@@ -475,7 +474,7 @@ namespace evalhyd
// \return // \return
// Critical success indices. // Critical success indices.
// shape: (subsets, samples, levels, thresholds) // shape: (subsets, samples, levels, thresholds)
inline xt::xtensor<double, 4> calc_ROCSS( inline xt::xtensor<double, 3> calc_ROCSS(
const xt::xtensor<double, 4>& POD, const xt::xtensor<double, 4>& POD,
const xt::xtensor<double, 4>& POFD const xt::xtensor<double, 4>& POFD
) )
......
...@@ -173,7 +173,7 @@ namespace evalhyd ...@@ -173,7 +173,7 @@ namespace evalhyd
return a_k.value(); return a_k.value();
}; };
xt::xtensor<double, 2> get_ct_a() xt::xtensor<double, 3> get_ct_a()
{ {
if (!ct_a.has_value()) if (!ct_a.has_value())
{ {
...@@ -184,7 +184,7 @@ namespace evalhyd ...@@ -184,7 +184,7 @@ namespace evalhyd
return ct_a.value(); return ct_a.value();
}; };
xt::xtensor<double, 2> get_ct_b() xt::xtensor<double, 3> get_ct_b()
{ {
if (!ct_b.has_value()) if (!ct_b.has_value())
{ {
...@@ -195,7 +195,7 @@ namespace evalhyd ...@@ -195,7 +195,7 @@ namespace evalhyd
return ct_b.value(); return ct_b.value();
}; };
xt::xtensor<double, 2> get_ct_c() xt::xtensor<double, 3> get_ct_c()
{ {
if (!ct_c.has_value()) if (!ct_c.has_value())
{ {
...@@ -206,7 +206,7 @@ namespace evalhyd ...@@ -206,7 +206,7 @@ namespace evalhyd
return ct_c.value(); return ct_c.value();
}; };
xt::xtensor<double, 2> get_ct_d() xt::xtensor<double, 3> get_ct_d()
{ {
if (!ct_d.has_value()) if (!ct_d.has_value())
{ {
...@@ -251,7 +251,7 @@ namespace evalhyd ...@@ -251,7 +251,7 @@ namespace evalhyd
return crps.value(); return crps.value();
}; };
xt::xtensor<double, 4> get_pod() xt::xtensor<double, 3> get_pod()
{ {
if (!pod.has_value()) if (!pod.has_value())
{ {
...@@ -262,7 +262,7 @@ namespace evalhyd ...@@ -262,7 +262,7 @@ namespace evalhyd
return pod.value(); return pod.value();
}; };
xt::xtensor<double, 4> get_pofd() xt::xtensor<double, 3> get_pofd()
{ {
if (!pofd.has_value()) if (!pofd.has_value())
{ {
...@@ -273,7 +273,7 @@ namespace evalhyd ...@@ -273,7 +273,7 @@ namespace evalhyd
return pofd.value(); return pofd.value();
}; };
xt::xtensor<double, 4> get_far() xt::xtensor<double, 3> get_far()
{ {
if (!far.has_value()) if (!far.has_value())
{ {
...@@ -284,7 +284,7 @@ namespace evalhyd ...@@ -284,7 +284,7 @@ namespace evalhyd
return far.value(); return far.value();
}; };
xt::xtensor<double, 4> get_csi() xt::xtensor<double, 3> get_csi()
{ {
if (!csi.has_value()) if (!csi.has_value())
{ {
......
...@@ -288,10 +288,10 @@ namespace evalhyd ...@@ -288,10 +288,10 @@ namespace evalhyd
dim["BS_LBD"] = {n_sit, n_ltm, n_msk, n_exp, n_thr, 3}; dim["BS_LBD"] = {n_sit, n_ltm, n_msk, n_exp, n_thr, 3};
dim["QS"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr}; dim["QS"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr};
dim["CRPS"] = {n_sit, n_ltm, n_msk, n_exp}; dim["CRPS"] = {n_sit, n_ltm, n_msk, n_exp};
dim["POD"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr, n_thr}; dim["POD"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr + 1, n_thr};
dim["POFD"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr, n_thr}; dim["POFD"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr + 1, n_thr};
dim["FAR"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr, n_thr}; dim["FAR"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr + 1, n_thr};
dim["CSI"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr, n_thr}; dim["CSI"] = {n_sit, n_ltm, n_msk, n_exp, n_mbr + 1, n_thr};
dim["ROCSS"] = {n_sit, n_ltm, n_msk, n_exp, n_thr}; dim["ROCSS"] = {n_sit, n_ltm, n_msk, n_exp, n_thr};
// generate masks from conditions if provided // generate masks from conditions if provided
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment