From 4978f20e29befeb8545f4bf0c7a33763d318a715 Mon Sep 17 00:00:00 2001
From: Pierre-Antoine Rouby <pierre-antoine.rouby@inrae.fr>
Date: Thu, 29 Jun 2023 11:33:59 +0200
Subject: [PATCH] Network: Export model network to sqlite db.

---
 src/Model/DB.py            | 71 ++++++++++++++++++++++----------------
 src/Model/Network/Edge.py  | 11 ++++--
 src/Model/Network/Graph.py | 11 ++----
 src/Model/Network/Node.py  | 31 +++++++++++------
 src/Model/River.py         | 68 +++++++++++++++++++++++++++++++++---
 src/Model/Study.py         | 28 +++++++++++----
 src/View/MainWindow.py     | 12 +++----
 7 files changed, 167 insertions(+), 65 deletions(-)

diff --git a/src/Model/DB.py b/src/Model/DB.py
index 5206555a..013f688e 100644
--- a/src/Model/DB.py
+++ b/src/Model/DB.py
@@ -22,54 +22,65 @@ class SQLModel(SQL):
         self._cur = self._db.cursor()
 
         if is_new:
+            print("CREATE")
             self._create()      # Create db
-            self._save()        # Save
+            # self._save()        # Save
         else:
+            print("UPDATE")
             self._update()      # Update db scheme if necessary
-            self._load()        # Load data
+            # self._load()        # Load data
+
+
+    def __init__(self, filename = None):
+        self._db = None
 
     def _create_submodel(self):
+        fn = lambda sql: self.execute(
+            sql,
+            fetch_one = False,
+            commit = True
+        )
+
         for cls in self._sub_classes:
-            requests = cls._sql_create(
-                lambda sql: self.execute(
-                    sql,
-                    fetch_one = False,
-                    commit = True
-                )
-            )
+            requests = cls._sql_create(fn)
 
     def _create(self):
         raise NotImplementedMethodeError(self, self._create)
 
     def _update_submodel(self, version):
+        fn = lambda sql: self.execute(
+            sql,
+            fetch_one = False,
+            commit = True
+        )
+
+        ok = True
         for cls in self._sub_classes:
-            requests = cls._sql_update(
-                lambda sql: self.execute(
-                    sql,
-                    fetch_one = False,
-                    commit = True
-                ),
-                version
-            )
+            ok &= cls._sql_update(fn, version)
+
+        return ok
 
     def _update(self):
         raise NotImplementedMethodeError(self, self._update)
 
-    def _save_submodel(self, objs):
+    def _save_submodel(self, objs, data = None):
+        fn = lambda sql: self.execute(
+            sql,
+            fetch_one = False,
+            commit = True
+        )
+
+        ok = True
         for obj in objs:
-            requests = obj._sql_save(
-                lambda sql: self.execute(
-                    sql,
-                    fetch_one = False,
-                    commit = True
-                )
-            )
+            ok &= obj._sql_save(fn)
+
+        return ok
 
     def _save(self):
         raise NotImplementedMethodeError(self, self._save)
 
     @classmethod
-    def _load(cls, filename):
+    def _load(cls, filename = None):
         raise NotImplementedMethodeError(cls, cls._load)
 
 # Sub model class
@@ -80,6 +91,8 @@ class SQLSubModel(object):
         # Replace ''' by '&#39;' to preserve SQL injection
         if type(value) == str:
             value = value.replace("'", "&#39;")
+        elif type(value) == bool:
+            value = 'TRUE' if value else 'FALSE'
         return value
 
     @classmethod
@@ -130,9 +143,9 @@ class SQLSubModel(object):
         """
         raise NotImplementedMethodeError(cls, cls._sql_load)
 
-    def _save_submodel(self, execute):
-        for sc in self._sub_classes:
-            sc._sql_update(execute)
+    def _save_submodel(self, execute, objs, data = None):
+        for o in objs:
+            o._sql_save(execute, data = data)
 
     def _sql_save(self, execute, data = None):
         """Save class data to data base
diff --git a/src/Model/Network/Edge.py b/src/Model/Network/Edge.py
index b0918a16..f1ced151 100644
--- a/src/Model/Network/Edge.py
+++ b/src/Model/Network/Edge.py
@@ -3,7 +3,9 @@
 from Model.Network.Node import Node
 
 class Edge(object):
-    def __init__(self, id:str, name:str,
+    _id_cnt = 0
+
+    def __init__(self, id:int, name:str,
                  node1:Node = None,
                  node2:Node = None,
                  status = None):
@@ -11,7 +13,12 @@ class Edge(object):
 
         self._status = status
 
-        self.id = id
+        if id == -1:
+            type(self)._id_cnt += 1
+            self.id = type(self)._id_cnt
+        else:
+            self.id = id
+
         self._name = name
 
         self.node1 = node1
diff --git a/src/Model/Network/Graph.py b/src/Model/Network/Graph.py
index eb91614a..0f432495 100644
--- a/src/Model/Network/Graph.py
+++ b/src/Model/Network/Graph.py
@@ -14,9 +14,6 @@ class Graph(object):
         self._node_ctor = Node
         self._edge_ctor = Edge
 
-        self._nodes_ids = 0
-        self._edges_ids = 0
-
         self._nodes = []
         self._edges = []
 
@@ -83,12 +80,11 @@ class Graph(object):
 
     def _create_node(self, x:float, y:float):
         node = self._node_ctor(
-            self._nodes_ids,
-            f"Node {self._nodes_ids}",
+            -1,
+            "",
             x = x, y = y,
             status = self._status
         )
-        self._nodes_ids += 1
         return node
 
     def _add_node(self, node):
@@ -118,11 +114,10 @@ class Graph(object):
 
     def _create_edge(self, n1:Node, n2:Node):
         edge = self._edge_ctor(
-            self._edges_ids,
+            -1,
             "", n1, n2,
             status = self._status
         )
-        self._edges_ids += 1
         return edge
 
 
diff --git a/src/Model/Network/Node.py b/src/Model/Network/Node.py
index 5281a236..5f37101f 100644
--- a/src/Model/Network/Node.py
+++ b/src/Model/Network/Node.py
@@ -3,33 +3,44 @@
 from Model.Network.Point import Point
 
 class Node(object):
-    def __init__(self, id:str, name:str,
+    _id_cnt = 0
+
+    def __init__(self, id:int, name:str,
                  x:float = 0.0, y:float = 0.0,
                  status = None):
         super(Node, self).__init__()
 
         self._status = status
 
-        self.id = id
-        self._name = name
+        if id == -1:
+            type(self)._id_cnt += 1
+            self.id = type(self)._id_cnt
+        else:
+            self.id = id
+
+        if name == "":
+            self._name = f"Node {self.id}"
+        else:
+            self._name = name
+
         self.pos = Point(x, y)
 
-    def __getitem__(self, name):
+    def __getitem__(self, key):
         ret = None
 
-        if name == "name":
+        if key == "name":
             ret = self._name
-        elif name == "id":
+        elif key == "id":
             ret = self.id
-        elif name == "pos":
+        elif key == "pos":
             ret = f"({self.pos.x},{self.pos.y})"
 
         return ret
 
-    def __setitem__(self, name, value):
-        if name == "name":
+    def __setitem__(self, key, value):
+        if key == "name":
             self._name = value
-        elif name == "id":
+        elif key == "id":
             self.id = value
 
         self._status.modified()
diff --git a/src/Model/River.py b/src/Model/River.py
index 63137639..09ce95c1 100644
--- a/src/Model/River.py
+++ b/src/Model/River.py
@@ -52,9 +52,26 @@ class RiverNode(Node, SQLSubModel):
 
     @classmethod
     def _sql_load(cls, execute, data = None):
-        return True
+        nodes = []
+
+        table = execute("SELECT id, name, x, y FROM river_node")
+        for row in table:
+            # Update id counter
+            cls._id_cnt = max(cls._id_cnt, row[0])
+            # Create new node
+            nodes.append(cls(*row, **data))
+
+        return nodes
 
     def _sql_save(self, execute, data = None):
+        sql = (
+            "INSERT OR REPLACE INTO river_node(id, name, x, y) VALUES (" +
+            f"{self.id}, '{self._sql_format(self.name)}', " +
+            f"{self.x}, {self.y}"+
+            ")"
+        )
+        execute(sql)
+
         return True
 
     @property
@@ -91,6 +108,7 @@ class RiverReach(Edge, SQLSubModel):
           CREATE TABLE river_reach(
             id INTEGER NOT NULL PRIMARY KEY,
             name TEXT NOT NULL,
+            enable BOOLEAN NOT NULL,
             node1 INTEGER,
             node2 INTEGER,
             FOREIGN KEY(node1) REFERENCES river_node(id),
@@ -107,9 +125,38 @@ class RiverReach(Edge, SQLSubModel):
 
     @classmethod
     def _sql_load(cls, execute, data = None):
-        return None
+        reachs = []
+
+        table = execute("SELECT id, name, enable, node1, node2 FROM river_reach")
+        for row in table:
+            # Update id counter
+            cls._id_cnt = max(cls._id_cnt, row[0])
+            # Create new reach
+            id = row[0]
+            name = row[1]
+            enable = (row[2] == 1)
+            # Get nodes corresponding to db foreign key id
+            node1 = next(filter(lambda n: n.id == row[3], data["nodes"]))
+            node2 = next(filter(lambda n: n.id == row[4], data["nodes"]))
+
+            new = cls(id, name, node1, node2, status = data["status"])
+            new.enable(enable = enable)
+            reachs.append(new)
+
+        return reachs
 
     def _sql_save(self, execute, data = None):
+        sql = (
+            "INSERT OR REPLACE INTO " +
+            "river_reach(id, name, enable, node1, node2) "+
+            "VALUES (" +
+            f"{self.id}, '{self._sql_format(self._name)}', " +
+            f"{self._sql_format(self.is_enable())},"
+            f"{self.node1.id}, {self.node2.id}"+
+            ")"
+        )
+        execute(sql)
+
         return True
 
     @property
@@ -152,13 +199,27 @@ class River(Graph, SQLSubModel):
 
     @classmethod
     def _sql_update(cls, execute, version):
+        cls._update_submodel(execute, version)
         return True
 
     @classmethod
     def _sql_load(cls, execute, data = None):
-        return None
+        new = cls(data["status"])
+        new._nodes = RiverNode._sql_load(
+            execute,
+            data
+        )
+        data["nodes"] = new.nodes()
+        new._edges = RiverReach._sql_load(
+            execute,
+            data
+        )
+
+        return new
 
     def _sql_save(self, execute, data = None):
+        objs = self._nodes + self._edges
+        self._save_submodel(execute, objs, data)
         return True
 
     @property
@@ -169,7 +230,6 @@ class River(Graph, SQLSubModel):
     def sections(self):
         return self._sections
 
-
     @property
     def boundary_condition(self):
         return self._boundary_condition
diff --git a/src/Model/Study.py b/src/Model/Study.py
index 238da3b4..50edc5ce 100644
--- a/src/Model/Study.py
+++ b/src/Model/Study.py
@@ -38,6 +38,8 @@ class Study(SQLModel):
         if init_new:
             # Study data
             self._river = River(status = self.status)
+        else:
+            self._init_db_file(filename, is_new = False)
 
     @classmethod
     def checkers(cls):
@@ -141,22 +143,36 @@ class Study(SQLModel):
     def _update(self):
         version = self.execute(f"SELECT value FROM info WHERE key='version'")
 
-        print(f"{version} == {self._version}")
-        if version == self._version:
+        print(f"{version[0]} == {self._version}")
+        if version[0] == self._version:
             return True
 
-        print("TODO: update")
+        if self._update_submodel(version):
+            self.execute(f"UPDATE info SET value='{self._version}' WHERE key='version'")
+            return True
+
+        print("TODO: update failed")
         raise NotImplementedMethodeError(self, self._update)
 
     @classmethod
     def _load(cls, filename):
-        new = cls(init_new = False)
+        new = cls(init_new = False, filename = filename)
+
+        # TODO: Load metadata
+        print("TODO: Load metadata")
 
         # Load river data
-        self._river = River.load()
+        new._river = River._sql_load(
+            lambda sql: new.execute(
+                sql,
+                fetch_one = False,
+                commit = True
+            ),
+            data = {"status": new.status}
+        )
 
         return new
 
     def _save(self):
-
+        self._save_submodel([self._river])
         self.commit()
diff --git a/src/View/MainWindow.py b/src/View/MainWindow.py
index 340e3cf0..49387884 100644
--- a/src/View/MainWindow.py
+++ b/src/View/MainWindow.py
@@ -216,13 +216,13 @@ class ApplicationWindow(QMainWindow, ListedSubWindow, WindowToolKit):
         if self.model.filename is None or self.model.filename == "":
             file_name, _ = QFileDialog.getSaveFileName(
                 self, "Save File",
-                "", "Pamhyr(*.pkl)"
+                "", "Pamhyr(*.pamhyr)"
             )
 
-            if file_name[-4:] == ".pkl":
+            if file_name[-4:] == ".pamhyr":
                 self.model.filename = file_name
             else:
-                self.model.filename = file_name + ".pkl"
+                self.model.filename = file_name + ".pamhyr"
 
         self.model.save()
 
@@ -237,13 +237,13 @@ class ApplicationWindow(QMainWindow, ListedSubWindow, WindowToolKit):
         """
         file_name, _ = QFileDialog.getSaveFileName(
             self, "Save File",
-            "", "Pamhyr(*.pkl)"
+            "", "Pamhyr(*.pamhyr)"
         )
 
-        if file_name[-4:] == ".pkl":
+        if file_name[-4:] == ".pamhyr":
             self.model.filename = file_name
         else:
-            self.model.filename = file_name + ".pkl"
+            self.model.filename = file_name + ".pamhyr"
 
         self.model.save()
 
-- 
GitLab