From 4d95e42d208e892ca3a08978547304f4d586dcbc Mon Sep 17 00:00:00 2001
From: Pierre-Antoine Rouby <pierre-antoine.rouby@inrae.fr>
Date: Mon, 8 Jan 2024 17:11:10 +0100
Subject: [PATCH] Geometry: Make import button an undo command.

---
 src/Model/Geometry/Reach.py      | 12 ++++++---
 src/View/Geometry/Table.py       | 14 ++++++++++
 src/View/Geometry/UndoCommand.py | 24 +++++++++++++++++
 src/View/Geometry/Window.py      | 45 ++++++++++++++++----------------
 4 files changed, 68 insertions(+), 27 deletions(-)

diff --git a/src/Model/Geometry/Reach.py b/src/Model/Geometry/Reach.py
index 663f3be0..8b61b88a 100644
--- a/src/Model/Geometry/Reach.py
+++ b/src/Model/Geometry/Reach.py
@@ -487,8 +487,7 @@ class Reach(SQLSubModel):
         Returns:
             Nothing.
         """
-        list_profile = []
-        list_header = []
+        imported_profiles = []
 
         try:
             list_profile, list_header = self.read_file_st(str(file_path_name))
@@ -505,9 +504,8 @@ class Reach(SQLSubModel):
                         **d, reach=self, status=self._status
                     )
                     prof.import_points(profile)
-                    self.profiles.append(prof)
-                    self._update_profile_numbers()
 
+                    imported_profiles.append(prof)
                 self._status.modified()
         except FileNotFoundError as e:
             logger.error(e)
@@ -515,6 +513,12 @@ class Reach(SQLSubModel):
         except FileFormatError as e:
             logger.error(e)
             e.alert()
+        finally:
+            self.profiles = imported_profiles + self.profiles
+
+            self._update_profile_numbers()
+            return imported_profiles
+
 
     @timer
     def read_file_st(self, filename):
diff --git a/src/View/Geometry/Table.py b/src/View/Geometry/Table.py
index 31970968..a256291d 100644
--- a/src/View/Geometry/Table.py
+++ b/src/View/Geometry/Table.py
@@ -185,6 +185,20 @@ class GeometryReachTableModel(PamhyrTableModel):
         self.endMoveRows()
         self.layoutChanged.emit()
 
+    def import_geometry(self, row, filename):
+        self.layoutAboutToBeChanged.emit()
+
+        self._undo.push(
+            ImportCommand(
+                self._data, row,
+                filename
+            )
+        )
+
+        self.layoutAboutToBeChanged.emit()
+        self.layoutChanged.emit()
+
+
     def duplicate(self, rows, profiles):
         self.layoutAboutToBeChanged.emit()
 
diff --git a/src/View/Geometry/UndoCommand.py b/src/View/Geometry/UndoCommand.py
index c62ae7dc..8f5a0cbd 100644
--- a/src/View/Geometry/UndoCommand.py
+++ b/src/View/Geometry/UndoCommand.py
@@ -16,6 +16,8 @@
 
 # -*- coding: utf-8 -*-
 
+import logging
+
 from copy import deepcopy
 from tools import trace, timer
 
@@ -27,6 +29,8 @@ from Model.Geometry import Reach
 
 from Meshing.Mage import MeshingWithMage
 
+logger = logging.getLogger()
+
 
 class SetDataCommand(QUndoCommand):
     def __init__(self, reach, index, old_value, new_value):
@@ -190,6 +194,26 @@ class DuplicateCommand(QUndoCommand):
         for profile in self._profiles:
             self._reach.insert_profile(self._rows[0], profile)
 
+class ImportCommand(QUndoCommand):
+    def __init__(self, reach, row, filename):
+        QUndoCommand.__init__(self)
+
+        self._reach = reach
+        self._row = row
+        self._filename = filename
+        self._profiles = None
+
+    def undo(self):
+        self._reach.delete_profiles(self._profiles)
+
+    def redo(self):
+        if self._profiles is None:
+            self._profiles = self._reach.import_geometry(self._filename)
+            self._profiles.reverse()
+        else:
+            for profile in self._profiles:
+                self._reach.insert_profile(self._row, profile)
+
 
 class MeshingCommand(QUndoCommand):
     def __init__(self, reach, mesher, step):
diff --git a/src/View/Geometry/Window.py b/src/View/Geometry/Window.py
index 10a5e6b5..20d92b3b 100644
--- a/src/View/Geometry/Window.py
+++ b/src/View/Geometry/Window.py
@@ -81,7 +81,7 @@ class GeometryWindow(PamhyrWindow):
         # Add reach to hash computation data
         self._hash_data.append(self._reach)
 
-        self._tablemodel = None
+        self._table = None
         self._profile_window = []
 
         self.setup_table()
@@ -94,14 +94,14 @@ class GeometryWindow(PamhyrWindow):
         table_headers = self._trad.get_dict("table_headers")
 
         table = self.find(QTableView, "tableView")
-        self._tablemodel = GeometryReachTableModel(
+        self._table = GeometryReachTableModel(
             table_view=table,
             table_headers=table_headers,
             editable_headers=["name", "kp"],
             data=self._reach,
             undo=self._undo_stack
         )
-        table.setModel(self._tablemodel)
+        table.setModel(self._table)
         table.setSelectionBehavior(QAbstractItemView.SelectRows)
         table.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch)
         table.setAlternatingRowColors(True)
@@ -214,8 +214,7 @@ class GeometryWindow(PamhyrWindow):
 
         if filename != "":
             size = os.stat(filename).st_size
-            self._reach.import_geometry(filename)
-            self._tablemodel.layoutChanged.emit()
+            self._table.import_geometry(0, filename)
 
             self.update_profile_windows()
             self.plot_xy()
@@ -251,7 +250,7 @@ class GeometryWindow(PamhyrWindow):
 
     def edit_meshing(self):
         mesher = MeshingWithMage()
-        self._tablemodel.meshing(mesher, -1)
+        self._table.meshing(mesher, -1)
 
         self.update_profile_windows()
         self.plot_xy()
@@ -370,9 +369,9 @@ class GeometryWindow(PamhyrWindow):
     def changed_slider_value(self):
         self.tableView.model().blockSignals(True)
 
-        if self._tablemodel.rowCount() != 0:
+        if self._table.rowCount() != 0:
             slider = self.find(QSlider, "verticalSlider")
-            slider.setMaximum(self._tablemodel.rowCount() - 1)
+            slider.setMaximum(self._table.rowCount() - 1)
 
             slider_value = slider.value()
             kp = self._reach.profile(slider_value).kp
@@ -386,20 +385,20 @@ class GeometryWindow(PamhyrWindow):
 
     def increment_value_slider(self):
         slider = self.find(QSlider, "verticalSlider")
-        if 0 <= slider.value() < self._tablemodel.rowCount() - 1:
+        if 0 <= slider.value() < self._table.rowCount() - 1:
             slider.setValue(slider.value() + 1)
 
     def decrement_value_slider(self):
         slider = self.find(QSlider, "verticalSlider")
-        if 0 < slider.value() < self._tablemodel.rowCount():
+        if 0 < slider.value() < self._table.rowCount():
             slider.setValue(slider.value() - 1)
 
     def add(self):
         if len(self.tableView.selectedIndexes()) == 0:
-            self._tablemodel.add(self._tablemodel.rowCount())
+            self._table.add(self._table.rowCount())
         else:
             row = self.index_selected_row()
-            self._tablemodel.add(row + 1)
+            self._table.add(row + 1)
 
     def delete(self):
         rows = sorted(
@@ -411,7 +410,7 @@ class GeometryWindow(PamhyrWindow):
         )
 
         if len(rows) > 0:
-            self._tablemodel.delete(rows)
+            self._table.delete(rows)
 
         self.update_plot_xy()
         self.select_current_profile()
@@ -426,24 +425,24 @@ class GeometryWindow(PamhyrWindow):
                    .row()
 
     def sort_ascending(self):
-        self._tablemodel.sort_profiles(False)
+        self._table.sort_profiles(False)
         self.select_current_profile()
         self.changed_slider_value()
 
     def sort_descending(self):
-        self._tablemodel.sort_profiles(True)
+        self._table.sort_profiles(True)
 
         self.select_current_profile()
         self.changed_slider_value()
 
     def move_up(self):
         row = self.index_selected_row()
-        self._tablemodel.move_up(row)
+        self._table.move_up(row)
         self.select_current_profile()
 
     def move_down(self):
         row = self.index_selected_row()
-        self._tablemodel.move_down(row)
+        self._table.move_down(row)
         self.select_current_profile()
 
     def duplicate(self):
@@ -461,7 +460,7 @@ class GeometryWindow(PamhyrWindow):
         if len(profiles) == 0:
             return
 
-        self._tablemodel.duplicate(rows, profiles)
+        self._table.duplicate(rows, profiles)
         self.select_current_profile()
 
     def _copy(self):
@@ -507,18 +506,18 @@ class GeometryWindow(PamhyrWindow):
             row.append(self._study.river._status)
 
         row = self.index_selected_row()
-        # self._tablemodel.paste(row, header, data)
-        self._tablemodel.paste(row, [], data)
+        # self._table.paste(row, header, data)
+        self._table.paste(row, [], data)
         self.select_current_profile()
 
     def _undo(self):
-        self._tablemodel.undo()
+        self._table.undo()
         self.select_current_profile()
         self.update_plot_xy()
         self.update_plot_kpc()
 
     def _redo(self):
-        self._tablemodel.redo()
+        self._table.redo()
         self.select_current_profile()
         self.update_plot_xy()
         self.update_plot_kpc()
@@ -541,4 +540,4 @@ class GeometryWindow(PamhyrWindow):
         current_dir = os.path.split(filename)[0] or DEFAULT_DIRECTORY
 
         if filename != '':
-            self._tablemodel.export_reach(filename)
+            self._table.export_reach(filename)
-- 
GitLab