#ifndef EVALHYD_UTILS_HPP
#define EVALHYD_UTILS_HPP

#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <stdexcept>
#include <xtensor/xtensor.hpp>
#include <xtensor/xrandom.hpp>

namespace evalhyd
{
    namespace utils
    {
        /// Procedure to determine based on a list of metrics which elements
        /// and which metrics (and their associated elements) require to be
        /// pre-computed for memoisation purposes.
        ///
        /// \param [in] metrics:
        ///     Vector of strings for the metric(s) to be computed.
        /// \param [in] elements:
        ///     Map between metrics and their required computation elements.
        /// \param [in] dependencies:
        ///     Map between metrics and their dependencies.
        /// \param [out] required_elements:
        ///     Set of elements identified as required to be pre-computed.
        /// \param [out] required_dependencies:
        ///     Set of metrics identified as required to be pre-computed.
        inline void find_requirements (
                const std::vector<std::string>& metrics,
                std::unordered_map<std::string, std::vector<std::string>>& elements,
                std::unordered_map<std::string, std::vector<std::string>>& dependencies,
                std::vector<std::string>& required_elements,
                std::vector<std::string>& required_dependencies
        )
        {
            std::unordered_set<std::string> found_elements;
            std::unordered_set<std::string> found_dependencies;

            for (const auto& metric : metrics)
            {
                // add elements to pre-computation set
                for (const auto& element : elements[metric])
                    if (found_elements.find(element) == found_elements.end())
                    {
                        found_elements.insert(element);
                        required_elements.push_back(element);
                    }

                // add metric dependencies to pre-computation set
                if (dependencies.find(metric) != dependencies.end())
                {
                    for (const auto& dependency : dependencies[metric])
                    {
                        if (found_dependencies.find(dependency) == found_dependencies.end())
                        {
                            found_dependencies.insert(dependency);
                            required_dependencies.push_back(dependency);
                        }
                        // add dependency elements to pre-computation set
                        for (const auto& element : elements[dependency])
                            if (found_elements.find(element) == found_elements.end())
                            {
                                found_elements.insert(element);
                                required_elements.push_back(element);
                            }
                    }
                }
            }
        }

        /// Procedure to check that all elements in the list of metrics are
        /// valid metrics.
        ///
        /// \param [in] requested_metrics:
        ///     Vector of strings for the metric(s) to be computed.
        /// \param [in] valid_metrics:
        ///     Vector of strings for the metric(s) to can be computed.
        inline void check_metrics (
                const std::vector<std::string>& requested_metrics,
                const std::vector<std::string>& valid_metrics
        )
        {
            for (const auto& metric : requested_metrics)
            {
                if (std::find(valid_metrics.begin(), valid_metrics.end(), metric)
                        == valid_metrics.end())
                {
                    throw std::runtime_error(
                            "invalid evaluation metric: " + metric
                    );
                }
            }
        }

        /// Procedure to check that all elements for a bootstrap experiment
        /// are provided and valid.
        ///
        /// \param [in] bootstrap:
        ///     Map of parameters for the bootstrap experiment.
        inline void check_bootstrap (
                const std::unordered_map<std::string, int>& bootstrap
        )
        {
            // check n_samples
            if (bootstrap.find("n_samples") == bootstrap.end())
                throw std::runtime_error(
                        "number of samples missing for bootstrap"
                );
            // check len_sample
            if (bootstrap.find("len_sample") == bootstrap.end())
                throw std::runtime_error(
                        "length of sample missing for bootstrap"
                );
            // check summary
            if (bootstrap.find("summary") == bootstrap.end())
                throw std::runtime_error(
                        "summary missing for bootstrap"
                );
            auto s = bootstrap.find("summary")->second;
            // TODO: change upper bound when mean+stddev and quantiles implemented
            if ((s < 0) or (s > 0))
                throw std::runtime_error(
                        "invalid value for bootstrap summary"
                );
            // set seed
            if (bootstrap.find("seed") == bootstrap.end())
                xt::random::seed(time(nullptr));
            else
                xt::random::seed(bootstrap.find("seed")->second);
        }

        namespace evalp
        {
            /// Procedure to check that optional parameters are provided
            /// as arguments when required metrics need them.
            ///
            /// \param [in] metrics:
            ///     Vector of strings for the metric(s) to be computed.
            /// \param [in] thresholds:
            ///     Array of thresholds for metrics based on exceedance events.
            inline void check_optionals (
                    const std::vector<std::string>& metrics,
                    const xt::xtensor<double, 1>& thresholds
            )
            {
                std::vector<std::string>threshold_metrics =
                        {"BS", "BS_CRD", "BS_LBD", "BSS"};

                for (const auto& metric : metrics)
                {
                    if (std::find(threshold_metrics.begin(), threshold_metrics.end(),
                                  metric) != threshold_metrics.end())
                        if (thresholds.size() < 1)
                            throw std::runtime_error(
                                    "missing thresholds *q_thr* required to "
                                    "compute " + metric
                            );
                }
            }
        }
    }
}

#endif //EVALHYD_UTILS_HPP