Commit 68311fbd authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[ROBUSTNESS PLOT] refactor and add class multiple_plot

parent 55abed0c
No related merge requests found
Showing with 164 additions and 53 deletions
+164 -53
...@@ -2,20 +2,26 @@ import time ...@@ -2,20 +2,26 @@ import time
class AbstractEstimator(object): class AbstractEstimator(object):
DURATION = 'Duration'
MAE_ERROR = 'Mean Average Error'
def __init__(self): def __init__(self) -> None:
self.fit_duration = None self.additional_information = dict()
def timed_fit(self): def fit(self):
ts = time.time() ts = time.time()
result = self.fit() self._fit()
te = time.time() te = time.time()
log_time = int((te - ts) * 1000) self.additional_information[self.DURATION] = int((te - ts) * 1000)
self.fit_duration = log_time
return result
def fit(self): def scalars(self, true_max_stable_params: dict):
error = self._error(true_max_stable_params)
return {**error, **self.additional_information}
# Methods to override in the child class
def _fit(self):
pass pass
def error(self, true_max_stable_params: dict): def _error(self, true_max_stable_params: dict):
pass pass
\ No newline at end of file
...@@ -5,19 +5,19 @@ import numpy as np ...@@ -5,19 +5,19 @@ import numpy as np
class MaxStableEstimator(AbstractEstimator): class MaxStableEstimator(AbstractEstimator):
MAE_ERROR = 'mae'
def __init__(self, dataset: AbstractDataset, max_stable_model: AbstractMaxStableModel): def __init__(self, dataset: AbstractDataset, max_stable_model: AbstractMaxStableModel):
super().__init__()
self.dataset = dataset self.dataset = dataset
self.max_stable_model = max_stable_model self.max_stable_model = max_stable_model
# Fit parameters # Fit parameters
self.max_stable_params_fitted = None self.max_stable_params_fitted = None
def fit(self): def _fit(self):
self.max_stable_params_fitted = self.max_stable_model.fitmaxstab(maxima=self.dataset.maxima, self.max_stable_params_fitted = self.max_stable_model.fitmaxstab(maxima=self.dataset.maxima,
coord=self.dataset.coord) coord=self.dataset.coord)
def error(self, true_max_stable_params: dict): def _error(self, true_max_stable_params: dict):
absolute_errors = {param_name: np.abs(param_true_value - self.max_stable_params_fitted[param_name]) absolute_errors = {param_name: np.abs(param_true_value - self.max_stable_params_fitted[param_name])
for param_name, param_true_value in true_max_stable_params.items()} for param_name, param_true_value in true_max_stable_params.items()}
mean_absolute_error = np.mean(np.array(list(absolute_errors.values()))) mean_absolute_error = np.mean(np.array(list(absolute_errors.values())))
......
import matplotlib.pyplot as plt
plt.style.use('seaborn-white')
class DisplayItem(object): class DisplayItem(object):
...@@ -21,15 +17,4 @@ class DisplayItem(object): ...@@ -21,15 +17,4 @@ class DisplayItem(object):
def update_kwargs_value(self, value, **kwargs): def update_kwargs_value(self, value, **kwargs):
updated_kwargs = kwargs.copy() updated_kwargs = kwargs.copy()
updated_kwargs.update({self.argument_name: value}) updated_kwargs.update({self.argument_name: value})
return updated_kwargs return updated_kwargs
\ No newline at end of file
class AbstractPlot(object):
COLORS = ['blue', 'red', 'green', 'black', 'magenta', 'cyan']
def __init__(self, grid_row_item, grid_column_item, plot_row_item, plot_label_item):
self.grid_row_item = grid_row_item # type: DisplayItem
self.grid_column_item = grid_column_item # type: DisplayItem
self.plot_row_item = plot_row_item # type: DisplayItem
self.plot_label_item = plot_label_item # type: DisplayItem
from extreme_estimator.R_fit.max_stable_fit.abstract_max_stable_model import AbstractMaxStableModel from extreme_estimator.R_fit.max_stable_fit.abstract_max_stable_model import AbstractMaxStableModel
from extreme_estimator.R_fit.max_stable_fit.max_stable_models import Smith, BrownResnick from extreme_estimator.R_fit.max_stable_fit.max_stable_models import Smith, BrownResnick
from extreme_estimator.estimator.abstract_estimator import AbstractEstimator
from extreme_estimator.estimator.msp_estimator import MaxStableEstimator from extreme_estimator.estimator.msp_estimator import MaxStableEstimator
from extreme_estimator.robustness_plot.abstract_robustness_plot import DisplayItem from extreme_estimator.robustness_plot.multiple_plot import MultiplePlot
from extreme_estimator.robustness_plot.single_scalar_plot import SingleScalarPlot from extreme_estimator.robustness_plot.single_plot import SinglePlot
from spatio_temporal_dataset.dataset.simulation_dataset import SimulatedDataset from spatio_temporal_dataset.dataset.simulation_dataset import SimulatedDataset
from spatio_temporal_dataset.spatial_coordinates.alps_station_coordinates import AlpsStationCoordinates, \ from spatio_temporal_dataset.spatial_coordinates.alps_station_coordinates import AlpsStationCoordinatesBetweenZeroAndOne
AlpsStationCoordinatesBetweenZeroAndOne
from spatio_temporal_dataset.spatial_coordinates.generated_coordinates import CircleCoordinates from spatio_temporal_dataset.spatial_coordinates.generated_coordinates import CircleCoordinates
from extreme_estimator.robustness_plot.display_item import DisplayItem
class MspSpatial(SingleScalarPlot): class MspSpatial(object):
MaxStableModelItem = DisplayItem('max_stable_model', Smith) MaxStableModelItem = DisplayItem('max_stable_model', Smith)
SpatialCoordinateClassItem = DisplayItem('spatial_coordinate_class', CircleCoordinates) SpatialCoordinateClassItem = DisplayItem('spatial_coordinate_class', CircleCoordinates)
SpatialParamsItem = DisplayItem('spatial_params', {"r": 1}) SpatialParamsItem = DisplayItem('spatial_params', {"r": 1})
NbStationItem = DisplayItem('nb_station', 50) NbStationItem = DisplayItem('nb_station', 50)
NbObservationItem = DisplayItem('nb_obs', 60) NbObservationItem = DisplayItem('nb_obs', 60)
def single_scalar_from_all_params(self, **kwargs_single_point) -> float: def msp_spatial_ordinates(self, **kwargs_single_point) -> dict:
# Get the argument from kwargs # Get the argument from kwargs
max_stable_model = self.MaxStableModelItem.value_from_kwargs(**kwargs_single_point) # type: AbstractMaxStableModel max_stable_model = self.MaxStableModelItem.value_from_kwargs(
**kwargs_single_point) # type: AbstractMaxStableModel
spatial_coordinate_class = self.SpatialCoordinateClassItem.value_from_kwargs(**kwargs_single_point) spatial_coordinate_class = self.SpatialCoordinateClassItem.value_from_kwargs(**kwargs_single_point)
nb_station = self.NbStationItem.value_from_kwargs(**kwargs_single_point) nb_station = self.NbStationItem.value_from_kwargs(**kwargs_single_point)
spatial_params = self.SpatialParamsItem.value_from_kwargs(**kwargs_single_point) spatial_params = self.SpatialParamsItem.value_from_kwargs(**kwargs_single_point)
...@@ -28,25 +30,53 @@ class MspSpatial(SingleScalarPlot): ...@@ -28,25 +30,53 @@ class MspSpatial(SingleScalarPlot):
dataset = SimulatedDataset.from_max_stable_sampling(nb_obs=nb_obs, max_stable_model=max_stable_model, dataset = SimulatedDataset.from_max_stable_sampling(nb_obs=nb_obs, max_stable_model=max_stable_model,
spatial_coordinates=spatial_coordinate) spatial_coordinates=spatial_coordinate)
estimator = MaxStableEstimator(dataset, max_stable_model) estimator = MaxStableEstimator(dataset, max_stable_model)
estimator.timed_fit() estimator.fit()
errors = estimator.error(max_stable_model.params_sample) return estimator.scalars(max_stable_model.params_sample)
mae_error = errors[MaxStableEstimator.MAE_ERROR]
return mae_error
def spatial_robustness_alps(): class SingleMspSpatial(SinglePlot, MspSpatial):
spatial_robustness = MspSpatial(grid_row_item=MspSpatial.NbObservationItem,
grid_column_item=MspSpatial.SpatialCoordinateClassItem, def compute_value_from_kwargs_single_point(self, **kwargs_single_point):
plot_row_item=MspSpatial.NbStationItem, return self.msp_spatial_ordinates(**kwargs_single_point)
plot_label_item=MspSpatial.MaxStableModelItem)
class MultipleMspSpatial(MultiplePlot, MspSpatial):
def compute_value_from_kwargs_single_point(self, **kwargs_single_point):
return self.msp_spatial_ordinates(**kwargs_single_point)
def single_spatial_robustness_alps():
spatial_robustness = SingleMspSpatial(grid_row_item=SingleMspSpatial.NbObservationItem,
grid_column_item=SingleMspSpatial.SpatialCoordinateClassItem,
plot_row_item=SingleMspSpatial.NbStationItem,
plot_label_item=SingleMspSpatial.MaxStableModelItem)
# Put only the parameter that will vary
spatial_robustness.robustness_grid_plot(**{
SingleMspSpatial.NbStationItem.argument_name: [10, 30, 50, 70, 86][:],
SingleMspSpatial.NbObservationItem.argument_name: [10],
SingleMspSpatial.MaxStableModelItem.argument_name: [Smith(), BrownResnick()][:],
SingleMspSpatial.SpatialCoordinateClassItem.argument_name: [CircleCoordinates,
AlpsStationCoordinatesBetweenZeroAndOne][:],
})
def multiple_spatial_robustness_alps():
spatial_robustness = MultipleMspSpatial(
grid_column_item=MspSpatial.MaxStableModelItem,
plot_row_item=MspSpatial.NbStationItem,
plot_label_item=MspSpatial.SpatialCoordinateClassItem)
# Put only the parameter that will vary # Put only the parameter that will vary
spatial_robustness.robustness_grid_plot(**{ spatial_robustness.robustness_grid_plot(**{
MspSpatial.NbStationItem.argument_name: [10, 30, 50, 70, 86][:], SinglePlot.OrdinateItem.argument_name: [AbstractEstimator.MAE_ERROR, AbstractEstimator.DURATION],
MspSpatial.NbObservationItem.argument_name: [10], MspSpatial.NbStationItem.argument_name: [10, 30, 50, 70, 86][:2],
MspSpatial.NbObservationItem.argument_name: 10,
MspSpatial.MaxStableModelItem.argument_name: [Smith(), BrownResnick()][:], MspSpatial.MaxStableModelItem.argument_name: [Smith(), BrownResnick()][:],
MspSpatial.SpatialCoordinateClassItem.argument_name: [CircleCoordinates, AlpsStationCoordinatesBetweenZeroAndOne][:], MspSpatial.SpatialCoordinateClassItem.argument_name: [CircleCoordinates,
AlpsStationCoordinatesBetweenZeroAndOne][:],
}) })
if __name__ == '__main__': if __name__ == '__main__':
spatial_robustness_alps() # single_spatial_robustness_alps()
multiple_spatial_robustness_alps()
from extreme_estimator.robustness_plot.single_plot import SinglePlot
class MultiplePlot(SinglePlot):
"""
In a Multiple Scalar plot, for each
Each scalar, will be display on a grid row (to ease visual comparison)
"""
def __init__(self, grid_column_item, plot_row_item, plot_label_item):
super().__init__(grid_row_item=self.OrdinateItem, grid_column_item=grid_column_item,
plot_row_item=plot_row_item, plot_label_item=plot_label_item)
self.kwargs_single_point_to_errors = {}
def compute_value_from_kwargs_single_point(self, **kwargs_single_point):
# Compute hash
hash_from_kwargs_single_point = self.hash_from_kwargs_single_point(kwargs_single_point)
# Either compute the errors or Reload them from cached results
if hash_from_kwargs_single_point in self.kwargs_single_point_to_errors:
errors = self.kwargs_single_point_to_errors[hash_from_kwargs_single_point]
else:
errors = self.multiple_scalar_from_all_params(**kwargs_single_point)
self.kwargs_single_point_to_errors[hash_from_kwargs_single_point] = errors
assert isinstance(errors, dict)
# Get the item of interest
error = errors[self.OrdinateItem.value_from_kwargs(**kwargs_single_point)]
return error
def hash_from_kwargs_single_point(self, kwargs_single_point):
items_except_error = [(k, v) for k, v in kwargs_single_point.items() if k != self.OrdinateItem.argument_name]
ordered_dict_items_str = str(sorted(items_except_error, key=lambda x: x[0]))
hash_from_kwargs_single_point = hash(ordered_dict_items_str)
return hash_from_kwargs_single_point
def multiple_scalar_from_all_params(self, **kwargs_single_point) -> dict:
pass
from extreme_estimator.robustness_plot.abstract_robustness_plot import AbstractPlot
class MultipleScalarPlot(AbstractPlot):
"""
In a Multiple Scalar plot, for each
Each scalar, will be display on a grid row (to ease visual comparison)
"""
def __init__(self, grid_column_item, plot_row_item, plot_label_item):
super().__init__(grid_row_item=None, grid_column_item=grid_column_item,
plot_row_item=plot_row_item, plot_label_item=plot_label_item)
from extreme_estimator.robustness_plot.abstract_robustness_plot import AbstractPlot
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from itertools import product from itertools import product
from extreme_estimator.estimator.abstract_estimator import AbstractEstimator
from extreme_estimator.robustness_plot.display_item import DisplayItem
class SingleScalarPlot(AbstractPlot): plt.style.use('seaborn-white')
"""
For a single scalar plot, for the combination of all the parameters of interest,
then the function
"""
def single_scalar_from_all_params(self, **kwargs_single_point) -> float:
pass class SinglePlot(object):
COLORS = ['blue', 'red', 'green', 'black', 'magenta', 'cyan']
OrdinateItem = DisplayItem('ordinate', AbstractEstimator.MAE_ERROR)
def __init__(self, grid_row_item, grid_column_item, plot_row_item, plot_label_item):
self.grid_row_item = grid_row_item # type: DisplayItem
self.grid_column_item = grid_column_item # type: DisplayItem
self.plot_row_item = plot_row_item # type: DisplayItem
self.plot_label_item = plot_label_item # type: DisplayItem
def robustness_grid_plot(self, **kwargs): def robustness_grid_plot(self, **kwargs):
# Extract Grid row and columns values # Extract Grid row and columns values
...@@ -37,16 +42,26 @@ class SingleScalarPlot(AbstractPlot): ...@@ -37,16 +42,26 @@ class SingleScalarPlot(AbstractPlot):
for j, plot_label_value in enumerate(plot_label_values): for j, plot_label_value in enumerate(plot_label_values):
# Compute # Compute
plot_row_value_to_error = {} plot_row_value_to_error = {}
# todo: do some parallzlization here (do the parallelization in the Asbtract class if possible) # todo: do some parallzlization here
for plot_row_value in plot_row_values: for plot_row_value in plot_row_values:
# Adapt the kwargs for the single value # Adapt the kwargs for the single value
kwargs_single_point = kwargs_single_plot.copy() kwargs_single_point = kwargs_single_plot.copy()
kwargs_single_point.update({self.plot_row_item.argument_name: plot_row_value, kwargs_single_point.update({self.plot_row_item.argument_name: plot_row_value,
self.plot_label_item.argument_name: plot_label_value}) self.plot_label_item.argument_name: plot_label_value})
plot_row_value_to_error[plot_row_value] = self.single_scalar_from_all_params(**kwargs_single_point) # The kwargs should not contain list of values
for k, v in kwargs_single_point.items():
assert not isinstance(v, list), '"{}" argument is a list'.format(k)
# Compute ordinate values
ordinates = self.compute_value_from_kwargs_single_point(**kwargs_single_point)
# Extract the ordinate value of interest
ordinate_name = self.OrdinateItem.value_from_kwargs(**kwargs_single_point)
plot_row_value_to_error[plot_row_value] = ordinates[ordinate_name]
plot_column_values = [plot_row_value_to_error[plot_row_value] for plot_row_value in plot_row_values] plot_column_values = [plot_row_value_to_error[plot_row_value] for plot_row_value in plot_row_values]
ax.plot(plot_row_values, plot_column_values, color=self.COLORS[j % len(self.COLORS)], label=str(j)) ax.plot(plot_row_values, plot_column_values, color=self.COLORS[j % len(self.COLORS)], label=str(j))
ax.legend() ax.legend()
ax.set_xlabel(self.plot_row_item.dislay_name) ax.set_xlabel(self.plot_row_item.dislay_name)
ax.set_ylabel('Absolute error') ax.set_ylabel('Absolute error')
ax.set_title('Title (display all the other parameters)') ax.set_title('Title (display all the other parameters)')
def compute_value_from_kwargs_single_point(self, **kwargs_single_point):
pass
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment