From 4978f20e29befeb8545f4bf0c7a33763d318a715 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Rouby <pierre-antoine.rouby@inrae.fr> Date: Thu, 29 Jun 2023 11:33:59 +0200 Subject: [PATCH] Network: Export model network to sqlite db. --- src/Model/DB.py | 71 ++++++++++++++++++++++---------------- src/Model/Network/Edge.py | 11 ++++-- src/Model/Network/Graph.py | 11 ++---- src/Model/Network/Node.py | 31 +++++++++++------ src/Model/River.py | 68 +++++++++++++++++++++++++++++++++--- src/Model/Study.py | 28 +++++++++++---- src/View/MainWindow.py | 12 +++---- 7 files changed, 167 insertions(+), 65 deletions(-) diff --git a/src/Model/DB.py b/src/Model/DB.py index 5206555a..013f688e 100644 --- a/src/Model/DB.py +++ b/src/Model/DB.py @@ -22,54 +22,65 @@ class SQLModel(SQL): self._cur = self._db.cursor() if is_new: + print("CREATE") self._create() # Create db - self._save() # Save + # self._save() # Save else: + print("UPDATE") self._update() # Update db scheme if necessary - self._load() # Load data + # self._load() # Load data + + + def __init__(self, filename = None): + self._db = None def _create_submodel(self): + fn = lambda sql: self.execute( + sql, + fetch_one = False, + commit = True + ) + for cls in self._sub_classes: - requests = cls._sql_create( - lambda sql: self.execute( - sql, - fetch_one = False, - commit = True - ) - ) + requests = cls._sql_create(fn) def _create(self): raise NotImplementedMethodeError(self, self._create) def _update_submodel(self, version): + fn = lambda sql: self.execute( + sql, + fetch_one = False, + commit = True + ) + + ok = True for cls in self._sub_classes: - requests = cls._sql_update( - lambda sql: self.execute( - sql, - fetch_one = False, - commit = True - ), - version - ) + ok &= cls._sql_update(fn, version) + + return ok def _update(self): raise NotImplementedMethodeError(self, self._update) - def _save_submodel(self, objs): + def _save_submodel(self, objs, data = None): + fn = lambda sql: self.execute( + sql, + fetch_one = False, + commit = True + ) + + ok = True for obj in objs: - requests = obj._sql_save( - lambda sql: self.execute( - sql, - fetch_one = False, - commit = True - ) - ) + ok &= obj._sql_save(fn) + + return ok def _save(self): raise NotImplementedMethodeError(self, self._save) @classmethod - def _load(cls, filename): + def _load(cls, filename = None): raise NotImplementedMethodeError(cls, cls._load) # Sub model class @@ -80,6 +91,8 @@ class SQLSubModel(object): # Replace ''' by ''' to preserve SQL injection if type(value) == str: value = value.replace("'", "'") + elif type(value) == bool: + value = 'TRUE' if value else 'FALSE' return value @classmethod @@ -130,9 +143,9 @@ class SQLSubModel(object): """ raise NotImplementedMethodeError(cls, cls._sql_load) - def _save_submodel(self, execute): - for sc in self._sub_classes: - sc._sql_update(execute) + def _save_submodel(self, execute, objs, data = None): + for o in objs: + o._sql_save(execute, data = data) def _sql_save(self, execute, data = None): """Save class data to data base diff --git a/src/Model/Network/Edge.py b/src/Model/Network/Edge.py index b0918a16..f1ced151 100644 --- a/src/Model/Network/Edge.py +++ b/src/Model/Network/Edge.py @@ -3,7 +3,9 @@ from Model.Network.Node import Node class Edge(object): - def __init__(self, id:str, name:str, + _id_cnt = 0 + + def __init__(self, id:int, name:str, node1:Node = None, node2:Node = None, status = None): @@ -11,7 +13,12 @@ class Edge(object): self._status = status - self.id = id + if id == -1: + type(self)._id_cnt += 1 + self.id = type(self)._id_cnt + else: + self.id = id + self._name = name self.node1 = node1 diff --git a/src/Model/Network/Graph.py b/src/Model/Network/Graph.py index eb91614a..0f432495 100644 --- a/src/Model/Network/Graph.py +++ b/src/Model/Network/Graph.py @@ -14,9 +14,6 @@ class Graph(object): self._node_ctor = Node self._edge_ctor = Edge - self._nodes_ids = 0 - self._edges_ids = 0 - self._nodes = [] self._edges = [] @@ -83,12 +80,11 @@ class Graph(object): def _create_node(self, x:float, y:float): node = self._node_ctor( - self._nodes_ids, - f"Node {self._nodes_ids}", + -1, + "", x = x, y = y, status = self._status ) - self._nodes_ids += 1 return node def _add_node(self, node): @@ -118,11 +114,10 @@ class Graph(object): def _create_edge(self, n1:Node, n2:Node): edge = self._edge_ctor( - self._edges_ids, + -1, "", n1, n2, status = self._status ) - self._edges_ids += 1 return edge diff --git a/src/Model/Network/Node.py b/src/Model/Network/Node.py index 5281a236..5f37101f 100644 --- a/src/Model/Network/Node.py +++ b/src/Model/Network/Node.py @@ -3,33 +3,44 @@ from Model.Network.Point import Point class Node(object): - def __init__(self, id:str, name:str, + _id_cnt = 0 + + def __init__(self, id:int, name:str, x:float = 0.0, y:float = 0.0, status = None): super(Node, self).__init__() self._status = status - self.id = id - self._name = name + if id == -1: + type(self)._id_cnt += 1 + self.id = type(self)._id_cnt + else: + self.id = id + + if name == "": + self._name = f"Node {self.id}" + else: + self._name = name + self.pos = Point(x, y) - def __getitem__(self, name): + def __getitem__(self, key): ret = None - if name == "name": + if key == "name": ret = self._name - elif name == "id": + elif key == "id": ret = self.id - elif name == "pos": + elif key == "pos": ret = f"({self.pos.x},{self.pos.y})" return ret - def __setitem__(self, name, value): - if name == "name": + def __setitem__(self, key, value): + if key == "name": self._name = value - elif name == "id": + elif key == "id": self.id = value self._status.modified() diff --git a/src/Model/River.py b/src/Model/River.py index 63137639..09ce95c1 100644 --- a/src/Model/River.py +++ b/src/Model/River.py @@ -52,9 +52,26 @@ class RiverNode(Node, SQLSubModel): @classmethod def _sql_load(cls, execute, data = None): - return True + nodes = [] + + table = execute("SELECT id, name, x, y FROM river_node") + for row in table: + # Update id counter + cls._id_cnt = max(cls._id_cnt, row[0]) + # Create new node + nodes.append(cls(*row, **data)) + + return nodes def _sql_save(self, execute, data = None): + sql = ( + "INSERT OR REPLACE INTO river_node(id, name, x, y) VALUES (" + + f"{self.id}, '{self._sql_format(self.name)}', " + + f"{self.x}, {self.y}"+ + ")" + ) + execute(sql) + return True @property @@ -91,6 +108,7 @@ class RiverReach(Edge, SQLSubModel): CREATE TABLE river_reach( id INTEGER NOT NULL PRIMARY KEY, name TEXT NOT NULL, + enable BOOLEAN NOT NULL, node1 INTEGER, node2 INTEGER, FOREIGN KEY(node1) REFERENCES river_node(id), @@ -107,9 +125,38 @@ class RiverReach(Edge, SQLSubModel): @classmethod def _sql_load(cls, execute, data = None): - return None + reachs = [] + + table = execute("SELECT id, name, enable, node1, node2 FROM river_reach") + for row in table: + # Update id counter + cls._id_cnt = max(cls._id_cnt, row[0]) + # Create new reach + id = row[0] + name = row[1] + enable = (row[2] == 1) + # Get nodes corresponding to db foreign key id + node1 = next(filter(lambda n: n.id == row[3], data["nodes"])) + node2 = next(filter(lambda n: n.id == row[4], data["nodes"])) + + new = cls(id, name, node1, node2, status = data["status"]) + new.enable(enable = enable) + reachs.append(new) + + return reachs def _sql_save(self, execute, data = None): + sql = ( + "INSERT OR REPLACE INTO " + + "river_reach(id, name, enable, node1, node2) "+ + "VALUES (" + + f"{self.id}, '{self._sql_format(self._name)}', " + + f"{self._sql_format(self.is_enable())}," + f"{self.node1.id}, {self.node2.id}"+ + ")" + ) + execute(sql) + return True @property @@ -152,13 +199,27 @@ class River(Graph, SQLSubModel): @classmethod def _sql_update(cls, execute, version): + cls._update_submodel(execute, version) return True @classmethod def _sql_load(cls, execute, data = None): - return None + new = cls(data["status"]) + new._nodes = RiverNode._sql_load( + execute, + data + ) + data["nodes"] = new.nodes() + new._edges = RiverReach._sql_load( + execute, + data + ) + + return new def _sql_save(self, execute, data = None): + objs = self._nodes + self._edges + self._save_submodel(execute, objs, data) return True @property @@ -169,7 +230,6 @@ class River(Graph, SQLSubModel): def sections(self): return self._sections - @property def boundary_condition(self): return self._boundary_condition diff --git a/src/Model/Study.py b/src/Model/Study.py index 238da3b4..50edc5ce 100644 --- a/src/Model/Study.py +++ b/src/Model/Study.py @@ -38,6 +38,8 @@ class Study(SQLModel): if init_new: # Study data self._river = River(status = self.status) + else: + self._init_db_file(filename, is_new = False) @classmethod def checkers(cls): @@ -141,22 +143,36 @@ class Study(SQLModel): def _update(self): version = self.execute(f"SELECT value FROM info WHERE key='version'") - print(f"{version} == {self._version}") - if version == self._version: + print(f"{version[0]} == {self._version}") + if version[0] == self._version: return True - print("TODO: update") + if self._update_submodel(version): + self.execute(f"UPDATE info SET value='{self._version}' WHERE key='version'") + return True + + print("TODO: update failed") raise NotImplementedMethodeError(self, self._update) @classmethod def _load(cls, filename): - new = cls(init_new = False) + new = cls(init_new = False, filename = filename) + + # TODO: Load metadata + print("TODO: Load metadata") # Load river data - self._river = River.load() + new._river = River._sql_load( + lambda sql: new.execute( + sql, + fetch_one = False, + commit = True + ), + data = {"status": new.status} + ) return new def _save(self): - + self._save_submodel([self._river]) self.commit() diff --git a/src/View/MainWindow.py b/src/View/MainWindow.py index 340e3cf0..49387884 100644 --- a/src/View/MainWindow.py +++ b/src/View/MainWindow.py @@ -216,13 +216,13 @@ class ApplicationWindow(QMainWindow, ListedSubWindow, WindowToolKit): if self.model.filename is None or self.model.filename == "": file_name, _ = QFileDialog.getSaveFileName( self, "Save File", - "", "Pamhyr(*.pkl)" + "", "Pamhyr(*.pamhyr)" ) - if file_name[-4:] == ".pkl": + if file_name[-4:] == ".pamhyr": self.model.filename = file_name else: - self.model.filename = file_name + ".pkl" + self.model.filename = file_name + ".pamhyr" self.model.save() @@ -237,13 +237,13 @@ class ApplicationWindow(QMainWindow, ListedSubWindow, WindowToolKit): """ file_name, _ = QFileDialog.getSaveFileName( self, "Save File", - "", "Pamhyr(*.pkl)" + "", "Pamhyr(*.pamhyr)" ) - if file_name[-4:] == ".pkl": + if file_name[-4:] == ".pamhyr": self.model.filename = file_name else: - self.model.filename = file_name + ".pkl" + self.model.filename = file_name + ".pamhyr" self.model.save() -- GitLab