// Copyright (c) 2023, INRAE.
// Distributed under the terms of the GPL-3 Licence.
// The full licence is in the file LICENCE, distributed with this software.

#include <fstream>
#include <vector>
#include <tuple>
#include <array>

#include <gtest/gtest.h>

#include <xtensor/xtensor.hpp>
#include <xtensor/xview.hpp>
#include <xtensor/xmanipulation.hpp>
#include <xtensor/xcsv.hpp>

#include "evalhyd/evald.hpp"

#ifndef EVALHYD_DATA_DIR
#error "need to define data directory"
#endif

using namespace xt::placeholders;  // required for `_` to work


std::vector<std::string> all_metrics_d = {
        "RMSE", "NSE", "KGE", "KGEPRIME"
};

std::tuple<xt::xtensor<double, 2>, xt::xtensor<double, 2>> load_data_d()
{
    // read in data
    std::ifstream ifs;
    ifs.open(EVALHYD_DATA_DIR "/q_obs.csv");
    xt::xtensor<double, 2> observed = xt::load_csv<int>(ifs);
    ifs.close();

    ifs.open(EVALHYD_DATA_DIR "/q_prd.csv");
    xt::xtensor<double, 2> predicted = xt::view(
            xt::load_csv<double>(ifs), xt::range(0, 5), xt::all()
    );
    ifs.close();

    return std::make_tuple(observed, predicted);
}

TEST(DeterministTests, TestMetrics)
{
    // read in data
    xt::xtensor<double, 2> observed;
    xt::xtensor<double, 2> predicted;
    std::tie(observed, predicted) = load_data_d();

    // compute scores (with 2D tensors)
    std::vector<xt::xarray<double>> metrics =
            evalhyd::evald(
                    observed, predicted, {"RMSE", "NSE", "KGE", "KGEPRIME"}
            );

    // check results on all metrics
    xt::xtensor<double, 3> rmse =
            {{{777.034272}},
             {{776.878479}},
             {{777.800217}},
             {{778.151082}},
             {{778.61487 }}};
    EXPECT_TRUE(xt::allclose(metrics[0], rmse));

    xt::xtensor<double, 3> nse =
            {{{0.718912}},
             {{0.719025}},
             {{0.718358}},
             {{0.718104}},
             {{0.717767}}};
    EXPECT_TRUE(xt::allclose(metrics[1], nse));

    xt::xtensor<double, 3> kge =
            {{{0.748088}},
             {{0.746106}},
             {{0.744111}},
             {{0.743011}},
             {{0.741768}}};
    EXPECT_TRUE(xt::allclose(metrics[2], kge));

    xt::xtensor<double, 3> kgeprime =
            {{{0.813141}},
             {{0.812775}},
             {{0.812032}},
             {{0.811787}},
             {{0.811387}}};
    EXPECT_TRUE(xt::allclose(metrics[3], kgeprime));
}

TEST(DeterministTests, TestTransform)
{
    // read in data
    xt::xtensor<double, 2> observed;
    xt::xtensor<double, 2> predicted;
    std::tie(observed, predicted) = load_data_d();

    // compute and check results on square-rooted streamflow series
    std::vector<xt::xarray<double>> metrics =
            evalhyd::evald(observed, predicted, {"NSE"}, "sqrt");

    xt::xtensor<double, 3> nse_sqrt =
            {{{0.882817}},
             {{0.883023}},
             {{0.883019}},
             {{0.883029}},
             {{0.882972}}};
    EXPECT_TRUE(xt::all(xt::isclose(metrics[0], nse_sqrt)));

    // compute and check results on inverted streamflow series
    metrics = evalhyd::evald(observed, predicted, {"NSE"}, "inv");

    xt::xtensor<double, 3> nse_inv =
            {{{0.737323}},
             {{0.737404}},
             {{0.737429}},
             {{0.737546}},
             {{0.737595}}};
    EXPECT_TRUE(xt::all(xt::isclose(metrics[0], nse_inv)));

    // compute and check results on square-rooted streamflow series
    metrics = evalhyd::evald(observed, predicted, {"NSE"}, "log");

    xt::xtensor<double, 3> nse_log =
            {{{0.893344}},
             {{0.893523}},
             {{0.893585}},
             {{0.893758}},
             {{0.893793}}};
    EXPECT_TRUE(xt::all(xt::isclose(metrics[0], nse_log)));

    // compute and check results on power-transformed streamflow series
    metrics = evalhyd::evald(observed, predicted, {"NSE"}, "pow", 0.2);

    xt::xtensor<double, 3> nse_pow =
            {{{0.899207}},
             {{0.899395}},
             {{0.899451}},
             {{0.899578}},
             {{0.899588}}};
    EXPECT_TRUE(xt::all(xt::isclose(metrics[0], nse_pow)));

}

TEST(DeterministTests, TestMasks)
{
    // read in data
    xt::xtensor<double, 2> observed;
    xt::xtensor<double, 2> predicted;
    std::tie(observed, predicted) = load_data_d();

    // generate temporal subset by dropping 20 first time steps
    xt::xtensor<bool, 3> masks =
            xt::ones<bool>({std::size_t {predicted.shape(0)},
                            std::size_t {1},
                            std::size_t {observed.size()}});
    xt::view(masks, xt::all(), 0, xt::range(0, 20)) = 0;

    // compute scores using masks to subset whole record
    std::vector<xt::xarray<double>> metrics_masked =
            evalhyd::evald(observed, predicted, all_metrics_d, {}, {}, {}, masks);

    // compute scores on pre-computed subset of whole record
    xt::xtensor<double, 2> obs = xt::view(observed, xt::all(), xt::range(20, _));
    xt::xtensor<double, 2> prd = xt::view(predicted, xt::all(), xt::range(20, _));

    std::vector<xt::xarray<double>> metrics_subset =
            evalhyd::evald(obs, prd, all_metrics_d);

    // check results are identical
    for (std::size_t m = 0; m < all_metrics_d.size(); m++)
    {
        EXPECT_TRUE(xt::all(xt::isclose(metrics_masked[m], metrics_subset[m])))
        << "Failure for (" << all_metrics_d[m] << ")";
    }
}

TEST(DeterministTests, TestMaskingConditions)
{
    // read in data
    xt::xtensor<double, 2> observed;
    xt::xtensor<double, 2> predicted;
    std::tie(observed, predicted) = load_data_d();

    // generate dummy empty masks required to access next optional argument
    xt::xtensor<bool, 3> masks;

    // conditions on streamflow values _________________________________________

    // compute scores using masking conditions on streamflow to subset whole record
    xt::xtensor<std::array<char, 32>, 2> q_conditions = {{
            {{std::array<char, 32>{"q_obs{<2000,>3000}"}}},
            {{std::array<char, 32>{"q_obs{<2000,>3000}"}}},
            {{std::array<char, 32>{"q_obs{<2000,>3000}"}}},
            {{std::array<char, 32>{"q_obs{<2000,>3000}"}}},
            {{std::array<char, 32>{"q_obs{<2000,>3000}"}}}
    }};

    std::vector<xt::xarray<double>> metrics_q_conditioned =
            evalhyd::evald(
                    observed, predicted, all_metrics_d,
                    {}, {}, {}, masks, q_conditions
            );

    // compute scores using "NaN-ed" time indices where conditions on streamflow met
    std::vector<xt::xarray<double>> metrics_q_preconditioned =
            evalhyd::evald(
                    xt::eval(xt::where((observed < 2000) | (observed > 3000), observed, NAN)),
                    predicted,
                    all_metrics_d
            );

    // check results are identical
    for (std::size_t m = 0; m < all_metrics_d.size(); m++)
    {
        EXPECT_TRUE(
                xt::all(xt::isclose(
                        metrics_q_conditioned[m], metrics_q_preconditioned[m]
                ))
        ) << "Failure for (" << all_metrics_d[m] << ")";
    }

    // conditions on streamflow statistics _____________________________________

    // compute scores using masking conditions on streamflow to subset whole record
    xt::xtensor<std::array<char, 32>, 2> q_conditions_ = {{
            {{std::array<char, 32>{"q_obs{>=mean}"}}},
            {{std::array<char, 32>{"q_obs{>=mean}"}}},
            {{std::array<char, 32>{"q_obs{>=mean}"}}},
            {{std::array<char, 32>{"q_obs{>=mean}"}}},
            {{std::array<char, 32>{"q_obs{>=mean}"}}}
    }};

    double mean = xt::mean(observed, {1})();

    std::vector<xt::xarray<double>> metrics_q_conditioned_ =
            evalhyd::evald(
                    observed, predicted, all_metrics_d,
                    {}, {}, {}, masks, q_conditions_
            );

    // compute scores using "NaN-ed" time indices where conditions on streamflow met
    std::vector<xt::xarray<double>> metrics_q_preconditioned_ =
            evalhyd::evald(
                    xt::eval(xt::where(observed >= mean, observed, NAN)),
                    predicted,
                    all_metrics_d
            );

    // check results are identical
    for (std::size_t m = 0; m < all_metrics_d.size(); m++)
    {
        EXPECT_TRUE(
                xt::all(xt::isclose(
                        metrics_q_conditioned_[m], metrics_q_preconditioned_[m]
                ))
        ) << "Failure for (" << all_metrics_d[m] << ")";
    }

    // conditions on temporal indices __________________________________________

    // compute scores using masking conditions on time indices to subset whole record
    xt::xtensor<std::array<char, 32>, 2> t_conditions = {{
            {{std::array<char, 32>{"t{0,1,2,3,4,5:97,97,98,99}"}}},
            {{std::array<char, 32>{"t{0,1,2,3,4,5:97,97,98,99}"}}},
            {{std::array<char, 32>{"t{0,1,2,3,4,5:97,97,98,99}"}}},
            {{std::array<char, 32>{"t{0,1,2,3,4,5:97,97,98,99}"}}},
            {{std::array<char, 32>{"t{0,1,2,3,4,5:97,97,98,99}"}}}
    }};

    std::vector<xt::xarray<double>> metrics_t_conditioned =
            evalhyd::evald(
                    observed, predicted, all_metrics_d,
                    {}, {}, {}, masks, t_conditions
            );

    // compute scores on already subset time series
    std::vector<xt::xarray<double>> metrics_t_subset =
            evalhyd::evald(
                    xt::eval(xt::view(observed, xt::all(), xt::range(0, 100))),
                    xt::eval(xt::view(predicted, xt::all(), xt::range(0, 100))),
                    all_metrics_d
            );

    // check results are identical
    for (std::size_t m = 0; m < all_metrics_d.size(); m++)
    {
        EXPECT_TRUE(
                xt::all(xt::isclose(
                        metrics_t_conditioned[m], metrics_t_subset[m]
                ))
        ) << "Failure for (" << all_metrics_d[m] << ")";
    }
}

