# SQL.py -- Pamhyr # Copyright (C) 2023-2024 INRAE # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see <https://www.gnu.org/licenses/>. # -*- coding: utf-8 -*- import os import logging import sqlite3 from pathlib import Path from tools import timer logger = logging.getLogger() class SQL(object): def _init_db_file(self, db): exists = Path(db).exists() os.makedirs( os.path.dirname(db), exist_ok=True ) self._db = sqlite3.connect(db) self._cur = self._db.cursor() if not exists: self._create() # Create db self._save() # Save else: self._update() # Update db scheme if necessary self._load() # Load data def __init__(self, filename=None): self._db = None if filename is not None: self._init_db_file(filename) def commit(self): logger.debug("SQL - commit") self._db.commit() def _close(self): self.commit() self._db.close() def _fetch_string(self, s): return s.replace("'", "'") def _fetch_tuple(self, tup): res = [] for v in tup: if type(v) is str: v = self._fetch_string(v) res.append(v) return res def _fetch_list(self, lst): res = [] for v in lst: if type(v) is str: v = self._fetch_string(v) elif type(v) is tuple: v = self._fetch_tuple(v) res.append(v) return res def _fetch(self, res, one): if one: value = res.fetchone() else: value = res.fetchall() res = value if type(value) is list: res = self._fetch_list(value) elif type(value) is tuple: res = self._fetch_tuple(value) return res def _db_format(self, value): # Replace ''' by ''' to preserve SQL injection if type(value) is str: value = value.replace("'", "'") return value @timer def execute(self, cmd, fetch_one=True, commit=False): logger.debug(f"SQL - {cmd}") value = None try: res = self._cur.execute(cmd) if commit: self._db.commit() value = self._fetch(res, fetch_one) except Exception as e: logger_exception(e) finally: return value def _create(self): logger.warning("TODO: Create") def _update(self): logger.warning("TODO: Update") def _save(self): logger.warning("TODO: Save") def _load(self): logger.warning("TODO: LOAD")