#include <xtensor/xmath.hpp>
#include <xtensor/xview.hpp>
#include <xtensor/xsort.hpp>

#include "probabilist/evaluator.hpp"

namespace evalhyd
{
    namespace probabilist
    {
        // Determine observed realisation of threshold(s) exceedance.
        //
        // \require q_obs:
        //     Streamflow observations.
        //     shape: (time,)
        // \require q_thr:
        //     Streamflow exceedance threshold(s).
        //     shape: (thresholds,)
        // \assign o_k:
        //     Event observed outcome.
        //     shape: (thresholds, time)
        void Evaluator::calc_o_k()
        {
            // determine observed realisation of threshold(s) exceedance
            o_k = q_obs >= xt::view(q_thr, xt::all(), xt::newaxis());
        }

        // Determine mean observed realisation of threshold(s) exceedance.
        //
        // \require o_k:
        //     Event observed outcome.
        //     shape: (thresholds, time)
        // \require t_msk:
        //     Temporal subsets of the whole record.
        //     shape: (subsets, time)
        // \assign bar_o:
        //     Mean event observed outcome.
        //     shape: (subsets, samples, thresholds)
        void Evaluator::calc_bar_o()
        {
            // apply the mask
            // (using NaN workaround until reducers work on masked_view)
            auto o_k_masked = xt::where(
                    xt::view(t_msk, xt::all(), xt::newaxis(), xt::all()),
                    o_k, NAN
            );

            // compute variable one sample at a time
            bar_o = xt::zeros<double>({n_msk, n_exp, n_thr});

            for (int e = 0; e < n_exp; e++)
            {
                // apply the bootstrap sampling
                auto o_k_masked_sampled =
                        xt::view(o_k_masked, xt::all(), xt::all(), b_exp[e]);

                // compute mean "climatology" relative frequency of the event
                // $\bar{o} = \frac{1}{n} \sum_{k=1}^{n} o_k$
                xt::view(bar_o, xt::all(), e, xt::all()) =
                        xt::nanmean(o_k_masked_sampled, -1);
            }
        }

        // Determine forecast probability of threshold(s) exceedance to occur.
        //
        // \require q_prd:
        //     Streamflow predictions.
        //     shape: (members, time)
        // \require q_thr:
        //     Streamflow exceedance threshold(s).
        //     shape: (thresholds,)
        // \assign y_k:
        //     Event probability forecast.
        //     shape: (thresholds, time)
        void Evaluator::calc_y_k()
        {
            // determine if members have exceeded threshold(s)
            auto e_frc =
                    q_prd
                    >= xt::view(q_thr, xt::all(), xt::newaxis(), xt::newaxis());

            // calculate how many members have exceeded threshold(s)
            auto n_frc = xt::sum(e_frc, 1);

            // determine probability of threshold(s) exceedance
            // /!\ probability calculation dividing by n (the number of
            //     members), not n+1 (the number of ranks) like in other metrics
            y_k = xt::cast<double>(n_frc) / n_mbr;
        }

        // Compute the forecast quantiles from the ensemble members.
        //
        // \require q_prd:
        //     Streamflow predictions.
        //     shape: (members, time)
        // \assign q_qnt:
        //     Streamflow forecast quantiles.
        //     shape: (quantiles, time)
        void Evaluator::calc_q_qnt()
        {
            q_qnt = xt::sort(q_prd, 0);
        }
    }
}