from typing import Dict
import matplotlib.pyplot as plt

from extreme_data.meteo_france_data.scm_models_data.visualization.utils import create_adjusted_axes
from projects.exceeding_snow_loads.utils import dpi_paper1_figure
from extreme_trend.visualizers.study_visualizer_for_non_stationary_trends import StudyVisualizerForNonStationaryTrends


def permute(l, permutation):
    # permutation = [i//2  if i % 2 == 0 else 4 + i //2 for i in range(8)]
    return [l[i] for i in permutation]

def plot_selection_curves(altitude_to_visualizer: Dict[int, StudyVisualizerForNonStationaryTrends]):
    """
    Plot a single trend curves
    :return:
    """
    visualizer = list(altitude_to_visualizer.values())[0]

    ax = create_adjusted_axes(1, 1)

    selected_counter = merge_counter([v.selected_trend_test_class_counter for v in altitude_to_visualizer.values()])
    # selected_and_significative_counter = merge_counter([v.selected_and_significative_trend_test_class_counter for v in altitude_to_visualizer.values()])
    selected_and_significative_counter = merge_counter([v.selected_and_anderson_goodness_of_fit_trend_test_class_counter for v in altitude_to_visualizer.values()])
    # selected_and_significative_counter = merge_counter([v.selected_and_kstest_goodness_of_fit_trend_test_class_counter for v in altitude_to_visualizer.values()])
    total_of_selected_models = sum(selected_counter.values())
    l = sorted(enumerate(visualizer.non_stationary_trend_test), key=lambda e: selected_counter[e[1]])
    permutation = [i for i, v in l][::-1]

    select_list = get_ordered_list_from_counter(selected_counter, total_of_selected_models, visualizer, permutation)
    selected_and_signifcative_list = get_ordered_list_from_counter(selected_and_significative_counter, total_of_selected_models, visualizer, permutation)
    labels = permute(['${}$'.format(t.label) for t in visualizer.non_stationary_trend_test], permutation)

    print(l)
    print(sum(select_list), select_list)
    print(sum(selected_and_signifcative_list), selected_and_signifcative_list)
    # [(5, <    class 'data.trend_analysis.univariate_test.extreme_trend.trend_test_two_parameters.gev_trend_test_two_parameters.GevLocationAgainstGumbel'> ), (6, < class 'data.trend_analysis.univariate_test.extreme_trend.trend_test_two_parameters.gev_trend_test_two_parameters.GevScaleAgainstGumbel' > ), (2, < class 'data.trend_analysis.univariate_test.extreme_trend.trend_test_one_parameter.gumbel_trend_test_one_parameter.GumbelScaleTrendTest' > ), (1, < class 'data.trend_analysis.univariate_test.extreme_trend.trend_test_one_parameter.gumbel_trend_test_one_parameter.GumbelLocationTrendTest' > ), (7, < class 'data.trend_analysis.univariate_test.extreme_trend.trend_test_three_parameters.gev_trend_test_three_parameters.GevLocationAndScaleTrendTestAgainstGumbel' > ), (3, < class 'data.trend_analysis.univariate_test.extreme_trend.trend_test_two_parameters.gumbel_test_two_parameters.GumbelLocationAndScaleTrendTest' > ), (4, < class 'data.trend_analysis.univariate_test.extreme_trend.trend_test_one_parameter.gumbel_trend_test_one_parameter.GevStationaryVersusGumbel' > ), (0, < class 'data.trend_analysis.univariate_test.extreme_trend.trend_test_one_parameter.gumbel_trend_test_one_parameter.GumbelVersusGumbel' > )]
    # [32.64462809917355, 24.380165289256198, 12.396694214876034, 9.50413223140496, 9.090909090909092, 5.785123966942149, 3.71900826446281, 2.479338842975207]
    # [0, 13.223140495867769, 7.851239669421488, 8.264462809917354, 4.958677685950414, 2.479338842975207, 0.8264462809917356, 2.0661157024793386]

    # parameters
    width = 5
    size = 30
    legend_fontsize = 30
    labelsize = 25
    linewidth = 3
    x = [10 * (i+1) for i in range(len(select_list))]
    ax.bar(x, select_list, width=width, color='grey', edgecolor='grey', label='Non significant model',
           linewidth=linewidth)
    ax.bar(x, selected_and_signifcative_list, width=width, color='black', edgecolor='black',
           label='Significant model',
           linewidth=linewidth)
    ax.legend(loc='upper right', prop={'size': size})
    ax.set_ylabel(' Frequency of selected model w.r.t all time series \n '
                  'i.e. for all massifs and altitudes (\%)', fontsize=legend_fontsize)
    ax.set_xlabel('Models', fontsize=legend_fontsize)
    ax.tick_params(axis='both', which='major', labelsize=labelsize)
    ax.set_xticks(x)
    ax.yaxis.grid()
    ax.set_xticklabels(labels)

    # for ax_horizontal in [ax, ax_twiny]:
    #     if ax_horizontal == ax_twiny:
    #         ax_horizontal.plot(altitudes, [0 for _ in altitudes], linewidth=0)
    #     else:
    #         ax_horizontal.set_xlabel('Altitude', fontsize=legend_fontsize)
    #     ax_horizontal.set_xticks(altitudes)
    #     ax_horizontal.set_xlim([700, 5000])
    #     ax_horizontal.tick_params(labelsize=labelsize)
    #
    # # Set the number of massifs on the upper axis
    # ax_twiny.set_xticklabels([v.study.nb_study_massif_names for v in altitude_to_visualizer.values()])
    # ax_twiny.set_xlabel('Total number of massifs at each altitude (for the percentage)', fontsize=legend_fontsize)
    #
    # ax.set_ylabel('Massifs with decreasing trend (\%)', fontsize=legend_fontsize)
    # max_percent = int(max(percent_decrease))
    # n = 2 + (max_percent // 10)
    # ax_ticks = [10 * i for i in range(n)]
    # # upper_lim = max_percent + 3
    # upper_lim = n + 5
    # ax_lim = [0, upper_lim]
    # for axis in [ax, ax_twinx]:
    #     axis.set_ylim(ax_lim)
    #     axis.set_yticks(ax_ticks)
    #     axis.tick_params(labelsize=labelsize)
    # ax.yaxis.grid()
    #
    # label_curve = (visualizer.label).replace('change', 'decrease')
    # ax_twinx.set_ylabel(label_curve.replace('', ''), fontsize=legend_fontsize)
    # for region_name, mean_decrease in zip(AbstractExtendedStudy.region_names, mean_decreases):
    #     if len(mean_decreases) > 1:
    #         label = region_name
    #     else:
    #         label = 'Mean relative decrease'
    #     ax_twinx.plot(altitudes, mean_decrease, label=label, linewidth=linewidth, marker='o')
    #     ax_twinx.legend(loc='upper right', prop={'size': size})

    # Save plot
    visualizer.plot_name = 'Selection curves'
    visualizer.show_or_save_to_file(no_title=True, dpi=dpi_paper1_figure)
    plt.close()


def get_ordered_list_from_counter(selected_counter, total_of_selected_models, visualizer, permutation):
    return permute([100 * float(selected_counter[t]) / total_of_selected_models if t in selected_counter else 0
                for t in visualizer.non_stationary_trend_test], permutation)

def merge_counter(counters_list):
    global_counter = counters_list[0]
    for c in counters_list[1:]:
        global_counter += c
    return global_counter