Commit 55abed0c authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[ROBUSTNESS PLOT] add class single_scalar_plot, and an example of use in spatial_robustness.py

parent a3e040da
No related merge requests found
Showing with 169 additions and 20 deletions
+169 -20
import time
class AbstractEstimator(object): class AbstractEstimator(object):
def __init__(self): def __init__(self):
pass self.fit_duration = None
def timed_fit(self):
ts = time.time()
result = self.fit()
te = time.time()
log_time = int((te - ts) * 1000)
self.fit_duration = log_time
return result
def fit(self): def fit(self):
pass pass
def error(self): def error(self, true_max_stable_params: dict):
pass pass
\ No newline at end of file
...@@ -5,18 +5,20 @@ import numpy as np ...@@ -5,18 +5,20 @@ import numpy as np
class MaxStableEstimator(AbstractEstimator): class MaxStableEstimator(AbstractEstimator):
MAE_ERROR = 'mae'
def __init__(self, dataset: AbstractDataset, max_stable_model: AbstractMaxStableModel): def __init__(self, dataset: AbstractDataset, max_stable_model: AbstractMaxStableModel):
self.dataset = dataset self.dataset = dataset
self.max_stable_model = max_stable_model self.max_stable_model = max_stable_model
# Fit parameters
self.max_stable_params_fitted = None self.max_stable_params_fitted = None
def fit(self): def fit(self):
self.max_stable_params_fitted = self.max_stable_model.fitmaxstab(maxima=self.dataset.maxima, coord=self.dataset.coord) self.max_stable_params_fitted = self.max_stable_model.fitmaxstab(maxima=self.dataset.maxima,
coord=self.dataset.coord)
def error(self, true_max_stable_params: dict): def error(self, true_max_stable_params: dict):
absolute_errors = {param_name: np.abs(param_true_value - self.max_stable_params_fitted[param_name]) absolute_errors = {param_name: np.abs(param_true_value - self.max_stable_params_fitted[param_name])
for param_name, param_true_value in true_max_stable_params.items()} for param_name, param_true_value in true_max_stable_params.items()}
mean_absolute_error = np.mean(np.array(list(absolute_errors.values()))) mean_absolute_error = np.mean(np.array(list(absolute_errors.values())))
# return {**absolute_errors, **{'mae': mean_absolute_error}} return {**absolute_errors, **{self.MAE_ERROR: mean_absolute_error}}
return mean_absolute_error
import matplotlib.pyplot as plt
plt.style.use('seaborn-white')
class DisplayItem(object):
def __init__(self, argument_name, default_value, dislay_name=None):
self.argument_name = argument_name
self.default_value = default_value
self.dislay_name = dislay_name if dislay_name is not None else self.argument_name
def values_from_kwargs(self, **kwargs):
values = kwargs.get(self.argument_name, [self.default_value])
assert isinstance(values, list)
return values
def value_from_kwargs(self, **kwargs):
return kwargs.get(self.argument_name, self.default_value)
def update_kwargs_value(self, value, **kwargs):
updated_kwargs = kwargs.copy()
updated_kwargs.update({self.argument_name: value})
return updated_kwargs
class AbstractPlot(object):
COLORS = ['blue', 'red', 'green', 'black', 'magenta', 'cyan']
def __init__(self, grid_row_item, grid_column_item, plot_row_item, plot_label_item):
self.grid_row_item = grid_row_item # type: DisplayItem
self.grid_column_item = grid_column_item # type: DisplayItem
self.plot_row_item = plot_row_item # type: DisplayItem
self.plot_label_item = plot_label_item # type: DisplayItem
from extreme_estimator.R_fit.max_stable_fit.abstract_max_stable_model import AbstractMaxStableModel
from extreme_estimator.R_fit.max_stable_fit.max_stable_models import Smith, BrownResnick
from extreme_estimator.estimator.msp_estimator import MaxStableEstimator
from extreme_estimator.robustness_plot.abstract_robustness_plot import DisplayItem
from extreme_estimator.robustness_plot.single_scalar_plot import SingleScalarPlot
from spatio_temporal_dataset.dataset.simulation_dataset import SimulatedDataset
from spatio_temporal_dataset.spatial_coordinates.alps_station_coordinates import AlpsStationCoordinates, \
AlpsStationCoordinatesBetweenZeroAndOne
from spatio_temporal_dataset.spatial_coordinates.generated_coordinates import CircleCoordinates
class MspSpatial(SingleScalarPlot):
MaxStableModelItem = DisplayItem('max_stable_model', Smith)
SpatialCoordinateClassItem = DisplayItem('spatial_coordinate_class', CircleCoordinates)
SpatialParamsItem = DisplayItem('spatial_params', {"r": 1})
NbStationItem = DisplayItem('nb_station', 50)
NbObservationItem = DisplayItem('nb_obs', 60)
def single_scalar_from_all_params(self, **kwargs_single_point) -> float:
# Get the argument from kwargs
max_stable_model = self.MaxStableModelItem.value_from_kwargs(**kwargs_single_point) # type: AbstractMaxStableModel
spatial_coordinate_class = self.SpatialCoordinateClassItem.value_from_kwargs(**kwargs_single_point)
nb_station = self.NbStationItem.value_from_kwargs(**kwargs_single_point)
spatial_params = self.SpatialParamsItem.value_from_kwargs(**kwargs_single_point)
nb_obs = self.NbObservationItem.value_from_kwargs(**kwargs_single_point)
# Run the estimation
spatial_coordinate = spatial_coordinate_class.from_nb_points(nb_points=nb_station, **spatial_params)
dataset = SimulatedDataset.from_max_stable_sampling(nb_obs=nb_obs, max_stable_model=max_stable_model,
spatial_coordinates=spatial_coordinate)
estimator = MaxStableEstimator(dataset, max_stable_model)
estimator.timed_fit()
errors = estimator.error(max_stable_model.params_sample)
mae_error = errors[MaxStableEstimator.MAE_ERROR]
return mae_error
def spatial_robustness_alps():
spatial_robustness = MspSpatial(grid_row_item=MspSpatial.NbObservationItem,
grid_column_item=MspSpatial.SpatialCoordinateClassItem,
plot_row_item=MspSpatial.NbStationItem,
plot_label_item=MspSpatial.MaxStableModelItem)
# Put only the parameter that will vary
spatial_robustness.robustness_grid_plot(**{
MspSpatial.NbStationItem.argument_name: [10, 30, 50, 70, 86][:],
MspSpatial.NbObservationItem.argument_name: [10],
MspSpatial.MaxStableModelItem.argument_name: [Smith(), BrownResnick()][:],
MspSpatial.SpatialCoordinateClassItem.argument_name: [CircleCoordinates, AlpsStationCoordinatesBetweenZeroAndOne][:],
})
if __name__ == '__main__':
spatial_robustness_alps()
from extreme_estimator.robustness_plot.abstract_robustness_plot import AbstractPlot
class MultipleScalarPlot(AbstractPlot):
"""
In a Multiple Scalar plot, for each
Each scalar, will be display on a grid row (to ease visual comparison)
"""
def __init__(self, grid_column_item, plot_row_item, plot_label_item):
super().__init__(grid_row_item=None, grid_column_item=grid_column_item,
plot_row_item=plot_row_item, plot_label_item=plot_label_item)
from typing import List from extreme_estimator.robustness_plot.abstract_robustness_plot import AbstractPlot
from extreme_estimator.estimator.msp_estimator import MaxStableEstimator
from extreme_estimator.R_fit.max_stable_fit.abstract_max_stable_model import GaussianMSP, AbstractMaxStableModel
from itertools import product
from spatio_temporal_dataset.dataset.simulation_dataset import SimulatedDataset
from spatio_temporal_dataset.spatial_coordinates.generated_coordinates import CircleCoordinates
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from itertools import product
plt.style.use('seaborn-white')
class DisplayItem(object):
def __init__(self, argument_name, default_value, dislay_name=None):
self.argument_name = argument_name
self.default_value = default_value
self.dislay_name = dislay_name if dislay_name is not None else self.argument_name
def values_from_kwargs(self, **kwargs):
return kwargs.get(self.argument_name, [self.default_value])
def value_from_kwargs(self, **kwargs):
return kwargs.get(self.argument_name, self.default_value)
MaxStableModelItem = DisplayItem('max_stable_model', GaussianMSP)
SpatialCoordinateClassItem = DisplayItem('spatial_coordinate_class', CircleCoordinates)
SpatialParamsItem = DisplayItem('spatial_params', {"r": 1})
NbStationItem = DisplayItem('nb_station', None)
NbObservationItem = DisplayItem('nb_obs', 50)
class AbstractRobustnessPlot(object):
def __init__(self, grid_row_item, grid_column_item, plot_row_item, plot_label_item): class SingleScalarPlot(AbstractPlot):
self.grid_row_item = grid_row_item # type: DisplayItem """
self.grid_column_item = grid_column_item # type: DisplayItem For a single scalar plot, for the combination of all the parameters of interest,
self.plot_row_item = plot_row_item # type: DisplayItem then the function
self.plot_label_item = plot_label_item # type: DisplayItem """
self.estimation_error = self.estimation_error_max_stable_unitary_frechet def single_scalar_from_all_params(self, **kwargs_single_point) -> float:
pass
def robustness_grid_plot(self, **kwargs): def robustness_grid_plot(self, **kwargs):
# Extract Grid row and columns values
grid_row_values = self.grid_row_item.values_from_kwargs(**kwargs) grid_row_values = self.grid_row_item.values_from_kwargs(**kwargs)
grid_column_values = self.grid_column_item.values_from_kwargs(**kwargs) grid_column_values = self.grid_column_item.values_from_kwargs(**kwargs)
nb_grid_rows, nb_grid_columns = len(grid_row_values), len(grid_column_values) nb_grid_rows, nb_grid_columns = len(grid_row_values), len(grid_column_values)
# Start the overall plot
fig = plt.figure() fig = plt.figure()
fig.subplots_adjust(hspace=0.4, wspace=0.4) fig.subplots_adjust(hspace=0.4, wspace=0.4)
for i, (grid_row_value, grid_column_value) in enumerate(product(grid_row_values, grid_column_values), 1): for i, (grid_row_value, grid_column_value) in enumerate(product(grid_row_values, grid_column_values), 1):
print('Grid plot: {}={} {}={}'.format(self.grid_row_item.dislay_name, grid_row_value, print('Grid plot: {}={} {}={}'.format(self.grid_row_item.dislay_name, grid_row_value,
self.grid_column_item.dislay_name, grid_column_value)) self.grid_column_item.dislay_name, grid_column_value))
ax = fig.add_subplot(nb_grid_rows, nb_grid_columns, i) ax = fig.add_subplot(nb_grid_rows, nb_grid_columns, i)
# Adapt the kwargs for the single plot
kwargs_single_plot = kwargs.copy() kwargs_single_plot = kwargs.copy()
kwargs_single_plot.update({self.grid_row_item.argument_name: grid_row_value, kwargs_single_plot.update({self.grid_row_item.argument_name: grid_row_value,
self.grid_column_item.argument_name: grid_column_value}) self.grid_column_item.argument_name: grid_column_value})
...@@ -61,37 +34,19 @@ class AbstractRobustnessPlot(object): ...@@ -61,37 +34,19 @@ class AbstractRobustnessPlot(object):
def robustness_single_plot(self, ax, **kwargs_single_plot): def robustness_single_plot(self, ax, **kwargs_single_plot):
plot_row_values = self.plot_row_item.values_from_kwargs(**kwargs_single_plot) plot_row_values = self.plot_row_item.values_from_kwargs(**kwargs_single_plot)
plot_label_values = self.plot_label_item.values_from_kwargs(**kwargs_single_plot) plot_label_values = self.plot_label_item.values_from_kwargs(**kwargs_single_plot)
colors = ['blue', 'red', 'green', 'black']
assert isinstance(plot_label_values, list), plot_label_values
assert isinstance(plot_row_values, list), plot_row_values
for j, plot_label_value in enumerate(plot_label_values): for j, plot_label_value in enumerate(plot_label_values):
# Compute
plot_row_value_to_error = {} plot_row_value_to_error = {}
# todo: do some parallzlization here # todo: do some parallzlization here (do the parallelization in the Asbtract class if possible)
for plot_row_value in plot_row_values: for plot_row_value in plot_row_values:
# Adapt the kwargs for the single value
kwargs_single_point = kwargs_single_plot.copy() kwargs_single_point = kwargs_single_plot.copy()
kwargs_single_point.update({self.plot_row_item.argument_name: plot_row_value, kwargs_single_point.update({self.plot_row_item.argument_name: plot_row_value,
self.plot_label_item.argument_name: plot_label_value}) self.plot_label_item.argument_name: plot_label_value})
plot_row_value_to_error[plot_row_value] = self.estimation_error(**kwargs_single_point) plot_row_value_to_error[plot_row_value] = self.single_scalar_from_all_params(**kwargs_single_point)
plot_column_values = [plot_row_value_to_error[plot_row_value] for plot_row_value in plot_row_values] plot_column_values = [plot_row_value_to_error[plot_row_value] for plot_row_value in plot_row_values]
ax.plot(plot_row_values, plot_column_values, color=colors[j % len(colors)], label=str(j)) ax.plot(plot_row_values, plot_column_values, color=self.COLORS[j % len(self.COLORS)], label=str(j))
ax.legend() ax.legend()
ax.set_xlabel(self.plot_row_item.dislay_name) ax.set_xlabel(self.plot_row_item.dislay_name)
ax.set_ylabel('Absolute error') ax.set_ylabel('Absolute error')
ax.set_title('Title (display all the other parameters)') ax.set_title('Title (display all the other parameters)')
@staticmethod
def estimation_error_max_stable_unitary_frechet(**kwargs_single_points):
# Get the argument from kwargs
max_stable_model = MaxStableModelItem.value_from_kwargs(**kwargs_single_points)
spatial_coordinate_class = SpatialCoordinateClassItem.value_from_kwargs(**kwargs_single_points)
nb_station = NbStationItem.value_from_kwargs(**kwargs_single_points)
spatial_params = SpatialParamsItem.value_from_kwargs(**kwargs_single_points)
nb_obs = NbObservationItem.value_from_kwargs(**kwargs_single_points)
# Run the estimation
spatial_coordinate = spatial_coordinate_class.from_nb_points(nb_points=nb_station, **spatial_params)
dataset = SimulatedDataset.from_max_stable_sampling(nb_obs=nb_obs, max_stable_model=max_stable_model,
spatial_coordinates=spatial_coordinate)
estimator = MaxStableEstimator(dataset, max_stable_model)
estimator.fit()
errors = estimator.error(max_stable_model.params_sample)
return errors
from extreme_estimator.robustness_plot.abstract_robustness import DisplayItem, AbstractRobustnessPlot, \
SpatialCoordinateClassItem, NbObservationItem, NbStationItem, MaxStableModelItem, SpatialParamsItem
from spatio_temporal_dataset.spatial_coordinates.generated_coordinates import CircleCoordinates
spatial_robustness = AbstractRobustnessPlot(grid_row_item=SpatialCoordinateClassItem,
grid_column_item=NbObservationItem,
plot_row_item=NbStationItem,
plot_label_item=MaxStableModelItem)
# Put only the parameter that will vary
spatial_robustness.robustness_grid_plot(**{
NbStationItem.argument_name: [10, 30, 50, 100],
MaxStableModelItem.argument_name: [GaussianMSP(), BrownResick()][:],
NbObservationItem.argument_name: [10, 50, 100]
})
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment