From e8e5b332026b49f824d16ec638914172e0de8673 Mon Sep 17 00:00:00 2001
From: Thibault Hallouin <thibault.hallouin@inrae.fr>
Date: Mon, 23 Jan 2023 15:10:30 +0100
Subject: [PATCH] fix broadcasting problems with multi-sites/multi-leadtimes

---
 include/evalhyd/detail/probabilist/ranks.hpp | 34 +++++++++++++-------
 1 file changed, 23 insertions(+), 11 deletions(-)

diff --git a/include/evalhyd/detail/probabilist/ranks.hpp b/include/evalhyd/detail/probabilist/ranks.hpp
index 376a80c..a03f4d0 100644
--- a/include/evalhyd/detail/probabilist/ranks.hpp
+++ b/include/evalhyd/detail/probabilist/ranks.hpp
@@ -107,7 +107,8 @@ namespace evalhyd
                 return ranks;
             }
 
-            /// Compute the number of observations for all possible rank values.
+            /// Compute the number of observations in each interval of the
+            /// rank diagram.
             ///
             /// \param r_k
             ///     Ranks of streamflow observations.
@@ -129,7 +130,7 @@ namespace evalhyd
             /// \param n_exp
             ///     Number of bootstrap samples.
             /// \return
-            ///     Tallies of streamflow observations for all possible ranks.
+            ///     Tallies of streamflow observations in each rank interval.
             ///     shape: (sites, lead times, subsets, samples, ranks)
             inline xt::xtensor<double, 5> calc_o_j(
                     const xt::xtensor<double, 3>& r_k,
@@ -235,7 +236,9 @@ namespace evalhyd
                                          m, b_exp[e]);
 
                         // calculate length of subset
-                        auto l = xt::sum(t_msk_sampled, -1);
+                        auto l = xt::eval(
+                                xt::sum(t_msk_sampled, -1, xt::keep_dims)
+                        );
 
                         // compute the rank diagram
                         xt::view(REL_DIAG, xt::all(), xt::all(), m, e, xt::all()) =
@@ -300,20 +303,29 @@ namespace evalhyd
                                          m, b_exp[e]);
 
                         // calculate length of subset
-                        auto l = xt::sum(t_msk_sampled, -1);
+                        auto l = xt::eval(
+                                xt::sum(t_msk_sampled, -1, xt::keep_dims)
+                        );
 
                         // compute the Delta score
                         // \Delta = \sum_{k=1}^{N+1} (r_k - \frac{M}{N+1})^2
+                        auto delta =  xt::nansum(
+                                xt::square(
+                                        xt::view(o_j, xt::all(), xt::all(), m, e, xt::all())
+                                        - (l / (n_mbr + 1))
+                                ),
+                                -1
+                        );
+
                         // \Delta_o = \frac{MN}{N+1}
+                        auto delta_o = (
+                                xt::view(l, xt::all(), xt::all(), 0)
+                                * n_mbr / (n_mbr + 1)
+                        );
+
                         // \delta = $\frac{\Delta}{\Delta_o}
                         xt::view(DS, xt::all(), xt::all(), m, e) =
-                                xt::nansum(
-                                        xt::square(
-                                                xt::view(o_j, xt::all(), xt::all(), m, e, xt::all())
-                                                - (l / (n_mbr + 1))
-                                        ),
-                                        -1
-                                ) / (l * n_mbr / (n_mbr + 1));
+                               delta / delta_o;
                     }
                 }
 
-- 
GitLab