# PlotKPC.py -- Pamhyr
# Copyright (C) 2023-2024  INRAE
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

# -*- coding: utf-8 -*-

import logging

from functools import reduce

from tools import timer
from View.Tools.PamhyrPlot import PamhyrPlot

from PyQt5.QtCore import (
    QCoreApplication
)

logger = logging.getLogger()


class PlotKPC(PamhyrPlot):
    def __init__(self, canvas=None, trad=None, toolbar=None,
                 results=None, reach_id=0, profile_id=0,
                 parent=None):
        super(PlotKPC, self).__init__(
            canvas=canvas,
            trad=trad,
            data=results,
            toolbar=toolbar,
            parent=parent
        )

        self._timestamps = results.get("timestamps")
        self._current_timestamp = max(self._timestamps)
        self._current_reach_id = reach_id
        self._current_profile_id = profile_id

        self.label_x = self._trad["unit_kp"]
        self.label_y = self._trad["unit_elevation"]

        self.label_bottom = self._trad["label_bottom"]
        self.label_water = self._trad["label_water"]
        self.label_water_max = self._trad["label_water_max"]

        self._isometric_axis = False

    @property
    def results(self):
        return self.data

    @results.setter
    def results(self, results):
        self.data = results
        self._current_timestamp = max(results.get("timestamps"))

    @timer
    def draw(self, highlight=None):
        self.init_axes()

        if self.results is None:
            return

        reach = self.results.river.reach(self._current_reach_id)

        self.draw_bottom(reach)
        self.draw_water_elevation(reach)
        self.draw_water_elevation_max(reach)
        self.draw_water_elevation_overflow(reach)
        self.draw_current(reach)
        self.draw_profiles_hs(reach)

        # self.enable_legend()

        self.idle()
        self._init = True

    def draw_bottom(self, reach):
        if reach.has_sediment():
            self.draw_bottom_with_bedload(reach)
        else:
            self.draw_bottom_geometry(reach)

    def draw_bottom_with_bedload(self, reach):
        self._bedrock = self.sl_compute_bedrock(reach)

        kp = reach.geometry.get_kp()
        z = self.sl_compute_current_z(reach)

        self.line_bottom, = self.canvas.axes.plot(
            kp, z,
            linestyle="solid", lw=1.,
            color=self.color_plot_river_bottom,
        )

        self._river_bottom = z

    def draw_profiles_hs(self, reach):
        lhs = filter(
            lambda hs: hs._input_reach.reach is reach.geometry,
            self.results.study.river.hydraulic_structures.lst
        )

        for hs in lhs:
            x = hs.input_kp
            z_min = reach.geometry.get_z_min()
            z_max = reach.geometry.get_z_max()

            self.canvas.axes.plot(
                [x, x],
                [min(z_min), max(z_max)],
                linestyle="solid",
                lw=1.,
                color=self.color_plot_previous,
            )

            self.canvas.axes.annotate(
                " > " + hs.name,
                (x, max(z_max)),
                horizontalalignment='left',
                verticalalignment='top',
                annotation_clip=True,
                fontsize=9, color=self.color_plot_previous,
            )

    def sl_compute_bedrock(self, reach):
        z_min = reach.geometry.get_z_min()
        sl = self.sl_compute_initial(reach)

        z = list(
            map(
                lambda z, sl: reduce(
                    lambda z, h: z - h[0],
                    sl, z
                ),
                z_min,          # Original geometry
                sl              # Original sediment layers
            )
        )

        return z

    def sl_compute_current_z(self, reach):
        z_br = self._bedrock
        sl = self.sl_compute_current(reach)

        z = list(
            map(
                lambda z, sl: reduce(
                    lambda z, h: z + h[0],
                    sl, z
                ),
                z_br,           # Bedrock elevation
                sl              # Current sediment layers
            )
        )

        return z

    def sl_compute_initial(self, reach):
        """
        Get SL list for profile p at initial time (initial data)
        """
        return map(
            lambda p: p.get_ts_key(min(self._timestamps), "sl")[0],
            reach.profiles
        )

    def sl_compute_current(self, reach):
        """
        Get SL list for profile p at current time
        """
        return map(
            lambda p: p.get_ts_key(self._current_timestamp, "sl")[0],
            reach.profiles
        )

    def draw_bottom_geometry(self, reach):
        kp = reach.geometry.get_kp()
        z_min = reach.geometry.get_z_min()
        z_max = reach.geometry.get_z_max()

        self.line_kp_zmin = self.canvas.axes.plot(
            kp, z_min,
            color=self.color_plot_river_bottom,
            lw=1.
        )

        self._river_bottom = z_min

    def draw_water_elevation(self, reach):
        if len(reach.geometry.profiles) != 0:
            kp = reach.geometry.get_kp()
            z_min = reach.geometry.get_z_min()

            water_z = list(
                map(
                    lambda p: p.get_ts_key(
                        self._current_timestamp, "Z"
                    ),
                    reach.profiles
                )
            )

            self.water = self.canvas.axes.plot(
                kp, water_z,
                lw=1., color=self.color_plot_river_water,
            )

            self.water_fill = self.canvas.axes.fill_between(
                kp, self._river_bottom, water_z,
                color=self.color_plot_river_water_zone,
                alpha=0.7,
                interpolate=True
            )

    def draw_water_elevation_max(self, reach):
        if len(reach.geometry.profiles) != 0:
            kp = reach.geometry.get_kp()
            z_min = reach.geometry.get_z_min()

            water_z = list(
                map(
                    lambda p: max(p.get_key("Z")),
                    reach.profiles
                )
            )

            self.canvas.axes.plot(
                kp, water_z, lw=1.,
                color=self.color_plot_river_water,
                linestyle='dotted',
            )

    def draw_current(self, reach):
        kp = reach.geometry.get_kp()
        z_min = reach.geometry.get_z_min()
        z_max = reach.geometry.get_z_max()

        self.profile, = self.canvas.axes.plot(
            [
                kp[self._current_profile_id],
                kp[self._current_profile_id]
            ],
            [
                z_max[self._current_profile_id],
                z_min[self._current_profile_id]
            ],
            color=self.color_plot,
            lw=1.
        )

    def draw_water_elevation_overflow(self, reach):
        overflow = []

        for profile in reach.profiles:
            z_max = max(profile.get_key("Z"))
            z_max_ts = 0
            for ts in self._timestamps:
                z = profile.get_ts_key(ts, "Z")
                if z == z_max:
                    z_max_ts = ts
                    break

            pt_left, pt_right = profile.get_ts_key(z_max_ts, "water_limits")

            if self.is_overflow_point(profile, pt_left):
                overflow.append((profile, z_max))
            elif self.is_overflow_point(profile, pt_right):
                overflow.append((profile, z_max))

        for profile, z in overflow:
            self.canvas.axes.plot(
                profile.kp, z,
                lw=1.,
                color=self.color_plot,
                markersize=3,
                marker='x'
            )

    def is_overflow_point(self, profile, point):
        left_limit = profile.geometry.point(0)
        right_limit = profile.geometry.point(
            profile.geometry.number_points - 1
        )

        return (
            point == left_limit
            or point == right_limit
        )

    def set_reach(self, reach_id):
        self._current_reach_id = reach_id
        self._current_profile_id = 0
        self.draw()

    def set_profile(self, profile_id):
        self._current_profile_id = profile_id
        self.update_current()

    def set_timestamp(self, timestamp):
        self._current_timestamp = timestamp
        self.update()

    def update(self):
        if not self._init:
            self.draw()

        reach = self.results.river.reach(self._current_reach_id)
        if reach.has_sediment():
            self.update_bottom_with_bedload()

        self.update_water_elevation()

        self.update_idle()

    def update_water_elevation(self):
        reach = self.results.river.reach(self._current_reach_id)
        kp = reach.geometry.get_kp()
        z_min = reach.geometry.get_z_min()

        water_z = list(
            map(
                lambda p: p.get_ts_key(
                    self._current_timestamp, "Z"
                ),
                reach.profiles
            )
        )

        self.water[0].set_data(
            kp, water_z
        )

        self.water_fill.remove()
        self.water_fill = self.canvas.axes.fill_between(
            kp, self._river_bottom, water_z,
            color=self.color_plot_river_water_zone,
            alpha=0.7, interpolate=True
        )

    def update_current(self):
        reach = self.results.river.reach(self._current_reach_id)
        kp = reach.geometry.get_kp()
        z_min = reach.geometry.get_z_min()
        z_max = reach.geometry.get_z_max()
        cid = self._current_profile_id

        self.profile.set_data(
            [kp[cid], kp[cid]],
            [z_max[cid], z_min[cid]]
        )
        self.canvas.figure.canvas.draw_idle()

    def update_bottom_with_bedload(self):
        reach = self.results.river.reach(self._current_reach_id)
        kp = reach.geometry.get_kp()
        z = self.sl_compute_current_z(reach)

        self.line_bottom.remove()

        self.line_bottom, = self.canvas.axes.plot(
            kp, z,
            linestyle="solid", lw=1.,
            color=self.color_plot_river_bottom,
        )

        self._river_bottom = z