# Window.py -- Pamhyr
# Copyright (C) 2023  INRAE
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

# -*- coding: utf-8 -*-

import os
import sys
import time
import pathlib
import logging

from copy import deepcopy
from tools import timer, trace, logger_exception

from PyQt5 import QtWidgets
from PyQt5.QtGui import (
    QKeySequence,
)
from PyQt5.QtCore import (
    QModelIndex, Qt, QSettings, pyqtSlot,
    QItemSelectionModel, QCoreApplication, QSize
)
from PyQt5.QtWidgets import (
    QApplication, QMainWindow, QFileDialog, QCheckBox,
    QUndoStack, QShortcut, QTableView, QHeaderView,
    QAction, QSlider, QPushButton, QVBoxLayout,
    QLabel, QAbstractItemView,
)

from Model.Except import ExternFileMissingError

from View.Tools.PamhyrWindow import PamhyrWindow
from View.Tools.Plot.PamhyrToolbar import PamhyrPlotToolbar
from View.Tools.Plot.PamhyrCanvas import MplCanvas

from Meshing.Mage import (
    MeshingWithMage, MeshingWithMageMailleurTT
)

from View.Geometry.Table import GeometryReachTableModel
from View.Geometry.PlotXY import PlotXY
from View.Geometry.PlotAC import PlotAC
from View.Geometry.PlotKPZ import PlotKPZ
from View.Geometry.MeshingDialog import MeshingDialog
from View.Geometry.Translate import GeometryTranslate
from View.Geometry.Profile.Window import ProfileWindow

_translate = QCoreApplication.translate

logger = logging.getLogger()


class GeometryWindow(PamhyrWindow):
    _pamhyr_ui = "GeometryReach"
    _pamhyr_name = "Geometry"

    def __init__(self, reach=None, study=None, config=None, parent=None):
        if reach is None:
            self._reach = study.river.current_reach().reach
        else:
            self._reach = reach

        name = f"{self._pamhyr_name} - {self._reach.name}"

        super(GeometryWindow, self).__init__(
            title=name,
            study=study,
            config=config,
            trad=GeometryTranslate(),
            parent=parent
        )

        # Add reach to hash computation data
        self._hash_data.append(self._reach)

        self._table = None
        self._profile_window = []

        self.setup_table()
        self.setup_plots()
        self.setup_statusbar()
        self.setup_connections()
        self.changed_slider_value()

    def setup_table(self):
        table_headers = self._trad.get_dict("table_headers")

        table = self.find(QTableView, "tableView")
        self._table = GeometryReachTableModel(
            table_view=table,
            table_headers=table_headers,
            editable_headers=["name", "kp"],
            data=self._reach,
            undo=self._undo_stack
        )
        table.setModel(self._table)
        table.setSelectionBehavior(QAbstractItemView.SelectRows)
        table.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch)
        table.setAlternatingRowColors(True)

    def setup_plots(self):
        self._canvas_xy = MplCanvas(width=3, height=4, dpi=100)
        self._canvas_xy.setObjectName("canvas_xy")
        self._toolbar_xy = PamhyrPlotToolbar(
            self._canvas_xy, self,
            items=["home", "zoom", "save", "iso", "back/forward", "move"]
        )
        self._plot_layout_xy = self.find(QVBoxLayout, "verticalLayout")
        self._plot_layout_xy.addWidget(self._toolbar_xy)
        self._plot_layout_xy.addWidget(self._canvas_xy)
        self.plot_xy()

        self._canvas_kpc = MplCanvas(width=6, height=4, dpi=100)
        self._canvas_kpc.setObjectName("canvas_kpc")
        self._toolbar_kpc = PamhyrPlotToolbar(
            self._canvas_kpc, self,
            items=["home", "zoom", "save", "iso", "back/forward", "move"]
        )
        self._plot_layout_kpc = self.find(QVBoxLayout, "verticalLayout_2")
        self._plot_layout_kpc.addWidget(self._toolbar_kpc)
        self._plot_layout_kpc.addWidget(self._canvas_kpc)
        self.plot_kpc()

        self._canvas_ac = MplCanvas(width=9, height=4, dpi=100)
        self._canvas_ac.setObjectName("canvas_ac")
        self._toolbar_ac = PamhyrPlotToolbar(
            self._canvas_ac, self,
            items=["home", "zoom", "save", "iso", "back/forward", "move"]
        )
        self._plot_layout_ac = self.find(QVBoxLayout, "verticalLayout_3")
        self._plot_layout_ac.addWidget(self._toolbar_ac)
        self._plot_layout_ac.addWidget(self._canvas_ac)
        self.plot_ac()

    def _compute_status_label(self):
        row = self.index_selected_row()
        profile = self._reach.profile(row)

        name = profile.name + " " + str(profile.kp)

        return (
            "<font color=\"Grey\">" +
            f"{self._trad['reach']}: {self._reach.name}" +
            " - " +
            f"{self._trad['cross_section']}:</font> {name}"
        )

    def setup_statusbar(self):
        txt = ""
        self._status_label = QLabel(txt)
        self.statusbar.addPermanentWidget(self._status_label)

    def update_statusbar(self):
        txt = self._compute_status_label()
        self._status_label.setText(txt)

    def setup_connections(self):
        actions = {
            "action_import": self.import_from_file,
            "action_export": self.export_to_file,
            "action_sort_asc": self.sort_ascending,
            "action_sort_des": self.sort_descending,
            "action_up": self.move_up,
            "action_down": self.move_down,
            "action_add": self.add,
            "action_delete": self.delete,
            "action_edit": self.edit_profile,
            "action_meshing": self.edit_meshing,
        }

        for action in actions:
            self.find(QAction, action)\
                .triggered.connect(actions[action])

        self.find(QSlider, "verticalSlider").valueChanged.connect(
            self.changed_slider_value)
        self.find(QPushButton, "pushButton_up").clicked.connect(
            self.decrement_value_slider)
        self.find(QPushButton, "pushButton_down").clicked.connect(
            self.increment_value_slider)

        # Profile selection when line change in table
        self.find(QTableView, "tableView").selectionModel()\
                                          .selectionChanged\
                                          .connect(self.select_current_profile)

        self._table.layoutChanged.connect(self.update)

    def update(self):
        self.update_profile_windows()
        self.plot_xy()
        self.plot_kpc()
        self.plot_ac()

        self.select_current_profile()
        self.changed_slider_value()

    def import_from_file(self):
        options = QFileDialog.Options()
        settings = QSettings(QSettings.IniFormat,
                             QSettings.UserScope, 'MyOrg', )
        options |= QFileDialog.DontUseNativeDialog

        file_types = [
            self._trad["file_st"],
            self._trad["file_m"],
            self._trad["file_all"],
        ]

        filename, _ = QtWidgets.QFileDialog.getOpenFileName(
            self,
            self._trad["open_file"],
            "",
            ";; ".join(file_types),
            options=options
        )

        if filename != "":
            size = os.stat(filename).st_size
            self._table.import_geometry(0, filename)

    def edit_profile(self):
        self.tableView.model().blockSignals(True)

        rows = list(
            set(
                (i.row() for i in self.tableView.selectedIndexes())
            )
        )

        for row in rows:
            profile = self._reach.profile(row)

            if self.sub_window_exists(
                ProfileWindow,
                data=[None, None, profile]
            ):
                continue

            win = ProfileWindow(
                profile=profile,
                parent=self,
            )
            self._profile_window.append(win)
            win.show()

        self.tableView.model().blockSignals(False)

    def edit_meshing(self):
        try:
            dlg = MeshingDialog(
                reach=self._reach,
                parent=self
            )
            if dlg.exec():
                data = {
                    "step": dlg.space_step,
                    "lplan": dlg.lplan,
                    "linear": dlg.linear,
                }
                self._edit_meshing(data)
        except Exception as e:
            return

    def _edit_meshing(self, data):
        try:
            mesher = MeshingWithMageMailleurTT()
            self._table.meshing(mesher, data)
        except Exception as e:
            logger_exception(e)
            raise ExternFileMissingError(
                module="mage",
                filename="MailleurTT",
                path=MeshingWithMageMailleurTT._path(),
                src_except=e
            )

    pyqtSlot(bool)

    def changed_profile_slot(self, status):
        self.update_view1 = status

    def update_profile_windows(self):
        self.list_second_window = []
        self.list_row = []

    def plot_xy(self):
        self.tableView.model().blockSignals(True)

        self._plot_xy = PlotXY(
            canvas=self._canvas_xy,
            data=self._reach,
            toolbar=self._toolbar_xy
        )
        self._plot_xy.draw()

        self.tableView.model().blockSignals(False)

    def update_plot_xy(self):
        self.tableView.model().blockSignals(True)
        self._plot_xy.update()
        self.tableView.model().blockSignals(False)

    def plot_kpc(self):
        self.tableView.model().blockSignals(True)

        self._plot_kpc = PlotKPZ(
            canvas=self._canvas_kpc,
            data=self._reach,
            toolbar=self._toolbar_kpc
        )
        self._plot_kpc.draw()

        self.tableView.model().blockSignals(False)

    def update_plot_kpc(self):
        self.tableView.model().blockSignals(True)
        self._plot_kpc.update()
        self.tableView.model().blockSignals(False)

    def plot_ac(self):
        self.tableView.model().blockSignals(True)

        self._plot_ac = PlotAC(
            canvas=self._canvas_ac,
            data=self._reach,
            toolbar=self._toolbar_ac,
            plot_xy=self._plot_xy
        )
        self._plot_ac.draw()

        self.tableView.model().blockSignals(False)

    def update_plot_ac(self, ind: int):
        self.tableView.model().blockSignals(True)
        self._plot_ac.update(ind=ind)
        self.tableView.model().blockSignals(False)

    def get_station(self, ind: int):
        return self._reach.profile(ind).get_station()

    def get_elevation(self, ind: int):
        return self._reach.profile(ind).z()

    def select_plot_xy(self, ind: int):
        self.tableView.model().blockSignals(True)
        self._plot_xy.update(ind=ind)
        self.tableView.model().blockSignals(False)

    def select_plot_kpc(self, ind: int):
        self.tableView.model().blockSignals(True)
        self._plot_kpc.update(ind=ind)
        self.tableView.model().blockSignals(False)

    def select_plot_ac(self, ind: int):
        self.tableView.model().blockSignals(True)
        self._plot_ac.update(ind=ind)
        self.tableView.model().blockSignals(False)

    def select_row_profile_slider(self, ind: int = 0):
        if self.tableView is not None:
            selectionModel = self.tableView.selectionModel()
            index = self.tableView.model().index(ind, 0)

            selectionModel.select(
                index,
                QItemSelectionModel.Rows |
                QItemSelectionModel.ClearAndSelect |
                QItemSelectionModel.Select
            )

            self.tableView.scrollTo(index)

    def select_current_profile(self):
        self.tableView.model().blockSignals(True)

        if len(self.tableView.selectedIndexes()) > 0:
            row = self.index_selected_row()

            self.find(QSlider, "verticalSlider").setValue(row)
            self.select_plot_xy(row)
            self.select_plot_kpc(row)
            self.select_plot_ac(row)

        self.tableView.model().blockSignals(False)

    def changed_slider_value(self):
        self.tableView.model().blockSignals(True)

        if self._table.rowCount() != 0:
            slider = self.find(QSlider, "verticalSlider")
            slider.setMaximum(self._table.rowCount() - 1)

            slider_value = slider.value()
            kp = self._reach.profile(slider_value).kp

            self.select_plot_xy(slider_value)
            self.select_plot_kpc(slider_value)
            self.select_row_profile_slider(slider_value)
            self.update_statusbar()

        self.tableView.model().blockSignals(False)

    def increment_value_slider(self):
        slider = self.find(QSlider, "verticalSlider")
        if 0 <= slider.value() < self._table.rowCount() - 1:
            slider.setValue(slider.value() + 1)

    def decrement_value_slider(self):
        slider = self.find(QSlider, "verticalSlider")
        if 0 < slider.value() < self._table.rowCount():
            slider.setValue(slider.value() - 1)

    def add(self):
        if len(self.tableView.selectedIndexes()) == 0:
            self._table.add(self._table.rowCount())
        else:
            row = self.index_selected_row()
            self._table.add(row + 1)

    def delete(self):
        rows = sorted(
            list(
                set(
                    [index.row() for index in self.tableView.selectedIndexes()]
                )
            )
        )

        if len(rows) > 0:
            self._table.delete(rows)

    def index_selected_row(self):
        return self.tableView\
                   .selectionModel()\
                   .selectedRows()[0]\
                   .row()

    def sort_ascending(self):
        self._table.sort_profiles(False)
        self.select_current_profile()
        self.changed_slider_value()

    def sort_descending(self):
        self._table.sort_profiles(True)

        self.select_current_profile()
        self.changed_slider_value()

    def move_up(self):
        row = self.index_selected_row()
        self._table.move_up(row)
        self.select_current_profile()

    def move_down(self):
        row = self.index_selected_row()
        self._table.move_down(row)
        self.select_current_profile()

    def duplicate(self):
        rows = [
            row.row() for row in
            self.tableView.selectionModel().selectedRows()
        ]

        profiles = []
        for row in rows:
            profiles.append(
                self._reach.profile(row)
            )

        if len(profiles) == 0:
            return

        self._table.duplicate(rows, profiles)
        self.select_current_profile()

    def _copy(self):
        rows = self.tableView\
                   .selectionModel()\
                   .selectedRows()

        table = []
        # table.append(["name", "kp"])

        for row in rows:
            profile = self._reach.profile(row.row())
            table.append(
                [profile.name, profile.kp]
            )

        self.copyTableIntoClipboard(table)

    def _paste(self):
        header, data = self.parseClipboardTable()

        if len(data) + len(header) == 0:
            return

        # if len(header) != 0:
        #     header.append("reach")
        #     header.append("status")

        # HACK: The CVS module detect the first line to csv header in
        # some particular case... To avoid this we append to data list
        # the "headers". /!\ This hack must failed if a real header
        # exists (a better solution is welcome).
        logger.debug(
            "Geometry: Paste: " +
            f"header = {header}, " +
            f"data = {data}"
        )
        if len(header) != 0:
            data = [header] + data

        try:
            for row in data:
                row.append(self._reach)
                row.append(self._study.river._status)

            row = self.index_selected_row()
            # self._table.paste(row, header, data)
            self._table.paste(row, [], data)
            self.select_current_profile()
        except Exception as e:
            logger_exception(e)

    def _undo(self):
        self._table.undo()
        self.select_current_profile()
        # self.update_plot_ac()
        # self.update_plot_xy()
        # self.update_plot_kpc()

    def _redo(self):
        self._table.redo()
        self.select_current_profile()
        # self.update_plot_ac()
        # self.update_plot_xy()
        # self.update_plot_kpc()

    def export_to_file(self):
        settings = QSettings(
            QSettings.IniFormat,
            QSettings.UserScope, 'MyOrg'
        )

        if self._study.filename != "" or self._study.filename is not None:
            default_directory = os.path.basename(self._study.filename)
            current_dir = settings.value(
                'current_directory',
                default_directory,
                type=str
            )

        options = QFileDialog.Options()
        options |= QFileDialog.DontUseNativeDialog

        filename, filters = QFileDialog.getSaveFileName(
            self,
            filter=(
                self._trad["file_st"] + ";; " +
                self._trad["file_all"]
            ),
            options=options
        )

        if filename != '':
            self._export_to_file_st(filename)

    def _export_to_file_st(self, filename):
        with open(filename, "w+") as f:
            f.write("# Exported from Pamhyr2\n")
            self._export_to_file_st_reach(f, self._reach)

    def _export_to_file_st_reach(self, wfile, reach):
        pid = 0
        for profile in reach.profiles:
            self._export_to_file_st_profile(wfile, profile, pid)
            pid += 1

    def _export_to_file_st_profile(self, wfile, profile, pid):
        num = f"{pid:>6}"
        c1 = f"{profile.code1:>6}"
        c2 = f"{profile.code2:>6}"
        t = f"{len(profile.points):>6}"
        kp = f"{profile.kp:>12f}"[0:12]
        pname = profile.name
        if profile.name == "":
            pname = f"p{profile.id:>3}".replace(" ", "0")
            name = f"{pname:<19}"

        wfile.write(f"{num}{c1}{c2}{t} {kp} {pname}\n")

        for point in profile.points:
            self._export_to_file_st_point(wfile, point)

        wfile.write(f"     999.9990     999.9990     999.9990\n")

    def _export_to_file_st_point(self, wfile, point):
        x = f"{point.x:<12.4f}"[0:12]
        y = f"{point.y:<12.4f}"[0:12]
        z = f"{point.z:<12.4f}"[0:12]
        n = f"{point.name:<3}"

        wfile.write(f"{x} {y} {z} {n}\n")