From 30d27e161b1384d7923f70be2678d1f1212ebe5e Mon Sep 17 00:00:00 2001 From: Thibault Hallouin <thibault.hallouin@inrae.fr> Date: Mon, 30 Jan 2023 11:15:16 +0100 Subject: [PATCH] turn m_cdt into xexpression to avoid copies --- include/evalhyd/evald.hpp | 18 ++++++++++-------- include/evalhyd/evalp.hpp | 18 ++++++++++-------- tests/test_determinist.cpp | 2 +- tests/test_probabilist.cpp | 4 ++-- 4 files changed, 23 insertions(+), 19 deletions(-) diff --git a/include/evalhyd/evald.hpp b/include/evalhyd/evald.hpp index 5fb72f9..2fbcd95 100644 --- a/include/evalhyd/evald.hpp +++ b/include/evalhyd/evald.hpp @@ -158,7 +158,8 @@ namespace evalhyd /// evalhyd::evald(obs, prd, {"NSE"}, "none", 1, -9, msk); /// /// \endrst - template <class XD2, class XB3 = xt::xtensor<bool, 3>> + template <class XD2, class XB3 = xt::xtensor<bool, 3>, + class XS2 = xt::xtensor<std::array<char, 32>, 2>> std::vector<xt::xarray<double>> evald( const xt::xexpression<XD2>& q_obs, const xt::xexpression<XD2>& q_prd, @@ -170,7 +171,7 @@ namespace evalhyd xtl::xoptional<double, bool> epsilon = xtl::missing<double>(), const xt::xexpression<XB3>& t_msk = XB3({}), - const xt::xtensor<std::array<char, 32>, 2>& m_cdt = {}, + const xt::xexpression<XS2>& m_cdt = XS2({}), xtl::xoptional<const std::unordered_map<std::string, int>, bool> bootstrap = xtl::missing<const std::unordered_map<std::string, int>>(), const std::vector<std::string>& dts = {}, @@ -197,6 +198,7 @@ namespace evalhyd const XD2& q_prd_ = q_prd.derived_cast(); const XB3& t_msk_ = t_msk.derived_cast(); + const XS2& m_cdt_ = m_cdt.derived_cast(); // check that the metrics to be computed are valid utils::check_metrics( @@ -262,9 +264,9 @@ namespace evalhyd } } - if (m_cdt.size() > 0) + if (m_cdt_.size() > 0) { - if (q_prd_.shape(0) != m_cdt.shape(0)) + if (q_prd_.shape(0) != m_cdt_.shape(0)) { throw std::runtime_error( "predictions and masking conditions feature different " @@ -279,10 +281,10 @@ namespace evalhyd // generate masks from conditions if provided auto gen_msk = [&]() { - if ((t_msk_.size() < 1) && (m_cdt.size() > 0)) + if ((t_msk_.size() < 1) && (m_cdt_.size() > 0)) { std::size_t n_srs = q_prd_.shape(0); - std::size_t n_msk = m_cdt.shape(1); + std::size_t n_msk = m_cdt_.shape(1); XB3 c_msk = xt::zeros<bool>({n_srs, n_msk, n_tim}); @@ -292,7 +294,7 @@ namespace evalhyd { xt::view(c_msk, s, m) = masks::generate_mask_from_conditions( - xt::view(m_cdt, s, m), + xt::view(m_cdt_, s, m), xt::view(q_obs_, 0), xt::view(q_prd_, s, xt::newaxis()) ); @@ -418,7 +420,7 @@ namespace evalhyd // instantiate determinist evaluator determinist::Evaluator<XD2, XB3> evaluator( obs, prd, - t_msk_.size() > 0 ? t_msk_: (m_cdt.size() > 0 ? c_msk : t_msk_), + t_msk_.size() > 0 ? t_msk_: (m_cdt_.size() > 0 ? c_msk : t_msk_), exp ); diff --git a/include/evalhyd/evalp.hpp b/include/evalhyd/evalp.hpp index 4c49e42..b00a33c 100644 --- a/include/evalhyd/evalp.hpp +++ b/include/evalhyd/evalp.hpp @@ -157,7 +157,8 @@ namespace evalhyd /// evalhyd::evalp(obs, prd, {"CRPS"}); /// /// \endrst - template <class XD2, class XD4, class XB4 = xt::xtensor<bool, 4>> + template <class XD2, class XD4, class XB4 = xt::xtensor<bool, 4>, + class XS2 = xt::xtensor<std::array<char, 32>, 2>> std::vector<xt::xarray<double>> evalp( const xt::xexpression<XD2>& q_obs, const xt::xexpression<XD4>& q_prd, @@ -167,7 +168,7 @@ namespace evalhyd xtl::missing<const std::string>(), const std::vector<double>& c_lvl = {}, const xt::xexpression<XB4>& t_msk = XB4({}), - const xt::xtensor<std::array<char, 32>, 2>& m_cdt = {}, + const xt::xexpression<XS2>& m_cdt = XS2({}), xtl::xoptional<const std::unordered_map<std::string, int>, bool> bootstrap = xtl::missing<const std::unordered_map<std::string, int>>(), const std::vector<std::string>& dts = {}, @@ -201,6 +202,7 @@ namespace evalhyd const XD2& q_thr_ = q_thr.derived_cast(); const XB4& t_msk_ = t_msk.derived_cast(); + const XS2& m_cdt_ = m_cdt.derived_cast(); // adapt vector to tensor const xt::xtensor<double, 1> c_lvl_ = xt::adapt(c_lvl); @@ -297,9 +299,9 @@ namespace evalhyd } } - if (m_cdt.size() > 0) + if (m_cdt_.size() > 0) { - if (q_obs_.shape(0) != m_cdt.shape(0)) + if (q_obs_.shape(0) != m_cdt_.shape(0)) { throw std::runtime_error( "observations and masking conditions feature different " @@ -314,11 +316,11 @@ namespace evalhyd // generate masks from conditions if provided auto gen_msk = [&]() { - if ((t_msk_.size() < 1) && (m_cdt.size() > 0)) + if ((t_msk_.size() < 1) && (m_cdt_.size() > 0)) { std::size_t n_sit = q_prd_.shape(0); std::size_t n_ltm = q_prd_.shape(1); - std::size_t n_msk = m_cdt.shape(1); + std::size_t n_msk = m_cdt_.shape(1); XB4 c_msk = xt::zeros<bool>({n_sit, n_ltm, n_msk, n_tim}); @@ -330,7 +332,7 @@ namespace evalhyd { xt::view(c_msk, s, l, m) = masks::generate_mask_from_conditions( - xt::view(m_cdt, s, m), + xt::view(m_cdt_, s, m), xt::view(q_obs_, s), xt::view(q_prd_, s, l) ); @@ -380,7 +382,7 @@ namespace evalhyd // instantiate determinist evaluator probabilist::Evaluator<XD2, XD4, XB4> evaluator( q_obs_, q_prd_, q_thr_, c_lvl_, events, - t_msk_.size() > 0 ? t_msk_: (m_cdt.size() > 0 ? c_msk : t_msk_), + t_msk_.size() > 0 ? t_msk_: (m_cdt_.size() > 0 ? c_msk : t_msk_), b_exp, random_seed ); diff --git a/tests/test_determinist.cpp b/tests/test_determinist.cpp index 560c713..0ddcd18 100644 --- a/tests/test_determinist.cpp +++ b/tests/test_determinist.cpp @@ -376,7 +376,7 @@ TEST(DeterministTests, TestBootstrap) {}, // exponent {}, // epsilon xt::xtensor<bool, 3>({}), // t_msk - {}, // m_cdt + xt::xtensor<std::array<char, 32>, 2>({}), // m_cdt bootstrap, datetimes ); diff --git a/tests/test_probabilist.cpp b/tests/test_probabilist.cpp index f0449b7..e0d8dd4 100644 --- a/tests/test_probabilist.cpp +++ b/tests/test_probabilist.cpp @@ -426,7 +426,7 @@ TEST(ProbabilistTests, TestRanks) "high", // events {}, // c_lvl xt::xtensor<bool, 4>({}), // t_msk - {}, // m_cdt + xt::xtensor<std::array<char, 32>, 2>({}), // m_cdt xtl::missing<const std::unordered_map<std::string, int>>(), // bootstrap {}, // dts 7 // seed @@ -916,7 +916,7 @@ TEST(ProbabilistTests, TestBootstrap) "high", // events confidence_levels, xt::xtensor<bool, 4>({}), // t_msk - {}, // m_cdt + xt::xtensor<std::array<char, 32>, 2>({}), // m_cdt bootstrap, datetimes ); -- GitLab