Commit 30d27e16 authored by Thibault Hallouin's avatar Thibault Hallouin
Browse files

turn m_cdt into xexpression to avoid copies

1 merge request!3release v0.1.0
Pipeline #43835 passed with stage
in 3 minutes and 42 seconds
Showing with 23 additions and 19 deletions
+23 -19
......@@ -158,7 +158,8 @@ namespace evalhyd
/// evalhyd::evald(obs, prd, {"NSE"}, "none", 1, -9, msk);
///
/// \endrst
template <class XD2, class XB3 = xt::xtensor<bool, 3>>
template <class XD2, class XB3 = xt::xtensor<bool, 3>,
class XS2 = xt::xtensor<std::array<char, 32>, 2>>
std::vector<xt::xarray<double>> evald(
const xt::xexpression<XD2>& q_obs,
const xt::xexpression<XD2>& q_prd,
......@@ -170,7 +171,7 @@ namespace evalhyd
xtl::xoptional<double, bool> epsilon =
xtl::missing<double>(),
const xt::xexpression<XB3>& t_msk = XB3({}),
const xt::xtensor<std::array<char, 32>, 2>& m_cdt = {},
const xt::xexpression<XS2>& m_cdt = XS2({}),
xtl::xoptional<const std::unordered_map<std::string, int>, bool> bootstrap =
xtl::missing<const std::unordered_map<std::string, int>>(),
const std::vector<std::string>& dts = {},
......@@ -197,6 +198,7 @@ namespace evalhyd
const XD2& q_prd_ = q_prd.derived_cast();
const XB3& t_msk_ = t_msk.derived_cast();
const XS2& m_cdt_ = m_cdt.derived_cast();
// check that the metrics to be computed are valid
utils::check_metrics(
......@@ -262,9 +264,9 @@ namespace evalhyd
}
}
if (m_cdt.size() > 0)
if (m_cdt_.size() > 0)
{
if (q_prd_.shape(0) != m_cdt.shape(0))
if (q_prd_.shape(0) != m_cdt_.shape(0))
{
throw std::runtime_error(
"predictions and masking conditions feature different "
......@@ -279,10 +281,10 @@ namespace evalhyd
// generate masks from conditions if provided
auto gen_msk = [&]()
{
if ((t_msk_.size() < 1) && (m_cdt.size() > 0))
if ((t_msk_.size() < 1) && (m_cdt_.size() > 0))
{
std::size_t n_srs = q_prd_.shape(0);
std::size_t n_msk = m_cdt.shape(1);
std::size_t n_msk = m_cdt_.shape(1);
XB3 c_msk = xt::zeros<bool>({n_srs, n_msk, n_tim});
......@@ -292,7 +294,7 @@ namespace evalhyd
{
xt::view(c_msk, s, m) =
masks::generate_mask_from_conditions(
xt::view(m_cdt, s, m),
xt::view(m_cdt_, s, m),
xt::view(q_obs_, 0),
xt::view(q_prd_, s, xt::newaxis())
);
......@@ -418,7 +420,7 @@ namespace evalhyd
// instantiate determinist evaluator
determinist::Evaluator<XD2, XB3> evaluator(
obs, prd,
t_msk_.size() > 0 ? t_msk_: (m_cdt.size() > 0 ? c_msk : t_msk_),
t_msk_.size() > 0 ? t_msk_: (m_cdt_.size() > 0 ? c_msk : t_msk_),
exp
);
......
......@@ -157,7 +157,8 @@ namespace evalhyd
/// evalhyd::evalp(obs, prd, {"CRPS"});
///
/// \endrst
template <class XD2, class XD4, class XB4 = xt::xtensor<bool, 4>>
template <class XD2, class XD4, class XB4 = xt::xtensor<bool, 4>,
class XS2 = xt::xtensor<std::array<char, 32>, 2>>
std::vector<xt::xarray<double>> evalp(
const xt::xexpression<XD2>& q_obs,
const xt::xexpression<XD4>& q_prd,
......@@ -167,7 +168,7 @@ namespace evalhyd
xtl::missing<const std::string>(),
const std::vector<double>& c_lvl = {},
const xt::xexpression<XB4>& t_msk = XB4({}),
const xt::xtensor<std::array<char, 32>, 2>& m_cdt = {},
const xt::xexpression<XS2>& m_cdt = XS2({}),
xtl::xoptional<const std::unordered_map<std::string, int>, bool> bootstrap =
xtl::missing<const std::unordered_map<std::string, int>>(),
const std::vector<std::string>& dts = {},
......@@ -201,6 +202,7 @@ namespace evalhyd
const XD2& q_thr_ = q_thr.derived_cast();
const XB4& t_msk_ = t_msk.derived_cast();
const XS2& m_cdt_ = m_cdt.derived_cast();
// adapt vector to tensor
const xt::xtensor<double, 1> c_lvl_ = xt::adapt(c_lvl);
......@@ -297,9 +299,9 @@ namespace evalhyd
}
}
if (m_cdt.size() > 0)
if (m_cdt_.size() > 0)
{
if (q_obs_.shape(0) != m_cdt.shape(0))
if (q_obs_.shape(0) != m_cdt_.shape(0))
{
throw std::runtime_error(
"observations and masking conditions feature different "
......@@ -314,11 +316,11 @@ namespace evalhyd
// generate masks from conditions if provided
auto gen_msk = [&]()
{
if ((t_msk_.size() < 1) && (m_cdt.size() > 0))
if ((t_msk_.size() < 1) && (m_cdt_.size() > 0))
{
std::size_t n_sit = q_prd_.shape(0);
std::size_t n_ltm = q_prd_.shape(1);
std::size_t n_msk = m_cdt.shape(1);
std::size_t n_msk = m_cdt_.shape(1);
XB4 c_msk = xt::zeros<bool>({n_sit, n_ltm, n_msk, n_tim});
......@@ -330,7 +332,7 @@ namespace evalhyd
{
xt::view(c_msk, s, l, m) =
masks::generate_mask_from_conditions(
xt::view(m_cdt, s, m),
xt::view(m_cdt_, s, m),
xt::view(q_obs_, s),
xt::view(q_prd_, s, l)
);
......@@ -380,7 +382,7 @@ namespace evalhyd
// instantiate determinist evaluator
probabilist::Evaluator<XD2, XD4, XB4> evaluator(
q_obs_, q_prd_, q_thr_, c_lvl_, events,
t_msk_.size() > 0 ? t_msk_: (m_cdt.size() > 0 ? c_msk : t_msk_),
t_msk_.size() > 0 ? t_msk_: (m_cdt_.size() > 0 ? c_msk : t_msk_),
b_exp,
random_seed
);
......
......@@ -376,7 +376,7 @@ TEST(DeterministTests, TestBootstrap)
{}, // exponent
{}, // epsilon
xt::xtensor<bool, 3>({}), // t_msk
{}, // m_cdt
xt::xtensor<std::array<char, 32>, 2>({}), // m_cdt
bootstrap,
datetimes
);
......
......@@ -426,7 +426,7 @@ TEST(ProbabilistTests, TestRanks)
"high", // events
{}, // c_lvl
xt::xtensor<bool, 4>({}), // t_msk
{}, // m_cdt
xt::xtensor<std::array<char, 32>, 2>({}), // m_cdt
xtl::missing<const std::unordered_map<std::string, int>>(), // bootstrap
{}, // dts
7 // seed
......@@ -916,7 +916,7 @@ TEST(ProbabilistTests, TestBootstrap)
"high", // events
confidence_levels,
xt::xtensor<bool, 4>({}), // t_msk
{}, // m_cdt
xt::xtensor<std::array<char, 32>, 2>({}), // m_cdt
bootstrap,
datetimes
);
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment