#include <fstream>
#include <vector>
#include <array>
#include <gtest/gtest.h>
#include <xtensor/xtensor.hpp>
#include <xtensor/xmanipulation.hpp>
#include <xtensor/xcsv.hpp>

#include "evalhyd/evalp.hpp"

using namespace xt::placeholders;  // required for `_` to work

TEST(ProbabilistTests, TestBrier)
{
    // read in data
    std::ifstream ifs;
    ifs.open("./data/q_obs.csv");
    xt::xtensor<double, 1> observed = xt::squeeze(xt::load_csv<int>(ifs));
    ifs.close();

    ifs.open("./data/q_prd.csv");
    xt::xtensor<double, 2> predicted = xt::load_csv<double>(ifs);
    ifs.close();

    // compute scores
    xt::xtensor<double, 2> thresholds = {{690, 534, 445, NAN}};

    std::vector<xt::xarray<double>> metrics =
            evalhyd::evalp(
                    // shape: (sites [1], time [t])
                    xt::view(observed, xt::newaxis(), xt::all()),
                    // shape: (sites [1], lead times [1], members [m], time [t])
                    xt::view(predicted, xt::newaxis(), xt::newaxis(), xt::all(), xt::all()),
                    {"BS", "BSS", "BS_CRD", "BS_LBD"},
                    thresholds
            );

    // check results
    // Brier scores
    xt::xtensor<double, 4> bs =
            {{{{0.10615136, 0.07395622, 0.08669186, NAN}}}};
    EXPECT_TRUE(
            xt::sum(xt::isclose(metrics[0], bs, 1e-05, 1e-08, true))
            == xt::xscalar<double>(4)
    );

    // Brier skill scores
    xt::xtensor<double, 4> bss =
            {{{{0.5705594, 0.6661165, 0.5635126, NAN}}}};
    EXPECT_TRUE(
            xt::sum(xt::isclose(metrics[1], bss, 1e-05, 1e-08, true))
            == xt::xscalar<double>(4)
    );

    // Brier calibration-refinement decompositions
    xt::xtensor<double, 5> bs_crd =
            {{{{{0.011411758, 0.1524456, 0.2471852},
                {0.005532413, 0.1530793, 0.2215031},
                {0.010139431, 0.1220601, 0.1986125},
                {NAN, NAN, NAN}}}}};
    EXPECT_TRUE(
            xt::sum(xt::isclose(metrics[2], bs_crd, 1e-05, 1e-08, true))
            == xt::xscalar<double>(12)
    );

    // Brier likelihood-base rate decompositions
    xt::xtensor<double, 5> bs_lbd =
            {{{{{0.012159881, 0.1506234, 0.2446149},
                {0.008031746, 0.1473869, 0.2133114},
                {0.017191279, 0.1048221, 0.1743227},
                {NAN, NAN, NAN}}}}};
    EXPECT_TRUE(
            xt::sum(xt::isclose(metrics[3], bs_lbd, 1e-05, 1e-08, true))
            == xt::xscalar<double>(12)
    );
}

TEST(ProbabilistTests, TestQuantiles)
{
    // read in data
    std::ifstream ifs;
    ifs.open("./data/q_obs.csv");
    xt::xtensor<double, 1> observed = xt::squeeze(xt::load_csv<int>(ifs));
    ifs.close();

    ifs.open("./data/q_prd.csv");
    xt::xtensor<double, 2> predicted = xt::load_csv<double>(ifs);
    ifs.close();

    // compute scores
    std::vector<xt::xarray<double>> metrics =
            evalhyd::evalp(
                    // shape: (sites [1], time [t])
                    xt::view(observed, xt::newaxis(), xt::all()),
                    // shape: (sites [1], lead times [1], members [m], time [t])
                    xt::view(predicted, xt::newaxis(), xt::newaxis(), xt::all(), xt::all()),
                    {"QS", "CRPS"}
            );

    // check results
    // Quantile scores
    xt::xtensor<double, 4> qs =
            {{{{ 345.91578 ,  345.069256,  343.129359,  340.709869,  338.281598,
                 335.973535,  333.555157,  330.332426,  327.333539,  324.325996,
                 321.190082,  318.175117,  315.122186,  311.97205 ,  308.644942,
                 305.612169,  302.169552,  298.445956,  294.974648,  291.273807,
                 287.724586,  284.101905,  280.235592,  276.21865 ,  272.501484,
                 268.652733,  264.740168,  260.8558  ,  256.90329 ,  252.926292,
                 248.931239,  244.986396,  240.662998,  236.328964,  232.089785,
                 227.387089,  222.976008,  218.699975,  214.099678,  209.67252 ,
                 205.189587,  200.395746,  195.2372  ,  190.080139,  185.384244,
                 180.617858,  174.58323 ,  169.154093,  163.110932,  156.274796,
                 147.575315}}}};
    EXPECT_TRUE(xt::allclose(metrics[0], qs));

    // Continuous ranked probability scores
    xt::xtensor<double, 3> crps =
            {{{252.956919}}};
    EXPECT_TRUE(xt::allclose(metrics[1], crps));
}

TEST(ProbabilistTests, TestMasks)
{
    // read in data
    std::ifstream ifs;
    ifs.open("./data/q_obs.csv");
    xt::xtensor<double, 1> observed = xt::squeeze(xt::load_csv<int>(ifs));
    ifs.close();

    ifs.open("./data/q_prd.csv");
    xt::xtensor<double, 2> predicted = xt::load_csv<double>(ifs);
    ifs.close();

    // generate temporal subset by dropping 20 first time steps
    xt::xtensor<double, 4> masks =
            xt::ones<bool>({std::size_t {1}, std::size_t {1}, std::size_t {1},
                            std::size_t {observed.size()}});
    xt::view(masks, 0, xt::all(), 0, xt::range(0, 20)) = 0;

    // compute scores using masks to subset whole record
    xt::xtensor<double, 2> thresholds = {{690, 534, 445}};
    std::vector<std::string> metrics =
            {"BS", "BSS", "BS_CRD", "BS_LBD", "QS", "CRPS"};

    std::vector<xt::xarray<double>> metrics_masked =
            evalhyd::evalp(
                    // shape: (sites [1], time [t])
                    xt::view(observed, xt::newaxis(), xt::all()),
                    // shape: (sites [1], lead times [1], members [m], time [t])
                    xt::view(predicted, xt::newaxis(), xt::newaxis(), xt::all(), xt::all()),
                    metrics,
                    thresholds,
                    // shape: (sites [1], lead times [1], subsets [1], time [t])
                    masks
            );

    // compute scores on pre-computed subset of whole record
    std::vector<xt::xarray<double>> metrics_subset =
            evalhyd::evalp(
                    // shape: (sites [1], time [t-20])
                    xt::view(observed, xt::newaxis(), xt::range(20, _)),
                    // shape: (sites [1], lead times [1], members [m], time [t-20])
                    xt::view(predicted, xt::newaxis(), xt::newaxis(), xt::all(), xt::range(20, _)),
                    metrics,
                    thresholds
            );

    // check results are identical
    for (int m = 0; m < metrics.size(); m++)
    {
        EXPECT_TRUE(xt::allclose(metrics_masked[m], metrics_subset[m]))
        << "Failure for (" << metrics[m] << ")";
    }
}

TEST(ProbabilistTests, TestMaskingConditions)
{
    xt::xtensor<double, 2> thresholds = {{690, 534, 445}};
    std::vector<std::string> metrics =
            {"BS", "BSS", "BS_CRD", "BS_LBD", "QS", "CRPS"};

    // read in data
    // read in data
    std::ifstream ifs;
    ifs.open("./data/q_obs.csv");
    xt::xtensor<double, 2> observed = xt::load_csv<int>(ifs);
    ifs.close();

    ifs.open("./data/q_prd.csv");
    xt::xtensor<double, 2> predicted = xt::load_csv<double>(ifs);
    ifs.close();

    // generate dummy empty masks required to access next optional argument
    xt::xtensor<bool, 4> masks;

    // conditions on streamflow values _________________________________________

    // compute scores using masking conditions on streamflow to subset whole record
    xt::xtensor<std::array<char, 32>, 2> q_conditions = {
            {{"q_obs{<2000,>3000}"}}
    };

    std::vector<xt::xarray<double>> metrics_q_conditioned =
            evalhyd::evalp(
                    observed,
                    xt::view(predicted, xt::newaxis(), xt::newaxis(), xt::all(), xt::all()),
                    metrics, thresholds,
                    masks, q_conditions
            );

    // compute scores using "NaN-ed" time indices where conditions on streamflow met
    std::vector<xt::xarray<double>> metrics_q_preconditioned =
            evalhyd::evalp(
                    xt::where((observed < 2000) | (observed > 3000), observed, NAN),
                    xt::view(predicted, xt::newaxis(), xt::newaxis(), xt::all(), xt::all()),
                    metrics, thresholds
            );

    // check results are identical
    for (int m = 0; m < metrics.size(); m++)
    {
        EXPECT_TRUE(
                xt::allclose(
                        metrics_q_conditioned[m], metrics_q_preconditioned[m]
                )
        ) << "Failure for (" << metrics[m] << ")";
    }

    // conditions on streamflow statistics _____________________________________

    // compute scores using masking conditions on streamflow to subset whole record
    xt::xtensor<std::array<char, 32>, 2> q_conditions_ = {
            {{"q_prd_mean{>=median}"}}
    };

    auto q_prd_mean = xt::mean(predicted, {0}, xt::keep_dims);
    double median = xt::median(q_prd_mean);

    std::vector<xt::xarray<double>> metrics_q_conditioned_ =
            evalhyd::evalp(
                    observed,
                    xt::view(predicted, xt::newaxis(), xt::newaxis(), xt::all(), xt::all()),
                    metrics, thresholds,
                    masks, q_conditions_
            );

    // compute scores using "NaN-ed" time indices where conditions on streamflow met
    std::vector<xt::xarray<double>> metrics_q_preconditioned_ =
            evalhyd::evalp(
                    xt::where(q_prd_mean >= median, observed, NAN),
                    xt::view(predicted, xt::newaxis(), xt::newaxis(), xt::all(), xt::all()),
                    metrics, thresholds
            );

    // check results are identical
    for (int m = 0; m < metrics.size(); m++)
    {
        EXPECT_TRUE(
                xt::allclose(
                        metrics_q_conditioned_[m], metrics_q_preconditioned_[m]
                )
        ) << "Failure for (" << metrics[m] << ")";
    }

    // conditions on temporal indices __________________________________________

    // compute scores using masking conditions on time indices to subset whole record
    xt::xtensor<std::array<char, 32>, 2> t_conditions = {
            {{"t{0,1,2,3,4,5:97,97,98,99}"}}
    };

    std::vector<xt::xarray<double>> metrics_t_conditioned =
            evalhyd::evalp(
                    observed,
                    xt::view(predicted, xt::newaxis(), xt::newaxis(), xt::all(), xt::all()),
                    metrics, thresholds,
                    masks, t_conditions
            );

    // compute scores on already subset time series
    std::vector<xt::xarray<double>> metrics_t_subset =
            evalhyd::evalp(
                    xt::view(observed, xt::all(), xt::range(0, 100)),
                    xt::view(predicted, xt::newaxis(), xt::newaxis(), xt::all(), xt::range(0, 100)),
                    metrics, thresholds
            );

    // check results are identical
    for (int m = 0; m < metrics.size(); m++)
    {
        EXPECT_TRUE(
                xt::allclose(
                        metrics_t_conditioned[m], metrics_t_subset[m]
                )
        ) << "Failure for (" << metrics[m] << ")";
    }
}

