From 84121e1547d8d48ebd4b5a8c5f2b1de443bce370 Mon Sep 17 00:00:00 2001
From: Thibault Hallouin <thibault.hallouin@inrae.fr>
Date: Wed, 15 Feb 2023 09:19:33 +0100
Subject: [PATCH] fix bug with multi-dimensional inputs for KGENP KGENP_D

because `xt::sort` does not behave nicely with series containing NaN
values (e.g. it does not consistently put NaN values at the end/
beginning), so need to eliminate them before the sorting, which requires
to treat each series separately
---
 .../detail/determinist/efficiencies.hpp       | 111 ++++++++++--------
 .../evalhyd/detail/determinist/evaluator.hpp  |   2 +-
 2 files changed, 65 insertions(+), 48 deletions(-)

diff --git a/include/evalhyd/detail/determinist/efficiencies.hpp b/include/evalhyd/detail/determinist/efficiencies.hpp
index 93281ba..af83745 100644
--- a/include/evalhyd/detail/determinist/efficiencies.hpp
+++ b/include/evalhyd/detail/determinist/efficiencies.hpp
@@ -134,37 +134,42 @@ namespace evalhyd
                     // compute variable one sample at a time
                     for (std::size_t e = 0; e < n_exp; e++)
                     {
-                        auto prd = xt::view(prd_masked, xt::all(), b_exp[e]);
-                        auto obs = xt::view(obs_masked, xt::all(), b_exp[e]);
-
-                        auto prd_rank = xt::eval(xt::argsort(
-                                xt::argsort(xt::eval(prd), {1}),
-                                {1}
-                        ));
-                        auto obs_rank = xt::eval(xt::argsort(
-                                xt::argsort(xt::eval(obs), {1}),
-                                {1}
-                        ));
-
-                        auto mean_prd_rank =
-                                xt::eval(xt::nanmean(prd_rank, {1}, xt::keep_dims));
-                        auto mean_obs_rank =
-                                xt::eval(xt::nanmean(obs_rank, {1}, xt::keep_dims));
-
-                        auto prd_rank_err = xt::eval(prd_rank - mean_prd_rank);
-                        auto obs_rank_err = xt::eval(obs_rank - mean_obs_rank);
-
-                        auto r_num = xt::nansum(
-                                prd_rank_err * obs_rank_err,
-                                {1}
-                        );
-
-                        auto r_den = xt::sqrt(
-                                xt::nansum(xt::square(prd_rank_err), {1})
-                                * xt::nansum(xt::square(obs_rank_err), {1})
-                        );
-
-                        xt::view(r_spearman, m, e) = r_num / r_den;
+                        // compute one series at a time because xt::sort does not
+                        // consistently put NaN values at the end/beginning, so
+                        // need to eliminate them before the sorting
+                        for (std::size_t s = 0; s < n_srs; s++)
+                        {
+                            auto prd = xt::view(prd_masked, s, b_exp[e]);
+                            auto obs = xt::view(obs_masked, s, b_exp[e]);
+
+                            auto prd_filtered =
+                                    xt::filter(prd, !xt::isnan(prd));
+                            auto obs_filtered =
+                                    xt::filter(obs, !xt::isnan(obs));
+
+                            auto prd_sort = xt::argsort(xt::eval(prd_filtered));
+                            auto obs_sort = xt::argsort(xt::eval(obs_filtered));
+
+                            auto prd_rank = xt::eval(xt::argsort(prd_sort));
+                            auto obs_rank = xt::eval(xt::argsort(obs_sort));
+
+                            auto mean_prd_rank =
+                                    xt::eval(xt::nanmean(prd_rank));
+                            auto mean_obs_rank =
+                                    xt::eval(xt::nanmean(obs_rank));
+
+                            auto prd_rank_err = xt::eval(prd_rank - mean_prd_rank);
+                            auto obs_rank_err = xt::eval(obs_rank - mean_obs_rank);
+
+                            auto r_num = xt::nansum(prd_rank_err * obs_rank_err);
+
+                            auto r_den = xt::sqrt(
+                                    xt::nansum(xt::square(prd_rank_err))
+                                    * xt::nansum(xt::square(obs_rank_err))
+                            );
+
+                            xt::view(r_spearman, m, e, s) = r_num / r_den;
+                        }
                     }
                 }
 
@@ -308,8 +313,7 @@ namespace evalhyd
                     const std::vector<xt::xkeep_slice<int>>& b_exp,
                     std::size_t n_srs,
                     std::size_t n_msk,
-                    std::size_t n_exp,
-                    std::size_t n_tim
+                    std::size_t n_exp
             )
             {
                 // calculate error in spread of flow $alpha$
@@ -328,20 +332,33 @@ namespace evalhyd
                     // compute variable one sample at a time
                     for (std::size_t e = 0; e < n_exp; e++)
                     {
-                        auto prd = xt::view(prd_masked, xt::all(), b_exp[e]);
-                        auto obs = xt::view(obs_masked, xt::all(), b_exp[e]);
-
-                        auto prd_fdc = xt::sort(
-                                xt::eval(prd / (n_tim * xt::view(mean_prd, m, e))),
-                                {1}
-                        );
-                        auto obs_fdc = xt::sort(
-                                xt::eval(obs / (n_tim * xt::view(mean_obs, m, e))),
-                                {1}
-                        );
-
-                        xt::view(alpha_np, m, e) =
-                                1 - 0.5 * xt::nansum(xt::abs(prd_fdc - obs_fdc), {1});
+                        // compute one series at a time because xt::sort does not
+                        // consistently put NaN values at the end/beginning, so
+                        // need to eliminate them before the sorting
+                        for (std::size_t s = 0; s < n_srs; s++)
+                        {
+                            auto prd = xt::view(prd_masked, s, b_exp[e]);
+                            auto obs = xt::view(obs_masked, s, b_exp[e]);
+
+                            auto prd_filtered =
+                                    xt::filter(prd, !xt::isnan(prd));
+                            auto obs_filtered =
+                                    xt::filter(obs, !xt::isnan(obs));
+
+                            auto prd_fdc = xt::sort(
+                                    xt::eval(prd_filtered
+                                             / (prd_filtered.size()
+                                                * xt::view(mean_prd, m, e, s)))
+                            );
+                            auto obs_fdc = xt::sort(
+                                    xt::eval(obs_filtered
+                                             / (obs_filtered.size()
+                                                * xt::view(mean_obs, m, e, s)))
+                            );
+
+                            xt::view(alpha_np, m, e, s) =
+                                    1 - 0.5 * xt::nansum(xt::abs(prd_fdc - obs_fdc));
+                        }
                     }
                 }
 
diff --git a/include/evalhyd/detail/determinist/evaluator.hpp b/include/evalhyd/detail/determinist/evaluator.hpp
index fb38f5b..165a2c3 100644
--- a/include/evalhyd/detail/determinist/evaluator.hpp
+++ b/include/evalhyd/detail/determinist/evaluator.hpp
@@ -223,7 +223,7 @@ namespace evalhyd
                 {
                     alpha_np = elements::calc_alpha_np(
                             q_obs, q_prd, get_mean_obs(), get_mean_prd(),
-                            t_msk, b_exp, n_srs, n_msk, n_exp, n_tim
+                            t_msk, b_exp, n_srs, n_msk, n_exp
                     );
                 }
                 return alpha_np.value();
-- 
GitLab