From 45ee8a96628cc5926ea91240101cdf125a49412b Mon Sep 17 00:00:00 2001
From: Pierre-Antoine Rouby <pierre-antoine.rouby@inrae.fr>
Date: Wed, 6 Dec 2023 11:05:55 +0100
Subject: [PATCH] Friction: Fix load method in case of multiple reaches.

---
 src/Model/Friction/Friction.py                | 21 ++++++++++++++-----
 src/Model/Friction/FrictionList.py            |  4 +++-
 src/Model/Geometry/PointXYZ.py                |  2 +-
 src/Model/Geometry/ProfileXYZ.py              |  9 +-------
 src/Model/Geometry/Reach.py                   |  2 +-
 .../HydraulicStructures.py                    | 10 ++++++++-
 src/Model/River.py                            |  9 ++++----
 7 files changed, 36 insertions(+), 21 deletions(-)

diff --git a/src/Model/Friction/Friction.py b/src/Model/Friction/Friction.py
index db1a60db..2e6ec259 100644
--- a/src/Model/Friction/Friction.py
+++ b/src/Model/Friction/Friction.py
@@ -16,10 +16,13 @@
 
 # -*- coding: utf-8 -*-
 
+import logging
+
 from tools import trace, timer
 
 from Model.Tools.PamhyrDB import SQLSubModel
 
+logger = logging.getLogger()
 
 class Friction(SQLSubModel):
     def __init__(self, name: str = "", status=None):
@@ -60,7 +63,10 @@ class Friction(SQLSubModel):
     @classmethod
     def _db_load(cls, execute, data=None):
         new = []
-        reach = data["parent"]  # Reach object
+
+        logger.info(data)
+
+        reach = data["reach"]
         status = data["status"]
         stricklers = data["stricklers"].stricklers
 
@@ -69,9 +75,6 @@ class Friction(SQLSubModel):
             f"FROM friction WHERE reach = {reach.id}"
         )
 
-        for _ in table:
-            new.append(None)
-
         for row in table:
             ind = row[0]
             # Get stricklers
@@ -86,7 +89,11 @@ class Friction(SQLSubModel):
             sec.begin_strickler = bs
             sec.end_strickler = es
 
-            yield ind, sec
+            new.append((ind, sec))
+
+        logger.info(new)
+
+        return new
 
     def _db_save(self, execute, data=None):
         ind = data["ind"]
@@ -116,6 +123,10 @@ class Friction(SQLSubModel):
     def edge(self):
         return self._edge
 
+    @property
+    def reach(self):
+        return self._edge
+
     @edge.setter
     def edge(self, edge):
         self._edge = edge
diff --git a/src/Model/Friction/FrictionList.py b/src/Model/Friction/FrictionList.py
index 20c33197..6a5a6c09 100644
--- a/src/Model/Friction/FrictionList.py
+++ b/src/Model/Friction/FrictionList.py
@@ -50,10 +50,12 @@ class FrictionList(PamhyrModelList):
     def _db_load(cls, execute, data=None):
         new = cls(status=data['status'])
 
-        new._lst = Friction._db_load(
+        ilst = Friction._db_load(
             execute, data
         )
 
+        new._lst = list(map(lambda x: x[1], sorted(ilst)))
+
         return new
 
     def _db_save(self, execute, data=None):
diff --git a/src/Model/Geometry/PointXYZ.py b/src/Model/Geometry/PointXYZ.py
index 06a6ccf6..57c13a37 100644
--- a/src/Model/Geometry/PointXYZ.py
+++ b/src/Model/Geometry/PointXYZ.py
@@ -116,7 +116,7 @@ class PointXYZ(Point, SQLSubModel):
         sl = self._sl.id if self._sl is not None else -1
 
         sql = (
-            "INSERT OR REPLACE INTO " +
+            "INSERT INTO " +
             "geometry_pointXYZ(ind, name, x, y, z, profile, sl) " +
             "VALUES (" +
             f"{ind}, '{self._db_format(self._name)}', " +
diff --git a/src/Model/Geometry/ProfileXYZ.py b/src/Model/Geometry/ProfileXYZ.py
index ad3ebb5e..495fbac1 100644
--- a/src/Model/Geometry/ProfileXYZ.py
+++ b/src/Model/Geometry/ProfileXYZ.py
@@ -115,9 +115,6 @@ class ProfileXYZ(Profile, SQLSubModel):
             f"WHERE reach = {reach.id}"
         )
 
-        for _ in table:
-            profiles.append(None)
-
         for row in table:
             id = row[0]
             ind = row[1]
@@ -132,7 +129,7 @@ class ProfileXYZ(Profile, SQLSubModel):
                 id=id, num=num,
                 name=name, kp=kp,
                 code1=code1, code2=code2,
-                reach=data["parent"],
+                reach=reach,
                 status=status
             )
 
@@ -151,10 +148,6 @@ class ProfileXYZ(Profile, SQLSubModel):
 
             yield ind, new
 
-        #     profiles[ind] = new
-
-        # return profiles
-
     def _db_save(self, execute, data=None):
         ok = True
         ind = data["ind"]
diff --git a/src/Model/Geometry/Reach.py b/src/Model/Geometry/Reach.py
index 5351a9cb..3c02f16a 100644
--- a/src/Model/Geometry/Reach.py
+++ b/src/Model/Geometry/Reach.py
@@ -61,7 +61,7 @@ class Reach(SQLSubModel):
 
     @classmethod
     def _db_load(cls, execute, data=None):
-        new = cls(status=data["status"], parent=data["parent"])
+        new = cls(status=data["status"], parent=data["reach"])
 
         new._profiles = ProfileXYZ._db_load(
             execute,
diff --git a/src/Model/HydraulicStructures/HydraulicStructures.py b/src/Model/HydraulicStructures/HydraulicStructures.py
index f5fa356c..9234d511 100644
--- a/src/Model/HydraulicStructures/HydraulicStructures.py
+++ b/src/Model/HydraulicStructures/HydraulicStructures.py
@@ -140,6 +140,14 @@ class HydraulicStructure(SQLSubModel):
         if self._output_reach is not None:
             output_reach_id = self._output_reach.id
 
+        input_kp = -1
+        if self.input_kp is not None:
+            input_kp = self.input_kp
+
+        output_kp = -1
+        if self.output_kp is not None:
+            output_kp = self.output_kp
+
         sql = (
             "INSERT INTO " +
             "hydraulic_structures(" +
@@ -149,7 +157,7 @@ class HydraulicStructure(SQLSubModel):
             "VALUES (" +
             f"{self.id}, '{self._db_format(self._name)}', " +
             f"{self._db_format(self.enabled)}, " +
-            f"{self.input_kp}, {self.output_kp}, " +
+            f"{input_kp}, {output_kp}, " +
             f"{input_reach_id}, {output_reach_id}" +
             ")"
         )
diff --git a/src/Model/River.py b/src/Model/River.py
index d995d2c3..2ba5a54f 100644
--- a/src/Model/River.py
+++ b/src/Model/River.py
@@ -157,7 +157,9 @@ class RiverReach(Edge, SQLSubModel):
             data = {}
 
         table = execute(
-            "SELECT id, name, enable, node1, node2 FROM river_reach")
+            "SELECT id, name, enable, node1, node2 FROM river_reach"
+        )
+
         for row in table:
             # Update id counter
             cls._id_cnt = max(cls._id_cnt, row[0])
@@ -172,10 +174,9 @@ class RiverReach(Edge, SQLSubModel):
             new = cls(id, name, node1, node2, status=data["status"])
             new.enable(enable=enable)
 
-            data["reach"] = id
-            data["parent"] = new
-            new._reach = Reach._db_load(execute, data)
+            data["reach"] = new
 
+            new._reach = Reach._db_load(execute, data)
             new._frictions = FrictionList._db_load(execute, data)
 
             reachs.append(new)
-- 
GitLab