diff --git a/extreme_estimator/estimator/margin_estimator.py b/extreme_estimator/estimator/margin_estimator.py index 30bb6836f6494ab6dc5fabe4077ae0d026d2b8fe..82fdc39fbdf040d24a41d22cf992e4d294acf3f0 100644 --- a/extreme_estimator/estimator/margin_estimator.py +++ b/extreme_estimator/estimator/margin_estimator.py @@ -1,5 +1,8 @@ from extreme_estimator.extreme_models.margin_model.abstract_margin_model import AbstractMarginModel from extreme_estimator.estimator.abstract_estimator import AbstractEstimator +from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \ + AbstractMarginFunction +from extreme_estimator.extreme_models.margin_model.smooth_margin_model import LinearMarginModel from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset @@ -11,8 +14,8 @@ class AbstractMarginEstimator(AbstractEstimator): self._margin_function_fitted = None @property - def margin_function_fitted(self): - assert self._margin_function_fitted is not None, 'Error: estimator has not been not fitted yet' + def margin_function_fitted(self) -> AbstractMarginFunction: + assert self._margin_function_fitted is not None, 'Error: estimator has not been fitted' return self._margin_function_fitted @@ -23,9 +26,9 @@ class PointWiseMarginEstimator(AbstractMarginEstimator): class SmoothMarginEstimator(AbstractMarginEstimator): """# with different type of marginals: cosntant, linear....""" - def __init__(self, dataset: AbstractDataset, margin_model: AbstractMarginModel): + def __init__(self, dataset: AbstractDataset, margin_model: LinearMarginModel): super().__init__(dataset) - assert isinstance(margin_model, AbstractMarginModel) + assert isinstance(margin_model, LinearMarginModel) self.margin_model = margin_model def _fit(self): diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py index 6db453091f16ea9c8890c8ba1fac0bc1641d6f29..e5a593b45c1f1241c22fd4e1f075a2dd0c05e33f 100644 --- a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py +++ b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py @@ -18,16 +18,23 @@ class AbstractMarginFunction(object): def get_gev_params(self, coordinate: np.ndarray) -> GevParams: pass - def visualize_2D(self, gev_param_name=GevParams.GEV_LOC, show=False): + def visualize_2D(self, gev_param_name=GevParams.GEV_LOC, ax=None, show=False): x = self.spatial_coordinates.x_coordinates y = self.spatial_coordinates.y_coordinates + grid = self.get_grid_2D(x, y) + gev_param_idx = GevParams.GEV_PARAM_NAMES.index(gev_param_name) + if ax is None: + ax = plt.gca() + imshow_method = ax.imshow + imshow_method(grid[..., gev_param_idx], extent=(x.min(), x.max(), y.min(), y.max()), + interpolation='nearest', cmap=cm.gist_rainbow) + if show: + plt.show() + + def get_grid_2D(self, x, y): resolution = 100 grid = np.zeros([resolution, resolution, 3]) for i, xi in enumerate(np.linspace(x.min(), x.max(), resolution)): for j, yj in enumerate(np.linspace(y.min(), y.max(), resolution)): grid[i, j] = self.get_gev_params(np.array([xi, yj])).to_array() - gev_param_idx = GevParams.GEV_PARAM_NAMES.index(gev_param_name) - plt.imshow(grid[..., gev_param_idx], extent=(x.min(), x.max(), y.min(), y.max()), - interpolation='nearest', cmap=cm.gist_rainbow) - if show: - plt.show() + return grid diff --git a/extreme_estimator/extreme_models/margin_model/smooth_margin_model.py b/extreme_estimator/extreme_models/margin_model/smooth_margin_model.py index 8d9febabb2a1d9bee04fb38a1e70af3e9daedca2..6a33d939a567af114610bfc45b2fba0653d5c9ff 100644 --- a/extreme_estimator/extreme_models/margin_model/smooth_margin_model.py +++ b/extreme_estimator/extreme_models/margin_model/smooth_margin_model.py @@ -33,8 +33,8 @@ class LinearShapeAxis0MarginModel(LinearMarginModel): def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_linear_axis=None): super().load_margin_functions({GevParams.GEV_SHAPE: 0}) - # def fitmargin_from_maxima_gev(self, maxima_gev: np.ndarray, coordinates: np.ndarray) -> AbstractMarginFunction: - # return self.margin_function_start_fit + def fitmargin_from_maxima_gev(self, maxima_gev: np.ndarray, coordinates: np.ndarray) -> AbstractMarginFunction: + return self.margin_function_start_fit if __name__ == '__main__': diff --git a/extreme_estimator/return_level_plot/spatial_2D_plot.py b/extreme_estimator/return_level_plot/spatial_2D_plot.py new file mode 100644 index 0000000000000000000000000000000000000000..8347b0e90cf475097612577e6dc9296a8d698542 --- /dev/null +++ b/extreme_estimator/return_level_plot/spatial_2D_plot.py @@ -0,0 +1,30 @@ +from itertools import product +from typing import List, Dict + +import matplotlib.pyplot as plt + +from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \ + AbstractMarginFunction +from extreme_estimator.gev_params import GevParams + +plt.style.use('seaborn-white') + + +class Spatial2DPlot(object): + + def __init__(self, name_to_margin_function: Dict[str, AbstractMarginFunction]): + self.name_to_margin_function = name_to_margin_function # type: Dict[str, AbstractMarginFunction] + self.grid_columns = GevParams.GEV_PARAM_NAMES + + def plot(self): + nb_grid_rows, nb_grid_columns = len(self.name_to_margin_function), len(self.grid_columns) + fig, axes = plt.subplots(nb_grid_rows, nb_grid_columns, sharex='col', sharey='row') + fig.subplots_adjust(hspace=0.4, wspace=0.4) + margin_function: AbstractMarginFunction + for i, (name, margin_function) in enumerate(self.name_to_margin_function.items()): + for j, param_name in enumerate(self.grid_columns): + ax = axes[i, j] if nb_grid_rows > 1 else axes[j] + margin_function.visualize_2D(gev_param_name=param_name, ax=ax) + ax.set_title("{} for {}".format(param_name, name)) + fig.suptitle('Spatial2DPlot') + plt.show() diff --git a/test/test_extreme_estimator/test_estimator/test_full_estimators.py b/test/test_extreme_estimator/test_estimator/test_full_estimators.py index 5e03f8fd0d6a070d239cba3705fab66c00d271ba..49df744f0ded4bbefef5f84e2825a2e02bf2e4d4 100644 --- a/test/test_extreme_estimator/test_estimator/test_full_estimators.py +++ b/test/test_extreme_estimator/test_estimator/test_full_estimators.py @@ -4,7 +4,7 @@ from itertools import product from extreme_estimator.estimator.full_estimator import SmoothMarginalsThenUnitaryMsp from spatio_temporal_dataset.dataset.simulation_dataset import FullSimulatedDataset from spatio_temporal_dataset.spatial_coordinates.generated_coordinates import CircleCoordinatesRadius1 -from test.test_extreme_estimator.test_estimator.test_margin_estimators import TestMarginEstimators +from test.test_extreme_estimator.test_estimator.test_margin_estimators import TestSmoothMarginEstimator from test.test_extreme_estimator.test_estimator.test_max_stable_estimators import TestMaxStableEstimators @@ -16,7 +16,7 @@ class TestFullEstimators(unittest.TestCase): super().setUp() self.spatial_coordinates = CircleCoordinatesRadius1.from_nb_points(nb_points=5, max_radius=1) self.max_stable_models = TestMaxStableEstimators.load_max_stable_models() - self.margin_models = TestMarginEstimators.load_margin_models(spatial_coordinates=self.spatial_coordinates) + self.margin_models = TestSmoothMarginEstimator.load_margin_models(spatial_coordinates=self.spatial_coordinates) def test_full_estimators(self): for margin_model, max_stable_model in product(self.margin_models, self.max_stable_models): diff --git a/test/test_extreme_estimator/test_estimator/test_margin_estimators.py b/test/test_extreme_estimator/test_estimator/test_margin_estimators.py index 480ebba2df2b33ad4a51c4d51fe9b073946b5d11..9dcea3dcaef2861970b3cf0e4244f8bc0f70b94c 100644 --- a/test/test_extreme_estimator/test_estimator/test_margin_estimators.py +++ b/test/test_extreme_estimator/test_estimator/test_margin_estimators.py @@ -1,15 +1,17 @@ import unittest from extreme_estimator.extreme_models.margin_model.abstract_margin_model import AbstractMarginModel -from extreme_estimator.extreme_models.margin_model.smooth_margin_model import ConstantMarginModel +from extreme_estimator.extreme_models.margin_model.smooth_margin_model import ConstantMarginModel, \ + LinearShapeAxis0MarginModel from extreme_estimator.estimator.margin_estimator import SmoothMarginEstimator +from extreme_estimator.return_level_plot.spatial_2D_plot import Spatial2DPlot from spatio_temporal_dataset.dataset.simulation_dataset import MarginDataset from spatio_temporal_dataset.spatial_coordinates.generated_coordinates import CircleCoordinatesRadius1 -class TestMarginEstimators(unittest.TestCase): +class TestSmoothMarginEstimator(unittest.TestCase): DISPLAY = False - MARGIN_TYPES = [ConstantMarginModel] + MARGIN_TYPES = [ConstantMarginModel, LinearShapeAxis0MarginModel][1:] MARGIN_ESTIMATORS = [SmoothMarginEstimator] def setUp(self): @@ -25,14 +27,17 @@ class TestMarginEstimators(unittest.TestCase): for margin_model in self.margin_models: dataset = MarginDataset.from_sampling(nb_obs=10, margin_model=margin_model, spatial_coordinates=self.spatial_coordinates) - - for estimator_class in self.MARGIN_ESTIMATORS: - estimator = estimator_class(dataset=dataset, margin_model=margin_model) - estimator.fit() - if self.DISPLAY: - print(type(margin_model)) - print(dataset.df_dataset.head()) - print(estimator.additional_information) + # Fit estimator + estimator = SmoothMarginEstimator(dataset=dataset, margin_model=margin_model) + estimator.fit() + # Map name to their margin functions + name_to_margin_function = { + 'Ground truth margin function': dataset.margin_model.margin_function_sample, + # 'Estimated margin function': estimator.margin_function_fitted, + } + # Spatial Plot + if self.DISPLAY: + Spatial2DPlot(name_to_margin_function=name_to_margin_function).plot() self.assertTrue(True)