# InitialConditions.py -- Pamhyr
# Copyright (C) 2023-2024  INRAE
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

# -*- coding: utf-8 -*-

import logging

from copy import copy, deepcopy
from tools import trace, timer
from functools import reduce

from Model.Tools.PamhyrDB import SQLSubModel

logger = logging.getLogger()


class Data(SQLSubModel):
    def __init__(self, name: str = "",
                 comment: str = "", reach=None,
                 rk: float = 0.0, discharge: float = 0.0,
                 height: float = 0.0,
                 status=None):
        super(Data, self).__init__()

        self._status = status

        self._reach = reach

        self._name = name
        self._comment = comment

        self._rk = rk
        self._discharge = discharge
        self._speed = 0.0
        self._elevation = 0.0
        self._height = height

        if self._rk != 0.0:
            self._update_from_rk()
        if self._height != 0.0:
            self._update_from_height()
        if self._discharge != 0.0:
            self._update_from_discharge()

    @classmethod
    def _db_create(cls, execute):
        execute("""
          CREATE TABLE initial_conditions(
            id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
            ind INTEGER NOT NULL,
            name TEXT NOT NULL,
            comment TEXT NOT NULL,
            reach INTEGER,
            rk REAL NOT NULL,
            discharge REAL NOT NULL,
            height REAL NOT NULL,
            FOREIGN KEY(reach) REFERENCES river_reach(id)
          )
        """)

        return cls._create_submodel(execute)

    @classmethod
    def _db_update(cls, execute, version):
        major, minor, release = version.strip().split(".")
        if major == minor == "0":
            if int(release) < 11:
                execute(
                    "ALTER TABLE initial_conditions RENAME COLUMN kp TO rk"
                )

        return cls._update_submodel(execute, version)

    @classmethod
    def _db_load(cls, execute, data=None):
        id = data["reach"].id
        table = execute(
            "SELECT ind, name, comment, rk, discharge, height " +
            "FROM initial_conditions " +
            f"WHERE reach = {id}"
        )

        new = []

        for _ in table:
            new.append(None)

        for row in table:
            ind = row[0]
            name = row[1]
            comment = row[2]
            rk = row[3]
            discharge = row[4]
            height = row[5]

            d = cls(
                reach=data["reach"],
                status=data["status"],
                name=name,
                comment=comment,
                rk=rk,
                discharge=discharge,
                height=height,
            )

            new[ind] = d

        return new

    def _db_save(self, execute, data=None):
        ind = data["ind"]

        execute(
            "INSERT INTO " +
            "initial_conditions(ind, name, comment, rk, " +
            "discharge, height, reach) " +
            "VALUES (" +
            f"{ind}, '{self._db_format(self.name)}', " +
            f"'{self._db_format(self._comment)}', " +
            f"{self._rk}, {self._discharge}, {self._height}, " +
            f"{self._reach.id}" +
            ")"
        )

        return True

    def copy(self):
        new = Data(
            name=self.name,
            comment=self._comment,
            rk=self._rk,
            discharge=self._discharge,
            height=self._height,
            reach=self._reach,
            status=self._status,
        )
        return new

    @property
    def name(self):
        return self._name

    def __getitem__(self, key):
        val = None

        if key == "name":
            val = self._name
        elif key == "comment":
            val = self._comment
        elif key == "rk":
            val = self._rk
        elif key == "speed":
            val = self._speed
        elif key == "discharge":
            val = self._discharge
        elif key == "elevation":
            val = self._elevation
        elif key == "height":
            val = self._height

        return val

    def _update_get_min(self):
        profile = self._reach.reach.get_profiles_from_rk(self._rk)
        if len(profile) > 0:
            min = profile[0].z_min()
        else:
            min = 0.0

        return min

    def _update_from_rk(self):
        min = self._update_get_min()
        self._elevation = min + self._height

    def _update_from_elevation(self):
        min = self._update_get_min()
        self._height = self._elevation - min

    def _update_from_height(self):
        min = self._update_get_min()
        self._elevation = self._height + min

    def _update_from_discharge(self):
        min = self._update_get_min()
        # print("TODO")

    def __setitem__(self, key, value):
        if key == "name":
            self._name = str(value)
        elif key == "comment":
            self._comment = str(value)
        elif key == "rk":
            self._rk = float(value)
            self._update_from_rk()
        elif key == "speed":
            # Not supposed to be modified
            self._speed = float(value)
        elif key == "discharge":
            self._discharge = float(value)
            self._update_from_discharge()
        elif key == "elevation":
            self._elevation = float(value)
            self._update_from_elevation()
        elif key == "height":
            self._height = float(value)
            self._update_from_height()

        self._status.modified()


class InitialConditions(SQLSubModel):
    _sub_classes = [
        Data
    ]

    def __init__(self, reach=None, status=None):
        super(InitialConditions, self).__init__()

        self._status = status

        self._reach = reach
        self._data = []

    @classmethod
    def _db_create(cls, execute):
        return cls._create_submodel(execute)

    @classmethod
    def _db_update(cls, execute, version):
        return cls._update_submodel(execute, version)

    @classmethod
    def _db_load(cls, execute, data=None):
        new = cls(
            reach=data["reach"],
            status=data["status"]
        )

        new._data = Data._db_load(
            execute,
            data=data
        )

        if new._data is not None:
            return new

    def _db_save(self, execute, data=None):
        ok = True

        ind = 0
        for d in self._data:
            data["ind"] = ind
            ok &= d._db_save(execute, data)
            ind += 1

        return ok

    def __len__(self):
        return len(self._data)

    def lst(self):
        return self._data

    @property
    def reach(self):
        return self._reach

    @reach.setter
    def reach(self, new):
        self._reach = reach
        self._status.modified()

    @property
    def data(self):
        return self._data.copy()

    @data.setter
    def data(self, data):
        self._data = data

    def get(self, index):
        return self._data[index]

    def set(self, index, data):
        self._data.insert(index, data)
        self._status.modified()

    def new(self, index):
        n = Data(reach=self._reach, status=self._status)
        self._data.insert(index, n)
        self._status.modified()

    def new_from_data(self, rk, discharge, elevation):
        n = Data(reach=self._reach, status=self._status)

        n['rk'] = rk
        n['discharge'] = discharge
        n['elevation'] = elevation

        return n

    def insert(self, index, data):
        self._data.insert(index, data)
        self._status.modified()

    def delete(self, data):
        self._data = list(
            filter(
                lambda x: x not in data,
                self._data
            )
        )
        self._status.modified()

    def delete_i(self, indexes):
        data = list(
            map(
                lambda x: x[1],
                filter(
                    lambda x: x[0] in indexes,
                    enumerate(self._data)
                )
            )
        )
        self.delete(data)

    def sort(self, reverse=False, key=None):
        self._data.sort(reverse=reverse, key=key)
        self._status.modified()

    def _data_get(self, key):
        return list(
            map(
                lambda d: d[key],
                self._data
            )
        )

    def get_rk(self):
        return self._data_get("rk")

    def get_elevation(self):
        return self._data_get("elevation")

    def get_discharge(self):
        return self._data_get("discharge")

    def _sort_by_z_and_rk(self, profiles):
        profiles.sort(
            reverse=False,
            key=lambda p: p.rk
        )

        first_z = profiles[0].z()
        last_z = profiles[-1].z()

        if first_z > last_z:
            profiles.sort(
                reverse=True,
                key=lambda p: p.rk
            )

    def generate_growing_constante_height(self, height: float,
                                          compute_discharge: bool):

        profiles = self._reach.reach.profiles.copy()
        self._sort_by_z_and_rk(profiles)

        previous_elevation = -99999.99

        data_discharge = {}
        if not compute_discharge:
            if len(self._data) == 0:
                for profile in profiles:
                    data_discharge[profile.rk] = 0.0
            else:
                for data in self._data:
                    data_discharge[data["rk"]] = data["discharge"]

        incline = self._reach.reach.get_incline_median_mean()
        logger.debug(f"incline = {incline}")
        self._data = []
        for profile in profiles:
            width = profile.wet_width(profile.z_min() + height)
            frictions = self._reach.frictions.frictions
            strickler = None
            for f in frictions:
                if f.contains_rk(profile.rk):
                    strickler = f.get_friction(profile.rk)[0]
            if strickler is None:
                strickler = 25.0

            if not compute_discharge:
                discharge = data_discharge[profile.rk]
            else:
                discharge = (
                    ((width * 0.8)
                     * strickler
                     * (height ** (5/3))
                     * (abs(incline) ** (0.5)))
                )

            elevation = max(
                profile.z_min() + height,
                previous_elevation
            )

            logger.debug(f"({profile.rk}):")
            logger.debug(f"  width  = {width}")
            logger.debug(f"  strickler = {strickler}")
            logger.debug(f"  discharge = {discharge}")

            new = Data(reach=self._reach, status=self._status)
            new["rk"] = profile.rk
            new["discharge"] = discharge
            new["elevation"] = elevation

            previous_elevation = elevation
            self._data.append(new)

        self._generate_resort_data(profiles)

    def generate_discharge(self, discharge: float, compute_height: bool):

        profiles = self._reach.reach.profiles.copy()
        self._sort_by_z_and_rk(profiles)

        previous_elevation = -99999.99

        data_height = {}
        if not compute_height:
            if len(self._data) == 0:
                for profile in profiles:
                    data_height[profile.rk] = 0.0
            else:
                for data in self._data:
                    data_height[data["rk"]] = data["height"]

        incline = self._reach.reach.get_incline_median_mean()
        logger.debug(f"incline = {incline}")
        self._data = []
        for profile in profiles:
            width = profile.width_approximation()
            frictions = self._reach.frictions.frictions
            strickler = None
            for f in frictions:
                if f.contains_rk(profile.rk):
                    strickler = f.get_friction(profile.rk)[0]
            if strickler is None:
                strickler = 25.0

            if not compute_height:
                height = data_height[profile.rk]
            else:
                height = (
                    discharge
                    /
                    ((width * 0.8) * strickler * (abs(incline) ** (0.5)))
                ) ** (0.6)

            elevation = max(
                profile.z_min() + height,
                previous_elevation
            )

            logger.debug(f"({profile.rk}):")
            logger.debug(f"  width  = {width}")
            logger.debug(f"  strickler = {strickler}")
            logger.debug(f"  height = {height}")

            new = Data(reach=self._reach, status=self._status)
            new["rk"] = profile.rk
            new["discharge"] = discharge
            new["elevation"] = elevation

            previous_elevation = elevation
            self._data.append(new)

        self._generate_resort_data(profiles)

    def _generate_resort_data(self, profiles):
        is_reverse = False
        if profiles[0].rk > profiles[-1].rk:
            is_reverse = True

        self._data.sort(
            reverse=not is_reverse,
            key=lambda d: d['rk']
        )