PamhyrDB.py 5.89 KiB
# PamhyrDB.py -- Pamhyr abstract model database classes
# Copyright (C) 2023  INRAE
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

# -*- coding: utf-8 -*-

import os
import sqlite3
import logging

from pathlib import Path
from functools import reduce

from tools import SQL
from Model.Except import NotImplementedMethodeError

logger = logging.getLogger()

# Top level model class


class SQLModel(SQL):
    _sub_classes = []

    def _init_db_file(self, db, is_new=True):
        exists = Path(db).exists()

        if exists and is_new:
            os.remove(db)

        self._db = sqlite3.connect(db)
        self._cur = self._db.cursor()

        if is_new:
            logger.info("Create database")
            self._create()      # Create db
            # self._save()        # Save
        else:
            self._update()      # Update db scheme if necessary
            # self._load()        # Load data

    def __init__(self, filename=None):
        self._db = None

    def _create_submodel(self):
        def fn(sql): return self.execute(
            sql,
            fetch_one=False,
            commit=False
        )

        for cls in self._sub_classes:
            requests = cls._db_create(fn)

        self.commit()
        return True

    def _create(self):
        raise NotImplementedMethodeError(self, self._create)

    def _update_submodel(self, version):
        def fn(sql): return self.execute(
            sql,
            fetch_one=False,
            commit=False
        )

        ok = True
        for cls in self._sub_classes:
            ok &= cls._db_update(fn, version)

        self.commit()
        return ok

    def _update(self):
        raise NotImplementedMethodeError(self, self._update)

    def _save_submodel(self, objs, data=None):
        progress = data if data is not None else lambda: None

        def fn(sql):
            res = self.execute(
                sql,
                fetch_one=False,
                commit=False
            )
            progress()
            return res

        ok = True
        for obj in objs:
            ok &= obj._db_save(fn)

        self.commit()
        return ok

    def _save(self, progress=None):
        raise NotImplementedMethodeError(self, self._save)

    def _count(self):
        raise NotImplementedMethodeError(self, self._count)

    def _save_count(self, objs, data=None):
        counter = {
            "insert": 0,
            "update": 0,
            "delete": 0,
            "other": 0,
        }

        def fn(sql):
            if "insert" in sql.lower():
                counter["insert"] = counter["insert"] + 1
            elif "update" in sql.lower():
                counter["update"] = counter["update"] + 1
            elif "delete" in sql.lower():
                counter["delete"] = counter["delete"] + 1
            else:
                counter["other"] = counter["other"] + 1
            return []

        ok = True
        for obj in objs:
            ok &= obj._db_save(fn)

        logger.debug(counter)

        return reduce(
            lambda acc, k: acc + counter[k],
            counter,
            0
        )

    @classmethod
    def _load(cls, filename=None):
        raise NotImplementedMethodeError(cls, cls._load)

# Sub model class


class SQLSubModel(object):
    _sub_classes = []

    def _db_format(self, value):
        # Replace ''' by '&#39;' to preserve SQL injection
        if type(value) is str:
            value = value.replace("'", "&#39;")
        elif type(value) is bool:
            value = 'TRUE' if value else 'FALSE'
        return value

    @classmethod
    def _create_submodel(cls, execute):
        for sc in cls._sub_classes:
            sc._db_create(execute)

    @classmethod
    def _db_create(cls, execute):
        """Create data base scheme

        Args:
            execute: Function to exec SQL resquest

        Returns:
            Return true, otherelse false if an issue appear
        """
        raise NotImplementedMethodeError(cls, cls._db_create)

    @classmethod
    def _update_submodel(cls, execute, version):
        for sc in cls._sub_classes:
            sc._db_update(execute, version)

    @classmethod
    def _db_update(cls, execute, version):
        """Update data base scheme

        Args:
            execute: Function to exec SQL resquest
            version: Current database version

        Returns:
            Return true, otherelse false if an issue appear
        """
        raise NotImplementedMethodeError(cls, cls._db_update)

    @classmethod
    def _db_load(cls, execute, data=None):
        """Load instance of this class from SQL data base

        Args:
            execute: Function to exec SQL request
            data: Optional data for the class constructor

        Returns:
            Return new instance of class
        """
        raise NotImplementedMethodeError(cls, cls._db_load)

    def _save_submodel(self, execute, objs, data=None):
        for o in objs:
            o._db_save(execute, data=data)

    def _db_save(self, execute, data=None):
        """Save class data to data base

        Args:
            execute: Function to exec SQL resquest
            data: Optional additional information for save

        Returns:
            Return true, otherelse false if an issue appear during
            save
        """
        raise NotImplementedMethodeError(self, self._db_save)