Commit 4978f20e authored by Pierre-Antoine Rouby's avatar Pierre-Antoine Rouby
Browse files

Network: Export model network to sqlite db.

Showing with 167 additions and 65 deletions
+167 -65
......@@ -22,54 +22,65 @@ class SQLModel(SQL):
self._cur = self._db.cursor()
if is_new:
print("CREATE")
self._create() # Create db
self._save() # Save
# self._save() # Save
else:
print("UPDATE")
self._update() # Update db scheme if necessary
self._load() # Load data
# self._load() # Load data
def __init__(self, filename = None):
self._db = None
def _create_submodel(self):
fn = lambda sql: self.execute(
sql,
fetch_one = False,
commit = True
)
for cls in self._sub_classes:
requests = cls._sql_create(
lambda sql: self.execute(
sql,
fetch_one = False,
commit = True
)
)
requests = cls._sql_create(fn)
def _create(self):
raise NotImplementedMethodeError(self, self._create)
def _update_submodel(self, version):
fn = lambda sql: self.execute(
sql,
fetch_one = False,
commit = True
)
ok = True
for cls in self._sub_classes:
requests = cls._sql_update(
lambda sql: self.execute(
sql,
fetch_one = False,
commit = True
),
version
)
ok &= cls._sql_update(fn, version)
return ok
def _update(self):
raise NotImplementedMethodeError(self, self._update)
def _save_submodel(self, objs):
def _save_submodel(self, objs, data = None):
fn = lambda sql: self.execute(
sql,
fetch_one = False,
commit = True
)
ok = True
for obj in objs:
requests = obj._sql_save(
lambda sql: self.execute(
sql,
fetch_one = False,
commit = True
)
)
ok &= obj._sql_save(fn)
return ok
def _save(self):
raise NotImplementedMethodeError(self, self._save)
@classmethod
def _load(cls, filename):
def _load(cls, filename = None):
raise NotImplementedMethodeError(cls, cls._load)
# Sub model class
......@@ -80,6 +91,8 @@ class SQLSubModel(object):
# Replace ''' by ''' to preserve SQL injection
if type(value) == str:
value = value.replace("'", "'")
elif type(value) == bool:
value = 'TRUE' if value else 'FALSE'
return value
@classmethod
......@@ -130,9 +143,9 @@ class SQLSubModel(object):
"""
raise NotImplementedMethodeError(cls, cls._sql_load)
def _save_submodel(self, execute):
for sc in self._sub_classes:
sc._sql_update(execute)
def _save_submodel(self, execute, objs, data = None):
for o in objs:
o._sql_save(execute, data = data)
def _sql_save(self, execute, data = None):
"""Save class data to data base
......
......@@ -3,7 +3,9 @@
from Model.Network.Node import Node
class Edge(object):
def __init__(self, id:str, name:str,
_id_cnt = 0
def __init__(self, id:int, name:str,
node1:Node = None,
node2:Node = None,
status = None):
......@@ -11,7 +13,12 @@ class Edge(object):
self._status = status
self.id = id
if id == -1:
type(self)._id_cnt += 1
self.id = type(self)._id_cnt
else:
self.id = id
self._name = name
self.node1 = node1
......
......@@ -14,9 +14,6 @@ class Graph(object):
self._node_ctor = Node
self._edge_ctor = Edge
self._nodes_ids = 0
self._edges_ids = 0
self._nodes = []
self._edges = []
......@@ -83,12 +80,11 @@ class Graph(object):
def _create_node(self, x:float, y:float):
node = self._node_ctor(
self._nodes_ids,
f"Node {self._nodes_ids}",
-1,
"",
x = x, y = y,
status = self._status
)
self._nodes_ids += 1
return node
def _add_node(self, node):
......@@ -118,11 +114,10 @@ class Graph(object):
def _create_edge(self, n1:Node, n2:Node):
edge = self._edge_ctor(
self._edges_ids,
-1,
"", n1, n2,
status = self._status
)
self._edges_ids += 1
return edge
......
......@@ -3,33 +3,44 @@
from Model.Network.Point import Point
class Node(object):
def __init__(self, id:str, name:str,
_id_cnt = 0
def __init__(self, id:int, name:str,
x:float = 0.0, y:float = 0.0,
status = None):
super(Node, self).__init__()
self._status = status
self.id = id
self._name = name
if id == -1:
type(self)._id_cnt += 1
self.id = type(self)._id_cnt
else:
self.id = id
if name == "":
self._name = f"Node {self.id}"
else:
self._name = name
self.pos = Point(x, y)
def __getitem__(self, name):
def __getitem__(self, key):
ret = None
if name == "name":
if key == "name":
ret = self._name
elif name == "id":
elif key == "id":
ret = self.id
elif name == "pos":
elif key == "pos":
ret = f"({self.pos.x},{self.pos.y})"
return ret
def __setitem__(self, name, value):
if name == "name":
def __setitem__(self, key, value):
if key == "name":
self._name = value
elif name == "id":
elif key == "id":
self.id = value
self._status.modified()
......
......@@ -52,9 +52,26 @@ class RiverNode(Node, SQLSubModel):
@classmethod
def _sql_load(cls, execute, data = None):
return True
nodes = []
table = execute("SELECT id, name, x, y FROM river_node")
for row in table:
# Update id counter
cls._id_cnt = max(cls._id_cnt, row[0])
# Create new node
nodes.append(cls(*row, **data))
return nodes
def _sql_save(self, execute, data = None):
sql = (
"INSERT OR REPLACE INTO river_node(id, name, x, y) VALUES (" +
f"{self.id}, '{self._sql_format(self.name)}', " +
f"{self.x}, {self.y}"+
")"
)
execute(sql)
return True
@property
......@@ -91,6 +108,7 @@ class RiverReach(Edge, SQLSubModel):
CREATE TABLE river_reach(
id INTEGER NOT NULL PRIMARY KEY,
name TEXT NOT NULL,
enable BOOLEAN NOT NULL,
node1 INTEGER,
node2 INTEGER,
FOREIGN KEY(node1) REFERENCES river_node(id),
......@@ -107,9 +125,38 @@ class RiverReach(Edge, SQLSubModel):
@classmethod
def _sql_load(cls, execute, data = None):
return None
reachs = []
table = execute("SELECT id, name, enable, node1, node2 FROM river_reach")
for row in table:
# Update id counter
cls._id_cnt = max(cls._id_cnt, row[0])
# Create new reach
id = row[0]
name = row[1]
enable = (row[2] == 1)
# Get nodes corresponding to db foreign key id
node1 = next(filter(lambda n: n.id == row[3], data["nodes"]))
node2 = next(filter(lambda n: n.id == row[4], data["nodes"]))
new = cls(id, name, node1, node2, status = data["status"])
new.enable(enable = enable)
reachs.append(new)
return reachs
def _sql_save(self, execute, data = None):
sql = (
"INSERT OR REPLACE INTO " +
"river_reach(id, name, enable, node1, node2) "+
"VALUES (" +
f"{self.id}, '{self._sql_format(self._name)}', " +
f"{self._sql_format(self.is_enable())},"
f"{self.node1.id}, {self.node2.id}"+
")"
)
execute(sql)
return True
@property
......@@ -152,13 +199,27 @@ class River(Graph, SQLSubModel):
@classmethod
def _sql_update(cls, execute, version):
cls._update_submodel(execute, version)
return True
@classmethod
def _sql_load(cls, execute, data = None):
return None
new = cls(data["status"])
new._nodes = RiverNode._sql_load(
execute,
data
)
data["nodes"] = new.nodes()
new._edges = RiverReach._sql_load(
execute,
data
)
return new
def _sql_save(self, execute, data = None):
objs = self._nodes + self._edges
self._save_submodel(execute, objs, data)
return True
@property
......@@ -169,7 +230,6 @@ class River(Graph, SQLSubModel):
def sections(self):
return self._sections
@property
def boundary_condition(self):
return self._boundary_condition
......
......@@ -38,6 +38,8 @@ class Study(SQLModel):
if init_new:
# Study data
self._river = River(status = self.status)
else:
self._init_db_file(filename, is_new = False)
@classmethod
def checkers(cls):
......@@ -141,22 +143,36 @@ class Study(SQLModel):
def _update(self):
version = self.execute(f"SELECT value FROM info WHERE key='version'")
print(f"{version} == {self._version}")
if version == self._version:
print(f"{version[0]} == {self._version}")
if version[0] == self._version:
return True
print("TODO: update")
if self._update_submodel(version):
self.execute(f"UPDATE info SET value='{self._version}' WHERE key='version'")
return True
print("TODO: update failed")
raise NotImplementedMethodeError(self, self._update)
@classmethod
def _load(cls, filename):
new = cls(init_new = False)
new = cls(init_new = False, filename = filename)
# TODO: Load metadata
print("TODO: Load metadata")
# Load river data
self._river = River.load()
new._river = River._sql_load(
lambda sql: new.execute(
sql,
fetch_one = False,
commit = True
),
data = {"status": new.status}
)
return new
def _save(self):
self._save_submodel([self._river])
self.commit()
......@@ -216,13 +216,13 @@ class ApplicationWindow(QMainWindow, ListedSubWindow, WindowToolKit):
if self.model.filename is None or self.model.filename == "":
file_name, _ = QFileDialog.getSaveFileName(
self, "Save File",
"", "Pamhyr(*.pkl)"
"", "Pamhyr(*.pamhyr)"
)
if file_name[-4:] == ".pkl":
if file_name[-4:] == ".pamhyr":
self.model.filename = file_name
else:
self.model.filename = file_name + ".pkl"
self.model.filename = file_name + ".pamhyr"
self.model.save()
......@@ -237,13 +237,13 @@ class ApplicationWindow(QMainWindow, ListedSubWindow, WindowToolKit):
"""
file_name, _ = QFileDialog.getSaveFileName(
self, "Save File",
"", "Pamhyr(*.pkl)"
"", "Pamhyr(*.pamhyr)"
)
if file_name[-4:] == ".pkl":
if file_name[-4:] == ".pamhyr":
self.model.filename = file_name
else:
self.model.filename = file_name + ".pkl"
self.model.filename = file_name + ".pamhyr"
self.model.save()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment