Commit 81692118 authored by Thibault Hallouin's avatar Thibault Hallouin
Browse files

replace enable_if in template for rank check by `xt::get_rank`

GNU GCC was complaining about a missing type definition ("no type name
'type' in 'struct::enable_if<false, int>'") for the false case
(as explained here https://stackoverflow.com/a/12625334). So given that
there are two/three false conditions for `evald`/`evalp`, it seemed
better to drop these rank checks from the templates, and use the
function `xt::get_rank` instead.
1 merge request!3release v0.1.0
Pipeline #42877 passed with stage
in 2 minutes and 7 seconds
Showing with 36 additions and 9 deletions
+36 -9
......@@ -3,7 +3,6 @@
#include <unordered_map>
#include <vector>
#include <type_traits>
#include <xtensor/xexpression.hpp>
#include <xtensor/xtensor.hpp>
......@@ -145,9 +144,7 @@ namespace evalhyd
/// evalhyd::evald(obs, prd, {"NSE"}, "none", 1, -9, msk);
///
/// \endrst
template <class D2, class B2,
std::enable_if_t<xt::has_rank_t<D2, 2>::value, int> = 0,
std::enable_if_t<xt::has_rank_t<B2, 2>::value, int> = 0>
template <class D2, class B2>
std::vector<xt::xarray<double>> evald(
const xt::xexpression<D2>& q_obs,
const xt::xexpression<D2>& q_prd,
......@@ -162,6 +159,20 @@ namespace evalhyd
const std::vector<std::string>& dts = {}
)
{
// check ranks of tensors
if (xt::get_rank<D2>::value != 2)
{
throw std::runtime_error(
"observations and/or predictions are not two-dimensional"
);
}
if (xt::get_rank<B2>::value != 2)
{
throw std::runtime_error(
"temporal masks are not two-dimensional"
);
}
// retrieve real types of the expressions
const D2& q_obs_ = q_obs.derived_cast();
const D2& q_prd_ = q_prd.derived_cast();
......
......@@ -3,7 +3,6 @@
#include <unordered_map>
#include <vector>
#include <type_traits>
#include <xtensor/xexpression.hpp>
#include <xtensor/xtensor.hpp>
......@@ -123,10 +122,7 @@ namespace evalhyd
/// evalhyd::evalp(obs, prd, {"CRPS"});
///
/// \endrst
template <class D2, class D4, class B4,
std::enable_if_t<xt::has_rank_t<D2, 2>::value, int> = 0,
std::enable_if_t<xt::has_rank_t<D4, 4>::value, int> = 0,
std::enable_if_t<xt::has_rank_t<B4, 4>::value, int> = 0>
template <class D2, class D4, class B4>
std::vector<xt::xarray<double>> evalp(
const xt::xexpression<D2>& q_obs,
const xt::xexpression<D4>& q_prd,
......@@ -139,6 +135,26 @@ namespace evalhyd
const std::vector<std::string>& dts = {}
)
{
// check ranks of tensors
if (xt::get_rank<D2>::value != 2)
{
throw std::runtime_error(
"observations and/or thresholds are not two-dimensional"
);
}
if (xt::get_rank<D4>::value != 4)
{
throw std::runtime_error(
"predictions are not four-dimensional"
);
}
if (xt::get_rank<B4>::value != 4)
{
throw std::runtime_error(
"temporal masks are not four-dimensional"
);
}
// retrieve real types of the expressions
const D2& q_obs_ = q_obs.derived_cast();
const D4& q_prd_ = q_prd.derived_cast();
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment