abstract_quantile_function.py 2.25 KiB
import numpy as np

from extreme_fit.function.abstract_function import AbstractFunction
from extreme_fit.function.margin_function.abstract_margin_function import AbstractMarginFunction
import matplotlib.pyplot as plt

from extreme_fit.function.param_function.param_function import AbstractParamFunction
from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates


class AbstractQuantileFunction(AbstractFunction):

    def get_quantile(self, coordinate: np.ndarray, is_transformed: bool = True) -> float:
        transformed_coordinate = coordinate if is_transformed else self.transform(coordinate)
        return self._get_quantile(transformed_coordinate)

    def _get_quantile(self, coordinate: np.ndarray):
        raise NotImplementedError

    def visualize(self, show=True):
        if self.coordinates.nb_coordinates == 1:
            self.visualize_1D(show=show)
        elif self.coordinates.nb_coordinates == 2:
            self.visualize_2D()
        else:
            return
            # raise NotImplementedError

    def visualize_1D(self, ax=None, show=True):
        if ax is None:
            ax = plt.gca()
        x = self.coordinates.coordinates_values()
        resolution = 100
        x = np.linspace(x.min(), x.max(), resolution)
        y = [self.get_quantile(np.array([e])) for e in x]
        ax.plot(x, y)
        if show:
            plt.show()

    def visualize_2D(self):
        return


class QuantileFunctionFromParamFunction(AbstractQuantileFunction):

    def __init__(self, coordinates: AbstractCoordinates, param_function: AbstractParamFunction):
        super().__init__(coordinates)
        self.param_function = param_function

    def _get_quantile(self, coordinate: np.ndarray) -> float:
        return self.param_function.get_param_value(coordinate)


class QuantileFunctionFromMarginFunction(AbstractQuantileFunction):

    def __init__(self, coordinates: AbstractCoordinates, margin_function: AbstractMarginFunction, quantile: float):
        super().__init__(coordinates)
        self.margin_function = margin_function
        self.quantile = quantile

    def _get_quantile(self, coordinate: np.ndarray) -> float:
        gev_params = self.margin_function.get_params(coordinate)
        return gev_params.quantile(self.quantile)