diff --git a/include/evalhyd/evald.hpp b/include/evalhyd/evald.hpp index c7e7c149e2c2cd8568b514b89d330439a41859fc..c10ed21d3541082de24707253736468e87275875 100644 --- a/include/evalhyd/evald.hpp +++ b/include/evalhyd/evald.hpp @@ -3,7 +3,6 @@ #include <unordered_map> #include <vector> -#include <type_traits> #include <xtensor/xexpression.hpp> #include <xtensor/xtensor.hpp> @@ -145,9 +144,7 @@ namespace evalhyd /// evalhyd::evald(obs, prd, {"NSE"}, "none", 1, -9, msk); /// /// \endrst - template <class D2, class B2, - std::enable_if_t<xt::has_rank_t<D2, 2>::value, int> = 0, - std::enable_if_t<xt::has_rank_t<B2, 2>::value, int> = 0> + template <class D2, class B2> std::vector<xt::xarray<double>> evald( const xt::xexpression<D2>& q_obs, const xt::xexpression<D2>& q_prd, @@ -162,6 +159,20 @@ namespace evalhyd const std::vector<std::string>& dts = {} ) { + // check ranks of tensors + if (xt::get_rank<D2>::value != 2) + { + throw std::runtime_error( + "observations and/or predictions are not two-dimensional" + ); + } + if (xt::get_rank<B2>::value != 2) + { + throw std::runtime_error( + "temporal masks are not two-dimensional" + ); + } + // retrieve real types of the expressions const D2& q_obs_ = q_obs.derived_cast(); const D2& q_prd_ = q_prd.derived_cast(); diff --git a/include/evalhyd/evalp.hpp b/include/evalhyd/evalp.hpp index 350b923f70fb59a47cf2056e411e351ad8c69a7c..0a48ae3550c670b5c527ee58d8b9c6c9dc0bb371 100644 --- a/include/evalhyd/evalp.hpp +++ b/include/evalhyd/evalp.hpp @@ -3,7 +3,6 @@ #include <unordered_map> #include <vector> -#include <type_traits> #include <xtensor/xexpression.hpp> #include <xtensor/xtensor.hpp> @@ -123,10 +122,7 @@ namespace evalhyd /// evalhyd::evalp(obs, prd, {"CRPS"}); /// /// \endrst - template <class D2, class D4, class B4, - std::enable_if_t<xt::has_rank_t<D2, 2>::value, int> = 0, - std::enable_if_t<xt::has_rank_t<D4, 4>::value, int> = 0, - std::enable_if_t<xt::has_rank_t<B4, 4>::value, int> = 0> + template <class D2, class D4, class B4> std::vector<xt::xarray<double>> evalp( const xt::xexpression<D2>& q_obs, const xt::xexpression<D4>& q_prd, @@ -139,6 +135,26 @@ namespace evalhyd const std::vector<std::string>& dts = {} ) { + // check ranks of tensors + if (xt::get_rank<D2>::value != 2) + { + throw std::runtime_error( + "observations and/or thresholds are not two-dimensional" + ); + } + if (xt::get_rank<D4>::value != 4) + { + throw std::runtime_error( + "predictions are not four-dimensional" + ); + } + if (xt::get_rank<B4>::value != 4) + { + throw std::runtime_error( + "temporal masks are not four-dimensional" + ); + } + // retrieve real types of the expressions const D2& q_obs_ = q_obs.derived_cast(); const D4& q_prd_ = q_prd.derived_cast();