From 150388af12d1847a9281cfb6587c960c75969aac Mon Sep 17 00:00:00 2001
From: Pierre-Antoine Rouby <pierre-antoine.rouby@inrae.fr>
Date: Wed, 5 Apr 2023 14:50:55 +0200
Subject: [PATCH] model: river: Keep current reach in memory.

---
 src/Model/Network/Graph.py      | 24 +++++++++++++++++-------
 src/Model/River.py              | 13 +++++++++++++
 src/View/Network/GraphWidget.py |  1 +
 3 files changed, 31 insertions(+), 7 deletions(-)

diff --git a/src/Model/Network/Graph.py b/src/Model/Network/Graph.py
index 8cc6de21..ad9d6f67 100644
--- a/src/Model/Network/Graph.py
+++ b/src/Model/Network/Graph.py
@@ -9,6 +9,9 @@ class Graph(object):
     def __init__(self):
         super(Graph, self).__init__()
 
+        self._node_ctor = Node
+        self._edge_ctor = Edge
+
         self._nodes_ids = 0
         self._edges_ids = 0
 
@@ -55,27 +58,34 @@ class Graph(object):
             )
         )[0]
 
+    def _add_node(self, node):
+        self._nodes.append(node)
+        self._nodes_ids += 1
+        return node
+
     def add_node(self, x:float = 0.0, y:float = 0.0):
-        node = Node(
+        node = self._node_ctor(
             self._nodes_ids,
             f"Node {self._nodes_ids}",
             x = x, y = y
         )
-        self._nodes.append(node)
-        self._nodes_ids += 1
-        return node
+        return self._add_node(node)
 
-    def add_edge(self, n1:Node, n2:Node):
+    def _add_edge(self, edge):
         # This edge already exists ?
-        if any(filter(lambda e: (e.node1 == n1 and e.node2 == n2),
+        if any(filter(lambda e: (e.node1 == edge.node1 and
+                                 e.node2 == edge.node2),
                       self._edges)):
             return None
 
-        edge = Edge(self._edges_ids, "", n1, n2)
         self._edges.append(edge)
         self._edges_ids += 1
         return edge
 
+    def add_edge(self, n1:Node, n2:Node):
+        edge = self._edge_ctor(self._edges_ids, "", n1, n2)
+        return self._add_edge(edge)
+
     def remove_node(self, node_name:str):
         self._nodes = list(
             filter(
diff --git a/src/Model/River.py b/src/Model/River.py
index 4094bc91..b38470c1 100644
--- a/src/Model/River.py
+++ b/src/Model/River.py
@@ -31,3 +31,16 @@ class RiverReach(Edge):
 class River(Graph):
     def __init__(self):
         super(River, self).__init__()
+
+        # Replace Node and Edge ctor by custom ctor
+        self._node_ctor = RiverNode
+        self._edge_ctor = RiverReach
+
+        self._current_reach = None
+
+
+    def current_reach(self):
+        return self._current_reach
+
+    def current_reach(self, reach):
+        self._current_reach = reach
diff --git a/src/View/Network/GraphWidget.py b/src/View/Network/GraphWidget.py
index 0bd70655..a0d19214 100644
--- a/src/View/Network/GraphWidget.py
+++ b/src/View/Network/GraphWidget.py
@@ -548,6 +548,7 @@ class GraphWidget(QGraphicsView):
         """
         previous_edge = self._current_edge
         self._current_edge = edge
+        self.graph.current_reach(edge.edge)
 
         if previous_edge:
             previous_edge.update()
-- 
GitLab