diff --git a/include/evalhyd/detail/probabilist/intervals.hpp b/include/evalhyd/detail/probabilist/intervals.hpp index 21316f2657fe5658865370bad66df97de1c00019..f5ed84657d3f6a1d1da2af728ac02d5f2bcd5f39 100644 --- a/include/evalhyd/detail/probabilist/intervals.hpp +++ b/include/evalhyd/detail/probabilist/intervals.hpp @@ -66,7 +66,9 @@ namespace evalhyd { for (std::size_t i = 0; i < n_itv; i++) { - auto res = xt::where(xt::equal(xt::view(quantiles, i), q_lvl)); + auto a = xt::broadcast(xt::view(quantiles, i), std::vector<std::size_t>({q_lvl.size(), 2})); + auto b = xt::broadcast(q_lvl, std::vector<std::size_t>({2, q_lvl.size()})); + auto res = xt::where(xt::equal(a, xt::transpose(b)); if (res.size() != 2) { throw std::runtime_error( @@ -76,9 +78,9 @@ namespace evalhyd } else { xt::view(itv_bnds, xt::all(), xt::all(), i, 0, xt::all()) = - xt::view(q_prd, res[0]); + xt::view(q_prd, std::min(res[0][1], res[1][1])); xt::view(itv_bnds, xt::all(), xt::all(), i, 1, xt::all()) = - xt::view(q_prd, res[1]); + xt::view(q_prd, std::max(res[0][1], res[1][1])); } } }