From 45ee8a96628cc5926ea91240101cdf125a49412b Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Rouby <pierre-antoine.rouby@inrae.fr> Date: Wed, 6 Dec 2023 11:05:55 +0100 Subject: [PATCH] Friction: Fix load method in case of multiple reaches. --- src/Model/Friction/Friction.py | 21 ++++++++++++++----- src/Model/Friction/FrictionList.py | 4 +++- src/Model/Geometry/PointXYZ.py | 2 +- src/Model/Geometry/ProfileXYZ.py | 9 +------- src/Model/Geometry/Reach.py | 2 +- .../HydraulicStructures.py | 10 ++++++++- src/Model/River.py | 9 ++++---- 7 files changed, 36 insertions(+), 21 deletions(-) diff --git a/src/Model/Friction/Friction.py b/src/Model/Friction/Friction.py index db1a60db..2e6ec259 100644 --- a/src/Model/Friction/Friction.py +++ b/src/Model/Friction/Friction.py @@ -16,10 +16,13 @@ # -*- coding: utf-8 -*- +import logging + from tools import trace, timer from Model.Tools.PamhyrDB import SQLSubModel +logger = logging.getLogger() class Friction(SQLSubModel): def __init__(self, name: str = "", status=None): @@ -60,7 +63,10 @@ class Friction(SQLSubModel): @classmethod def _db_load(cls, execute, data=None): new = [] - reach = data["parent"] # Reach object + + logger.info(data) + + reach = data["reach"] status = data["status"] stricklers = data["stricklers"].stricklers @@ -69,9 +75,6 @@ class Friction(SQLSubModel): f"FROM friction WHERE reach = {reach.id}" ) - for _ in table: - new.append(None) - for row in table: ind = row[0] # Get stricklers @@ -86,7 +89,11 @@ class Friction(SQLSubModel): sec.begin_strickler = bs sec.end_strickler = es - yield ind, sec + new.append((ind, sec)) + + logger.info(new) + + return new def _db_save(self, execute, data=None): ind = data["ind"] @@ -116,6 +123,10 @@ class Friction(SQLSubModel): def edge(self): return self._edge + @property + def reach(self): + return self._edge + @edge.setter def edge(self, edge): self._edge = edge diff --git a/src/Model/Friction/FrictionList.py b/src/Model/Friction/FrictionList.py index 20c33197..6a5a6c09 100644 --- a/src/Model/Friction/FrictionList.py +++ b/src/Model/Friction/FrictionList.py @@ -50,10 +50,12 @@ class FrictionList(PamhyrModelList): def _db_load(cls, execute, data=None): new = cls(status=data['status']) - new._lst = Friction._db_load( + ilst = Friction._db_load( execute, data ) + new._lst = list(map(lambda x: x[1], sorted(ilst))) + return new def _db_save(self, execute, data=None): diff --git a/src/Model/Geometry/PointXYZ.py b/src/Model/Geometry/PointXYZ.py index 06a6ccf6..57c13a37 100644 --- a/src/Model/Geometry/PointXYZ.py +++ b/src/Model/Geometry/PointXYZ.py @@ -116,7 +116,7 @@ class PointXYZ(Point, SQLSubModel): sl = self._sl.id if self._sl is not None else -1 sql = ( - "INSERT OR REPLACE INTO " + + "INSERT INTO " + "geometry_pointXYZ(ind, name, x, y, z, profile, sl) " + "VALUES (" + f"{ind}, '{self._db_format(self._name)}', " + diff --git a/src/Model/Geometry/ProfileXYZ.py b/src/Model/Geometry/ProfileXYZ.py index ad3ebb5e..495fbac1 100644 --- a/src/Model/Geometry/ProfileXYZ.py +++ b/src/Model/Geometry/ProfileXYZ.py @@ -115,9 +115,6 @@ class ProfileXYZ(Profile, SQLSubModel): f"WHERE reach = {reach.id}" ) - for _ in table: - profiles.append(None) - for row in table: id = row[0] ind = row[1] @@ -132,7 +129,7 @@ class ProfileXYZ(Profile, SQLSubModel): id=id, num=num, name=name, kp=kp, code1=code1, code2=code2, - reach=data["parent"], + reach=reach, status=status ) @@ -151,10 +148,6 @@ class ProfileXYZ(Profile, SQLSubModel): yield ind, new - # profiles[ind] = new - - # return profiles - def _db_save(self, execute, data=None): ok = True ind = data["ind"] diff --git a/src/Model/Geometry/Reach.py b/src/Model/Geometry/Reach.py index 5351a9cb..3c02f16a 100644 --- a/src/Model/Geometry/Reach.py +++ b/src/Model/Geometry/Reach.py @@ -61,7 +61,7 @@ class Reach(SQLSubModel): @classmethod def _db_load(cls, execute, data=None): - new = cls(status=data["status"], parent=data["parent"]) + new = cls(status=data["status"], parent=data["reach"]) new._profiles = ProfileXYZ._db_load( execute, diff --git a/src/Model/HydraulicStructures/HydraulicStructures.py b/src/Model/HydraulicStructures/HydraulicStructures.py index f5fa356c..9234d511 100644 --- a/src/Model/HydraulicStructures/HydraulicStructures.py +++ b/src/Model/HydraulicStructures/HydraulicStructures.py @@ -140,6 +140,14 @@ class HydraulicStructure(SQLSubModel): if self._output_reach is not None: output_reach_id = self._output_reach.id + input_kp = -1 + if self.input_kp is not None: + input_kp = self.input_kp + + output_kp = -1 + if self.output_kp is not None: + output_kp = self.output_kp + sql = ( "INSERT INTO " + "hydraulic_structures(" + @@ -149,7 +157,7 @@ class HydraulicStructure(SQLSubModel): "VALUES (" + f"{self.id}, '{self._db_format(self._name)}', " + f"{self._db_format(self.enabled)}, " + - f"{self.input_kp}, {self.output_kp}, " + + f"{input_kp}, {output_kp}, " + f"{input_reach_id}, {output_reach_id}" + ")" ) diff --git a/src/Model/River.py b/src/Model/River.py index d995d2c3..2ba5a54f 100644 --- a/src/Model/River.py +++ b/src/Model/River.py @@ -157,7 +157,9 @@ class RiverReach(Edge, SQLSubModel): data = {} table = execute( - "SELECT id, name, enable, node1, node2 FROM river_reach") + "SELECT id, name, enable, node1, node2 FROM river_reach" + ) + for row in table: # Update id counter cls._id_cnt = max(cls._id_cnt, row[0]) @@ -172,10 +174,9 @@ class RiverReach(Edge, SQLSubModel): new = cls(id, name, node1, node2, status=data["status"]) new.enable(enable=enable) - data["reach"] = id - data["parent"] = new - new._reach = Reach._db_load(execute, data) + data["reach"] = new + new._reach = Reach._db_load(execute, data) new._frictions = FrictionList._db_load(execute, data) reachs.append(new) -- GitLab