# P3DST.py -- Pamhyr
# Copyright (C) 2023  INRAE
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

# -*- coding: utf-8 -*-

import sys
import logging

from matplotlib import pyplot as plt
from numpy import mean

from Scripts.AScript import AScript

from Model.Saved import SavedStatus
from Model.Geometry.Reach import Reach

logger = logging.getLogger()


class Script3DST(AScript):
    name = "3DST"
    description = "Display a 3D plot of a river reach from ST file"

    def usage(self):
        logger.info(f"Usage : {self._args[0]} 3DST <INPUT_ST_FILE>")

    def set_axes_equal(self, ax):
        """Make axes of 3D plot have equal scale

        Make axes of 3D plot have equal scale so that spheres appear
        as spheres, cubes as cubes, etc.. This is one possible
        solution to Matplotlib's ax.set_aspect('equal') and
        ax.axis('equal') not working for 3D.

        Args:
            ax: a matplotlib axis, e.g., as output from plt.gca()

        Returns:
            The input axis
        """
        x_limits = ax.get_xlim3d()
        y_limits = ax.get_ylim3d()
        z_limits = ax.get_zlim3d()

        x_range = abs(x_limits[1] - x_limits[0])
        x_middle = mean(x_limits)
        y_range = abs(y_limits[1] - y_limits[0])
        y_middle = mean(y_limits)
        z_range = abs(z_limits[1] - z_limits[0])
        z_middle = mean(z_limits)

        # The plot bounding box is a sphere in the sense of the infinity
        # norm, hence I call half the max range the plot radius.
        plot_radius = 0.5*max([x_range, y_range, z_range])

        ax.set_xlim3d([x_middle - plot_radius, x_middle + plot_radius])
        ax.set_ylim3d([y_middle - plot_radius, y_middle + plot_radius])
        ax.set_zlim3d([z_middle - plot_radius, z_middle + plot_radius])

        return ax

    def run(self):
        try:
            st_file = self._args[2]
            logger.info(f"Use ST file: {st_file}")
        except Exception as e:
            logger.error(f"Arguments parcing: {str(e)}")
            return 1

        try:
            status = SavedStatus()

            my_reach = Reach(status=status)
            my_reach.import_geometry(st_file)
            my_reach.compute_guidelines()

            ax = plt.figure().add_subplot(projection="3d")
            logger.info(my_reach.get_x())
            for x, y, z in zip(
                    my_reach.get_x(),
                    my_reach.get_y(),
                    my_reach.get_z()
            ):
                ax.plot(x, y, z, color='r', lw=1.)

            for x, y, z in zip(
                    my_reach.get_guidelines_x(),
                    my_reach.get_guidelines_y(),
                    my_reach.get_guidelines_z()
            ):
                ax.plot(x, y, z, color='b', lw=1.)

            ax.set_xlabel('X')
            ax.set_ylabel('Y')
            ax.set_zlabel('Z')
            plt.tight_layout()
            self.set_axes_equal(ax)
            plt.show()

            return 0
        except Exception as e:
            logger.error(str(e))
            return 1