From 64eca8fabc91cf7a7286dc8b5915847cd5c2a468 Mon Sep 17 00:00:00 2001 From: Thibault Hallouin <thibault.hallouin@inrae.fr> Date: Thu, 2 Feb 2023 14:34:18 +0100 Subject: [PATCH] fix bug for standard deviation taking axis for degree of freedom in `xt::stddev``, when passing a scalar, it is passed to *ddof* instead of *axes*, so to compute the standard deviation on a given axis, the axis must be passed in curly braces --- include/evalhyd/detail/uncertainty.hpp | 8 ++++---- tests/test_determinist.cpp | 4 ++-- tests/test_probabilist.cpp | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/include/evalhyd/detail/uncertainty.hpp b/include/evalhyd/detail/uncertainty.hpp index d3b086b..36b459e 100644 --- a/include/evalhyd/detail/uncertainty.hpp +++ b/include/evalhyd/detail/uncertainty.hpp @@ -211,10 +211,10 @@ namespace evalhyd // compute mean xt::view(v, xt::all(), xt::all(), 0) = - xt::mean(values, 2); + xt::mean(values, {2}); // compute standard deviation xt::view(v, xt::all(), xt::all(), 1) = - xt::stddev(values, 2); + xt::stddev(values, {2}); return v; } @@ -282,10 +282,10 @@ namespace evalhyd // compute mean xt::view(v, xt::all(), xt::all(), xt::all(), 0) = - xt::mean(values, axis); + xt::mean(values, {axis}); // compute standard deviation xt::view(v, xt::all(), xt::all(), xt::all(), 1) = - xt::stddev(values, axis); + xt::stddev(values, {axis}); return v; } diff --git a/tests/test_determinist.cpp b/tests/test_determinist.cpp index c37e93f..85b33fc 100644 --- a/tests/test_determinist.cpp +++ b/tests/test_determinist.cpp @@ -472,14 +472,14 @@ TEST(DeterministTests, TestBootstrapSummary) // mean EXPECT_TRUE( xt::all(xt::isclose( - xt::mean(metrics_raw[m], 2), + xt::mean(metrics_raw[m], {2}), xt::view(metrics_mas[m], xt::all(), xt::all(), 0) )) ) << "Failure for (" << all_metrics_d[m] << ") on mean"; // standard deviation EXPECT_TRUE( xt::all(xt::isclose( - xt::stddev(metrics_raw[m], 2), + xt::stddev(metrics_raw[m], {2}), xt::view(metrics_mas[m], xt::all(), xt::all(), 1) )) ) << "Failure for (" << all_metrics_d[m] << ") on standard deviation"; diff --git a/tests/test_probabilist.cpp b/tests/test_probabilist.cpp index 3196100..96f3546 100644 --- a/tests/test_probabilist.cpp +++ b/tests/test_probabilist.cpp @@ -1045,14 +1045,14 @@ TEST(ProbabilistTests, TestBootstrapSummary) // mean EXPECT_TRUE( xt::all(xt::isclose( - xt::mean(metrics_raw[m], 3), + xt::mean(metrics_raw[m], {3}), xt::view(metrics_mas[m], xt::all(), xt::all(), xt::all(), 0) )) ) << "Failure for (" << all_metrics_p[m] << ") on mean"; // standard deviation EXPECT_TRUE( xt::all(xt::isclose( - xt::stddev(metrics_raw[m], 3), + xt::stddev(metrics_raw[m], {3}), xt::view(metrics_mas[m], xt::all(), xt::all(), xt::all(), 1) )) ) << "Failure for (" << all_metrics_p[m] << ") on standard deviation"; -- GitLab