import numpy as np
import matplotlib.cm as cm

import matplotlib.pyplot as plt
import seaborn as sns

from typing import Union, List

from extreme_estimator.estimator.full_estimator import AbstractFullEstimator
from extreme_estimator.estimator.margin_estimator import AbstractMarginEstimator
from extreme_estimator.extreme_models.margin_model.margin_function.combined_margin_function import \
    CombinedMarginFunction
from extreme_estimator.extreme_models.margin_model.margin_function.utils import error_dict_between_margin_functions
from extreme_estimator.gev_params import GevParams
from spatio_temporal_dataset.dataset.simulation_dataset import FullSimulatedDataset
from spatio_temporal_dataset.slicer.split import Split, ALL_SPLITS_EXCEPT_ALL


class SplitCurve(object):

    def __init__(self, nb_fit: int = 1):
        self.nb_fit = nb_fit
        self.margin_function_fitted_all = None

    def fit(self, show=True):
        self.margin_function_fitted_all = []

        for i in range(self.nb_fit):
            # A new dataset with the same margin, but just the observations are resampled
            self.dataset = self.load_dataset()
            # Both split must be defined
            assert not self.dataset.slicer.some_required_ind_are_not_defined
            self.margin_function_sample = self.dataset.margin_model.margin_function_sample

            print('Fitting {}/{}...'.format(i + 1, self.nb_fit))
            self.estimator = self.load_estimator(self.dataset)
            # Fit the estimator and get the margin_function
            self.estimator.fit()
            self.margin_function_fitted_all.append(self.estimator.margin_function_fitted)

        # Individual error dict
        self.error_dict_all = [error_dict_between_margin_functions(self.margin_function_sample, m)
                               for m in self.margin_function_fitted_all]

        # Mean margin
        self.mean_margin_function_fitted = CombinedMarginFunction.from_margin_functions(self.margin_function_fitted_all)
        self.mean_error_dict = error_dict_between_margin_functions(self.margin_function_sample,
                                                                   self.mean_margin_function_fitted)

        if show:
            self.visualize()

    def load_dataset(self):
        pass

    def load_estimator(self, dataset):
        pass

    @property
    def main_title(self):
        return self.dataset.slicer.summary(show=False)

    def visualize(self):
        fig, axes = plt.subplots(len(GevParams.GEV_VALUE_NAMES), 2)
        fig.subplots_adjust(hspace=0.4, wspace=0.4, )
        for i, gev_value_name in enumerate(GevParams.GEV_VALUE_NAMES):
            self.margin_graph(axes[i, 0], gev_value_name)
            self.score_graph(axes[i, 1], gev_value_name)
        fig.suptitle(self.main_title)
        plt.show()

    def margin_graph(self, ax, gev_value_name):
        # Create bins of data, each with an associated color corresponding to its error

        data = self.mean_error_dict[gev_value_name].values
        nb_bins = 10
        limits = np.linspace(data.min(), data.max(), num=nb_bins + 1)
        limits[-1] += 0.01
        colors = cm.binary(limits)

        # Display train/test points
        for split, marker in [(self.dataset.train_split, 'o'), (self.dataset.test_split, 'x')]:
            for left_limit, right_limit, color in zip(limits[:-1], limits[1:], colors):
                # Find for the split the index
                data_ind = self.mean_error_dict[gev_value_name].loc[
                    self.dataset.coordinates.coordinates_index(split)].values
                data_filter = np.logical_and(left_limit <= data_ind, data_ind < right_limit)

                self.margin_function_sample.set_datapoint_display_parameters(split, datapoint_marker=marker,
                                                                             filter=data_filter, color=color)
                self.margin_function_sample.visualize_single_param(gev_value_name, ax, show=False)

        # Display the individual fitted curve
        self.mean_margin_function_fitted.color = 'lightskyblue'
        for m in self.margin_function_fitted_all:
            m.visualize_single_param(gev_value_name, ax, show=False)
        # Display the mean fitted curve
        self.mean_margin_function_fitted.color = 'blue'
        self.mean_margin_function_fitted.visualize_single_param(gev_value_name, ax, show=False)

    def score_graph(self, ax, gev_value_name):
        # todo: for the moment only the train/test are interresting (the spatio temporal isn"t working yet)

        sns.set_style('whitegrid')
        s = self.mean_error_dict[gev_value_name]
        for split in self.dataset.splits:
            ind = self.dataset.coordinates_index(split)
            data = s.loc[ind].values
            sns.kdeplot(data, bw=0.5, ax=ax, label=split.name).set(xlim=0)
        ax.legend()
        # X axis
        ax.set_xlabel('Absolute error in percentage')
        plt.setp(ax.get_xticklabels(), visible=True)
        ax.xaxis.set_tick_params(labelbottom=True)
        # Y axis
        ax.set_ylabel(gev_value_name)
        plt.setp(ax.get_yticklabels(), visible=True)
        ax.yaxis.set_tick_params(labelbottom=True)