# Plot.py -- Pamhyr
# Copyright (C) 2023-2024  INRAE
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

# -*- coding: utf-8 -*-

import logging

from math import dist, sqrt

from tools import timer, trace
from View.Tools.PamhyrPlot import PamhyrPlot

from PyQt5.QtCore import (
    Qt, QCoreApplication, QItemSelectionModel,
    QItemSelection, QItemSelectionRange,
)

from PyQt5.QtWidgets import QApplication
from matplotlib.widgets import RectangleSelector

_translate = QCoreApplication.translate

logger = logging.getLogger()


class Plot(PamhyrPlot):
    def __init__(self, canvas=None, trad=None, data=None, toolbar=None,
                 table=None, parent=None):
        super(Plot, self).__init__(
            canvas=canvas,
            trad=trad,
            data=data,
            toolbar=toolbar,
            parent=parent
        )

        self._table = table
        self._parent = parent
        self._z_note = None
        self._z_line = None
        self._z_fill_between = None

        self.line_xy = []
        self.line_gl = []

        self.label_x = self._trad["transverse_abscissa"]
        self.label_y = self._trad["unit_elevation"]

        self.before_plot_selected = None
        self.plot_selected = None
        self.after_plot_selected = None

        self._isometric_axis = False

        self.hl_points = []
        self.highlight = None   # (z, wet_area, wet_preimeter, water_width)
        self._colors = []

        self._onpickevent = None
        self._rect_select = RectangleSelector(
            ax=self.canvas.axes,
            onselect=self.rect_select_callback,
            useblit=True,
            button=[1],  # don't use middle nor right button
            minspanx=2.0,
            minspany=2.0,
            spancoords='pixels',
            interactive=False
        )

    def onrelease(self, event):
        # we need to do that to prevent conflicst
        # between onpick and rect_select_callback
        modifiers = QApplication.keyboardModifiers()
        if self._onpickevent is not None:
            ind, point = self._closest_point(self._onpickevent)
            if modifiers == Qt.ControlModifier:
                rows = self._parent.index_selected_rows()
                if ind in rows:
                    rows.remove(ind)
                    self._select_in_table(rows)
                else:
                    self._select_in_table(rows+[ind])
            elif modifiers == Qt.ShiftModifier:
                rows = self._parent.index_selected_rows()
                if len(rows) > 0:
                    i1 = min(rows[0], rows[-1], ind)
                    i2 = max(rows[0], rows[-1], ind)
                else:
                    i1 = ind
                    i2 = ind
                self._select_range_in_table(i1, i2)
            else:
                self._select_in_table([ind])

        self._onpickevent = None

    def onpick(self, event):
        if event.mouseevent.inaxes != self.canvas.axes:
            return
        if event.mouseevent.button.value != 1:
            return

        modifiers = QApplication.keyboardModifiers()
        if modifiers not in [Qt.ControlModifier,
                             Qt.NoModifier,
                             Qt.ShiftModifier]:
            return

        self._onpickevent = event
        return

    def onclick(self, event):
        if event.inaxes != self.canvas.axes:
            return
        if event.button.value == 1:
            return

        z = self._get_z_from_click(event)
        if z < self.data.z_min() or event.button.value == 2:
            self.highlight = None
            self.draw_highligth()
            self.update_idle()
            return

        a, p, w = self._compute_hydraulics(z)

        logger.debug(f"{z, a, p, w}")

        self.highlight = (z, a, p, w)

        self.draw_highligth()
        self.update_idle()
        return

    def select_points_from_indices(self, indices):
        self.draw_highligth()
        # self.update()

    def _select_in_table(self, ind):
        if self._table is not None:
            self._table.blockSignals(True)
            self._table.setFocus()
            selection = self._table.selectionModel()
            index = QItemSelection()
            if len(ind) > 0:
                for i in ind:
                    index.append(QItemSelectionRange(
                        self._table.model().index(i, 0))
                    )
            selection.select(
                index,
                QItemSelectionModel.Rows |
                QItemSelectionModel.ClearAndSelect |
                QItemSelectionModel.Select
            )

            if len(ind) > 0:
                self._table.scrollTo(self._table.model().index(ind[-1], 0))
            self._table.blockSignals(False)

    def _select_range_in_table(self, ind1, ind2):
        if self._table is not None:
            self._table.blockSignals(True)
            self._table.setFocus()
            selection = self._table.selectionModel()
            index = QItemSelection(self._table.model().index(ind1, 0),
                                   self._table.model().index(ind2, 0))
            selection.select(
                index,
                QItemSelectionModel.Rows |
                QItemSelectionModel.ClearAndSelect |
                QItemSelectionModel.Select
            )
            self._table.scrollTo(self._table.model().index(ind2, 0))
            self._table.blockSignals(False)

    def _closest_point(self, event):
        points_ind = event.ind
        axes = self.canvas.axes
        bx, by = axes.get_xlim(), axes.get_ylim()
        ratio = (bx[0] - bx[1]) / (by[0] - by[1])

        x = event.artist.get_xdata()
        y = event.artist.get_ydata()
        points = enumerate(zip(x, y))

        mx = event.mouseevent.xdata
        my = event.mouseevent.ydata

        def dist_mouse(point):
            x, y = point[1]
            d2 = ((mx - x) / ratio) ** 2 + ((my - y) ** 2)
            return d2

        closest = min(
            points, key=dist_mouse
        )

        return closest

    def _get_z_from_click(self, event):
        return event.ydata

    def rect_select_callback(self, eclick, erelease):

        hyd = self.highlight
        x1, y1 = eclick.xdata, eclick.ydata
        x2, y2 = erelease.xdata, erelease.ydata

        if (max(abs(x1-x2), abs(y1-y2)) < 0.001):
            return
        modifiers = QApplication.keyboardModifiers()

        x1, y1 = eclick.xdata, eclick.ydata
        x2, y2 = erelease.xdata, erelease.ydata

        inds, points2 = self._points_in_rectangle(x1, y1, x2, y2)
        self._onclickevent = None
        if modifiers == Qt.ControlModifier:
            rows = self._parent.index_selected_rows()
            if all(i in rows for i in inds):
                for ind in sorted(inds, reverse=True):
                    rows.remove(ind)
                self.highlight = hyd
                self._select_in_table(rows)
            else:
                self.highlight = hyd
                self._select_in_table(rows+inds)
        else:
            self.highlight = hyd
            self._select_in_table(inds)
        return

    def _points_in_rectangle(self, x1, y1, x2, y2):
        # TODO: use lambdas
        listi = []
        listp = []
        station = self.data._get_station(self.data.points)
        for i, p in enumerate(self.data.points):
            if (min(x1, x2) < station[i] < max(x1, x2)
                    and min(y1, y2) < p.z < max(y1, y2)):
                listi.append(i)
                listp.append((station[i], p.z))
        return listi, listp

    def _compute_hydraulics(self, z):
        profile = self.data

        points = profile.wet_points(z)
        station = profile._get_station(points)
        width = abs(station[0] - station[-1])

        poly = profile.wet_polygon(z)
        area = poly.area
        perimeter = poly.length

        return area, perimeter, width

    @timer
    def draw(self):
        self.init_axes()

        x = self.data.get_station()
        y = self.data.z()
        x_carto = self.data.x()
        y_carto = self.data.y()

        self.profile_line2D, = self.canvas.axes.plot(
            x, y, color=self.color_plot,
            lw=1.5, markersize=7, marker='+',
            picker=10
        )

        self.draw_annotation()
        self.highlight = None
        self.draw_highligth()

        self.idle()

    def draw_annotation(self):
        gl = map(lambda p: p.name, self.data.points)
        x = self.data.get_station()
        y = self.data.z()

        # Add label on graph
        self.annotation = []
        for i, name in enumerate(list(gl)):
            annotation = self.canvas.axes.annotate(
                name, (x[i], y[i]),
                horizontalalignment='left',
                verticalalignment='top',
                annotation_clip=True,
                fontsize=10, color='black'
            )
            annotation.set_position((x[i], y[i]))
            annotation.set_color("black")
            self.annotation.append(annotation)

        al = 8.
        arrowprops = dict(
            clip_on=True,
            headwidth=5.,
            facecolor='k'
        )
        kwargs = dict(
            xycoords='axes fraction',
            textcoords='offset points',
            arrowprops=arrowprops,
        )

        self.canvas.axes.annotate("", (1, 0), xytext=(-al, 0), **kwargs)
        self.canvas.axes.annotate("", (0, 1), xytext=(0, -al), **kwargs)

        self.canvas.axes.spines[['top', 'right']].set_color('none')
        self.canvas.axes.yaxis.tick_left()
        self.canvas.axes.xaxis.tick_bottom()
        self.canvas.axes.set_facecolor('#F9F9F9')
        self.canvas.figure.patch.set_facecolor('white')

    def draw_highligth(self):
        hyd = self.highlight
        for p in self.hl_points:
            p[0].set_data([], [])

        self.hl_points = []
        x = self.data.get_station()
        y = self.data.z()

        for i in self._parent.index_selected_rows():
            self.hl_points.append(
                self.canvas.axes.plot(
                    x[i], y[i],
                    color=self.color_plot_highlight,
                    lw=1.5, markersize=7, marker='+',
                )
            )

        if hyd is not None:
            self.draw_highligth_z_line(*hyd)
        else:
            if self._z_note is not None:
                self._z_note.set_visible(False)
                self._z_line.set_visible(False)
                self._z_fill_between.set_visible(False)
        self.idle()

    def draw_highligth_z_line(self, z, a, p, w):
        text = (
            f"Z = {z:.3f} m, " +
            f"{self._trad['width']} = {w:.3f} m,\n" +
            f"{self._trad['area']} = {a:.3f} m², " +
            f"{self._trad['perimeter']} = {p:.3f} m"
        )

        x = self.data.get_station()
        xlim = (x[0], x[-1])
        ylim = self.canvas.axes.get_ylim()
        pos = (
            xlim[0] + (abs(xlim[0] - xlim[1]) * 0.05),
            z + + (abs(ylim[0] - ylim[1]) * 0.08)
        )
        y = self.data.z()

        if self._z_note is not None:
            self._z_note.remove()
        if self._z_line is not None:
            self._z_line.remove()
        self.draw_highligth_z_line_fill(x, y, z)

        self._z_line, = self.canvas.axes.plot(
            xlim, [z, z],
            color=self.color_plot_river_water
        )
        self._z_line.set_visible(True)
        self._z_fill_between.set_visible(True)

        self._z_note = self.canvas.axes.annotate(
            text, pos,
            horizontalalignment='left',
            verticalalignment='top',
            annotation_clip=True,
            color=self.color_plot_river_water,
            fontsize=9,
            fontweight='bold',
            alpha=0.7
        )
        self._z_note.set_visible(True)

    def draw_highligth_z_line_fill(self, x, y, z):
        if self._z_fill_between is not None:
            self._z_fill_between.remove()

        self._z_fill_between = self.canvas.axes.fill_between(
            x, y, z,
            where=y <= z,
            facecolor=self.color_plot_river_water_zone,
            interpolate=True, alpha=0.7
        )

    @timer
    def update(self):
        self.draw()

        self.update_idle()