From 81692118734d65bb58e2024af4dbc240da2c8cb5 Mon Sep 17 00:00:00 2001 From: Thibault Hallouin <thibault.hallouin@inrae.fr> Date: Mon, 2 Jan 2023 14:56:05 +0100 Subject: [PATCH] replace enable_if in template for rank check by `xt::get_rank` GNU GCC was complaining about a missing type definition ("no type name 'type' in 'struct::enable_if<false, int>'") for the false case (as explained here https://stackoverflow.com/a/12625334). So given that there are two/three false conditions for `evald`/`evalp`, it seemed better to drop these rank checks from the templates, and use the function `xt::get_rank` instead. --- include/evalhyd/evald.hpp | 19 +++++++++++++++---- include/evalhyd/evalp.hpp | 26 +++++++++++++++++++++----- 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/include/evalhyd/evald.hpp b/include/evalhyd/evald.hpp index c7e7c14..c10ed21 100644 --- a/include/evalhyd/evald.hpp +++ b/include/evalhyd/evald.hpp @@ -3,7 +3,6 @@ #include <unordered_map> #include <vector> -#include <type_traits> #include <xtensor/xexpression.hpp> #include <xtensor/xtensor.hpp> @@ -145,9 +144,7 @@ namespace evalhyd /// evalhyd::evald(obs, prd, {"NSE"}, "none", 1, -9, msk); /// /// \endrst - template <class D2, class B2, - std::enable_if_t<xt::has_rank_t<D2, 2>::value, int> = 0, - std::enable_if_t<xt::has_rank_t<B2, 2>::value, int> = 0> + template <class D2, class B2> std::vector<xt::xarray<double>> evald( const xt::xexpression<D2>& q_obs, const xt::xexpression<D2>& q_prd, @@ -162,6 +159,20 @@ namespace evalhyd const std::vector<std::string>& dts = {} ) { + // check ranks of tensors + if (xt::get_rank<D2>::value != 2) + { + throw std::runtime_error( + "observations and/or predictions are not two-dimensional" + ); + } + if (xt::get_rank<B2>::value != 2) + { + throw std::runtime_error( + "temporal masks are not two-dimensional" + ); + } + // retrieve real types of the expressions const D2& q_obs_ = q_obs.derived_cast(); const D2& q_prd_ = q_prd.derived_cast(); diff --git a/include/evalhyd/evalp.hpp b/include/evalhyd/evalp.hpp index 350b923..0a48ae3 100644 --- a/include/evalhyd/evalp.hpp +++ b/include/evalhyd/evalp.hpp @@ -3,7 +3,6 @@ #include <unordered_map> #include <vector> -#include <type_traits> #include <xtensor/xexpression.hpp> #include <xtensor/xtensor.hpp> @@ -123,10 +122,7 @@ namespace evalhyd /// evalhyd::evalp(obs, prd, {"CRPS"}); /// /// \endrst - template <class D2, class D4, class B4, - std::enable_if_t<xt::has_rank_t<D2, 2>::value, int> = 0, - std::enable_if_t<xt::has_rank_t<D4, 4>::value, int> = 0, - std::enable_if_t<xt::has_rank_t<B4, 4>::value, int> = 0> + template <class D2, class D4, class B4> std::vector<xt::xarray<double>> evalp( const xt::xexpression<D2>& q_obs, const xt::xexpression<D4>& q_prd, @@ -139,6 +135,26 @@ namespace evalhyd const std::vector<std::string>& dts = {} ) { + // check ranks of tensors + if (xt::get_rank<D2>::value != 2) + { + throw std::runtime_error( + "observations and/or thresholds are not two-dimensional" + ); + } + if (xt::get_rank<D4>::value != 4) + { + throw std::runtime_error( + "predictions are not four-dimensional" + ); + } + if (xt::get_rank<B4>::value != 4) + { + throw std::runtime_error( + "temporal masks are not four-dimensional" + ); + } + // retrieve real types of the expressions const D2& q_obs_ = q_obs.derived_cast(); const D4& q_prd_ = q_prd.derived_cast(); -- GitLab