diff --git a/include/evalhyd/evalp.hpp b/include/evalhyd/evalp.hpp index 54d855714e185a3202573bd161dfdd7a45b720fc..6c52fac81e46400fd108bdfd386a483b3f40948a 100644 --- a/include/evalhyd/evalp.hpp +++ b/include/evalhyd/evalp.hpp @@ -292,18 +292,19 @@ namespace evalhyd } // retrieve dimensions - std::size_t n_sit = q_prd_.shape(0); - std::size_t n_ltm = q_prd_.shape(1); std::size_t n_tim = q_prd_.shape(3); - std::size_t n_msk = t_msk_.size() > 0 ? t_msk_.shape(2) : - (m_cdt.size() > 0 ? m_cdt.shape(1) : 1); // generate masks from conditions if provided auto gen_msk = [&]() { - XB4 c_msk = xt::zeros<bool>({n_sit, n_ltm, n_msk, n_tim}); - if (m_cdt.size() > 0) + if ((t_msk_.size() < 1) && (m_cdt.size() > 0)) { + std::size_t n_sit = q_prd_.shape(0); + std::size_t n_ltm = q_prd_.shape(1); + std::size_t n_msk = m_cdt.shape(1); + + XB4 c_msk = xt::zeros<bool>({n_sit, n_ltm, n_msk, n_tim}); + for (std::size_t s = 0; s < n_sit; s++) { for (std::size_t l = 0; l < n_ltm; l++) @@ -319,9 +320,13 @@ namespace evalhyd } } } - } - return c_msk; + return c_msk; + } + else + { + return XB4({}); + } }; const XB4 c_msk = gen_msk();