"""
QRame
Copyright (C) 2023  INRAE

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published
by the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU Affero General Public License for more details.

You should have received a copy of the GNU Affero General Public License
along with this program.  If not, see <https://www.gnu.org/licenses/>.
"""

import seaborn as sns
import numpy as np
from PyQt5 import QtCore
import matplotlib.patches as mpatches


class FigDischargeUncertainty(object):
    """Class to plot discharge and OURSIN uncertainty graph

        Attributes
        ----------
        canvas: MplCanvas
            Object of MplCanvas a FigureCanvas
        fig: Object
            Figure object of the canvas
        units: dict
            Dictionary of units from units_conversion
        _translate: QCoreApplication.translate object
            Save words which need to be translated
        hover_connection: int
            Index to data cursor connection
        annot: Annotation
            Annotation object for data cursor
        """

    def __init__(self, canvas, units):
        """Initialize object using the specified canvas.

        Parameters
        ----------
        canvas: MplCanvas
            Object of MplCanvas
        units: dict
            Dictionary of units from units_conversion
        """

        # Initialize attributes
        self.canvas = canvas
        self.fig = canvas.fig
        self.units = units
        self.hover_connection = None
        self.annot = None
        self._translate = QtCore.QCoreApplication.translate

    def create(self, mean_selected_meas, deviation, intercomparison_process, discharge_ref,
               column_grouped=None, column_sorted=None, ascending=True, selected_meas=None):
        """Create the axes and lines for the figure.

        Parameters
        ----------
        mean_selected_meas: pandas DataFrame
            Measurement results dataframe
        deviation: float
            Discharge maximum tolerated deviation from reference
        intercomparison_process: bool
            Specify if intercomparison computations are applied
        discharge_ref: float
            Reference discharge
        column_grouped: str
            Column name that could be used for grouped analysis
        column_sorted: str
            Name of the selected column to sort measurement
        ascending: bool
            Specify if current data are sorted ascending or descending
        selected_meas: str
            Name of the selected measurement
        """
        # Clear the plot
        self.fig.clear()

        # Configure axis
        self.fig.ax = self.fig.add_subplot(1, 1, 1)

        # Set margins and padding for figure
        # self.fig.subplots_adjust(left=0.03, bottom=0.25, right=0.99, top=0.99, wspace=0.1, hspace=0)
        self.fig.ax.xaxis.label.set_fontsize(12)
        self.fig.ax.yaxis.label.set_fontsize(12)

        palette = sns.color_palette("rocket")
        if intercomparison_process:
            label_txt = [self._translate("Main", 'Reference discharge') + " ±" + str(deviation) + "%",
                         self._translate("Main", 'Discharge and\nuncertainty'),
                         self._translate("Main", 'Empirical uncertainty\nand its uncertainty')]
            if deviation == 0:
                label_txt[0] = self._translate("Main", 'Reference discharge')

        else:
            label_txt = [self._translate("Main", 'Reference discharge'),
                         self._translate("Main", 'Discharge and\nuncertainty')]
        label_y = self._translate("Main", 'Discharge') + ' (' + self.units['label_Q'] + ')'
        label_ref = self._translate("Main", 'Medians')

        if column_sorted is not None and column_grouped is None:
            mean_selected_meas = mean_selected_meas.sort_values(column_sorted, ascending=ascending)
        elif column_grouped is not None and column_sorted is None:
            mean_selected_meas = mean_selected_meas.sort_values([column_grouped, 'meas_name'],
                                                                ascending=[True, ascending])
            labels_2 = np.unique(mean_selected_meas[column_grouped])
        elif column_grouped and column_sorted:
            mean_selected_meas = mean_selected_meas.sort_values([column_grouped, column_sorted],
                                                                ascending=[True, ascending])
            labels_2 = np.unique(mean_selected_meas[column_grouped])

        q_meas = mean_selected_meas['tr_q_total'] * self.units['Q']
        oursin_meas = mean_selected_meas['meas_oursin_95']
        oursin_median = np.nanmax([0, oursin_meas.median()])
        oursin_meas = oursin_meas.fillna(0)

        x_ticks_label = [l for l in mean_selected_meas.index]
        len_meas = len(x_ticks_label)

        errorbar_color = ['darkorange'] * len(q_meas)
        scatter_color = ['red'] * len(q_meas)
        if selected_meas and selected_meas in x_ticks_label:
            selected_idx = x_ticks_label.index(selected_meas)
            errorbar_color[selected_idx] = 'seagreen'
            scatter_color[selected_idx] = 'darkgreen'

        x_ticks_label.insert(0, label_ref)

        # Empirical
        U_Q = None
        Q_meas = None
        if intercomparison_process:
            if column_grouped is None:
                if ~np.isnan(np.nanmean(mean_selected_meas['U_Q_n'])):
                    U_Q = mean_selected_meas['U_Q_n'][0]
                    U_Q_min = mean_selected_meas['U_Q_n_min'][0]
                    U_Q_max = mean_selected_meas['U_Q_n_max'][0]
                    Q_mean = U_Q * discharge_ref * 0.01
                    Q_min = U_Q_min * discharge_ref * 0.01
                    Q_max = U_Q_max * discharge_ref * 0.01

                if U_Q is not None:
                    if np.logical_not(np.isnan(U_Q_min)):
                        # High empirical uncertainty fill
                        self.fig.ax.fill_between([-2, len_meas + 1], [discharge_ref + Q_min, discharge_ref + Q_min],
                                                 [discharge_ref + Q_max, discharge_ref + Q_max],
                                                 color='grey', alpha=0.3)

                        # Low empirical uncertainty fill
                        self.fig.ax.fill_between([-2, len_meas + 1], [discharge_ref - Q_min, discharge_ref - Q_min],
                                                 [discharge_ref - Q_max, discharge_ref - Q_max],
                                                 color='grey', alpha=0.3)
                    else:
                        label_txt[-1] = self._translate("Main", 'Empirical uncertainty')
                    # Mean Empirical uncertainty dashed line
                    self.fig.ax.plot([-2, len_meas + 1], [discharge_ref + Q_mean, discharge_ref + Q_mean],
                                     color='k', linestyle='-.')
                    p1 = self.fig.ax.plot([-2, len_meas + 1], [discharge_ref - Q_mean, discharge_ref - Q_mean],
                                          color='k', linestyle='-.')
                    p2 = self.fig.ax.fill(np.NaN, np.NaN, color='grey', alpha=0.3)
            else:
                mean_selected_meas[column_grouped] = mean_selected_meas[column_grouped].fillna(
                    self._translate("Main", 'Unknown'))

                grouped_meas = mean_selected_meas.groupby([column_grouped]).mean()
                Q_meas = list(grouped_meas['meas_mean_q'])
                U_Q = list(grouped_meas['U_Q_n'])
                U_Q_min = list(grouped_meas['U_Q_n_min'])
                U_Q_max = list(grouped_meas['U_Q_n_max'])
                Q_mean = list([a * b * 0.01 for a, b in zip(U_Q, Q_meas)])
                Q_min = list([a * b * 0.01 for a, b in zip(U_Q_min, Q_meas)])
                Q_max = list([a * b * 0.01 for a, b in zip(U_Q_max, Q_meas)])

                # y_labels_2 = 0.8 * np.nanmax([discharge_ref * 1.1, discharge_ref + np.nanmax(Q_max), np.nanmax(
                #     q_meas * (1 + oursin_meas / 100))])
                y_labels_2 = np.nanmin(q_meas * (1 - oursin_meas / 100))

                # Divide figure by grouped column
                grouped_idx = [-0.5] + [idx - 0.5 for idx, (i, j) in enumerate(zip(
                    mean_selected_meas[column_grouped], mean_selected_meas[column_grouped][1:]), 1) if i != j] + [
                                  len_meas - 0.5]

                for j in range(len(grouped_idx[:-1])):
                    # self.fig.ax.plot([grouped_idx[j], grouped_idx[j+1]],
                    #                   [Q_meas[j] * (1 - deviation / 100), Q_meas[j] * (1 - deviation / 100)],
                    #                     linestyle=':', color=palette[2])
                    # self.fig.ax.plot([grouped_idx[j], grouped_idx[j + 1]],
                    #                  [Q_meas[j] * (1 + deviation / 100), Q_meas[j] * (1 + deviation / 100)],
                    #                  linestyle=':', color=palette[2])
                    # self.fig.ax.plot([grouped_idx[j], grouped_idx[j + 1]],
                    #                  [Q_meas[j] * self.units['Q'], Q_meas[j] * self.units['Q']],
                    #                  linestyle='--', color=palette[2])

                    # # Mean Empirical uncertainty dashed line
                    p1 = self.fig.ax.plot([grouped_idx[j], grouped_idx[j+1]],
                                          [Q_meas[j] + Q_mean[j], Q_meas[j] + Q_mean[j]],
                                          color='k', linestyle='-.')
                    self.fig.ax.plot([grouped_idx[j], grouped_idx[j+1]], [Q_meas[j] - Q_mean[j], Q_meas[j] - Q_mean[j]],
                                     color='k', linestyle='-.')
                    p2 = self.fig.ax.fill(np.NaN, np.NaN, color='grey', alpha=0.3)

                    # High empirical uncertainty fill
                    self.fig.ax.fill_between([grouped_idx[j], grouped_idx[j+1]],
                                             [Q_meas[j] + Q_min[j], Q_meas[j] + Q_min[j]],
                                             [Q_meas[j] + Q_max[j], Q_meas[j] + Q_max[j]],
                                             color='grey', alpha=0.3)

                    # Low empirical uncertainty fill
                    self.fig.ax.fill_between([grouped_idx[j], grouped_idx[j+1]],
                                             [Q_meas[j] - Q_min[j], Q_meas[j] - Q_min[j]],
                                             [Q_meas[j] - Q_max[j], Q_meas[j] - Q_max[j]],
                                             color='grey', alpha=0.3)

                    # Separator groups
                    self.fig.ax.axvline(x=grouped_idx[j], color='k')
                    # Label groups
                    self.fig.ax.text(x=(grouped_idx[j] + grouped_idx[j + 1]) / 2,
                                     y=y_labels_2, s=labels_2[j], fontsize=11, ha="center")

            # Tolerated deviation from reference
            self.fig.ax.axhline(y=discharge_ref * self.units['Q'] * (1 - deviation / 100),
                                linestyle=':', color=palette[2])
            self.fig.ax.axhline(y=discharge_ref * self.units['Q'] * (1 + deviation / 100),
                                linestyle=':', color=palette[2])

        # Reference discharge
        self.fig.ax.axhline(y=discharge_ref * self.units['Q'],
                            linestyle='--', color=palette[2])
        p4 = self.fig.ax.plot([np.nan, np.nan], [np.nan, np.nan],
                              color=palette[2], linestyle='--')

        # Reference Q and uncertainty
        if not np.isnan(discharge_ref):
            self.fig.ax.errorbar(-1, discharge_ref, yerr=[discharge_ref * oursin_median / 100],
                                 fmt='o', color='darkgreen', ecolor='seagreen', elinewidth=4, markersize=6)

        # Plot an error bar for each measurement
        for i in range(len_meas):
            if not np.isinf(oursin_meas[i]):
                self.fig.ax.errorbar(i, q_meas[i], yerr=[np.abs(q_meas[i] * oursin_meas[i] / 100)],
                                     fmt='o', color=scatter_color[i], ecolor=errorbar_color[i], elinewidth=4, markersize=6)
            else:
                self.fig.ax.scatter(i, q_meas[i], marker='o', color=scatter_color[i], s=14, zorder=1)
        p3 = self.fig.ax.errorbar([np.nan], [np.nan], yerr=[0], fmt='o', color='red', ecolor='orange',
                                  elinewidth=4, markersize=6)

        # Legend
        if intercomparison_process and U_Q:
            if isinstance(U_Q_min, float) or any(np.logical_not(np.isnan(U_Q_min))):
                self.fig.ax.legend([p4[0], p3, (p2[0], p1[0])], label_txt, fontsize=11, ncol=3, loc='upper center',
                                   fancybox=True, shadow=True, bbox_to_anchor=(0.5, 1), )
            elif ~np.isnan(U_Q).any():
                self.fig.ax.legend([p4[0], p3, p1[0]], label_txt, fontsize=11, ncol=3, loc='upper center',
                                   fancybox=True, shadow=True, bbox_to_anchor=(0.5, 1), )
            else:
                self.fig.ax.legend([p4[0], p3], label_txt, fontsize=11, ncol=2, loc='upper center',
                                   fancybox=True, shadow=True, bbox_to_anchor=(0.5, 1), )

            if Q_meas:
                self.fig.ax.set_ylim(np.nanmin([discharge_ref * 0.9, np.nanmin([b - a for a, b in zip(Q_max, Q_meas)]),
                                                np.nanmin(q_meas * (1 - oursin_meas / 100))]) * 0.99,
                                     np.nanmax([discharge_ref * 1.1, discharge_ref + np.nanmax(Q_max), np.nanmax(
                                         q_meas * (1 + oursin_meas / 100))]) * 1.01)
            else:
                self.fig.ax.set_ylim(np.nanmin([discharge_ref * 0.9, discharge_ref - Q_max,
                                                np.nanmin(q_meas * (1 - oursin_meas / 100))]) * 0.99,
                                     np.nanmax([discharge_ref * 1.1, discharge_ref + Q_max, np.nanmax(
                                         q_meas * (1 + oursin_meas / 100))]) * 1.01)

        else:
            self.fig.ax.legend([p4[0], p3], label_txt, fontsize=11, ncol=3, loc='upper center',
                               fancybox=True, shadow=True, bbox_to_anchor=(0.5, 1), )
            if len(q_meas) > 0:
                self.fig.ax.set_ylim(np.nanmin(q_meas * (1 - oursin_meas / 100)) * 0.99,
                                     np.nanmax(q_meas * (1 + oursin_meas / 100)) * 1.01)

        self.fig.ax.set_ylabel(label_y, fontsize=11)
        self.fig.ax.set_xticks(np.arange(-1, len_meas))

        if sum(len(s) for s in x_ticks_label[1:]) < 80:
            self.fig.ax.set_xticklabels(x_ticks_label, rotation=0, fontsize=10)
        else:
            self.fig.ax.set_xticklabels(x_ticks_label, rotation=80, fontsize=10, ha='right')

        self.fig.ax.tick_params(axis='y', labelsize=10)
        self.fig.ax.set_xlim(-1.5, len_meas - 0.5)
        self.fig.ax.grid(linestyle='--')
        # self.fig.subplots_adjust(left=0.03, right=0.99, top=0.99, wspace=0.1, hspace=0)
        # self.fig.subplots_adjust(
        #     left=0.03,
        #     # bottom=0.25,
        #     right=0.99,
        #     top=0.99,
        #     # wspace=0.02,
        #     # hspace=0.08,
        # )

        # # Initialize annotation for data cursor
        # self.annot = self.fig.ax.annotate("", xy=(0, 0), xytext=(-20, 20), textcoords="offset points",
        #                                   bbox=dict(boxstyle="round", fc="w"),
        #                                   arrowprops=dict(arrowstyle="->"))
        #
        # self.annot.set_visible(False)

        # self.canvas.draw()

    def hover(self, event):
        """Determines if the user has selected a location with temperature data and makes
        annotation visible and calls method to update the text of the annotation. If the
        location is not valid the existing annotation is hidden.

        Parameters
        ----------
        event: MouseEvent
            Triggered when mouse button is pressed.
        """

        # Set annotation to visible
        vis = self.annot.get_visible()

        # Determine if mouse location references a data point in the plot and update the annotation.
        if event.inaxes == self.fig.ax and event.button != 3:
            cont = False
            ind = None
            plotted_line = None

            # Find the transect(line) that contains the mouse click
            for plotted_line in self.fig.ax.lines:
                cont, ind = plotted_line.contains(event)
                if cont:
                    break
            if cont:
                self.update_annot(ind, plotted_line)
                self.annot.set_visible(True)
                self.canvas.draw_idle()
            else:
                # If the cursor location is not associated with the plotted data hide the annotation.
                if vis:
                    self.annot.set_visible(False)
                    self.canvas.draw_idle()

    def update_annot(self, ind, plt_ref):
        """Updates the location and text and makes visible the previously initialized and hidden annotation.

        Parameters
        ----------
        ind: dict
            Contains data selected.
        plt_ref: Line2D
            Reference containing plotted data
        vector_ref: Quiver
            Refernece containing plotted data
        ref_label: str
            Label used to ID data type in annotation
        """

        pos = plt_ref._xy[ind["ind"][0]]

        # Shift annotation box left or right depending on which half of the axis the pos x is located and the
        # direction of x increasing.
        if plt_ref.axes.viewLim.intervalx[0] < plt_ref.axes.viewLim.intervalx[1]:
            if pos[0] < (plt_ref.axes.viewLim.intervalx[0] + plt_ref.axes.viewLim.intervalx[1]) / 2:
                self.annot._x = -20
            else:
                self.annot._x = -80
        else:
            if pos[0] < (plt_ref.axes.viewLim.intervalx[0] + plt_ref.axes.viewLim.intervalx[1]) / 2:
                self.annot._x = -80
            else:
                self.annot._x = -20

        # Shift annotation box up or down depending on which half of the axis the pos y is located and the
        # direction of y increasing.
        if plt_ref.axes.viewLim.intervaly[0] < plt_ref.axes.viewLim.intervaly[1]:
            if pos[1] > (plt_ref.axes.viewLim.intervaly[0] + plt_ref.axes.viewLim.intervaly[1]) / 2:
                self.annot._y = -40
            else:
                self.annot._y = 20
        else:
            if pos[1] > (plt_ref.axes.viewLim.intervaly[0] + plt_ref.axes.viewLim.intervaly[1]) / 2:
                self.annot._y = 20
            else:
                self.annot._y = -40
        self.annot.xy = pos

        text = 'x: {:.2f}, y: {:.2f}'.format(pos[0], pos[1])
        self.annot.set_text(text)

    def set_hover_connection(self, setting):
        """Turns the connection to the mouse event on or off.

        Parameters
        ----------
        setting: bool
            Boolean to specify whether the connection for the mouse event is active or not.
        """

        if setting and self.hover_connection is None:
            self.hover_connection = self.canvas.mpl_connect('button_press_event', self.hover)
        elif not setting:
            self.canvas.mpl_disconnect(self.hover_connection)
            self.hover_connection = None
            self.annot.set_visible(False)
            self.canvas.draw_idle()