From 30f2543ff1bb5873a54c002c688545c258a35a81 Mon Sep 17 00:00:00 2001 From: Le Roux Erwan <erwan.le-roux@irstea.fr> Date: Thu, 15 Nov 2018 17:22:05 +0100 Subject: [PATCH] [RETURN LEVEL PLOT] add return level plot and modify test_margin_estimators accordingly --- .../estimator/margin_estimator.py | 11 ++++--- .../abstract_margin_function.py | 19 ++++++++---- .../margin_model/smooth_margin_model.py | 4 +-- .../return_level_plot/spatial_2D_plot.py | 30 +++++++++++++++++++ .../test_estimator/test_full_estimators.py | 4 +-- .../test_estimator/test_margin_estimators.py | 27 ++++++++++------- 6 files changed, 70 insertions(+), 25 deletions(-) create mode 100644 extreme_estimator/return_level_plot/spatial_2D_plot.py diff --git a/extreme_estimator/estimator/margin_estimator.py b/extreme_estimator/estimator/margin_estimator.py index 30bb6836..82fdc39f 100644 --- a/extreme_estimator/estimator/margin_estimator.py +++ b/extreme_estimator/estimator/margin_estimator.py @@ -1,5 +1,8 @@ from extreme_estimator.extreme_models.margin_model.abstract_margin_model import AbstractMarginModel from extreme_estimator.estimator.abstract_estimator import AbstractEstimator +from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \ + AbstractMarginFunction +from extreme_estimator.extreme_models.margin_model.smooth_margin_model import LinearMarginModel from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset @@ -11,8 +14,8 @@ class AbstractMarginEstimator(AbstractEstimator): self._margin_function_fitted = None @property - def margin_function_fitted(self): - assert self._margin_function_fitted is not None, 'Error: estimator has not been not fitted yet' + def margin_function_fitted(self) -> AbstractMarginFunction: + assert self._margin_function_fitted is not None, 'Error: estimator has not been fitted' return self._margin_function_fitted @@ -23,9 +26,9 @@ class PointWiseMarginEstimator(AbstractMarginEstimator): class SmoothMarginEstimator(AbstractMarginEstimator): """# with different type of marginals: cosntant, linear....""" - def __init__(self, dataset: AbstractDataset, margin_model: AbstractMarginModel): + def __init__(self, dataset: AbstractDataset, margin_model: LinearMarginModel): super().__init__(dataset) - assert isinstance(margin_model, AbstractMarginModel) + assert isinstance(margin_model, LinearMarginModel) self.margin_model = margin_model def _fit(self): diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py index 6db45309..e5a593b4 100644 --- a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py +++ b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py @@ -18,16 +18,23 @@ class AbstractMarginFunction(object): def get_gev_params(self, coordinate: np.ndarray) -> GevParams: pass - def visualize_2D(self, gev_param_name=GevParams.GEV_LOC, show=False): + def visualize_2D(self, gev_param_name=GevParams.GEV_LOC, ax=None, show=False): x = self.spatial_coordinates.x_coordinates y = self.spatial_coordinates.y_coordinates + grid = self.get_grid_2D(x, y) + gev_param_idx = GevParams.GEV_PARAM_NAMES.index(gev_param_name) + if ax is None: + ax = plt.gca() + imshow_method = ax.imshow + imshow_method(grid[..., gev_param_idx], extent=(x.min(), x.max(), y.min(), y.max()), + interpolation='nearest', cmap=cm.gist_rainbow) + if show: + plt.show() + + def get_grid_2D(self, x, y): resolution = 100 grid = np.zeros([resolution, resolution, 3]) for i, xi in enumerate(np.linspace(x.min(), x.max(), resolution)): for j, yj in enumerate(np.linspace(y.min(), y.max(), resolution)): grid[i, j] = self.get_gev_params(np.array([xi, yj])).to_array() - gev_param_idx = GevParams.GEV_PARAM_NAMES.index(gev_param_name) - plt.imshow(grid[..., gev_param_idx], extent=(x.min(), x.max(), y.min(), y.max()), - interpolation='nearest', cmap=cm.gist_rainbow) - if show: - plt.show() + return grid diff --git a/extreme_estimator/extreme_models/margin_model/smooth_margin_model.py b/extreme_estimator/extreme_models/margin_model/smooth_margin_model.py index 8d9febab..6a33d939 100644 --- a/extreme_estimator/extreme_models/margin_model/smooth_margin_model.py +++ b/extreme_estimator/extreme_models/margin_model/smooth_margin_model.py @@ -33,8 +33,8 @@ class LinearShapeAxis0MarginModel(LinearMarginModel): def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_linear_axis=None): super().load_margin_functions({GevParams.GEV_SHAPE: 0}) - # def fitmargin_from_maxima_gev(self, maxima_gev: np.ndarray, coordinates: np.ndarray) -> AbstractMarginFunction: - # return self.margin_function_start_fit + def fitmargin_from_maxima_gev(self, maxima_gev: np.ndarray, coordinates: np.ndarray) -> AbstractMarginFunction: + return self.margin_function_start_fit if __name__ == '__main__': diff --git a/extreme_estimator/return_level_plot/spatial_2D_plot.py b/extreme_estimator/return_level_plot/spatial_2D_plot.py new file mode 100644 index 00000000..8347b0e9 --- /dev/null +++ b/extreme_estimator/return_level_plot/spatial_2D_plot.py @@ -0,0 +1,30 @@ +from itertools import product +from typing import List, Dict + +import matplotlib.pyplot as plt + +from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \ + AbstractMarginFunction +from extreme_estimator.gev_params import GevParams + +plt.style.use('seaborn-white') + + +class Spatial2DPlot(object): + + def __init__(self, name_to_margin_function: Dict[str, AbstractMarginFunction]): + self.name_to_margin_function = name_to_margin_function # type: Dict[str, AbstractMarginFunction] + self.grid_columns = GevParams.GEV_PARAM_NAMES + + def plot(self): + nb_grid_rows, nb_grid_columns = len(self.name_to_margin_function), len(self.grid_columns) + fig, axes = plt.subplots(nb_grid_rows, nb_grid_columns, sharex='col', sharey='row') + fig.subplots_adjust(hspace=0.4, wspace=0.4) + margin_function: AbstractMarginFunction + for i, (name, margin_function) in enumerate(self.name_to_margin_function.items()): + for j, param_name in enumerate(self.grid_columns): + ax = axes[i, j] if nb_grid_rows > 1 else axes[j] + margin_function.visualize_2D(gev_param_name=param_name, ax=ax) + ax.set_title("{} for {}".format(param_name, name)) + fig.suptitle('Spatial2DPlot') + plt.show() diff --git a/test/test_extreme_estimator/test_estimator/test_full_estimators.py b/test/test_extreme_estimator/test_estimator/test_full_estimators.py index 5e03f8fd..49df744f 100644 --- a/test/test_extreme_estimator/test_estimator/test_full_estimators.py +++ b/test/test_extreme_estimator/test_estimator/test_full_estimators.py @@ -4,7 +4,7 @@ from itertools import product from extreme_estimator.estimator.full_estimator import SmoothMarginalsThenUnitaryMsp from spatio_temporal_dataset.dataset.simulation_dataset import FullSimulatedDataset from spatio_temporal_dataset.spatial_coordinates.generated_coordinates import CircleCoordinatesRadius1 -from test.test_extreme_estimator.test_estimator.test_margin_estimators import TestMarginEstimators +from test.test_extreme_estimator.test_estimator.test_margin_estimators import TestSmoothMarginEstimator from test.test_extreme_estimator.test_estimator.test_max_stable_estimators import TestMaxStableEstimators @@ -16,7 +16,7 @@ class TestFullEstimators(unittest.TestCase): super().setUp() self.spatial_coordinates = CircleCoordinatesRadius1.from_nb_points(nb_points=5, max_radius=1) self.max_stable_models = TestMaxStableEstimators.load_max_stable_models() - self.margin_models = TestMarginEstimators.load_margin_models(spatial_coordinates=self.spatial_coordinates) + self.margin_models = TestSmoothMarginEstimator.load_margin_models(spatial_coordinates=self.spatial_coordinates) def test_full_estimators(self): for margin_model, max_stable_model in product(self.margin_models, self.max_stable_models): diff --git a/test/test_extreme_estimator/test_estimator/test_margin_estimators.py b/test/test_extreme_estimator/test_estimator/test_margin_estimators.py index 480ebba2..9dcea3dc 100644 --- a/test/test_extreme_estimator/test_estimator/test_margin_estimators.py +++ b/test/test_extreme_estimator/test_estimator/test_margin_estimators.py @@ -1,15 +1,17 @@ import unittest from extreme_estimator.extreme_models.margin_model.abstract_margin_model import AbstractMarginModel -from extreme_estimator.extreme_models.margin_model.smooth_margin_model import ConstantMarginModel +from extreme_estimator.extreme_models.margin_model.smooth_margin_model import ConstantMarginModel, \ + LinearShapeAxis0MarginModel from extreme_estimator.estimator.margin_estimator import SmoothMarginEstimator +from extreme_estimator.return_level_plot.spatial_2D_plot import Spatial2DPlot from spatio_temporal_dataset.dataset.simulation_dataset import MarginDataset from spatio_temporal_dataset.spatial_coordinates.generated_coordinates import CircleCoordinatesRadius1 -class TestMarginEstimators(unittest.TestCase): +class TestSmoothMarginEstimator(unittest.TestCase): DISPLAY = False - MARGIN_TYPES = [ConstantMarginModel] + MARGIN_TYPES = [ConstantMarginModel, LinearShapeAxis0MarginModel][1:] MARGIN_ESTIMATORS = [SmoothMarginEstimator] def setUp(self): @@ -25,14 +27,17 @@ class TestMarginEstimators(unittest.TestCase): for margin_model in self.margin_models: dataset = MarginDataset.from_sampling(nb_obs=10, margin_model=margin_model, spatial_coordinates=self.spatial_coordinates) - - for estimator_class in self.MARGIN_ESTIMATORS: - estimator = estimator_class(dataset=dataset, margin_model=margin_model) - estimator.fit() - if self.DISPLAY: - print(type(margin_model)) - print(dataset.df_dataset.head()) - print(estimator.additional_information) + # Fit estimator + estimator = SmoothMarginEstimator(dataset=dataset, margin_model=margin_model) + estimator.fit() + # Map name to their margin functions + name_to_margin_function = { + 'Ground truth margin function': dataset.margin_model.margin_function_sample, + # 'Estimated margin function': estimator.margin_function_fitted, + } + # Spatial Plot + if self.DISPLAY: + Spatial2DPlot(name_to_margin_function=name_to_margin_function).plot() self.assertTrue(True) -- GitLab