TEST(ProbabilistTests, TestMissingData)
{
    xt::xtensor<double, 2> thresholds
        {{ 4., 5. }};
    std::vector<std::string> metrics =
            {"BS", "BSS", "BS_CRD", "BS_LBD", "QS", "CRPS"};

    // compute metrics on series with NaN
    xt::xtensor<double, 4> forecast_nan {{
        // leadtime 1
        {{ 5.3, 4.2, 5.7, 2.3, NAN },
         { 4.3, 4.2, 4.7, 4.3, NAN },
         { 5.3, 5.2, 5.7, 2.3, NAN }},
        // leadtime 2
        {{ NAN, 4.2, 5.7, 2.3, 3.1 },
         { NAN, 4.2, 4.7, 4.3, 3.3 },
         { NAN, 5.2, 5.7, 2.3, 3.9 }}
    }};

    xt::xtensor<double, 2> observed_nan
        {{ 4.7, 4.3, NAN, 2.7, 4.1 }};

    std::vector<xt::xarray<double>> metrics_nan =
        eh::evalp(
                observed_nan,
                forecast_nan,
                metrics,
                thresholds
        );

    // compute metrics on manually subset series (one leadtime at a time)
    xt::xtensor<double, 4> forecast_pp1 {{
        // leadtime 1
        {{ 5.3, 4.2, 2.3 },
         { 4.3, 4.2, 4.3 },
         { 5.3, 5.2, 2.3 }},
    }};

    xt::xtensor<double, 2> observed_pp1
        {{ 4.7, 4.3, 2.7 }};

    std::vector<xt::xarray<double>> metrics_pp1 =
        eh::evalp(
                observed_pp1,
                forecast_pp1,
                metrics,
                thresholds
        );

    xt::xtensor<double, 4> forecast_pp2 {{
        // leadtime 2
        {{ 4.2, 2.3, 3.1 },
         { 4.2, 4.3, 3.3 },
         { 5.2, 2.3, 3.9 }}
    }};

    xt::xtensor<double, 2> observed_pp2
        {{ 4.3, 2.7, 4.1 }};

    std::vector<xt::xarray<double>> metrics_pp2 =
        eh::evalp(
                observed_pp2,
                forecast_pp2,
                metrics,
                thresholds
        );

    // check that numerical results are identical
    for (int m = 0; m < metrics.size(); m++) {
        // for leadtime 1
        EXPECT_TRUE(
                xt::allclose(
                        xt::view(metrics_nan[m], xt::all(), 0),
                        xt::view(metrics_pp1[m], xt::all(), 0)
                )
        ) << "Failure for (" << metrics[m] << ", " << "leadtime 1)";
        
        // for leadtime 2
        EXPECT_TRUE(
                xt::allclose(
                        xt::view(metrics_nan[m], xt::all(), 1),
                        xt::view(metrics_pp2[m], xt::all(), 0)
                )
        ) << "Failure for (" << metrics[m] << ", " << "leadtime 2)";
    }
}