From b13d2f21273e14b76a20194baa443d33393f4af7 Mon Sep 17 00:00:00 2001 From: Thibault Hallouin <thibault.hallouin@inrae.fr> Date: Wed, 31 Aug 2022 16:00:38 +0200 Subject: [PATCH] implement functionality to generate temporal masks from conditions This functionality is inherited from `evalhyd-cli`. It allows the user to provide conditions as strings to specify how to generate temporal subsets. Conditions can be based on observed streamflow values (e.g. q>800, q<=120) or on time indices (e.g. to select particular events). This functionality is made available both for determinist and probabilist evaluation, unlike in `evalhyd-cli` where it was only available for probabilist evaluation. This is documented in the docstrings, and new unit tests are written. --- include/evalhyd/evald.hpp | 66 +++++++- include/evalhyd/evalp.hpp | 56 ++++++- src/masks.hpp | 310 +++++++++++++++++++++++++++++++++++++ tests/CMakeLists.txt | 1 + tests/test_determinist.cpp | 79 ++++++++++ tests/test_probabilist.cpp | 85 ++++++++++ 6 files changed, 583 insertions(+), 14 deletions(-) create mode 100644 src/masks.hpp diff --git a/include/evalhyd/evald.hpp b/include/evalhyd/evald.hpp index aed3380..2254c91 100644 --- a/include/evalhyd/evald.hpp +++ b/include/evalhyd/evald.hpp @@ -9,6 +9,7 @@ #include <xtensor/xscalar.hpp> #include "../../src/utils.hpp" +#include "../../src/masks.hpp" #include "../../src/determinist/evaluator.hpp" namespace eh = evalhyd; @@ -97,6 +98,18 @@ namespace evalhyd /// of them. /// shape: ({... ,} time) /// + /// m_cdt: ``xt::xtensor<std::string, N>``, optional + /// Masking conditions to use to generate temporal subsets. Each + /// condition consists in a string and can be specified on observed + /// streamflow values or on time indices. If provided in combination + /// with *t_msk*, the latter takes precedence. If not provided and + /// neither is *t_msk*, no subset is performed and only one set of + /// metrics is returned corresponding to the whole time series. If + /// provided, only one condition per observed time series can be + /// provided, and as many sets of metrics are returned as they are + /// observed time series. + /// shape: ({... ,} 1) + /// /// :Returns: /// /// ``std::vector<xt::xarray<double>>`` @@ -144,16 +157,57 @@ namespace evalhyd const std::string& transform = "none", const double exponent = 1, double epsilon = -9, - const xt::xtensor<bool, N>& t_msk = {} + const xt::xtensor<bool, N>& t_msk = {}, + const xt::xtensor<std::string, N>& m_cdt = {} ) { // initialise a mask if none provided // (corresponding to no temporal subset) - xt::xtensor<bool, N> msk; - if (t_msk.size() < 1) - msk = xt::ones<bool>(q_obs.shape()); - else - msk = std::move(t_msk); + auto gen_msk = [&]() { + // initialise tensor for mask + xt::xtensor<bool, N> c_msk = xt::zeros<bool>(q_obs.shape()); + + // if t_msk provided, it takes priority + if (t_msk.size() > 0) + c_msk = std::move(t_msk); + // else if m_cdt provided, use them to generate t_msk + else if (m_cdt.size() > 0) + { + // flatten arrays to bypass n-dim considerations + // (possible because shapes are constrained to be the same) + if (m_cdt.shape(m_cdt.dimension() - 1) != 1) + throw std::runtime_error("length of last axis in masking conditions " + "must be equal to one"); + for (int a = 0; a < m_cdt.dimension() - 1; a++) + if (q_obs.shape(a) != m_cdt.shape(a)) + throw std::runtime_error("masking conditions and observations " + "feature incompatible shapes"); + + auto f_cdt = xt::flatten(m_cdt); + auto f_msk = xt::flatten(c_msk); + auto f_obs = xt::flatten(q_obs); + + // determine length of temporal axis + auto nt = q_obs.shape(q_obs.dimension() - 1); + + // generate mask gradually, one condition at a time + for (int i = 0; i < m_cdt.size(); i++) + { + xt::view(f_msk, xt::range(i*nt, (i+1)*nt)) = + evalhyd::masks::generate_mask_from_conditions( + xt::view(f_cdt, i), + xt::view(f_obs, xt::range(i*nt, (i+1)*nt)) + ); + } + } + // if neither t_msk nor m_cdt provided, generate dummy mask + else + c_msk = xt::ones<bool>(q_obs.shape()); + + return c_msk; + }; + + const xt::xtensor<bool, N> msk = gen_msk(); // check that observations, predictions, and masks dimensions are compatible if (q_obs.dimension() != q_prd.dimension()) diff --git a/include/evalhyd/evalp.hpp b/include/evalhyd/evalp.hpp index 3263d26..c832682 100644 --- a/include/evalhyd/evalp.hpp +++ b/include/evalhyd/evalp.hpp @@ -9,6 +9,7 @@ #include <xtensor/xview.hpp> #include "../../src/utils.hpp" +#include "../../src/masks.hpp" #include "../../src/probabilist/evaluator.h" namespace eh = evalhyd; @@ -45,12 +46,23 @@ namespace evalhyd /// t_msk: ``xt::xtensor<bool, 3>``, optional /// Mask(s) used to generate temporal subsets of the whole streamflow /// time series (where True/False is used for the time steps to - /// include/discard in a given subset). If not provided, no subset is - /// performed and only one set of metrics is returned corresponding - /// to the whole time series. If provided, as many sets of metrics are - /// returned as they are masks provided. + /// include/discard in a given subset). If not provided and neither + /// is *m_cdt*, no subset is performed and only one set of metrics is + /// returned corresponding to the whole time series. If provided, as + /// many sets of metrics are returned as they are masks provided. /// shape: (sites, subsets, time) /// + /// m_cdt: ``xt::xtensor<std::string, 2>``, optional + /// Masking conditions to use to generate temporal subsets. Each + /// condition consists in a string and can be specified on observed + /// streamflow values or on time indices. If provided in combination + /// with *t_msk*, the latter takes precedence. If not provided and + /// neither is *t_msk*, no subset is performed and only one set of + /// metrics is returned corresponding to the whole time series. If + /// provided, as many sets of metrics are returned as they are + /// conditions provided. + /// shape: (sites, subsets) + /// /// :Returns: /// /// ``std::vector<xt::xarray<double>>`` @@ -90,7 +102,8 @@ namespace evalhyd const xt::xtensor<double, 4>& q_prd, const std::vector<std::string>& metrics, const xt::xtensor<double, 2>& q_thr = {}, - const xt::xtensor<bool, 3>& t_msk = {} + const xt::xtensor<bool, 3>& t_msk = {}, + const xt::xtensor<std::string, 2>& m_cdt = {} ) { // check that the metrics to be computed are valid @@ -124,7 +137,13 @@ namespace evalhyd if (t_msk.size() > 0) if (q_obs.shape(0) != t_msk.shape(0)) throw std::runtime_error( - "observations and masks feature different " + "observations and temporal masks feature different " + "numbers of sites" + ); + if (m_cdt.size() > 0) + if (q_obs.shape(0) != m_cdt.shape(0)) + throw std::runtime_error( + "observations and masking conditions feature different " "numbers of sites" ); @@ -132,8 +151,10 @@ namespace evalhyd std::size_t n_sit = q_prd.shape(0); std::size_t n_ltm = q_prd.shape(1); std::size_t n_mbr = q_prd.shape(2); + std::size_t n_tim = q_prd.shape(3); std::size_t n_thr = q_thr.shape(1); - std::size_t n_msk = t_msk.size() < 1 ? 1 : t_msk.shape(1); + std::size_t n_msk = t_msk.size() > 0 ? t_msk.shape(1) : + (m_cdt.size() > 0 ? m_cdt.shape(1) : 1); // register metrics number of dimensions std::unordered_map<std::string, std::vector<std::size_t>> dim; @@ -168,6 +189,20 @@ namespace evalhyd eh::utils::find_requirements(metrics, elt, dep, req_elt, req_dep); + // generate masks from conditions if provided + auto gen_msk = [&]() { + xt::xtensor<bool, 3> c_msk = xt::zeros<bool>({n_sit, n_msk, n_tim}); + if (m_cdt.size() > 0) + for (int s = 0; s < n_sit; s++) + for (int m = 0; m < n_msk; m++) + xt::view(c_msk, s, m) = + eh::masks::generate_mask_from_conditions( + xt::view(m_cdt, s, m), xt::view(q_obs, s) + ); + return c_msk; + }; + const xt::xtensor<bool, 3> c_msk = gen_msk(); + // initialise data structure for outputs std::vector<xt::xarray<double>> r; for (const auto& metric : metrics) @@ -182,7 +217,12 @@ namespace evalhyd const auto q_obs_v = xt::view(q_obs, s, xt::all()); const auto q_prd_v = xt::view(q_prd, s, l, xt::all(), xt::all()); const auto q_thr_v = xt::view(q_thr, s, xt::all()); - const auto t_msk_v = xt::view(t_msk, s, xt::all(), xt::all()); + const auto t_msk_v = + t_msk.size() > 0 ? + xt::view(t_msk, s, xt::all(), xt::all()) : + (m_cdt.size() > 0 ? + xt::view(c_msk, s, xt::all(), xt::all()) : + xt::view(t_msk, s, xt::all(), xt::all())); eh::probabilist::Evaluator evaluator( q_obs_v, q_prd_v, q_thr_v, t_msk_v diff --git a/src/masks.hpp b/src/masks.hpp new file mode 100644 index 0000000..c10572c --- /dev/null +++ b/src/masks.hpp @@ -0,0 +1,310 @@ +#ifndef EVALHYD_MASKS_HPP +#define EVALHYD_MASKS_HPP + +#include <map> +#include <set> +#include <vector> +#include <regex> + +#include <xtensor/xexpression.hpp> +#include <xtensor/xtensor.hpp> +#include <xtensor/xview.hpp> +#include <xtensor/xindex_view.hpp> + +typedef std::map<std::string, std::vector<std::vector<std::string>>> msk_tree; + +namespace evalhyd +{ + namespace masks + { + /// Function to parse a string containing masking conditions. + inline msk_tree parse_masking_conditions(std::string msk_str) + { + msk_tree subset; + + // pattern supported to specify conditions to generate masks on streamflow + // e.g. q{>9.} q{<9} q{>=99.0} q{<=99} q{>9,<99} q{==9} q{!=9} + std::regex exp_q (R"(([q])\{((([><!=]?=?[0-9]+\.?[0-9]*),*)+)\})"); + + for (std::sregex_iterator i = + std::sregex_iterator(msk_str.begin(), msk_str.end(), exp_q); + i != std::sregex_iterator(); i++) + { + const std::smatch & mtc = *i; + + std::string var = mtc[1]; + std::string s = mtc[2]; + + // process masking conditions on streamflow + std::vector<std::vector<std::string>> conditions; + + // pattern supported to specify masking conditions based on streamflow + std::regex e (R"(([><!=]?=?)([0-9]+\.?[0-9]*))"); + + for (std::sregex_iterator j = + std::sregex_iterator(s.begin(), s.end(), e); + j != std::sregex_iterator(); j++) + { + const std::smatch & m = *j; + + // check that operator is provided and is supported + std::set<std::string> supported_op = + {"<", ">", "<=", ">=", "!=", "=="}; + if (supported_op.find(m[1]) != supported_op.end()) + conditions.push_back({m[1].str(), m[2].str()}); + else if (m[1].str().empty()) + throw std::runtime_error( + "missing operator for streamflow masking condition" + ); + else + throw std::runtime_error( + "invalid operator for streamflow masking " + "condition: " + m[1].str() + ); + } + + // check that a maximum of two conditions were provided + if (conditions.size() > 2) + throw std::runtime_error( + "no more than two streamflow masking conditions " + "can be provided" + ); + + subset[var] = conditions; + } + + // pattern supported to specify conditions to generate masks on time index + // e.g. t{0:10} t{0:10,20:30} t{0,1,2,3} t{0:10,30,40,50} t{:} + std::regex exp_t (R"(([t])\{(((([0-9]+|[:]):?[0-9]*),*)+)\})"); + + for (std::sregex_iterator i = + std::sregex_iterator(msk_str.begin(), msk_str.end(), exp_t); + i != std::sregex_iterator(); i++) + { + const std::smatch & mtc = *i; + + std::string var = mtc[1]; + std::string s = mtc[2]; + + // process masking conditions on time index + std::vector<std::vector<std::string>> condition; + + // pattern supported to specify masking conditions based on time index + std::regex e (R"(([0-9]+|[:]):?([0-9]*))"); + + for (std::sregex_iterator j = + std::sregex_iterator(s.begin(), s.end(), e); + j != std::sregex_iterator(); j++) + { + const std::smatch & m = *j; + + // check whether it is all indices, a range of indices, or an index + if (m[1] == ":") + // it is all indices (i.e. t{:}) so keep everything + condition.emplace_back(); + else if (m[2].str().empty()) + // it is an index (i.e. t{#}) + condition.push_back({m[1].str()}); + else + { + // it is a range of indices (i.e. t{#:#}) + // generate sequence of integer indices from range + std::vector<int> vi(std::stoi(m[2].str()) + - std::stoi(m[1].str())); + std::iota(vi.begin(), vi.end(), std::stoi(m[1].str())); + // convert to sequence of integer indices to string indices + std::vector<std::string> vs; + std::transform(std::begin(vi), std::end(vi), + std::back_inserter(vs), + [](int d) { return std::to_string(d); }); + + condition.push_back(vs); + } + } + + subset[var] = condition; + } + + return subset; + } + + /// Function to generate temporal mask based on masking conditions + inline xt::xtensor<bool, 1> generate_mask_from_conditions( + const std::string& msk_str, const xt::xtensor<double, 1>& q_obs + ) + { + // parse string to identify masking conditions + msk_tree subset = parse_masking_conditions(msk_str); + + // initialise a boolean expression for the masks + xt::xtensor<bool, 1> t_msk = xt::zeros<bool>(q_obs.shape()); + + // populate the masks given the conditions + for (const auto & var_cond : subset) + { + auto var = var_cond.first; + auto cond = var_cond.second; + + // condition on streamflow + if (var == "q") + { + // preprocess conditions to identify special cases + // within/without + bool within = false; + bool without = false; + + std::string opr1, opr2; + double val1, val2; + + if (cond.size() == 2) + { + opr1 = cond[0][0]; + val1 = std::stod(cond[0][1]); + opr2 = cond[1][0]; + val2 = std::stod(cond[1][1]); + + if ((opr1 == "<") or (opr1 == "<=")) + { + if ((opr2 == ">") or (opr2 == ">=")) + { + if (val2 > val1) + without = true; + else { within = true; } + } + } + else if ((opr1 == ">") or (opr1 == ">=")) + { + if ((opr2 == "<") or (opr2 == "<=")) + { + if (val2 > val1) + within = true; + else { without = true; } + } + } + } + + // process conditions, starting with special cases + // within/without + if (within) + { + if ((opr1 == "<") and (opr2 == ">")) + t_msk = xt::where((q_obs < val1) & (q_obs > val2), + 1, t_msk); + else if ((opr1 == "<=") and (opr2 == ">")) + t_msk = xt::where((q_obs <= val1) & (q_obs > val2), + 1, t_msk); + else if ((opr1 == "<") and (opr2 == ">=")) + t_msk = xt::where((q_obs < val1) & (q_obs >= val2), + 1, t_msk); + else if ((opr1 == "<=") and (opr2 == ">=")) + t_msk = xt::where((q_obs <= val1) & (q_obs >= val2), + 1, t_msk); + + if ((opr2 == "<") and (opr1 == ">")) + t_msk = xt::where((q_obs < val2) & (q_obs > val1), + 1, t_msk); + else if ((opr2 == "<=") and (opr1 == ">")) + t_msk = xt::where((q_obs <= val2) & (q_obs > val1), + 1, t_msk); + else if ((opr2 == "<") and (opr1 == ">=")) + t_msk = xt::where((q_obs < val2) & (q_obs >= val1), + 1, t_msk); + else if ((opr2 == "<=") and (opr1 == ">=")) + t_msk = xt::where((q_obs <= val2) & (q_obs >= val1), + 1, t_msk); + } + else if (without) + { + if ((opr1 == "<") and (opr2 == ">")) + t_msk = xt::where((q_obs < val1) | (q_obs > val2), + 1, t_msk); + else if ((opr1 == "<=") and (opr2 == ">")) + t_msk = xt::where((q_obs <= val1) | (q_obs > val2), + 1, t_msk); + else if ((opr1 == "<") and (opr2 == ">=")) + t_msk = xt::where((q_obs < val1) | (q_obs >= val2), + 1, t_msk); + else if ((opr1 == "<=") and (opr2 == ">=")) + t_msk = xt::where((q_obs <= val1) & (q_obs >= val2), + 1, t_msk); + + if ((opr2 == "<") and (opr1 == ">")) + t_msk = xt::where((q_obs < val2) | (q_obs > val1), + 1, t_msk); + else if ((opr2 == "<=") and (opr1 == ">")) + t_msk = xt::where((q_obs <= val2) | (q_obs > val1), + 1, t_msk); + else if ((opr2 == "<") and (opr1 == ">=")) + t_msk = xt::where((q_obs < val2) | (q_obs >= val1), + 1, t_msk); + else if ((opr2 == "<=") and (opr1 == ">=")) + t_msk = xt::where((q_obs <= val2) | (q_obs >= val1), + 1, t_msk); + } + else + { + for (const auto & opr_val : cond) + { + auto opr = opr_val[0]; + + // convert comparison value from string to double + double val = std::stod(opr_val[1]); + + // apply masking condition to given subset + if (opr == "<") + t_msk = xt::where( + q_obs < val, 1, t_msk + ); + else if (opr == ">") + t_msk = xt::where( + q_obs > val, 1, t_msk + ); + else if (opr == "<=") + t_msk = xt::where( + q_obs <= val, 1, t_msk + ); + else if (opr == ">=") + t_msk = xt::where( + q_obs >= val, 1, t_msk + ); + else if (opr == "==") + t_msk = xt::where( + xt::equal(q_obs, val), 1, t_msk + ); + else if (opr == "!=") + t_msk = xt::where( + xt::not_equal(q_obs, val), 1, t_msk + ); + } + } + } + // condition on time index + else if (var == "t") + { + for (const auto & sequence : cond) + { + if (sequence.empty()) + // i.e. t{:} + xt::view(t_msk, xt::all()) = 1; + else + { + // convert string indices to integer indices + std::vector<int> vi; + std::transform(std::begin(sequence), + std::end(sequence), + std::back_inserter(vi), + [](const std::string& s) + { return std::stoi(s); }); + // apply masked indices to given subset + xt::index_view(t_msk, vi) = 1; + } + } + } + } + + return t_msk; + } + } +} + +#endif //EVALHYD_MASKS_HPP diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index c7117fa..c2ffa17 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -31,6 +31,7 @@ add_executable( ../include/evalhyd/evald.hpp ../src/determinist/evaluator.hpp ../src/utils.hpp + ../src/masks.hpp test_probabilist.cpp ../include/evalhyd/evalp.hpp ../src/probabilist/evaluator.h diff --git a/tests/test_determinist.cpp b/tests/test_determinist.cpp index bc5bc45..0d4c43c 100644 --- a/tests/test_determinist.cpp +++ b/tests/test_determinist.cpp @@ -198,6 +198,85 @@ TEST(DeterministTests, TestMasks) } } +TEST(DeterministTests, TestMaskingConditions) +{ + std::vector<std::string> metrics = + {"RMSE", "NSE", "KGE", "KGEPRIME"}; + + // read in data + std::ifstream ifs; + ifs.open("./data/q_obs.csv"); + xt::xtensor<double, 2> observed = xt::load_csv<int>(ifs); + ifs.close(); + + ifs.open("./data/q_prd.csv"); + xt::xtensor<double, 2> predicted = xt::load_csv<double>(ifs); + ifs.close(); + + // generate dummy empty masks required to access next optional argument + xt::xtensor<bool, 2> masks; + + // conditions on streamflow ________________________________________________ + + // compute scores using masking conditions on streamflow to subset whole record + xt::xtensor<std::string, 2> q_conditions = + {{"q{<2000,>3000}"}}; + + std::vector<xt::xarray<double>> metrics_q_conditioned = + evalhyd::evald<2>( + observed, predicted, metrics, + "none", 1, -9, masks, q_conditions + ); + + // compute scores using "NaN-ed" time indices where conditions on streamflow met + std::vector<xt::xarray<double>> metrics_q_preconditioned = + evalhyd::evald<2>( + xt::where((observed < 2000) | (observed > 3000), observed, NAN), + predicted, + metrics + ); + + // check results are identical + for (int m = 0; m < metrics.size(); m++) + { + EXPECT_TRUE( + xt::allclose( + metrics_q_conditioned[m], metrics_q_preconditioned[m] + ) + ) << "Failure for (" << metrics[m] << ")"; + } + + // conditions on temporal indices __________________________________________ + + // compute scores using masking conditions on time indices to subset whole record + xt::xtensor<std::string, 2> t_conditions = + {{"t{0,1,2,3,4,5:97,97,98,99}"}}; + + std::vector<xt::xarray<double>> metrics_t_conditioned = + evalhyd::evald<2>( + observed, predicted, metrics, + "none", 1, -9, masks, t_conditions + ); + + // compute scores on already subset time series + std::vector<xt::xarray<double>> metrics_t_subset = + evalhyd::evald<2>( + xt::view(observed, xt::all(), xt::range(0, 100)), + xt::view(predicted, xt::all(), xt::range(0, 100)), + metrics + ); + + // check results are identical + for (int m = 0; m < metrics.size(); m++) + { + EXPECT_TRUE( + xt::allclose( + metrics_t_conditioned[m], metrics_t_subset[m] + ) + ) << "Failure for (" << metrics[m] << ")"; + } +} + TEST(DeterministTests, TestMissingData) { std::vector<std::string> metrics = diff --git a/tests/test_probabilist.cpp b/tests/test_probabilist.cpp index 66dd066..3746308 100644 --- a/tests/test_probabilist.cpp +++ b/tests/test_probabilist.cpp @@ -172,6 +172,91 @@ TEST(ProbabilistTests, TestMasks) } } +TEST(ProbabilistTests, TestMaskingConditions) +{ + xt::xtensor<double, 2> thresholds = {{690, 534, 445}}; + std::vector<std::string> metrics = + {"BS", "BSS", "BS_CRD", "BS_LBD", "QS", "CRPS"}; + + // read in data + // read in data + std::ifstream ifs; + ifs.open("./data/q_obs.csv"); + xt::xtensor<double, 2> observed = xt::load_csv<int>(ifs); + ifs.close(); + + ifs.open("./data/q_prd.csv"); + xt::xtensor<double, 2> predicted = xt::load_csv<double>(ifs); + ifs.close(); + + // generate dummy empty masks required to access next optional argument + xt::xtensor<bool, 3> masks; + + // conditions on streamflow ________________________________________________ + + // compute scores using masking conditions on streamflow to subset whole record + xt::xtensor<std::string, 2> q_conditions = + {{"q{<2000,>3000}"}}; + + std::vector<xt::xarray<double>> metrics_q_conditioned = + evalhyd::evalp( + observed, + xt::view(predicted, xt::newaxis(), xt::newaxis(), xt::all(), xt::all()), + metrics, thresholds, + masks, q_conditions + ); + + // compute scores using "NaN-ed" time indices where conditions on streamflow met + std::vector<xt::xarray<double>> metrics_q_preconditioned = + evalhyd::evalp( + xt::where((observed < 2000) | (observed > 3000), observed, NAN), + xt::view(predicted, xt::newaxis(), xt::newaxis(), xt::all(), xt::all()), + metrics, thresholds + ); + + // check results are identical + for (int m = 0; m < metrics.size(); m++) + { + EXPECT_TRUE( + xt::allclose( + metrics_q_conditioned[m], metrics_q_preconditioned[m] + ) + ) << "Failure for (" << metrics[m] << ")"; + } + + // conditions on temporal indices __________________________________________ + + // compute scores using masking conditions on time indices to subset whole record + xt::xtensor<std::string, 2> t_conditions = + {{"t{0,1,2,3,4,5:97,97,98,99}"}}; + + std::vector<xt::xarray<double>> metrics_t_conditioned = + evalhyd::evalp( + observed, + xt::view(predicted, xt::newaxis(), xt::newaxis(), xt::all(), xt::all()), + metrics, thresholds, + masks, t_conditions + ); + + // compute scores on already subset time series + std::vector<xt::xarray<double>> metrics_t_subset = + evalhyd::evalp( + xt::view(observed, xt::all(), xt::range(0, 100)), + xt::view(predicted, xt::newaxis(), xt::newaxis(), xt::all(), xt::range(0, 100)), + metrics, thresholds + ); + + // check results are identical + for (int m = 0; m < metrics.size(); m++) + { + EXPECT_TRUE( + xt::allclose( + metrics_t_conditioned[m], metrics_t_subset[m] + ) + ) << "Failure for (" << metrics[m] << ")"; + } +} + TEST(ProbabilistTests, TestMissingData) { xt::xtensor<double, 2> thresholds -- GitLab