Commit 8650cc55 authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[ROBUSTNESS PLOT] refactor and clean plots

parent 68311fbd
No related merge requests found
Showing with 132 additions and 63 deletions
+132 -63
import time
from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
class AbstractEstimator(object):
DURATION = 'Duration'
MAE_ERROR = 'Mean Average Error'
def __init__(self) -> None:
def __init__(self, dataset: AbstractDataset):
self.dataset = dataset
self.additional_information = dict()
def fit(self):
......
from extreme_estimator.estimator.abstract_estimator import AbstractEstimator
class FullEstimatorInASingleStep(AbstractEstimator):
pass
class PointwiseAndThenUnitaryMsp(AbstractEstimator):
pass
class SmoothMarginalsThenUnitaryMsp(AbstractEstimator):
pass
class StochasticExpectationMaximization(AbstractEstimator):
pass
from extreme_estimator.estimator.abstract_estimator import AbstractEstimator
class PointWiseMarginEstimator(AbstractEstimator):
pass
class SmoothMarginEstimator(AbstractEstimator):
# with different type of marginals: cosntant, linear....
pass
......@@ -7,8 +7,7 @@ import numpy as np
class MaxStableEstimator(AbstractEstimator):
def __init__(self, dataset: AbstractDataset, max_stable_model: AbstractMaxStableModel):
super().__init__()
self.dataset = dataset
super().__init__(dataset=dataset)
self.max_stable_model = max_stable_model
# Fit parameters
self.max_stable_params_fitted = None
......
class DisplayItem(object):
def __init__(self, argument_name, default_value, dislay_name=None):
self.argument_name = argument_name
def __init__(self, name, default_value):
self.name = name
self.default_value = default_value
self.dislay_name = dislay_name if dislay_name is not None else self.argument_name
def display_name_from_value(self, value) -> str:
return ''
def values_from_kwargs(self, **kwargs):
values = kwargs.get(self.argument_name, [self.default_value])
values = kwargs.get(self.name, [self.default_value])
assert isinstance(values, list)
return values
def value_from_kwargs(self, **kwargs):
return kwargs.get(self.argument_name, self.default_value)
return kwargs.get(self.name, self.default_value)
def update_kwargs_value(self, value, **kwargs):
updated_kwargs = kwargs.copy()
updated_kwargs.update({self.argument_name: value})
updated_kwargs.update({self.name: value})
return updated_kwargs
\ No newline at end of file
from extreme_estimator.R_fit.max_stable_fit.abstract_max_stable_model import AbstractMaxStableModel
from extreme_estimator.R_fit.max_stable_fit.max_stable_models import Smith, BrownResnick
from extreme_estimator.estimator.abstract_estimator import AbstractEstimator
from extreme_estimator.estimator.msp_estimator import MaxStableEstimator
from extreme_estimator.estimator.unitary_msp_estimator import MaxStableEstimator
from extreme_estimator.robustness_plot.multiple_plot import MultiplePlot
from extreme_estimator.robustness_plot.single_plot import SinglePlot
from spatio_temporal_dataset.dataset.simulation_dataset import SimulatedDataset
......@@ -10,11 +10,17 @@ from spatio_temporal_dataset.spatial_coordinates.generated_coordinates import Ci
from extreme_estimator.robustness_plot.display_item import DisplayItem
class MaxStableDisplayItem(DisplayItem):
def display_name_from_value(self, value: AbstractMaxStableModel):
return value.cov_mod
class MspSpatial(object):
MaxStableModelItem = DisplayItem('max_stable_model', Smith)
MaxStableModelItem = MaxStableDisplayItem('max_stable_model', Smith)
SpatialCoordinateClassItem = DisplayItem('spatial_coordinate_class', CircleCoordinates)
SpatialParamsItem = DisplayItem('spatial_params', {"r": 1})
NbStationItem = DisplayItem('nb_station', 50)
NbStationItem = DisplayItem('Number of stations', 50)
NbObservationItem = DisplayItem('nb_obs', 60)
def msp_spatial_ordinates(self, **kwargs_single_point) -> dict:
......@@ -43,6 +49,7 @@ class SingleMspSpatial(SinglePlot, MspSpatial):
class MultipleMspSpatial(MultiplePlot, MspSpatial):
def compute_value_from_kwargs_single_point(self, **kwargs_single_point):
print('here')
return self.msp_spatial_ordinates(**kwargs_single_point)
......@@ -53,11 +60,11 @@ def single_spatial_robustness_alps():
plot_label_item=SingleMspSpatial.MaxStableModelItem)
# Put only the parameter that will vary
spatial_robustness.robustness_grid_plot(**{
SingleMspSpatial.NbStationItem.argument_name: [10, 30, 50, 70, 86][:],
SingleMspSpatial.NbObservationItem.argument_name: [10],
SingleMspSpatial.MaxStableModelItem.argument_name: [Smith(), BrownResnick()][:],
SingleMspSpatial.SpatialCoordinateClassItem.argument_name: [CircleCoordinates,
AlpsStationCoordinatesBetweenZeroAndOne][:],
SingleMspSpatial.NbStationItem.name: [10, 30, 50, 70, 86][:],
SingleMspSpatial.NbObservationItem.name: [10],
SingleMspSpatial.MaxStableModelItem.name: [Smith(), BrownResnick()][:],
SingleMspSpatial.SpatialCoordinateClassItem.name: [CircleCoordinates,
AlpsStationCoordinatesBetweenZeroAndOne][:],
})
......@@ -65,15 +72,16 @@ def multiple_spatial_robustness_alps():
spatial_robustness = MultipleMspSpatial(
grid_column_item=MspSpatial.MaxStableModelItem,
plot_row_item=MspSpatial.NbStationItem,
plot_label_item=MspSpatial.SpatialCoordinateClassItem)
plot_label_item=MspSpatial.SpatialCoordinateClassItem,
nb_samples=10)
# Put only the parameter that will vary
spatial_robustness.robustness_grid_plot(**{
SinglePlot.OrdinateItem.argument_name: [AbstractEstimator.MAE_ERROR, AbstractEstimator.DURATION],
MspSpatial.NbStationItem.argument_name: [10, 30, 50, 70, 86][:2],
MspSpatial.NbObservationItem.argument_name: 10,
MspSpatial.MaxStableModelItem.argument_name: [Smith(), BrownResnick()][:],
MspSpatial.SpatialCoordinateClassItem.argument_name: [CircleCoordinates,
AlpsStationCoordinatesBetweenZeroAndOne][:],
SinglePlot.OrdinateItem.name: [AbstractEstimator.MAE_ERROR, AbstractEstimator.DURATION],
MspSpatial.NbStationItem.name: [10, 20, 30, 50, 70, 86][:3],
MspSpatial.NbObservationItem.name: 10,
MspSpatial.MaxStableModelItem.name: [Smith(), BrownResnick()][:],
MspSpatial.SpatialCoordinateClassItem.name: [CircleCoordinates,
AlpsStationCoordinatesBetweenZeroAndOne][:],
})
......
......@@ -8,30 +8,28 @@ class MultiplePlot(SinglePlot):
Each scalar, will be display on a grid row (to ease visual comparison)
"""
def __init__(self, grid_column_item, plot_row_item, plot_label_item):
def __init__(self, grid_column_item, plot_row_item, plot_label_item, nb_samples=1):
super().__init__(grid_row_item=self.OrdinateItem, grid_column_item=grid_column_item,
plot_row_item=plot_row_item, plot_label_item=plot_label_item)
plot_row_item=plot_row_item, plot_label_item=plot_label_item,
nb_samples=nb_samples)
self.kwargs_single_point_to_errors = {}
def compute_value_from_kwargs_single_point(self, **kwargs_single_point):
def cached_compute_value_from_kwargs_single_point(self, **kwargs_single_point):
print('here1')
# Compute hash
hash_from_kwargs_single_point = self.hash_from_kwargs_single_point(kwargs_single_point)
# Either compute the errors or Reload them from cached results
if hash_from_kwargs_single_point in self.kwargs_single_point_to_errors:
print('Load')
errors = self.kwargs_single_point_to_errors[hash_from_kwargs_single_point]
else:
errors = self.multiple_scalar_from_all_params(**kwargs_single_point)
errors = self.compute_value_from_kwargs_single_point(**kwargs_single_point)
self.kwargs_single_point_to_errors[hash_from_kwargs_single_point] = errors
assert isinstance(errors, dict)
# Get the item of interest
error = errors[self.OrdinateItem.value_from_kwargs(**kwargs_single_point)]
return error
return errors
def hash_from_kwargs_single_point(self, kwargs_single_point):
items_except_error = [(k, v) for k, v in kwargs_single_point.items() if k != self.OrdinateItem.argument_name]
items_except_error = [(k, v) for k, v in kwargs_single_point.items() if k != self.OrdinateItem.name]
ordered_dict_items_str = str(sorted(items_except_error, key=lambda x: x[0]))
print(ordered_dict_items_str)
hash_from_kwargs_single_point = hash(ordered_dict_items_str)
return hash_from_kwargs_single_point
def multiple_scalar_from_all_params(self, **kwargs_single_point) -> dict:
pass
import matplotlib.pyplot as plt
import numpy as np
from itertools import product
from extreme_estimator.estimator.abstract_estimator import AbstractEstimator
......@@ -11,11 +12,12 @@ class SinglePlot(object):
COLORS = ['blue', 'red', 'green', 'black', 'magenta', 'cyan']
OrdinateItem = DisplayItem('ordinate', AbstractEstimator.MAE_ERROR)
def __init__(self, grid_row_item, grid_column_item, plot_row_item, plot_label_item):
def __init__(self, grid_row_item, grid_column_item, plot_row_item, plot_label_item, nb_samples=1):
self.grid_row_item = grid_row_item # type: DisplayItem
self.grid_column_item = grid_column_item # type: DisplayItem
self.plot_row_item = plot_row_item # type: DisplayItem
self.plot_label_item = plot_label_item # type: DisplayItem
self.nb_samples = nb_samples
def robustness_grid_plot(self, **kwargs):
# Extract Grid row and columns values
......@@ -26,42 +28,72 @@ class SinglePlot(object):
fig = plt.figure()
fig.subplots_adjust(hspace=0.4, wspace=0.4)
for i, (grid_row_value, grid_column_value) in enumerate(product(grid_row_values, grid_column_values), 1):
print('Grid plot: {}={} {}={}'.format(self.grid_row_item.dislay_name, grid_row_value,
self.grid_column_item.dislay_name, grid_column_value))
print('Grid plot: {}={} {}={}'.format(self.grid_row_item.name, grid_row_value,
self.grid_column_item.name, grid_column_value))
ax = fig.add_subplot(nb_grid_rows, nb_grid_columns, i)
# Adapt the kwargs for the single plot
kwargs_single_plot = kwargs.copy()
kwargs_single_plot.update({self.grid_row_item.argument_name: grid_row_value,
self.grid_column_item.argument_name: grid_column_value})
kwargs_single_plot[self.grid_row_item.name] = grid_row_value
kwargs_single_plot[self.grid_column_item.name] = grid_column_value
self.robustness_single_plot(ax, **kwargs_single_plot)
self.add_title(ax, grid_column_value, grid_row_value)
plt.show()
def robustness_single_plot(self, ax, **kwargs_single_plot):
plot_row_values = self.plot_row_item.values_from_kwargs(**kwargs_single_plot)
plot_label_values = self.plot_label_item.values_from_kwargs(**kwargs_single_plot)
ordinate_name = self.OrdinateItem.value_from_kwargs(**kwargs_single_plot)
for j, plot_label_value in enumerate(plot_label_values):
# Compute
plot_row_value_to_error = {}
# todo: do some parallzlization here
for plot_row_value in plot_row_values:
# Adapt the kwargs for the single value
kwargs_single_point = kwargs_single_plot.copy()
kwargs_single_point.update({self.plot_row_item.argument_name: plot_row_value,
self.plot_label_item.argument_name: plot_label_value})
# The kwargs should not contain list of values
for k, v in kwargs_single_point.items():
assert not isinstance(v, list), '"{}" argument is a list'.format(k)
# Compute ordinate values
ordinates = self.compute_value_from_kwargs_single_point(**kwargs_single_point)
# Extract the ordinate value of interest
ordinate_name = self.OrdinateItem.value_from_kwargs(**kwargs_single_point)
plot_row_value_to_error[plot_row_value] = ordinates[ordinate_name]
plot_column_values = [plot_row_value_to_error[plot_row_value] for plot_row_value in plot_row_values]
ax.plot(plot_row_values, plot_column_values, color=self.COLORS[j % len(self.COLORS)], label=str(j))
mean_values, std_values = self.compute_mean_and_std_ordinate_values(kwargs_single_plot, ordinate_name,
plot_label_value, plot_row_values)
ax.errorbar(plot_row_values, mean_values, std_values,
# linestyle='None', marker='^',
linewidth = 0.5,
color=self.COLORS[j % len(self.COLORS)],
label=self.plot_label_item.display_name_from_value(plot_label_value))
ax.legend()
ax.set_xlabel(self.plot_row_item.dislay_name)
ax.set_ylabel('Absolute error')
ax.set_title('Title (display all the other parameters)')
ax.set_xlabel(self.plot_row_item.name)
ax.set_ylabel(ordinate_name)
def compute_value_from_kwargs_single_point(self, **kwargs_single_point):
def compute_mean_and_std_ordinate_values(self, kwargs_single_plot, ordinate_name, plot_label_value, plot_row_values):
all_ordinate_values = []
for nb_sample in range(self.nb_samples):
# Important to add the nb_sample argument, to differentiate the different experiments
kwargs_single_plot['nb_sample'] = nb_sample
ordinate_values = self.compute_ordinate_values(kwargs_single_plot, ordinate_name, plot_label_value, plot_row_values)
all_ordinate_values.append(ordinate_values)
all_ordinate_values = np.array(all_ordinate_values)
return np.mean(all_ordinate_values, axis=0), np.std(all_ordinate_values, axis=0)
def compute_ordinate_values(self, kwargs_single_plot, ordinate_name, plot_label_value, plot_row_values):
# Compute
plot_row_value_to_ordinate_value = {}
# todo: do some parallzlization here
for plot_row_value in plot_row_values:
# Adapt the kwargs for the single value
kwargs_single_point = kwargs_single_plot.copy()
kwargs_single_point.update({self.plot_row_item.name: plot_row_value,
self.plot_label_item.name: plot_label_value})
# The kwargs should not contain list of values
for k, v in kwargs_single_point.items():
assert not isinstance(v, list), '"{}" argument is a list'.format(k)
# Compute ordinate values
ordinate_name_to_ordinate_value = self.cached_compute_value_from_kwargs_single_point(**kwargs_single_point)
print(ordinate_name, plot_row_value)
plot_row_value_to_ordinate_value[plot_row_value] = ordinate_name_to_ordinate_value[ordinate_name]
# Plot the figure
plot_ordinate_values = [plot_row_value_to_ordinate_value[plot_row_value] for plot_row_value in
plot_row_values]
return plot_ordinate_values
def compute_value_from_kwargs_single_point(self, **kwargs_single_point) -> dict:
pass
def cached_compute_value_from_kwargs_single_point(self, **kwargs_single_point) -> dict:
return self.compute_value_from_kwargs_single_point(**kwargs_single_point)
def add_title(self, ax, grid_column_value, grid_row_value):
title_str = self.grid_row_item.display_name_from_value(grid_row_value)
title_str += ' ' if len(title_str) > 0 else ''
title_str += self.grid_column_item.display_name_from_value(grid_column_value)
ax.set_title(title_str)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment