# SQL.py -- Pamhyr
# Copyright (C) 2023-2024  INRAE
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

# -*- coding: utf-8 -*-

import os
import logging
import sqlite3

from pathlib import Path

from tools import timer

logger = logging.getLogger()


class SQL(object):
    def _init_db_file(self, db):
        exists = Path(db).exists()

        os.makedirs(
            os.path.dirname(db),
            exist_ok=True
        )

        self._db = sqlite3.connect(db)
        self._cur = self._db.cursor()

        if not exists:
            self._create()      # Create db
            self._save()        # Save
        else:
            self._update()      # Update db scheme if necessary
            self._load()        # Load data

    def __init__(self, filename=None):
        self._db = None

        if filename is not None:
            self._init_db_file(filename)

    def commit(self):
        logger.debug("SQL - commit")
        self._db.commit()

    def _close(self):
        self.commit()
        self._db.close()

    def _fetch_string(self, s):
        return s.replace("&#39;", "'")

    def _fetch_tuple(self, tup):
        res = []
        for v in tup:
            if type(v) is str:
                v = self._fetch_string(v)
            res.append(v)

        return res

    def _fetch_list(self, lst):
        res = []
        for v in lst:
            if type(v) is str:
                v = self._fetch_string(v)
            elif type(v) is tuple:
                v = self._fetch_tuple(v)
            res.append(v)

        return res

    def _fetch(self, res, one):
        if one:
            value = res.fetchone()
        else:
            value = res.fetchall()
        res = value

        if type(value) is list:
            res = self._fetch_list(value)
        elif type(value) is tuple:
            res = self._fetch_tuple(value)

        return res

    def _db_format(self, value):
        # Replace ''' by '&#39;' to preserve SQL injection
        if type(value) is str:
            value = value.replace("'", "&#39;")
        return value

    @timer
    def execute(self, cmd, fetch_one=True, commit=False):
        logger.debug(f"SQL - {cmd}")

        value = None
        try:
            res = self._cur.execute(cmd)

            if commit:
                self._db.commit()

            value = self._fetch(res, fetch_one)
        except Exception as e:
            logger_exception(e)
        finally:
            return value

    def _create(self):
        logger.warning("TODO: Create")

    def _update(self):
        logger.warning("TODO: Update")

    def _save(self):
        logger.warning("TODO: Save")

    def _load(self):
        logger.warning("TODO: LOAD")