From 5edcf1507266127e70c8336bb2ec305e502acd7b Mon Sep 17 00:00:00 2001
From: Theophile Terraz <theophile.terraz@inrae.fr>
Date: Fri, 20 Dec 2024 10:05:56 +0100
Subject: [PATCH] Export Adis to csv

---
 src/View/Results/CustomExportAdis.py  | 129 ++++++++++++++++
 src/View/Results/Window.py            |   8 +-
 src/View/Results/WindowAdisTS.py      | 211 +++++++++++++++++++-------
 src/View/ui/CustomExportAdisDialog.ui | 104 +++++++++++++
 4 files changed, 390 insertions(+), 62 deletions(-)
 create mode 100644 src/View/Results/CustomExportAdis.py
 create mode 100644 src/View/ui/CustomExportAdisDialog.ui

diff --git a/src/View/Results/CustomExportAdis.py b/src/View/Results/CustomExportAdis.py
new file mode 100644
index 00000000..d8b845f3
--- /dev/null
+++ b/src/View/Results/CustomExportAdis.py
@@ -0,0 +1,129 @@
+# CustomPlotValuesSelectionDialog.py -- Pamhyr
+# Copyright (C) 2023-2024  INRAE
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <https://www.gnu.org/licenses/>.
+
+# -*- coding: utf-8 -*-
+
+from View.Tools.PamhyrWindow import PamhyrDialog
+
+from PyQt5.QtWidgets import (
+    QRadioButton, QCheckBox, QVBoxLayout,
+)
+
+from View.Results.translate import ResultsTranslate
+
+
+class CustomExportAdisDialog(PamhyrDialog):
+    _pamhyr_ui = "CustomExportAdisDialog"
+    _pamhyr_name = "Custom Plot Selection"
+
+    def __init__(self, pollutants, parent=None):
+        trad = ResultsTranslate()
+        super(CustomExportAdisDialog, self).__init__(
+            title=trad[self._pamhyr_name],
+            options=[],
+            trad=trad,
+            parent=parent
+        )
+
+        if pollutants is not None:
+            self.pollutants = pollutants
+            if "total_sediment" in self.pollutants:
+                self.pollutants.remove("total_sediment")
+        else:
+            self.pollutants = pollutants
+
+        self._available_values_x = self._trad.get_dict("values_x")
+        self._available_values_y = self._trad.get_dict("values_y_pol")
+
+        self.setup_radio_buttons_x()
+        self.setup_radio_buttons_pol()
+        self.setup_check_boxes()
+
+        self.value = None
+
+    def setup_radio_buttons_x(self):
+        self._radio = []
+        layout = self.find(QVBoxLayout, "verticalLayout_x")
+
+        for value in self._available_values_x:
+            btn = QRadioButton(
+                self._available_values_x[value],
+                parent=self
+            )
+            self._radio.append((value, btn))
+            layout.addWidget(btn)
+
+        self._radio[0][1].setChecked(True)
+        layout.addStretch()
+
+    def setup_radio_buttons_pol(self):
+        self._radio2 = []
+        layout = self.find(QVBoxLayout, "verticalLayout_pol")
+
+        for value in self.pollutants:
+            btn = QRadioButton(
+                value,
+                parent=self
+            )
+            self._radio2.append((value, btn))
+            layout.addWidget(btn)
+
+        self._radio2[0][1].setChecked(True)
+        layout.addStretch()
+
+    def setup_check_boxes(self):
+        self._check = []
+        layout = self.find(QVBoxLayout, "verticalLayout_y")
+
+        for value in self._available_values_y:
+            btn = QCheckBox(
+                self._available_values_y[value],
+                parent=self
+            )
+            self._check.append((value, btn))
+            layout.addWidget(btn)
+
+        self._check[0][1].setChecked(True)
+        layout.addStretch()
+
+    def accept(self):
+        x = next(
+            filter(
+                lambda r: r[1].isChecked(),
+                self._radio
+            )
+        )[0]
+
+        y = list(
+            map(
+                lambda b: b[0],
+                filter(
+                    lambda b: b[1].isChecked(),
+                    self._check
+                )
+            )
+        )
+
+        pol = next(
+            filter(
+                lambda r: r[1].isChecked(),
+                self._radio2
+            )
+        )[0]
+
+        self.value = x, y, pol
+
+        super().accept()
diff --git a/src/View/Results/Window.py b/src/View/Results/Window.py
index 0ddc87ef..13f15138 100644
--- a/src/View/Results/Window.py
+++ b/src/View/Results/Window.py
@@ -614,13 +614,13 @@ class ResultsWindow(PamhyrWindow):
         if x == "rk":
             timestamp = self._get_current_timestamp()
             first_line.append(f"Time: {timestamp}s")
