Commit 30f2543f authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[RETURN LEVEL PLOT] add return level plot and modify test_margin_estimators accordingly

parent 05a80ecc
No related merge requests found
Showing with 70 additions and 25 deletions
+70 -25
from extreme_estimator.extreme_models.margin_model.abstract_margin_model import AbstractMarginModel
from extreme_estimator.estimator.abstract_estimator import AbstractEstimator
from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \
AbstractMarginFunction
from extreme_estimator.extreme_models.margin_model.smooth_margin_model import LinearMarginModel
from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
......@@ -11,8 +14,8 @@ class AbstractMarginEstimator(AbstractEstimator):
self._margin_function_fitted = None
@property
def margin_function_fitted(self):
assert self._margin_function_fitted is not None, 'Error: estimator has not been not fitted yet'
def margin_function_fitted(self) -> AbstractMarginFunction:
assert self._margin_function_fitted is not None, 'Error: estimator has not been fitted'
return self._margin_function_fitted
......@@ -23,9 +26,9 @@ class PointWiseMarginEstimator(AbstractMarginEstimator):
class SmoothMarginEstimator(AbstractMarginEstimator):
"""# with different type of marginals: cosntant, linear...."""
def __init__(self, dataset: AbstractDataset, margin_model: AbstractMarginModel):
def __init__(self, dataset: AbstractDataset, margin_model: LinearMarginModel):
super().__init__(dataset)
assert isinstance(margin_model, AbstractMarginModel)
assert isinstance(margin_model, LinearMarginModel)
self.margin_model = margin_model
def _fit(self):
......
......@@ -18,16 +18,23 @@ class AbstractMarginFunction(object):
def get_gev_params(self, coordinate: np.ndarray) -> GevParams:
pass
def visualize_2D(self, gev_param_name=GevParams.GEV_LOC, show=False):
def visualize_2D(self, gev_param_name=GevParams.GEV_LOC, ax=None, show=False):
x = self.spatial_coordinates.x_coordinates
y = self.spatial_coordinates.y_coordinates
grid = self.get_grid_2D(x, y)
gev_param_idx = GevParams.GEV_PARAM_NAMES.index(gev_param_name)
if ax is None:
ax = plt.gca()
imshow_method = ax.imshow
imshow_method(grid[..., gev_param_idx], extent=(x.min(), x.max(), y.min(), y.max()),
interpolation='nearest', cmap=cm.gist_rainbow)
if show:
plt.show()
def get_grid_2D(self, x, y):
resolution = 100
grid = np.zeros([resolution, resolution, 3])
for i, xi in enumerate(np.linspace(x.min(), x.max(), resolution)):
for j, yj in enumerate(np.linspace(y.min(), y.max(), resolution)):
grid[i, j] = self.get_gev_params(np.array([xi, yj])).to_array()
gev_param_idx = GevParams.GEV_PARAM_NAMES.index(gev_param_name)
plt.imshow(grid[..., gev_param_idx], extent=(x.min(), x.max(), y.min(), y.max()),
interpolation='nearest', cmap=cm.gist_rainbow)
if show:
plt.show()
return grid
......@@ -33,8 +33,8 @@ class LinearShapeAxis0MarginModel(LinearMarginModel):
def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_linear_axis=None):
super().load_margin_functions({GevParams.GEV_SHAPE: 0})
# def fitmargin_from_maxima_gev(self, maxima_gev: np.ndarray, coordinates: np.ndarray) -> AbstractMarginFunction:
# return self.margin_function_start_fit
def fitmargin_from_maxima_gev(self, maxima_gev: np.ndarray, coordinates: np.ndarray) -> AbstractMarginFunction:
return self.margin_function_start_fit
if __name__ == '__main__':
......
from itertools import product
from typing import List, Dict
import matplotlib.pyplot as plt
from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \
AbstractMarginFunction
from extreme_estimator.gev_params import GevParams
plt.style.use('seaborn-white')
class Spatial2DPlot(object):
def __init__(self, name_to_margin_function: Dict[str, AbstractMarginFunction]):
self.name_to_margin_function = name_to_margin_function # type: Dict[str, AbstractMarginFunction]
self.grid_columns = GevParams.GEV_PARAM_NAMES
def plot(self):
nb_grid_rows, nb_grid_columns = len(self.name_to_margin_function), len(self.grid_columns)
fig, axes = plt.subplots(nb_grid_rows, nb_grid_columns, sharex='col', sharey='row')
fig.subplots_adjust(hspace=0.4, wspace=0.4)
margin_function: AbstractMarginFunction
for i, (name, margin_function) in enumerate(self.name_to_margin_function.items()):
for j, param_name in enumerate(self.grid_columns):
ax = axes[i, j] if nb_grid_rows > 1 else axes[j]
margin_function.visualize_2D(gev_param_name=param_name, ax=ax)
ax.set_title("{} for {}".format(param_name, name))
fig.suptitle('Spatial2DPlot')
plt.show()
......@@ -4,7 +4,7 @@ from itertools import product
from extreme_estimator.estimator.full_estimator import SmoothMarginalsThenUnitaryMsp
from spatio_temporal_dataset.dataset.simulation_dataset import FullSimulatedDataset
from spatio_temporal_dataset.spatial_coordinates.generated_coordinates import CircleCoordinatesRadius1
from test.test_extreme_estimator.test_estimator.test_margin_estimators import TestMarginEstimators
from test.test_extreme_estimator.test_estimator.test_margin_estimators import TestSmoothMarginEstimator
from test.test_extreme_estimator.test_estimator.test_max_stable_estimators import TestMaxStableEstimators
......@@ -16,7 +16,7 @@ class TestFullEstimators(unittest.TestCase):
super().setUp()
self.spatial_coordinates = CircleCoordinatesRadius1.from_nb_points(nb_points=5, max_radius=1)
self.max_stable_models = TestMaxStableEstimators.load_max_stable_models()
self.margin_models = TestMarginEstimators.load_margin_models(spatial_coordinates=self.spatial_coordinates)
self.margin_models = TestSmoothMarginEstimator.load_margin_models(spatial_coordinates=self.spatial_coordinates)
def test_full_estimators(self):
for margin_model, max_stable_model in product(self.margin_models, self.max_stable_models):
......
import unittest
from extreme_estimator.extreme_models.margin_model.abstract_margin_model import AbstractMarginModel
from extreme_estimator.extreme_models.margin_model.smooth_margin_model import ConstantMarginModel
from extreme_estimator.extreme_models.margin_model.smooth_margin_model import ConstantMarginModel, \
LinearShapeAxis0MarginModel
from extreme_estimator.estimator.margin_estimator import SmoothMarginEstimator
from extreme_estimator.return_level_plot.spatial_2D_plot import Spatial2DPlot
from spatio_temporal_dataset.dataset.simulation_dataset import MarginDataset
from spatio_temporal_dataset.spatial_coordinates.generated_coordinates import CircleCoordinatesRadius1
class TestMarginEstimators(unittest.TestCase):
class TestSmoothMarginEstimator(unittest.TestCase):
DISPLAY = False
MARGIN_TYPES = [ConstantMarginModel]
MARGIN_TYPES = [ConstantMarginModel, LinearShapeAxis0MarginModel][1:]
MARGIN_ESTIMATORS = [SmoothMarginEstimator]
def setUp(self):
......@@ -25,14 +27,17 @@ class TestMarginEstimators(unittest.TestCase):
for margin_model in self.margin_models:
dataset = MarginDataset.from_sampling(nb_obs=10, margin_model=margin_model,
spatial_coordinates=self.spatial_coordinates)
for estimator_class in self.MARGIN_ESTIMATORS:
estimator = estimator_class(dataset=dataset, margin_model=margin_model)
estimator.fit()
if self.DISPLAY:
print(type(margin_model))
print(dataset.df_dataset.head())
print(estimator.additional_information)
# Fit estimator
estimator = SmoothMarginEstimator(dataset=dataset, margin_model=margin_model)
estimator.fit()
# Map name to their margin functions
name_to_margin_function = {
'Ground truth margin function': dataset.margin_model.margin_function_sample,
# 'Estimated margin function': estimator.margin_function_fitted,
}
# Spatial Plot
if self.DISPLAY:
Spatial2DPlot(name_to_margin_function=name_to_margin_function).plot()
self.assertTrue(True)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment