abstract_margin_function.py 10.14 KiB
from typing import Dict, List

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from experiment.meteo_france_data.scm_models_data.visualization.utils import create_adjusted_axes
from extreme_fit.distribution.gev.gev_params import GevParams
from experiment.meteo_france_data.scm_models_data.visualization.create_shifted_cmap import imshow_shifted
from extreme_fit.function.abstract_function import AbstractFunction
from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
from spatio_temporal_dataset.slicer.split import Split
from root_utils import cached_property


class AbstractMarginFunction(AbstractFunction):
    """
    AbstractMarginFunction maps points from a space S (could be 1D, 2D,...) to R^3 (the 3 parameters of the GEV)
    """
    VISUALIZATION_RESOLUTION = 100
    VISUALIZATION_TEMPORAL_STEPS = 2

    def __init__(self, coordinates: AbstractCoordinates):
        self.coordinates = coordinates
        self.mask_2D = None

        # Visualization parameters
        self.visualization_axes = None
        self.datapoint_display = False
        self.spatio_temporal_split = Split.all
        self.datapoint_marker = 'o'
        self.color = 'skyblue'
        self.filter = None
        self.linewidth = 1
        self.subplot_space = 1.0

        self.temporal_step_to_grid_2D = {}
        self._grid_1D = None
        self.title = None
        self.add_future_temporal_steps = False

        # Visualization limits
        self._visualization_x_limits = None
        self._visualization_y_limits = None

    @property
    def x(self):
        return self.coordinates.x_coordinates

    @property
    def y(self):
        return self.coordinates.y_coordinates

    def get_gev_params(self, coordinate: np.ndarray) -> GevParams:
        """Main method that maps each coordinate to its GEV parameters"""
        raise NotImplementedError

    @property
    def gev_value_name_to_serie(self) -> Dict[str, pd.Series]:
        # Load the gev_params
        gev_params = [self.get_gev_params(coordinate) for coordinate in self.coordinates.coordinates_values()]
        # Load the dictionary of values (distribution parameters + the quantiles)
        value_dicts = [gev_param.summary_dict for gev_param in gev_params]
        gev_value_name_to_serie = {}
        for value_name in GevParams.SUMMARY_NAMES:
            s = pd.Series(data=[d[value_name] for d in value_dicts], index=self.coordinates.index)
            gev_value_name_to_serie[value_name] = s
        return gev_value_name_to_serie

    # Visualization function

    def set_datapoint_display_parameters(self, spatio_temporal_split=Split.all, datapoint_marker=None, filter=None,
                                         color=None,
                                         linewidth=1, datapoint_display=False):
        self.datapoint_display = datapoint_display
        self.spatio_temporal_split = spatio_temporal_split
        self.datapoint_marker = datapoint_marker
        self.linewidth = linewidth
        self.filter = filter
        self.color = color

    def visualize_function(self, axes=None, show=True, dot_display=False, title=None):
        self.title = title
        self.datapoint_display = dot_display
        if axes is None:
            if self.coordinates.has_temporal_coordinates:
                axes = create_adjusted_axes(GevParams.NB_SUMMARY_NAMES, self.VISUALIZATION_TEMPORAL_STEPS)
            else:
                axes = create_adjusted_axes(1, GevParams.NB_SUMMARY_NAMES, subplot_space=self.subplot_space)
        self.visualization_axes = axes
        assert len(axes) == GevParams.NB_SUMMARY_NAMES
        for ax, gev_value_name in zip(axes, GevParams.SUMMARY_NAMES):
            self.visualize_single_param(gev_value_name, ax, show=False)
            self.set_title(ax, gev_value_name)
        if show:
            plt.show()
        return axes

    def set_title(self, ax, gev_value_name):
        if hasattr(ax, 'set_title'):
            title_str = gev_value_name if self.title is None else self.title
            ax.set_title(title_str)

    def visualize_single_param(self, gev_value_name=GevParams.LOC, ax=None, show=True):
        assert gev_value_name in GevParams.SUMMARY_NAMES
        nb_coordinates_spatial = self.coordinates.nb_spatial_coordinates
        has_temporal_coordinates = self.coordinates.has_temporal_coordinates
        if nb_coordinates_spatial == 1 and not has_temporal_coordinates:
            self.visualize_1D(gev_value_name, ax, show)
        elif nb_coordinates_spatial == 2 and not has_temporal_coordinates:
            self.visualize_2D(gev_value_name, ax, show)
        elif nb_coordinates_spatial == 2 and has_temporal_coordinates:
            self.visualize_2D_spatial_1D_temporal(gev_value_name, ax, show)
        else:
            raise NotImplementedError('Other visualization not yet implemented')

    # Visualization 1D

    def visualize_1D(self, gev_value_name=GevParams.LOC, ax=None, show=True):
        x = self.coordinates.x_coordinates
        grid, linspace = self.grid_1D(x)
        if ax is None:
            ax = plt.gca()
        if self.datapoint_display:
            ax.plot(linspace, grid[gev_value_name], marker=self.datapoint_marker, color=self.color)
        else:
            ax.plot(linspace, grid[gev_value_name], color=self.color, linewidth=self.linewidth)
        # X axis
        ax.set_xlabel('coordinate X')
        plt.setp(ax.get_xticklabels(), visible=True)
        ax.xaxis.set_tick_params(labelbottom=True)

        if show:
            plt.show()

    def grid_1D(self, x):
        # if self._grid_1D is None:
        #     self._grid_1D = self.get_grid_values_1D(x)
        # return self._grid_1D
        return self.get_grid_values_1D(x, self.spatio_temporal_split)

    def get_grid_values_1D(self, x, spatio_temporal_split):
        # TODO: to avoid getting the value several times, I could cache the results
        if self.datapoint_display:
            # todo: keep only the index of interest here
            linspace = self.coordinates.coordinates_values(spatio_temporal_split)[:, 0]
            if self.filter is not None:
                linspace = linspace[self.filter]
            resolution = len(linspace)
        else:
            resolution = 100
            linspace = np.linspace(x.min(), x.max(), resolution)

        grid = []
        for i, xi in enumerate(linspace):
            gev_param = self.get_gev_params(np.array([xi]))
            assert not gev_param.has_undefined_parameters, 'This case needs to be handled during display,' \
                                                           'gev_parameter for xi={} is undefined'.format(xi)
            grid.append(gev_param.summary_dict)
        grid = {gev_param: [g[gev_param] for g in grid] for gev_param in GevParams.SUMMARY_NAMES}
        return grid, linspace

    # Visualization 2D

    def visualize_2D(self, gev_param_name=GevParams.LOC, ax=None, show=True, temporal_step=None):
        if ax is None:
            ax = plt.gca()

        # Special display
        imshow_shifted(ax, gev_param_name, self.grid_2D(temporal_step)[gev_param_name], self.visualization_extend,
                       self.mask_2D)

        # X axis
        ax.set_xlabel('coordinate X')
        plt.setp(ax.get_xticklabels(), visible=True)
        ax.xaxis.set_tick_params(labelbottom=True)
        # Y axis
        ax.set_ylabel('coordinate Y')
        plt.setp(ax.get_yticklabels(), visible=True)
        ax.yaxis.set_tick_params(labelbottom=True)
        # todo: add dot display in 2D
        if show:
            plt.show()

    @property
    def visualization_x_limits(self):
        if self._visualization_x_limits is None:
            return self.x.min(), self.x.max()
        else:
            return self._visualization_x_limits

    @property
    def visualization_y_limits(self):
        if self._visualization_y_limits is None:
            return self.y.min(), self.y.max()
        else:
            return self._visualization_y_limits

    @property
    def visualization_extend(self):
        return self.visualization_x_limits + self.visualization_y_limits

    def grid_2D(self, temporal_step=None):
        # Cache the results
        if temporal_step not in self.temporal_step_to_grid_2D:
            self.temporal_step_to_grid_2D[temporal_step] = self._grid_2D(temporal_step)
        return self.temporal_step_to_grid_2D[temporal_step]

    def _grid_2D(self, temporal_step=None):
        grid = []
        for xi in np.linspace(*self.visualization_x_limits, self.VISUALIZATION_RESOLUTION):
            for yj in np.linspace(*self.visualization_y_limits, self.VISUALIZATION_RESOLUTION):
                # Build spatio temporal coordinate
                coordinate = [xi, yj]
                if temporal_step is not None:
                    coordinate.append(temporal_step)
                grid.append(self.get_gev_params(np.array(coordinate)).summary_dict)
        grid = {value_name: np.array([g[value_name] for g in grid]).reshape(
            [self.VISUALIZATION_RESOLUTION, self.VISUALIZATION_RESOLUTION])
            for value_name in GevParams.SUMMARY_NAMES}
        return grid

    # Visualization 3D

    def visualize_2D_spatial_1D_temporal(self, gev_param_name=GevParams.LOC, axes=None, show=True):
        if axes is None:
            axes = create_adjusted_axes(self.VISUALIZATION_TEMPORAL_STEPS, 1)
        assert len(axes) == self.VISUALIZATION_TEMPORAL_STEPS

        # Build temporal_steps a list of time steps
        assert len(self.temporal_steps) == self.VISUALIZATION_TEMPORAL_STEPS
        for ax, temporal_step in zip(axes, self.temporal_steps):
            self.visualize_2D(gev_param_name, ax, show=False, temporal_step=temporal_step)
            self.set_title(ax, gev_param_name)

        if show:
            plt.show()

    @cached_property
    def temporal_steps(self) -> List[int]:
        future_temporal_steps = [10, 100] if self.add_future_temporal_steps else []
        nb_past_temporal_step = self.VISUALIZATION_TEMPORAL_STEPS - len(future_temporal_steps)
        start, stop = self.coordinates.df_temporal_range()
        temporal_steps = [int(step) for step in np.linspace(start, stop, num=nb_past_temporal_step)]
        temporal_steps += [stop + step for step in future_temporal_steps]
        return temporal_steps