Plot.py 9.15 KiB
# Plot.py -- Pamhyr
# Copyright (C) 2023-2024  INRAE
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

# -*- coding: utf-8 -*-

import logging

from functools import reduce
from datetime import datetime

from tools import timer
from View.Tools.PamhyrPlot import PamhyrPlot

from View.Results.CustomPlot.Translate import CustomPlotTranslate

logger = logging.getLogger()

unit = {
    "elevation": "0-meter",
    "water_elevation": "0-meter",
    "discharge": "1-m3s",
}


class CustomPlot(PamhyrPlot):
    def __init__(self, x, y, reach, profile, timestamp,
                 data=None, canvas=None, trad=None,
                 toolbar=None, parent=None):
        super(CustomPlot, self).__init__(
            canvas=canvas,
            trad=CustomPlotTranslate(),
            data=data,
            toolbar=toolbar,
            parent=parent
        )

        self._x = x
        self._y = y
        self._reach = reach
        self._profile = profile
        self._timestamp = timestamp

        logger.debug(
            "Create custom plot for: " +
            f"{x} -> {','.join(y)}: " +
            f"reach={reach}, profile={profile}, " +
            f"timestamp={timestamp}"
        )

        self._y_axes = sorted(
            set(
                map(
                    lambda y: unit[y],
                    self._y
                )
            )
        )

        self._axes = {}

    def _draw_kp(self):
        results = self.data
        reach = results.river.reach(self._reach)
        kp = reach.geometry.get_kp()
        z_min = reach.geometry.get_z_min()

        # self.canvas.axes.set_xlim(
        #     left=min(kp), right=max(kp)
        # )

        meter_axes = self.canvas.axes
        m3S_axes = self.canvas.axes
        if "0-meter" in self._y_axes and "1-m3s" in self._y_axes:
            m3s_axes = self._axes["1-m3s"]

        lines = {}
        if "elevation" in self._y:
            # meter_axes.set_ylim(
            #     bottom=min(0, min(z_min)),
            #     top=max(z_min) + 1
            # )

            line = meter_axes.plot(
                kp, z_min,
                color='grey', lw=1.,
            )
            lines["elevation"] = line

        if "water_elevation" in self._y:
            # Water elevation
            water_z = list(
                map(
                    lambda p: p.get_ts_key(self._timestamp, "Z"),
                    reach.profiles
                )
            )

            # meter_axes.set_ylim(
            #     bottom=min(0, min(z_min)),
            #     top=max(water_z) + 1
            # )

            line = meter_axes.plot(
                kp, water_z, lw=1.,
                color='blue',
            )
            lines["water_elevation"] = line

            if "elevation" in self._y:
                meter_axes.fill_between(
                    kp, z_min, water_z,
                    color='blue', alpha=0.5, interpolate=True
                )

        if "discharge" in self._y:
            q = list(
                map(
                    lambda p: p.get_ts_key(self._timestamp, "Q"),
                    reach.profiles
                )
            )

            # m3s_axes.set_ylim(
            #     bottom=min(0, min(q)),
            #     top=max(q) + 1
            # )

            line = m3s_axes.plot(
                kp, q, lw=1.,
                color='r',
            )
            lines["discharge"] = line

        # Legend
        lns = reduce(
            lambda acc, line: acc + line,
            map(lambda line: lines[line], lines),
            []
        )
        labs = list(map(lambda line: self._trad[line], lines))
        self.canvas.axes.legend(lns, labs, loc="lower left")

    def _customize_x_axes_time(self, ts, mode="time"):
        # Custom time display
        nb = len(ts)
        mod = int(nb / 5)
        mod = mod if mod > 0 else nb

        fx = list(
            map(
                lambda x: x[1],
                filter(
                    lambda x: x[0] % mod == 0,
                    enumerate(ts)
                )
            )
        )

        if mode == "time":
            t0 = datetime.fromtimestamp(0)
            xt = list(
                map(
                    lambda v: (
                        str(
                            datetime.fromtimestamp(v) - t0
                        ).split(",")[0]
                        .replace("days", self._trad["days"])
                        .replace("day", self._trad["day"])
                    ),
                    fx
                )
            )
        else:
            xt = list(
                map(
                    lambda v: str(datetime.fromtimestamp(v).date()),
                    fx
                )
            )

        self.canvas.axes.set_xticks(ticks=fx, labels=xt, rotation=45)

    def _draw_time(self):
        results = self.data
        reach = results.river.reach(self._reach)
        profile = reach.profile(self._profile)

        meter_axes = self.canvas.axes
        m3S_axes = self.canvas.axes
        if "0-meter" in self._y_axes and "1-m3s" in self._y_axes:
            m3s_axes = self._axes["1-m3s"]

        ts = list(results.get("timestamps"))
        ts.sort()

        # self.canvas.axes.set_xlim(
        #     left=min(ts), right=max(ts)
        # )

        x = ts
        lines = {}
        if "elevation" in self._y:
            # Z min is constant in time
            z_min = profile.geometry.z_min()
            ts_z_min = list(
                map(
                    lambda ts:  z_min,
                    ts
                )
            )

            line = meter_axes.plot(
                ts, ts_z_min,
                color='grey', lw=1.
            )
            lines["elevation"] = line

        if "water_elevation" in self._y:
            # Water elevation
            z = profile.get_key("Z")

            # meter_axes.set_ylim(
            #     bottom=min(0, min(z)),
            #     top=max(z) + 1
            # )

            line = meter_axes.plot(
                ts, z, lw=1.,
                color='b',
            )
            lines["water_elevation"] = line

            if "elevation" in self._y:
                z_min = profile.geometry.z_min()
                ts_z_min = list(
                    map(
                        lambda ts:  z_min,
                        ts
                    )
                )

                meter_axes.fill_between(
                    ts, ts_z_min, z,
                    color='blue', alpha=0.5, interpolate=True
                )

        if "discharge" in self._y:
            q = profile.get_key("Q")

            # m3s_axes.set_ylim(
            #     bottom=min(0, min(q)),
            #     top=max(q) + 1
            # )

            line = m3s_axes.plot(
                ts, q, lw=1.,
                color='r',
            )
            lines["discharge"] = line

        self._customize_x_axes_time(ts)

        # Legend
        lns = reduce(
            lambda acc, line: acc + line,
            map(lambda line: lines[line], lines),
            []
        )
        labs = list(map(lambda line: self._trad[line], lines))
        self.canvas.axes.legend(lns, labs, loc="lower left")

    @timer
    def draw(self):
        self.canvas.axes.cla()
        self.canvas.axes.grid(color='grey', linestyle='--', linewidth=0.5)

        if self.data is None:
            return

        self.canvas.axes.set_xlabel(
            self._trad[self._x],
            color='black', fontsize=10
        )

        self.canvas.axes.set_ylabel(
            self._trad[self._y_axes[0]],
            color='black', fontsize=10
        )

        for axes in self._y_axes[1:]:
            if axes in self._axes:
                self._axes[axes].clear()
                continue

            ax_new = self.canvas.axes.twinx()
            ax_new.set_ylabel(
                self._trad[axes],
                color='black', fontsize=10
            )
            self._axes[axes] = ax_new

        if self._x == "kp":
            self._draw_kp()
        elif self._x == "time":
            self._draw_time()

        self.canvas.figure.tight_layout()
        self.canvas.figure.canvas.draw_idle()
        if self.toolbar is not None:
            self.toolbar.update()

    @timer
    def update(self):
        if not self._init:
            self.draw()
            return

    def set_reach(self, reach_id):
        self._reach = reach_id
        self._profile = 0

        self.update()

    def set_profile(self, profile_id):
        self._profile = profile_id

        if self._x != "kp":
            self.update()

    def set_timestamp(self, timestamp):
        self._timestamp = timestamp

        if self._x != "time":
            self.update()