From 2b0694d506465f9230056274fc9605fd3024d30d Mon Sep 17 00:00:00 2001
From: Thibault Hallouin <thibault.hallouin@inrae.fr>
Date: Tue, 24 May 2022 16:18:08 +0200
Subject: [PATCH] add check on metrics validity

---
 include/evalhyd/probabilist.hpp |  7 +++++++
 include/evalhyd/utils.hpp       | 26 ++++++++++++++++++++++++++
 2 files changed, 33 insertions(+)

diff --git a/include/evalhyd/probabilist.hpp b/include/evalhyd/probabilist.hpp
index d70ab34..051f4dd 100644
--- a/include/evalhyd/probabilist.hpp
+++ b/include/evalhyd/probabilist.hpp
@@ -39,6 +39,13 @@ namespace evalhyd
                 const xt::xtensor<double, 1>& q_thr
         )
         {
+            // check that the metrics to be computed are valid
+            utils::check_metrics(
+                    metrics,
+                    {"bs", "bss", "bs_crd", "bs_lbd", "qs", "crps"}
+            );
+
+            // instantiate probabilist evaluator
             eh::probabilist::Evaluator evaluator(q_obs, q_frc, q_thr);
 
             // declare maps for memoisation purposes
diff --git a/include/evalhyd/utils.hpp b/include/evalhyd/utils.hpp
index bc8fbb6..4caa7e5 100644
--- a/include/evalhyd/utils.hpp
+++ b/include/evalhyd/utils.hpp
@@ -4,6 +4,7 @@
 #include <unordered_map>
 #include <unordered_set>
 #include <vector>
+#include <stdexcept>
 #include <xtensor/xtensor.hpp>
 
 namespace evalhyd
@@ -66,6 +67,31 @@ namespace evalhyd
                 }
             }
         }
+
+        /// Procedure to check that all elements in the list of metrics are
+        /// valid metrics.
+        ///
+        /// \param [in] requested_metrics:
+        ///     Vector of strings for the metric(s) to be computed.
+        /// \param [in] valid_metrics:
+        ///     Vector of strings for the metric(s) to can be computed.
+        inline void check_metrics (
+                const std::vector<std::string>& requested_metrics,
+                const std::vector<std::string>& valid_metrics
+        )
+        {
+            for (const auto& metric : requested_metrics)
+            {
+                if (std::find(valid_metrics.begin(), valid_metrics.end(), metric)
+                        == valid_metrics.end())
+                {
+                    throw std::runtime_error(
+                            "invalid evaluation metric: " + metric
+                    );
+                }
+            }
+        }
+
     }
 }
 
-- 
GitLab