From 021971816a1410a1e1e8fb4640ee2d7283e0b0b7 Mon Sep 17 00:00:00 2001
From: Pierre-Antoine Rouby <pierre-antoine.rouby@inrae.fr>
Date: Fri, 20 Oct 2023 15:09:24 +0200
Subject: [PATCH] Results: CustomPlot: Implement plot with time on y axes.

---
 src/View/Results/CustomPlot/Plot.py | 243 +++++++++++++++++++---------
 src/View/Results/Window.py          |  17 +-
 2 files changed, 177 insertions(+), 83 deletions(-)

diff --git a/src/View/Results/CustomPlot/Plot.py b/src/View/Results/CustomPlot/Plot.py
index 6bfe6aba..a539112a 100644
--- a/src/View/Results/CustomPlot/Plot.py
+++ b/src/View/Results/CustomPlot/Plot.py
@@ -68,109 +68,184 @@ class CustomPlot(PamhyrPlot):
             )
         )
 
-    @timer
-    def draw(self):
-        self.canvas.axes.cla()
-        self.canvas.axes.grid(color='grey', linestyle='--', linewidth=0.5)
+        self._axes = {}
 
-        if self.data is None:
-            return
+    def _draw_kp(self):
+        results = self.data
+        reach = results.river.reach(self._reach)
+        kp = reach.geometry.get_kp()
+        z_min = reach.geometry.get_z_min()
 
-        self.canvas.axes.set_xlabel(
-            self._trad[self._x],
-            color='green', fontsize=12
+        self.canvas.axes.set_xlim(
+            left=min(kp), right=max(kp)
         )
 
-        self.canvas.axes.set_ylabel(
-            self._trad[self._y_axes[0]],
-            color='green', fontsize=12
-        )
+        meter_axes = self.canvas.axes
+        m3S_axes = self.canvas.axes
+        if "0-meter" in self._y_axes and "1-m3s" in self._y_axes:
+            m3s_axes = self._axes["1-m3s"]
 
-        self._axes = {}
-        for axes in self._y_axes[1:]:
-            ax_new = self.canvas.axes.twinx()
-            ax_new.set_ylabel(
-                self._trad[axes],
-                color='green', fontsize=12
+        if "elevation" in self._y:
+            meter_axes.set_ylim(
+                bottom=min(0, min(z_min)),
+                top=max(z_min) + 1
             )
-            self._axes[axes] = ax_new
 
-        if self._x == "kp":
-            results = self.data
-            reach = results.river.reach(self._reach)
-            kp = reach.geometry.get_kp()
-            z_min = reach.geometry.get_z_min()
+            meter_axes.plot(
+                kp, z_min,
+                color='grey', lw=1.
+            )
 
-            self.canvas.axes.set_xlim(
-                left=min(kp), right=max(kp)
+        if "water_elevation" in self._y:
+            # Water elevation
+            water_z = list(
+                map(
+                    lambda p: p.get_ts_key(self._timestamp, "Z"),
+                    reach.profiles
+                )
+            )
+
+            meter_axes.set_ylim(
+                bottom=min(0, min(z_min)),
+                top=max(water_z) + 1
             )
 
-            meter_axes = self.canvas.axes
-            m3S_axes = self.canvas.axes
-            if "0-meter" in self._y_axes and "1-m3s" in self._y_axes:
-                m3s_axes = self._axes["1-m3s"]
+            meter_axes.plot(
+                kp, water_z, lw=1.,
+                color='b',
+            )
 
             if "elevation" in self._y:
-                meter_axes.set_ylim(
-                    bottom=min(0, min(z_min)),
-                    top=max(z_min) + 1
+                meter_axes.fill_between(
+                    kp, z_min, water_z,
+                    color='blue', alpha=0.5, interpolate=True
                 )
 
-                meter_axes.plot(
-                    kp, z_min,
-                    color='grey', lw=1.
+        if "discharge" in self._y:
+            q = list(
+                map(
+                    lambda p: p.get_ts_key(self._timestamp, "Q"),
+                    reach.profiles
                 )
+            )
 
-            if "water_elevation" in self._y:
-                # Water elevation
-                water_z = list(
-                    map(
-                        lambda p: p.get_ts_key(self._timestamp, "Z"),
-                        reach.profiles
-                    )
-                )
+            m3s_axes.set_ylim(
+                bottom=min(0, min(q)),
+                top=max(q) + 1
+            )
 
-                meter_axes.set_ylim(
-                    bottom=min(0, min(z_min)),
-                    top=max(water_z) + 1
-                )
+            m3s_axes.plot(
+                kp, q, lw=1.,
+                color='r',
+            )
+
+    def _draw_time(self):
+        results = self.data
+        reach = results.river.reach(self._reach)
+        profile = reach.profile(self._profile)
+
+        meter_axes = self.canvas.axes
+        m3S_axes = self.canvas.axes
+        if "0-meter" in self._y_axes and "1-m3s" in self._y_axes:
+            m3s_axes = self._axes["1-m3s"]
+
+        ts = list(results.get("timestamps"))
+        ts.sort()
+
+        self.canvas.axes.set_xlim(
+            left=min(ts), right=max(ts)
+        )
 
-                meter_axes.plot(
-                    kp, water_z, lw=1.,
-                    color='b',
+        x = ts
+        if "elevation" in self._y:
+            # Z min is constant in time
+            z_min = profile.geometry.z_min()
+            ts_z_min = list(
+                map(
+                    lambda ts:  z_min,
+                    ts
                 )
+            )
 
-                if "elevation" in self._y:
-                    meter_axes.fill_between(
-                        kp, z_min, water_z,
-                        color='blue', alpha=0.5, interpolate=True
-                    )
+            meter_axes.plot(
+                ts, ts_z_min,
+                color='grey', lw=1.
+            )
+
+        if "water_elevation" in self._y:
+            # Water elevation
+            z = profile.get_key("Z")
 
-            if "discharge" in self._y:
-                q = list(
+            meter_axes.set_ylim(
+                bottom=min(0, min(z)),
+                top=max(z) + 1
+            )
+
+            meter_axes.plot(
+                ts, z, lw=1.,
+                color='b',
+            )
+
+            if "elevation" in self._y:
+                z_min = profile.geometry.z_min()
+                ts_z_min = list(
                     map(
-                        lambda p: p.get_ts_key(self._timestamp, "Q"),
-                        reach.profiles
+                        lambda ts:  z_min,
+                        ts
                     )
                 )
 
-                m3s_axes.set_ylim(
-                    bottom=min(0, min(q)),
-                    top=max(q) + 1
+                meter_axes.fill_between(
+                    ts, ts_z_min, z,
+                    color='blue', alpha=0.5, interpolate=True
                 )
 
-                m3s_axes.plot(
-                    kp, q, lw=1.,
-                    color='r',
-                )
+        if "discharge" in self._y:
+            q = profile.get_key("Q")
+
+            m3s_axes.set_ylim(
+                bottom=min(0, min(q)),
+                top=max(q) + 1
+            )
+
+            m3s_axes.plot(
+                ts, q, lw=1.,
+                color='r',
+            )
+
+    @timer
+    def draw(self):
+        self.canvas.axes.cla()
+        self.canvas.axes.grid(color='grey', linestyle='--', linewidth=0.5)
+
+        if self.data is None:
+            return
+
+        self.canvas.axes.set_xlabel(
+            self._trad[self._x],
+            color='green', fontsize=12
+        )
 
+        self.canvas.axes.set_ylabel(
+            self._trad[self._y_axes[0]],
+            color='green', fontsize=12
+        )
+
+        for axes in self._y_axes[1:]:
+            if axes in self._axes:
+                continue
+
+            ax_new = self.canvas.axes.twinx()
+            ax_new.set_ylabel(
+                self._trad[axes],
+                color='green', fontsize=12
+            )
+            self._axes[axes] = ax_new
+
+        if self._x == "kp":
+            self._draw_kp()
         elif self._x == "time":
-            if "elevation" in self._y:
-                logging.info("TODO: time/elevation")
-            if "water_elevation" in self._y:
-                logging.info("TODO: time/water_elevation")
-            if "discharge" in self._y:
-                logging.info("TODO: time/discharge")
+            self._draw_time()
 
         self.canvas.figure.tight_layout()
         self.canvas.figure.canvas.draw_idle()
@@ -178,7 +253,25 @@ class CustomPlot(PamhyrPlot):
             self.toolbar.update()
 
     @timer
-    def update(self, reach, profile, timestamp):
+    def update(self):
         if not self._init:
             self.draw()
             return
+
+    def set_reach(self, reach_id):
+        self._reach = reach_id
+        self._profile = 0
+
+        self.update()
+
+    def set_profile(self, profile_id):
+        self._profile = profile_id
+
+        if self._x != "kp":
+            self.update()
+
+    def set_timestamp(self, timestamp):
+        self._timestamp = timestamp
+
+        if self._x != "time":
+            self.update()
diff --git a/src/View/Results/Window.py b/src/View/Results/Window.py
index ca53d59d..eb6f7aaa 100644
--- a/src/View/Results/Window.py
+++ b/src/View/Results/Window.py
@@ -345,6 +345,9 @@ class ResultsWindow(PamhyrWindow):
                 self.plot_sed_reach.set_reach(reach_id)
                 self.plot_sed_profile.set_reach(reach_id)
 
+            for plot in self._additional_plot:
+                self._additional_plot[plot].set_reach(reach_id)
+
             self.update_table_selection_reach(reach_id)
             self.update_table_selection_profile(0)
 
@@ -358,7 +361,11 @@ class ResultsWindow(PamhyrWindow):
                 self.plot_sed_reach.set_profile(profile_id)
                 self.plot_sed_profile.set_profile(profile_id)
 
+            for plot in self._additional_plot:
+                self._additional_plot[plot].set_profile(profile_id)
+
             self.update_table_selection_profile(profile_id)
+
         if timestamp is not None:
             self.plot_xy.set_timestamp(timestamp)
             self.plot_ac.set_timestamp(timestamp)
@@ -369,14 +376,8 @@ class ResultsWindow(PamhyrWindow):
                 self.plot_sed_reach.set_timestamp(timestamp)
                 self.plot_sed_profile.set_timestamp(timestamp)
 
-        self.plot_xy.draw()
-        self.plot_ac.draw()
-        self.plot_kpc.draw()
-        self.plot_h.draw()
-
-        if self._study.river.has_sediment():
-            self.plot_sed_reach.draw()
-            self.plot_sed_profile.draw()
+            for plot in self._additional_plot:
+                self._additional_plot[plot].set_timestamp(timestamp)
 
         self.update_statusbar()
 
-- 
GitLab