TEST(DeterministTests, TestMissingData)
{
    // read in data
    xt::xtensor<double, 2> observed;
    xt::xtensor<double, 2> predicted;
    std::tie(observed, predicted) = load_data_d();

    // add some missing observations artificially by assigning NaN values
    xt::view(observed, xt::all(), xt::range(0, 20)) = NAN;
    // add some missing predictions artificially by assigning NaN values
    xt::view(observed, 0, xt::range(20, 23)) = NAN;
    xt::view(observed, 1, xt::range(20, 26)) = NAN;
    xt::view(observed, 2, xt::range(20, 29)) = NAN;
    xt::view(observed, 3, xt::range(20, 32)) = NAN;
    xt::view(observed, 4, xt::range(20, 35)) = NAN;

    // compute metrics with observations containing NaN values
    std::vector<xt::xarray<double>> metrics_nan =
            evalhyd::evald(observed, predicted, all_metrics_d);

    for (std::size_t m = 0; m < all_metrics_d.size(); m++)
    {
        for (std::size_t p = 0; p < predicted.shape(0); p++)
        {
            // compute metrics on subset of observations and predictions (i.e.
            // eliminating region with NaN in observations or predictions manually)
            xt::xtensor<double, 1> obs =
                    xt::view(observed, 0, xt::range(20+(3*(p+1)), _));
            xt::xtensor<double, 1> prd =
                    xt::view(predicted, p, xt::range(20+(3*(p+1)), _));

            std::vector<xt::xarray<double>> metrics_sbs =
                    evalhyd::evald(
                            xt::eval(xt::view(obs, xt::newaxis(), xt::all())),
                            xt::eval(xt::view(prd, xt::newaxis(), xt::all())),
                            {all_metrics_d[m]}
                    );

            // compare to check results are the same
            EXPECT_TRUE(
                    xt::all(xt::isclose(
                            xt::view(metrics_nan[m], p),
                            metrics_sbs[0]
                    ))
            ) << "Failure for (" << all_metrics_d[m] << ")";
        }
    }
}

TEST(DeterministTests, TestBootstrap)
{
    // read in data
    std::ifstream ifs;

    ifs.open(EVALHYD_DATA_DIR "/q_obs_1yr.csv");
    xt::xtensor<std::string, 1> x_dts = xt::squeeze(xt::load_csv<std::string>(ifs, ',', 0, 1));
    ifs.close();
    std::vector<std::string> datetimes (x_dts.begin(), x_dts.end());

    ifs.open(EVALHYD_DATA_DIR "/q_obs_1yr.csv");
    xt::xtensor<double, 1> observed = xt::squeeze(xt::load_csv<double>(ifs, ',', 1));
    ifs.close();

    ifs.open(EVALHYD_DATA_DIR "/q_prd_1yr.csv");
    xt::xtensor<double, 2> predicted = xt::load_csv<double>(ifs, ',', 1);
    ifs.close();

    // compute metrics via bootstrap
    std::unordered_map<std::string, int> bootstrap =
            {{"n_samples", 10}, {"len_sample", 3}, {"summary", 0}};

    std::vector<xt::xarray<double>> metrics_bts =
            evalhyd::evald(
                    xt::eval(xt::view(observed, xt::newaxis(), xt::all())),
                    predicted,
                    all_metrics_d,
                    {},  // transform
                    {},  // exponent
                    {},  // epsilon
                    xt::xtensor<bool, 3>({}),  // t_msk
                    xt::xtensor<std::array<char, 32>, 2>({}),  // m_cdt
                    bootstrap,
                    datetimes
            );

    // compute metrics by repeating year of data 3 times
    // (since there is only one year of data, and that the bootstrap works on
    //  one-year blocks, it can only select that given year to form samples,
    //  and the length of the sample corresponds to how many times this year
    //  is repeated in the sample, so that repeating the input data this many
    //  times should result in the same numerical results)
    xt::xtensor<double, 1> observed_x3 =
            xt::concatenate(xt::xtuple(observed, observed, observed), 0);
    xt::xtensor<double, 2> predicted_x3 =
            xt::concatenate(xt::xtuple(predicted, predicted, predicted), 1);

    std::vector<xt::xarray<double>> metrics_rep =
            evalhyd::evald(
                    xt::eval(xt::view(observed_x3, xt::newaxis(), xt::all())),
                    predicted_x3,
                    all_metrics_d
            );

    // check results are identical
    for (std::size_t m = 0; m < all_metrics_d.size(); m++)
    {
        EXPECT_TRUE(
                xt::all(xt::isclose(
                        metrics_bts[m], metrics_rep[m]
                ))
        ) << "Failure for (" << all_metrics_d[m] << ")";
    }
}