-            val_dict = self._export_rk(timestamp, y, envelop, filename)
+            val_dict = self._export_rk(timestamp, y, envelop)
         elif x == "time":
             profile_id = self._get_current_profile()
             profile = reach.profile(profile_id)
             pname = profile.name if profile.name != "" else profile.rk
             first_line.append(f"Profile: {pname}")
-            val_dict = self._export_time(profile_id, y, filename)
+            val_dict = self._export_time(profile_id, y)
 
         with open(filename, 'w', newline='') as csvfile:
             writer = csv.writer(csvfile, delimiter=',',
@@ -675,7 +675,7 @@ class ResultsWindow(PamhyrWindow):
         self._additional_plot.pop(tab_widget.tabText(index))
         tab_widget.removeTab(index)
 
-    def _export_rk(self, timestamp, y, envelop, filename):
+    def _export_rk(self, timestamp, y, envelop):
         reach = self._results.river.reachs[self._get_current_reach()]
         dict_x = self._trad.get_dict("values_x")
         dict_y = self._trad.get_dict("values_y")
@@ -807,7 +807,7 @@ class ResultsWindow(PamhyrWindow):
 
         return my_dict
 
-    def _export_time(self, profile, y, filename):
+    def _export_time(self, profile, y):
         reach = self._results.river.reachs[self._get_current_reach()]
         profile = reach.profile(profile)
         ts = list(self._results.get("timestamps"))
diff --git a/src/View/Results/WindowAdisTS.py b/src/View/Results/WindowAdisTS.py
index d11cf8d9..28c77932 100644
--- a/src/View/Results/WindowAdisTS.py
+++ b/src/View/Results/WindowAdisTS.py
@@ -49,8 +49,8 @@ from View.Tools.Plot.PamhyrToolbar import PamhyrPlotToolbar
 from View.Results.PlotSedAdis import PlotAdis_dx, PlotAdis_dt
 
 from View.Results.CustomPlot.Plot import CustomPlot
-from View.Results.CustomPlot.CustomPlotValuesSelectionDialog import (
-    CustomPlotValuesSelectionDialog,
+from View.Results.CustomExportAdis import (
+    CustomExportAdisDialog,
 )
 
 from View.Results.TableAdisTS import TableModel
@@ -100,7 +100,7 @@ class ResultsWindowAdisTS(PamhyrWindow):
         self._hash_data.append(self._results)
 
         self._additional_plot = {}
-        self._pol_id = [1]
+        self._current_pol_id = [1]
         self._reach_id = 0
         self._profile_id = 0
 
@@ -114,22 +114,25 @@ class ResultsWindowAdisTS(PamhyrWindow):
             self.setup_connections()
             self.update_table_selection_reach(self._reach_id)
             self.update_table_selection_profile(self._profile_id)
-            self.update_table_selection_pol(self._pol_id)
+            self.update_table_selection_pol(self._current_pol_id)
         except Exception as e:
             logger_exception(e)
             return
 
     def set_type_pol(self):
         self._type_pol = []
+        self._pol_id_dict = {}
         tmp_list = self._results.river.reach(0).profiles
         for pol_index in range(self._results.nb_pollutants):
-            if self._results.pollutants_list[pol_index] == "total_sediment":
+            pol_name = self._results.pollutants_list[pol_index]
+            if pol_name == "total_sediment":
                 self._type_pol.append(-1)
             else:
                 self._type_pol.append(len(
                     tmp_list[0].get_ts_key(
                         self._timestamps[0], "pols")[pol_index])
                 )
+            self._pol_id_dict[pol_name] = pol_index
 
     def setup_table(self):
         self._table = {}
@@ -187,7 +190,7 @@ class ResultsWindowAdisTS(PamhyrWindow):
             results=self._results,
             reach_id=self._reach_id,
             profile_id=self._profile_id,
-            pol_id=self._pol_id,
+            pol_id=self._current_pol_id,
             key="C",
             type_pol=self._type_pol,
             trad=self._trad,
@@ -213,7 +216,7 @@ class ResultsWindowAdisTS(PamhyrWindow):
             results=self._results,
             reach_id=self._reach_id,
             profile_id=self._profile_id,
-            pol_id=self._pol_id,
+            pol_id=self._current_pol_id,
             key="C",
             type_pol=self._type_pol,
             trad=self._trad,
@@ -239,7 +242,7 @@ class ResultsWindowAdisTS(PamhyrWindow):
             results=self._results,
             reach_id=self._reach_id,
             profile_id=self._profile_id,
-            pol_id=self._pol_id,
+            pol_id=self._current_pol_id,
             key="M",
             type_pol=self._type_pol,
             trad=self._trad,
@@ -266,7 +269,7 @@ class ResultsWindowAdisTS(PamhyrWindow):
             results=self._results,
             reach_id=self._reach_id,
             profile_id=self._profile_id,
-            pol_id=self._pol_id,
+            pol_id=self._current_pol_id,
             key="M",
             type_pol=self._type_pol,
             trad=self._trad,
@@ -475,6 +478,7 @@ class ResultsWindowAdisTS(PamhyrWindow):
     def update(self, reach_id=None, profile_id=None,
                pol_id=None, timestamp=None):
         if reach_id is not None:
+            self._reach_id = reach_id
             self.plot_cdt.set_reach(reach_id)
             self.plot_cdx.set_reach(reach_id)
             self.plot_mdx.set_reach(reach_id)
@@ -486,6 +490,7 @@ class ResultsWindowAdisTS(PamhyrWindow):
             self.update_table_selection_profile(0)
 
         if profile_id is not None:
+            self._profile_id = profile_id
             self.plot_cdt.set_profile(profile_id)
             self.plot_cdx.set_profile(profile_id)
             self.plot_mdx.set_profile(profile_id)
@@ -496,11 +501,11 @@ class ResultsWindowAdisTS(PamhyrWindow):
             self.update_table_selection_profile(profile_id)
 
         if pol_id is not None:
-            self._pol_id = [p+1 for p in pol_id]  # remove total_sediment
-            self.plot_cdt.set_pollutant(self._pol_id)
-            self.plot_cdx.set_pollutant(self._pol_id)
-            self.plot_mdx.set_pollutant(self._pol_id)
-            self.plot_mdt.set_pollutant(self._pol_id)
+            self._current_pol_id = [p+1 for p in pol_id]  # remove total_sediment
+            self.plot_cdt.set_pollutant(self._current_pol_id)
+            self.plot_cdx.set_pollutant(self._current_pol_id)
+            self.plot_mdx.set_pollutant(self._current_pol_id)
+            self.plot_mdt.set_pollutant(self._current_pol_id)
 
         if timestamp is not None:
             self.plot_cdt.set_timestamp(timestamp)
@@ -635,53 +640,143 @@ class ResultsWindowAdisTS(PamhyrWindow):
             self._button_play.setIcon(self._icon_start)
 
     def export(self):
-        self.file_dialog(
-            select_file=False,
-            callback=lambda d: self.export_to(d[0])
-        )
-
-    def export_to(self, directory):
-        for reach in self._results.river.reachs:
-            self.export_reach(reach, directory)
 
-    def export_reach(self, reach, directory):
-        name = reach.name
-        name = name.replace(" ", "-")
+        pols = self._results.pollutants_list.copy()
+        dlg = CustomExportAdisDialog(pollutants=pols,
+                                     parent=self)
+        if dlg.exec():
+            x, y, pol = dlg.value
+        else:
+            return
 
-        file_name = os.path.join(
-            directory,
-            f"reach_{name}.csv"
+        self.file_dialog(
+            select_file="AnyFile",
+            callback=lambda f: self.export_to(f[0], x, y, pol),
+            default_suffix=".csv",
+            file_filter=["CSV (*.csv)"],
         )
 
-        with open(file_name, 'w', newline='') as csvfile:
-            writer = csv.writer(csvfile, delimiter=',',
-                                quotechar='|', quoting=csv.QUOTE_MINIMAL)
-            writer.writerow(["name", "rk", "data-file"])
-            for profile in reach.profiles:
-                p_file_name = os.path.join(
-                    directory,
-                    f"cs_{profile.geometry.id}.csv"
-                )
-
-                writer.writerow([
-                    profile.name,
-                    profile.rk,
-                    p_file_name
-                ])
-
-                self.export_profile(reach, profile, p_file_name)
-
-    def export_profile(self, reach, profile, file_name):
-        with open(file_name, 'w', newline='') as csvfile:
+    def export_to(self, filename, x, y, pol):
+        timestamps = sorted(self._results.get("timestamps"))
+        reach = self._results.river.reachs[self._reach_id]
+        first_line = [f"Study: {self._results.study.name}",
+                      f"Reach: {reach.name}"]
+        if x == "rk":
+            timestamp = self._get_current_timestamp()
+            first_line.append(f"Time: {timestamp}s")
+            val_dict = self._export_rk(timestamp, y, pol)
+        elif x == "time":
+            profile = reach.profile(self._profile_id)
+            pname = profile.name if profile.name != "" else profile.rk
+            first_line.append(f"Profile: {pname}")
+            val_dict = self._export_time(self._profile_id, y, pol)
+        with open(filename, 'w', newline='') as csvfile:
             writer = csv.writer(csvfile, delimiter=',',
                                 quotechar='|', quoting=csv.QUOTE_MINIMAL)
-
-            writer.writerow(["timestamp", "z", "q"])
-            timestamps = sorted(self._results.get("timestamps"))
-
-            for ts in timestamps:
-                writer.writerow([
-                    ts,
-                    profile.get_ts_key(ts, "Z"),
-                    profile.get_ts_key(ts, "Q"),
-                ])
+            dict_x = self._trad.get_dict("values_x")
+            header = []
+            writer.writerow(first_line)
+            for text in val_dict.keys():
+                header.append(text)
+            writer.writerow(header)
+            for row in range(len(val_dict[dict_x[x]])):
+                line = []
+                for var in val_dict.keys():
+                    line.append(val_dict[var][row])
+                writer.writerow(line)
+
+    def _export_rk(self, timestamp, y, pol):
+        reach = self._results.river.reachs[self._reach_id]
+        dict_x = self._trad.get_dict("values_x")
+        dict_y = self._trad.get_dict("values_y_pol")
+        my_dict = {}
+        my_dict[dict_x["rk"]] = reach.geometry.get_rk()
+        val_id = {}
+        val_id["unit_C"] = 0
+        val_id["unit_M"] = 2
+        val_id["unit_thickness"] = 2
+        for unit in y:
+            if unit == "unit_thickness":
+                pol_id = 0
+            else:
+                pol_id = self._pol_id_dict[pol]
+            if unit == "unit_M" and self._type_pol[pol_id] == 1:
+                my_dict[dict_y[unit]] = [0.0]*len(reach.profiles)
+            else:
+                my_dict[dict_y[unit]] = list(map(lambda p:
+                                             p.get_ts_key(
+                                             timestamp, "pols"
+                                             )[pol_id][val_id[unit]],
+                                             reach.profiles))
+
+        return my_dict
+
+    def _export_time(self, profile, y, pol):
+        reach = self._results.river.reachs[self._reach_id]
+        profile = reach.profile(profile)
+        ts = list(self._results.get("timestamps"))
+        ts.sort()
+        dict_x = self._trad.get_dict("values_x")
+        dict_y = self._trad.get_dict("values_y_pol")
+        my_dict = {}
+        my_dict[dict_x["time"]] = ts
+        val_id = {}
+        val_id["unit_C"] = 0
+        val_id["unit_M"] = 2
+        val_id["unit_thickness"] = 2
+        for unit in y:
+            if unit == "unit_thickness":
+                pol_id = 0
+            else:
+                pol_id = self._pol_id_dict[pol]
+            if unit == "unit_M" and self._type_pol[pol_id] == 1:
+                my_dict[dict_y[unit]] = [0.0]*len(ts)
+            else:
+                my_dict[dict_y[unit]] = list(map(lambda data_el:
+                                             data_el[pol_id][val_id[unit]],
+                                             profile.get_key("pols")
+                                             ))
+
+        return my_dict
+
+    #def export_reach(self, reach, directory):
+        #name = reach.name
+        #name = name.replace(" ", "-")
+
+        #file_name = os.path.join(
+            #directory,
+            #f"reach_{name}.csv"
+        #)
+
+        #with open(file_name, 'w', newline='') as csvfile:
+            #writer = csv.writer(csvfile, delimiter=',',
+                                #quotechar='|', quoting=csv.QUOTE_MINIMAL)
+            #writer.writerow(["name", "rk", "data-file"])
+            #for profile in reach.profiles:
+                #p_file_name = os.path.join(
+                    #directory,
+                    #f"cs_{profile.geometry.id}.csv"
+                #)
+
+                #writer.writerow([
+                    #profile.name,
+                    #profile.rk,
+                    #p_file_name
+                #])
+
+                #self.export_profile(reach, profile, p_file_name)
+
+    #def export_profile(self, reach, profile, file_name):
+        #with open(file_name, 'w', newline='') as csvfile:
+            #writer = csv.writer(csvfile, delimiter=',',
+                                #quotechar='|', quoting=csv.QUOTE_MINIMAL)
+
+            #writer.writerow(["timestamp", "z", "q"])
+            #timestamps = sorted(self._results.get("timestamps"))
+
+            #for ts in timestamps:
+                #writer.writerow([
+                    #ts,
+                    #profile.get_ts_key(ts, "Z"),
+                    #profile.get_ts_key(ts, "Q"),
+                #])
diff --git a/src/View/ui/CustomExportAdisDialog.ui b/src/View/ui/CustomExportAdisDialog.ui
new file mode 100644
index 00000000..b1c0f419
--- /dev/null
+++ b/src/View/ui/CustomExportAdisDialog.ui
@@ -0,0 +1,104 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<ui version="4.0">
+ <class>Dialog</class>
+ <widget class="QDialog" name="Dialog">
+  <property name="geometry">
+   <rect>
+    <x>0</x>
+    <y>0</y>
+    <width>194</width>
+    <height>70</height>
+   </rect>
+  </property>
+  <property name="windowTitle">
+   <string>Dialog</string>
+  </property>
+  <layout class="QGridLayout" name="gridLayout">
+   <item row="4" column="0">
+    <widget class="QDialogButtonBox" name="buttonBox">
+     <property name="orientation">
+      <enum>Qt::Horizontal</enum>
+     </property>
+     <property name="standardButtons">
+      <set>QDialogButtonBox::Cancel|QDialogButtonBox::Ok</set>
+     </property>
+    </widget>
+   </item>
+   <item row="3" column="0">
+    <widget class="QSplitter" name="splitter">
+     <property name="orientation">
+      <enum>Qt::Horizontal</enum>
+     </property>
+     <widget class="QWidget" name="verticalLayoutWidget">
+      <layout class="QVBoxLayout" name="verticalLayout_x">
+       <item>
+        <widget class="QLabel" name="label">
+         <property name="text">
+          <string>X axis:</string>
+         </property>
+        </widget>
+       </item>
+      </layout>
+     </widget>
+     <widget class="QWidget" name="verticalLayoutWidget_2">
+      <layout class="QVBoxLayout" name="verticalLayout_y">
+       <item>
+        <widget class="QLabel" name="label_2">
+         <property name="text">
+          <string>Y axis:</string>
+         </property>
+        </widget>
+       </item>
+      </layout>
+     </widget>
+     <widget class="QWidget" name="verticalLayoutWidget_3">
+      <layout class="QVBoxLayout" name="verticalLayout_pol">
+       <item>
+        <widget class="QLabel" name="label_3">
+         <property name="text">
+          <string>Pollutant:</string>
+         </property>
+        </widget>
+       </item>
+      </layout>
+     </widget>
+    </widget>
+   </item>
+  </layout>
+ </widget>
+ <resources/>
+ <connections>
+  <connection>
+   <sender>buttonBox</sender>
+   <signal>rejected()</signal>
+   <receiver>Dialog</receiver>
+   <slot>reject()</slot>
+   <hints>
+    <hint type="sourcelabel">
+     <x>316</x>
+     <y>260</y>
+    </hint>
+    <hint type="destinationlabel">
+     <x>286</x>
+     <y>274</y>
+    </hint>
+   </hints>
+  </connection>
+  <connection>
+   <sender>buttonBox</sender>
+   <signal>accepted()</signal>
+   <receiver>Dialog</receiver>
+   <slot>accept()</slot>
+   <hints>
+    <hint type="sourcelabel">
+     <x>248</x>
+     <y>254</y>
+    </hint>
+    <hint type="destinationlabel">
+     <x>157</x>
+     <y>274</y>
+    </hint>
+   </hints>
+  </connection>
+ </connections>
+</ui>
-- 
GitLab