diff --git a/src/config.py b/src/config.py index 599289229f485a51680e62349710f4bf8e57ab00..61c52272d0d30a773638fe4847b3eade7fc082ed 100644 --- a/src/config.py +++ b/src/config.py @@ -61,14 +61,17 @@ class Config(SQL): def _update(self): version = self.execute(f"SELECT value FROM info WHERE key='version'") - if version.fetchone()[0] != self._version: + if version != self._version: print("update") def _load_solver(self): self._solvers = [] - solvers = self.execute("SELECT * FROM solver") - for solver in solvers.fetchall(): + solvers = self.execute( + "SELECT * FROM solver", + fetch_one = False + ) + for solver in solvers: solver_type = solver[0] ctor = solver_type_list[solver_type] @@ -87,8 +90,11 @@ class Config(SQL): self.stricklers = StricklersList() id = 0 - stricklers = self.execute("SELECT * FROM stricklers") - for strickler in stricklers.fetchall(): + stricklers = self.execute( + "SELECT * FROM stricklers", + fetch_one = False + ) + for strickler in stricklers: new = Stricklers() new._name = strickler[0] new._comment = strickler[1] @@ -101,31 +107,31 @@ class Config(SQL): def _load(self): # Meshing tool v = self.execute("SELECT value FROM data WHERE key='meshing_tool'") - self.meshing_tool = v.fetchone()[0] + self.meshing_tool = v[0] # Const v = self.execute("SELECT value FROM data WHERE key='segment'") - self.segment = int(v.fetchone()[0]) + self.segment = int(v[0]) v = self.execute("SELECT value FROM data WHERE key='max_listing'") - self.max_listing = int(v.fetchone()[0]) + self.max_listing = int(v[0]) # Backup v = self.execute("SELECT value FROM data WHERE key='backup_enable'") - self.backup_enable = v.fetchone()[0] == "True" + self.backup_enable = v[0] == "True" v = self.execute("SELECT value FROM data WHERE key='backup_path'") - self.backup_path = v.fetchone()[0] + self.backup_path = v[0] v = self.execute("SELECT value FROM data WHERE key='backup_frequence'") - self.backup_frequence = v.fetchone()[0] + self.backup_frequence = v[0] v = self.execute("SELECT value FROM data WHERE key='backup_max'") - self.backup_max = int(v.fetchone()[0]) + self.backup_max = int(v[0]) # Editor v = self.execute("SELECT value FROM data WHERE key='editor'") - self.editor = v.fetchone()[0] + self.editor = v[0] # Languages v = self.execute("SELECT value FROM data WHERE key='lang'") - self.lang = v.fetchone()[0] + self.lang = v[0] self._load_solver() self._load_stricklers() @@ -137,18 +143,17 @@ class Config(SQL): self.execute(f""" INSERT INTO solver VALUES ( '{solver._type}', - '{solver._name}', - '{solver._description}', - '{solver._path_input}', - '{solver._path_solver}', - '{solver._path_output}', - '{solver._cmd_input}', - '{solver._cmd_solver}', - '{solver._cmd_output}' + '{self._sql_format(solver._name)}', + '{self._sql_format(solver._description)}', + '{self._sql_format(solver._path_input)}', + '{self._sql_format(solver._path_solver)}', + '{self._sql_format(solver._path_output)}', + '{self._sql_format(solver._cmd_input)}', + '{self._sql_format(solver._cmd_solver)}', + '{self._sql_format(solver._cmd_output)}' ) - """) - - self.commit() + """, + commit = True) def _save_stricklers(self): self.execute(f"DELETE FROM stricklers") @@ -156,8 +161,8 @@ class Config(SQL): for stricklers in self.stricklers.stricklers: self.execute(f""" INSERT INTO stricklers VALUES ( - '{stricklers._name}', - '{stricklers._comment}', + '{self._sql_format(stricklers._name)}', + '{self._sql_format(stricklers._comment)}', '{stricklers._minor}', '{stricklers._medium}' ) @@ -179,8 +184,15 @@ class Config(SQL): } for key in data: - self.execute(f"INSERT OR IGNORE INTO data VALUES ('{key}', '{data[key]}')") - self.execute(f"UPDATE data SET value='{data[key]}' WHERE key='{key}'") + self.execute( + f"INSERT OR IGNORE INTO data VALUES " + + f" ('{key}', '{self._sql_format(data[key])}')" + ) + self.execute( + f"UPDATE data SET " + + f"value='{self._sql_format(data[key])}' " + + f"WHERE key='{key}'" + ) self.commit() self._save_solver() diff --git a/src/tools.py b/src/tools.py index 83a147aaf1087a79f3ec8fd9d6d9954ef0996254..a3f20d049e382dd132bc2667a2dd4f22c72031ab 100644 --- a/src/tools.py +++ b/src/tools.py @@ -181,10 +181,56 @@ class SQL(object): def commit(self): self._db.commit() - def execute(self, cmd, commit = False): - value = self._cur.execute(cmd) + def _fetch_string(self, s): + return s.replace("'", "'") + + def _fetch_tuple(self, tup): + res = [] + for v in tup: + if type(v) == str: + v = self._fetch_string(v) + res.append(v) + + return res + + def _fetch_list(self, lst): + res = [] + for v in lst: + if type(v) == str: + v = self._fetch_string(v) + elif type(v) == tuple: + v = self._fetch_tuple(v) + res.append(v) + + return res + + def _fetch(self, res, one): + if one: + value = res.fetchone() + else: + value = res.fetchall() + res = value + + if type(value) == list: + res = self._fetch_list(value) + elif type(value) == tuple: + res = self._fetch_tuple(value) + + return res + + def _sql_format(self, value): + # Replace ''' by ''' to preserve SQL injection + if type(value) == str: + value = value.replace("'", "'") + return value + + def execute(self, cmd, fetch_one = True, commit = False): + res = self._cur.execute(cmd) + if commit: self._db.commit() + + value = self._fetch(res, fetch_one) return value def _create(self):