# RubarBE.py -- Pamhyr
# Copyright (C) 2024  INRAE
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

# -*- coding: utf-8 -*-

import os
import logging
import numpy as np

from tools import timer, trace, old_pamhyr_date_to_timestamp

from Solver.CommandLine import CommandLineSolver

from Model.Results.Results import Results
from Model.Results.River.River import River, Reach, Profile

logger = logging.getLogger()


class RubarBE(CommandLineSolver):
    _type = "rubarbe"

    def __init__(self, name):
        super(RubarBE, self).__init__(name)

        self._type = "rubarbe"

        self._cmd_input = ""
        self._cmd_solver = "@path @input -o @output"
        self._cmd_output = ""

    @classmethod
    def default_parameters(cls):
        lst = super(RubarBE, cls).default_parameters()

        lst += [
            ("rubarbe_cfl", "0.50000E+00"),
            ("rubarbe_condam", "1"),
            ("rubarbe_condav", "3"),
            ("rubarbe_regime", "0"),
            ("rubarbe_iodev", "n"),
            ("rubarbe_iodebord", ""),
            ("rubarbe_iostockage", ""),
            ("rubarbe_iopdt", "y"),
            ("rubarbe_iovis", "n"),
            ("rubarbe_rep", "n"),
            ("rubarbe_tinit", "000:00:00:00"),
            ("rubarbe_tmax", "999:99:99:00"),
            ("rubarbe_tiopdt", "000:00:00:00"),
            ("rubarbe_dt", "3000.0"),
            ("rubarbe_ts", "999:99:99:00"),
            ("rubarbe_dtsauv", "999:99:99:00"),
            ("rubarbe_psave", "999:99:99:00"),
            ("rubarbe_fdeb1", "1"),
            ("rubarbe_fdeb2", "10"),
            ("rubarbe_fdeb3", "100"),
            ("rubarbe_tf_1", "y"),
            ("rubarbe_tf_2", "y"),
            ("rubarbe_tf_3", "y"),
            ("rubarbe_tf_4", "y"),
            ("rubarbe_tf_5", "y"),
            ("rubarbe_tf_6", "n"),
            ("rubarbe_trased", "y"),
            ("rubarbe_optfpc", "0"),
            ("rubarbe_ros", "2650.0"),
            ("rubarbe_dm", "0.1"),
            ("rubarbe_segma", "1.0"),
            # Sediment parameters
            ("rubarbe_sediment_ros", "2650.0"),
            ("rubarbe_sediment_por", "0.4"),
            ("rubarbe_sediment_dcharg", "0.0"),
            ("rubarbe_sediment_halfa", "1.0"),
            ("rubarbe_sediment_mult_1", "1.0"),
            ("rubarbe_sediment_mult_2", ""),
            ("rubarbe_sediment_mult_3", ""),
            ("rubarbe_sediment_mult_4", ""),
            ("rubarbe_sediment_mult_5", ""),
            ("rubarbe_sediment_visc", "0.047"),
            ("rubarbe_sediment_opts", "6"),
            ("rubarbe_sediment_odchar", "0"),
            ("rubarbe_sediment_unisol", "1"),
            ("rubarbe_sediment_typdef", "3"),
            ("rubarbe_sediment_depot", "2"),
            ("rubarbe_sediment_choixc", "2"),
            ("rubarbe_sediment_option", "2"),
            ("rubarbe_sediment_capsol", "1"),
            ("rubarbe_sediment_bmiu", "0.85"),
            ("rubarbe_sediment_demix", "0"),
            ("rubarbe_sediment_defond", "1"),
            ("rubarbe_sediment_varcons", "1"),
            ("rubarbe_sediment_dchard", "0.0"),
            ("rubarbe_sediment_dchars", "0.0"),
        ]

        return lst

    @classmethod
    def checkers(cls):
        lst = [
        ]

        return lst

    ##########
    # Export #
    ##########

    def cmd_args(self, study):
        lst = super(RubarBE, self).cmd_args(study)

        return lst

    def input_param(self):
        name = self._study.name
        return f"{name}.REP"

    def output_param(self):
        name = self._study.name
        return f"{name}.BIN"

    def log_file(self):
        name = self._study.name
        return f"{name}.TRA"

    def export(self, study, repertory, qlog=None):
        self._study = study
        name = study.name.replace(" ", "_")

        self._export_donnee(study, repertory, qlog, name=name)
        self._export_ts(study, repertory, qlog, name=name)
        self._export_geomac_i(study, repertory, qlog, name=name)
        self._export_mail(study, repertory, qlog, name=name)
        self._export_condin(study, repertory, qlog, name=name)
        self._export_stricklers(study, repertory, qlog, name=name)

    def _export_donnee(self, study, repertory, qlog, name="0"):
        if qlog is not None:
            qlog.put("Export DONNEE file")

        with open(
                os.path.join(
                    repertory, f"donnee.{name}"
                ), "w+"
        ) as f:
            params = filter(
                lambda p: "rubarbe_sediment_" not in p.name,
                study.river.get_params(self._type).parameters
            )

            it = iter(params)

            line = 0
            while line < 29:
                param = next(it)
                name = param.name
                value = param.value

                if value != "":
                    # Value format
                    if value.count(':') == 3:
                        value = old_pamhyr_date_to_timestamp(value)
                        value = f"{value:>12.5e}".upper()

                    if value.count('.') == 1:
                        value = f"{float(value):>12.5e}".upper()

                    if value == "y" or value == "n":
                        value = "O" if value == "y" else "N"

                    # Write value
                    f.write(f"{name:<50}{value}")

                    # Add values of 'rubarbe_iodebord' and
                    # 'rubarbe_iostockage'
                    if name == "rubarbe_iodev":
                        v2 = next(it).value
                        v3 = next(it).value

                        f.write(f"{v2}{v3}")

                    # New line
                    f.write(f"\n")

                line += 1

    def _export_ts(self, study, repertory, qlog, name="0"):
        if qlog is not None:
            qlog.put("Export TS file")

        with open(
                os.path.join(
                    repertory, f"ts.{name}"
                ), "w+"
        ) as f:
            def float_format(string):
                if "." in string:
                    return f"{float(string):>10.0f}"
                return ""

            params = filter(
                lambda p: "rubarbe_sediment_" in p.name,
                study.river.get_params(self.type).parameters
            )
            it = iter(params)

            line = 0
            while line < 20:
                param = next(it)
                name = param.name
                value = param.value

                if value != "":
                    # Value format
                    if value.count('.') == 1:
                        value = f"{float_format(value)}"
                    else:
                        value = f"{value:>10}"

                    # Write value
                    f.write(f"{name:<50}{value}")

                    # Add values of 'rubarbe_iodebord' and
                    # 'rubarbe_iostockage'
                    if name == "rubarbe_sediment_mult_1":
                        m2 = f"{float_format(next(it).value)}"
                        m3 = f"{float_format(next(it).value)}"
                        m4 = f"{float_format(next(it).value)}"
                        m5 = f"{float_format(next(it).value)}"

                        f.write(f"{m2}{m3}{m4}{m5}")

                    # New line
                    f.write(f"\n")

                line += 1

    def _export_geomac_i(self, study, repertory, qlog, name="0"):
        if qlog is not None:
            qlog.put("Export GEOMAC-i file")

        with open(
                os.path.join(
                    repertory, f"geomac-i.{name}"
                ), "w+"
        ) as f:
            for edge in study.river.enable_edges():
                reach = edge.reach
                n_profiles = len(reach)
                time = 0.0

                f.write(f"{n_profiles:>5} {time:>11.3f}\n")

                ind = 1
                for profile in reach.profiles:
                    rk = profile.rk
                    n_points = len(profile)

                    f.write(f"{ind:>4} {rk:>11.3f} {n_points:>4}\n")

                    for point in profile.points:
                        label = point.name.lower()
                        if label != "":
                            if label[0] == "r":
                                label = label[1].upper()
                            else:
                                label = lable[0]

                        y = point.y
                        z = point.z
                        dcs = 0.001
                        scs = 1.0
                        tmcs = 0.0

                        f.write(
                            f"{label} {y:>11.5f}" +
                            f"{z:>13.5f}{dcs:>15.10f}" +
                            f"{scs:>15.10f}{tmcs:>15.5f}" +
                            "\n"
                        )

                    ind += 1

    def _export_mail(self, study, repertory, qlog, name="0"):
        if qlog is not None:
            qlog.put("Export MAIL file")

        with open(
                os.path.join(
                    repertory, f"mail.{name}"
                ), "w+"
        ) as f:
            for edge in study.river.enable_edges():
                reach = edge.reach
                lm = len(reach) + 1
                f.write(f"{lm:>13}\n")

                for mails in [reach.inter_profiles_rk(),
                              reach.get_rk()]:
                    ind = 0
                    for mail in mails:
                        f.write(f"{mail:15.3f}")

                        ind += 1
                        if ind % 3 == 0:
                            f.write("\n")

                    if ind % 3 != 0:
                        f.write("\n")

    def _export_stricklers(self, study, repertory, qlog, name="0"):
        self._export_frot(study, repertory, qlog, name=name, version="")
        self._export_frot(study, repertory, qlog, name=name, version="2")

    def _export_frot(self, study, repertory, qlog, name="0", version=""):
        if qlog is not None:
            qlog.put(f"Export FROT{version} file")

        with open(
                os.path.join(
                    repertory, f"frot{version}.{name}"
                ), "w+"
        ) as f:
            for edge in study.river.enable_edges():
                reach = edge.reach
                lm = len(reach) + 1
                f.write(f"{lm:>6}\n")

                def get_stricklers_from_rk(rk):
                    return next(
                        map(
                            lambda s: (
                                s.begin_strickler.medium if version == "2"
                                else s.begin_strickler.minor
                            ),
                            filter(
                                lambda f: rk in f,
                                edge.frictions.lst
                            )
                        )
                    )

                ind = 1
                for mail in edge.reach.inter_profiles_rk():
                    coef = get_stricklers_from_rk(mail)

                    f.write(f"{ind:>6} {coef:>12.5f}")

                    ind += 1
                    f.write("\n")

    def _export_condin(self, study, repertory, qlog, name="0"):
        if qlog is not None:
            qlog.put("Export CONDIN file")

        with open(
                os.path.join(
                    repertory, f"condin.{name}"
                ), "w+"
        ) as f:
            for edge in study.river.enable_edges():
                reach = edge.reach

                f.write(f"0.0\n")

                ics = study.river.initial_conditions.get(edge)
                data = self._export_condin_init_data(ics)

                profiles = reach.profiles
                first = profiles[0]
                last = profiles[-1]

                if first.rk not in data or last.rk not in data:
                    logger.error(
                        "Study initial condition is not fully defined"
                    )
                    return

                f_h_s = self._export_condin_profile_height_speed(first, data)
                l_h_s = self._export_condin_profile_height_speed(last, data)

                # First mail
                f.write(f"{1:>5} {f_h_s[0]} {f_h_s[1]}")

                ind = 2
                it = iter(profiles)
                prev = next(it)

                prev_h, prev_s = f_h_s
                for profile in it:
                    if profile.rk not in data:
                        ind += 1
                        continue

                    cur_h, cur_s = self._export_condin_profile_height_speed(
                        profile, data
                    )

                    # Mean of height and speed
                    h = (prev_h + cur_h) / 2
                    s = (prev_s + cur_s) / 2

                    f.write(f"{ind:>5} {h} {s}\n")

                    prev_h, prev_s = cur_h, cur_s
                    ind += 1

                # Last mail
                f.write(f"{ind:>5} {f_h_s[0]} {f_h_s[1]}")

    def _export_condin_init_data(self, ics):
        data = {}

        for d in ics.data:
            data[d['rk']] = (
                d['elevation'],
                d['discharge'],
            )

        return data

    def _export_condin_profile_height_speed(self, profile, data):
        z = data[profile.rk][0]
        q = data[profile.rk][1]

        height = z - profile.z_min()
        speed = profile.speed(q, z)

        return height, speed