diff --git a/extreme_estimator/estimator/abstract_estimator.py b/extreme_estimator/estimator/abstract_estimator.py index 49a323b94381685897d4416aa982a073e5aa4354..32027aaca4ad4590af6b700cdd19f263b7f89f4a 100644 --- a/extreme_estimator/estimator/abstract_estimator.py +++ b/extreme_estimator/estimator/abstract_estimator.py @@ -1,11 +1,14 @@ import time +from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset + class AbstractEstimator(object): DURATION = 'Duration' MAE_ERROR = 'Mean Average Error' - def __init__(self) -> None: + def __init__(self, dataset: AbstractDataset): + self.dataset = dataset self.additional_information = dict() def fit(self): diff --git a/extreme_estimator/estimator/full_msp_estimator.py b/extreme_estimator/estimator/full_msp_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..d62f9acd52aff5f0b7d5813c0a2dfb60fe46494c --- /dev/null +++ b/extreme_estimator/estimator/full_msp_estimator.py @@ -0,0 +1,17 @@ +from extreme_estimator.estimator.abstract_estimator import AbstractEstimator + + +class FullEstimatorInASingleStep(AbstractEstimator): + pass + + +class PointwiseAndThenUnitaryMsp(AbstractEstimator): + pass + + +class SmoothMarginalsThenUnitaryMsp(AbstractEstimator): + pass + + +class StochasticExpectationMaximization(AbstractEstimator): + pass diff --git a/extreme_estimator/estimator/margin_estimator.py b/extreme_estimator/estimator/margin_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..00aaad385385a8d3df58417283e48f8a0b4b69e1 --- /dev/null +++ b/extreme_estimator/estimator/margin_estimator.py @@ -0,0 +1,10 @@ +from extreme_estimator.estimator.abstract_estimator import AbstractEstimator + + +class PointWiseMarginEstimator(AbstractEstimator): + pass + + +class SmoothMarginEstimator(AbstractEstimator): + # with different type of marginals: cosntant, linear.... + pass diff --git a/extreme_estimator/estimator/msp_estimator.py b/extreme_estimator/estimator/unitary_msp_estimator.py similarity index 95% rename from extreme_estimator/estimator/msp_estimator.py rename to extreme_estimator/estimator/unitary_msp_estimator.py index 35661c6dc5d6940f45d7e31319a9ed55c5c1c13e..b66982efbf61b851f8f791f901c27ae09bbfba67 100644 --- a/extreme_estimator/estimator/msp_estimator.py +++ b/extreme_estimator/estimator/unitary_msp_estimator.py @@ -7,8 +7,7 @@ import numpy as np class MaxStableEstimator(AbstractEstimator): def __init__(self, dataset: AbstractDataset, max_stable_model: AbstractMaxStableModel): - super().__init__() - self.dataset = dataset + super().__init__(dataset=dataset) self.max_stable_model = max_stable_model # Fit parameters self.max_stable_params_fitted = None diff --git a/extreme_estimator/robustness_plot/display_item.py b/extreme_estimator/robustness_plot/display_item.py index 4dbc2c010a8c61b2e295c96b74b9ac659a8e5efd..48fda88de70a51247914940037be22ceed638fa4 100644 --- a/extreme_estimator/robustness_plot/display_item.py +++ b/extreme_estimator/robustness_plot/display_item.py @@ -1,20 +1,22 @@ class DisplayItem(object): - def __init__(self, argument_name, default_value, dislay_name=None): - self.argument_name = argument_name + def __init__(self, name, default_value): + self.name = name self.default_value = default_value - self.dislay_name = dislay_name if dislay_name is not None else self.argument_name + + def display_name_from_value(self, value) -> str: + return '' def values_from_kwargs(self, **kwargs): - values = kwargs.get(self.argument_name, [self.default_value]) + values = kwargs.get(self.name, [self.default_value]) assert isinstance(values, list) return values def value_from_kwargs(self, **kwargs): - return kwargs.get(self.argument_name, self.default_value) + return kwargs.get(self.name, self.default_value) def update_kwargs_value(self, value, **kwargs): updated_kwargs = kwargs.copy() - updated_kwargs.update({self.argument_name: value}) + updated_kwargs.update({self.name: value}) return updated_kwargs \ No newline at end of file diff --git a/extreme_estimator/robustness_plot/msp_robustness.py b/extreme_estimator/robustness_plot/msp_robustness.py index 7c64f199505a22a7727c16a540bcb130b39e1b42..10c892063311975548a3c0aecfce3e9e42c84e66 100644 --- a/extreme_estimator/robustness_plot/msp_robustness.py +++ b/extreme_estimator/robustness_plot/msp_robustness.py @@ -1,7 +1,7 @@ from extreme_estimator.R_fit.max_stable_fit.abstract_max_stable_model import AbstractMaxStableModel from extreme_estimator.R_fit.max_stable_fit.max_stable_models import Smith, BrownResnick from extreme_estimator.estimator.abstract_estimator import AbstractEstimator -from extreme_estimator.estimator.msp_estimator import MaxStableEstimator +from extreme_estimator.estimator.unitary_msp_estimator import MaxStableEstimator from extreme_estimator.robustness_plot.multiple_plot import MultiplePlot from extreme_estimator.robustness_plot.single_plot import SinglePlot from spatio_temporal_dataset.dataset.simulation_dataset import SimulatedDataset @@ -10,11 +10,17 @@ from spatio_temporal_dataset.spatial_coordinates.generated_coordinates import Ci from extreme_estimator.robustness_plot.display_item import DisplayItem +class MaxStableDisplayItem(DisplayItem): + + def display_name_from_value(self, value: AbstractMaxStableModel): + return value.cov_mod + + class MspSpatial(object): - MaxStableModelItem = DisplayItem('max_stable_model', Smith) + MaxStableModelItem = MaxStableDisplayItem('max_stable_model', Smith) SpatialCoordinateClassItem = DisplayItem('spatial_coordinate_class', CircleCoordinates) SpatialParamsItem = DisplayItem('spatial_params', {"r": 1}) - NbStationItem = DisplayItem('nb_station', 50) + NbStationItem = DisplayItem('Number of stations', 50) NbObservationItem = DisplayItem('nb_obs', 60) def msp_spatial_ordinates(self, **kwargs_single_point) -> dict: @@ -43,6 +49,7 @@ class SingleMspSpatial(SinglePlot, MspSpatial): class MultipleMspSpatial(MultiplePlot, MspSpatial): def compute_value_from_kwargs_single_point(self, **kwargs_single_point): + print('here') return self.msp_spatial_ordinates(**kwargs_single_point) @@ -53,11 +60,11 @@ def single_spatial_robustness_alps(): plot_label_item=SingleMspSpatial.MaxStableModelItem) # Put only the parameter that will vary spatial_robustness.robustness_grid_plot(**{ - SingleMspSpatial.NbStationItem.argument_name: [10, 30, 50, 70, 86][:], - SingleMspSpatial.NbObservationItem.argument_name: [10], - SingleMspSpatial.MaxStableModelItem.argument_name: [Smith(), BrownResnick()][:], - SingleMspSpatial.SpatialCoordinateClassItem.argument_name: [CircleCoordinates, - AlpsStationCoordinatesBetweenZeroAndOne][:], + SingleMspSpatial.NbStationItem.name: [10, 30, 50, 70, 86][:], + SingleMspSpatial.NbObservationItem.name: [10], + SingleMspSpatial.MaxStableModelItem.name: [Smith(), BrownResnick()][:], + SingleMspSpatial.SpatialCoordinateClassItem.name: [CircleCoordinates, + AlpsStationCoordinatesBetweenZeroAndOne][:], }) @@ -65,15 +72,16 @@ def multiple_spatial_robustness_alps(): spatial_robustness = MultipleMspSpatial( grid_column_item=MspSpatial.MaxStableModelItem, plot_row_item=MspSpatial.NbStationItem, - plot_label_item=MspSpatial.SpatialCoordinateClassItem) + plot_label_item=MspSpatial.SpatialCoordinateClassItem, + nb_samples=10) # Put only the parameter that will vary spatial_robustness.robustness_grid_plot(**{ - SinglePlot.OrdinateItem.argument_name: [AbstractEstimator.MAE_ERROR, AbstractEstimator.DURATION], - MspSpatial.NbStationItem.argument_name: [10, 30, 50, 70, 86][:2], - MspSpatial.NbObservationItem.argument_name: 10, - MspSpatial.MaxStableModelItem.argument_name: [Smith(), BrownResnick()][:], - MspSpatial.SpatialCoordinateClassItem.argument_name: [CircleCoordinates, - AlpsStationCoordinatesBetweenZeroAndOne][:], + SinglePlot.OrdinateItem.name: [AbstractEstimator.MAE_ERROR, AbstractEstimator.DURATION], + MspSpatial.NbStationItem.name: [10, 20, 30, 50, 70, 86][:3], + MspSpatial.NbObservationItem.name: 10, + MspSpatial.MaxStableModelItem.name: [Smith(), BrownResnick()][:], + MspSpatial.SpatialCoordinateClassItem.name: [CircleCoordinates, + AlpsStationCoordinatesBetweenZeroAndOne][:], }) diff --git a/extreme_estimator/robustness_plot/multiple_plot.py b/extreme_estimator/robustness_plot/multiple_plot.py index 5383ce4e5eb34bc9f54c177abadd38230928749b..facb521ab5d4524c6d711ba2017a050f022eeaf1 100644 --- a/extreme_estimator/robustness_plot/multiple_plot.py +++ b/extreme_estimator/robustness_plot/multiple_plot.py @@ -8,30 +8,28 @@ class MultiplePlot(SinglePlot): Each scalar, will be display on a grid row (to ease visual comparison) """ - def __init__(self, grid_column_item, plot_row_item, plot_label_item): + def __init__(self, grid_column_item, plot_row_item, plot_label_item, nb_samples=1): super().__init__(grid_row_item=self.OrdinateItem, grid_column_item=grid_column_item, - plot_row_item=plot_row_item, plot_label_item=plot_label_item) + plot_row_item=plot_row_item, plot_label_item=plot_label_item, + nb_samples=nb_samples) self.kwargs_single_point_to_errors = {} - def compute_value_from_kwargs_single_point(self, **kwargs_single_point): + def cached_compute_value_from_kwargs_single_point(self, **kwargs_single_point): + print('here1') # Compute hash hash_from_kwargs_single_point = self.hash_from_kwargs_single_point(kwargs_single_point) # Either compute the errors or Reload them from cached results if hash_from_kwargs_single_point in self.kwargs_single_point_to_errors: + print('Load') errors = self.kwargs_single_point_to_errors[hash_from_kwargs_single_point] else: - errors = self.multiple_scalar_from_all_params(**kwargs_single_point) + errors = self.compute_value_from_kwargs_single_point(**kwargs_single_point) self.kwargs_single_point_to_errors[hash_from_kwargs_single_point] = errors - assert isinstance(errors, dict) - # Get the item of interest - error = errors[self.OrdinateItem.value_from_kwargs(**kwargs_single_point)] - return error + return errors def hash_from_kwargs_single_point(self, kwargs_single_point): - items_except_error = [(k, v) for k, v in kwargs_single_point.items() if k != self.OrdinateItem.argument_name] + items_except_error = [(k, v) for k, v in kwargs_single_point.items() if k != self.OrdinateItem.name] ordered_dict_items_str = str(sorted(items_except_error, key=lambda x: x[0])) + print(ordered_dict_items_str) hash_from_kwargs_single_point = hash(ordered_dict_items_str) return hash_from_kwargs_single_point - - def multiple_scalar_from_all_params(self, **kwargs_single_point) -> dict: - pass diff --git a/extreme_estimator/robustness_plot/single_plot.py b/extreme_estimator/robustness_plot/single_plot.py index c93dcabc527ec4fff8e0e780868ced06168e0e42..4daf7de4aed2d8c142e852486d09115aed4c7dc0 100644 --- a/extreme_estimator/robustness_plot/single_plot.py +++ b/extreme_estimator/robustness_plot/single_plot.py @@ -1,4 +1,5 @@ import matplotlib.pyplot as plt +import numpy as np from itertools import product from extreme_estimator.estimator.abstract_estimator import AbstractEstimator @@ -11,11 +12,12 @@ class SinglePlot(object): COLORS = ['blue', 'red', 'green', 'black', 'magenta', 'cyan'] OrdinateItem = DisplayItem('ordinate', AbstractEstimator.MAE_ERROR) - def __init__(self, grid_row_item, grid_column_item, plot_row_item, plot_label_item): + def __init__(self, grid_row_item, grid_column_item, plot_row_item, plot_label_item, nb_samples=1): self.grid_row_item = grid_row_item # type: DisplayItem self.grid_column_item = grid_column_item # type: DisplayItem self.plot_row_item = plot_row_item # type: DisplayItem self.plot_label_item = plot_label_item # type: DisplayItem + self.nb_samples = nb_samples def robustness_grid_plot(self, **kwargs): # Extract Grid row and columns values @@ -26,42 +28,72 @@ class SinglePlot(object): fig = plt.figure() fig.subplots_adjust(hspace=0.4, wspace=0.4) for i, (grid_row_value, grid_column_value) in enumerate(product(grid_row_values, grid_column_values), 1): - print('Grid plot: {}={} {}={}'.format(self.grid_row_item.dislay_name, grid_row_value, - self.grid_column_item.dislay_name, grid_column_value)) + print('Grid plot: {}={} {}={}'.format(self.grid_row_item.name, grid_row_value, + self.grid_column_item.name, grid_column_value)) ax = fig.add_subplot(nb_grid_rows, nb_grid_columns, i) # Adapt the kwargs for the single plot kwargs_single_plot = kwargs.copy() - kwargs_single_plot.update({self.grid_row_item.argument_name: grid_row_value, - self.grid_column_item.argument_name: grid_column_value}) + kwargs_single_plot[self.grid_row_item.name] = grid_row_value + kwargs_single_plot[self.grid_column_item.name] = grid_column_value self.robustness_single_plot(ax, **kwargs_single_plot) + self.add_title(ax, grid_column_value, grid_row_value) plt.show() def robustness_single_plot(self, ax, **kwargs_single_plot): plot_row_values = self.plot_row_item.values_from_kwargs(**kwargs_single_plot) plot_label_values = self.plot_label_item.values_from_kwargs(**kwargs_single_plot) + ordinate_name = self.OrdinateItem.value_from_kwargs(**kwargs_single_plot) for j, plot_label_value in enumerate(plot_label_values): - # Compute - plot_row_value_to_error = {} - # todo: do some parallzlization here - for plot_row_value in plot_row_values: - # Adapt the kwargs for the single value - kwargs_single_point = kwargs_single_plot.copy() - kwargs_single_point.update({self.plot_row_item.argument_name: plot_row_value, - self.plot_label_item.argument_name: plot_label_value}) - # The kwargs should not contain list of values - for k, v in kwargs_single_point.items(): - assert not isinstance(v, list), '"{}" argument is a list'.format(k) - # Compute ordinate values - ordinates = self.compute_value_from_kwargs_single_point(**kwargs_single_point) - # Extract the ordinate value of interest - ordinate_name = self.OrdinateItem.value_from_kwargs(**kwargs_single_point) - plot_row_value_to_error[plot_row_value] = ordinates[ordinate_name] - plot_column_values = [plot_row_value_to_error[plot_row_value] for plot_row_value in plot_row_values] - ax.plot(plot_row_values, plot_column_values, color=self.COLORS[j % len(self.COLORS)], label=str(j)) + mean_values, std_values = self.compute_mean_and_std_ordinate_values(kwargs_single_plot, ordinate_name, + plot_label_value, plot_row_values) + ax.errorbar(plot_row_values, mean_values, std_values, + # linestyle='None', marker='^', + linewidth = 0.5, + color=self.COLORS[j % len(self.COLORS)], + label=self.plot_label_item.display_name_from_value(plot_label_value)) ax.legend() - ax.set_xlabel(self.plot_row_item.dislay_name) - ax.set_ylabel('Absolute error') - ax.set_title('Title (display all the other parameters)') + ax.set_xlabel(self.plot_row_item.name) + ax.set_ylabel(ordinate_name) - def compute_value_from_kwargs_single_point(self, **kwargs_single_point): + def compute_mean_and_std_ordinate_values(self, kwargs_single_plot, ordinate_name, plot_label_value, plot_row_values): + all_ordinate_values = [] + for nb_sample in range(self.nb_samples): + # Important to add the nb_sample argument, to differentiate the different experiments + kwargs_single_plot['nb_sample'] = nb_sample + ordinate_values = self.compute_ordinate_values(kwargs_single_plot, ordinate_name, plot_label_value, plot_row_values) + all_ordinate_values.append(ordinate_values) + all_ordinate_values = np.array(all_ordinate_values) + return np.mean(all_ordinate_values, axis=0), np.std(all_ordinate_values, axis=0) + + def compute_ordinate_values(self, kwargs_single_plot, ordinate_name, plot_label_value, plot_row_values): + # Compute + plot_row_value_to_ordinate_value = {} + # todo: do some parallzlization here + for plot_row_value in plot_row_values: + # Adapt the kwargs for the single value + kwargs_single_point = kwargs_single_plot.copy() + kwargs_single_point.update({self.plot_row_item.name: plot_row_value, + self.plot_label_item.name: plot_label_value}) + # The kwargs should not contain list of values + for k, v in kwargs_single_point.items(): + assert not isinstance(v, list), '"{}" argument is a list'.format(k) + # Compute ordinate values + ordinate_name_to_ordinate_value = self.cached_compute_value_from_kwargs_single_point(**kwargs_single_point) + print(ordinate_name, plot_row_value) + plot_row_value_to_ordinate_value[plot_row_value] = ordinate_name_to_ordinate_value[ordinate_name] + # Plot the figure + plot_ordinate_values = [plot_row_value_to_ordinate_value[plot_row_value] for plot_row_value in + plot_row_values] + return plot_ordinate_values + + def compute_value_from_kwargs_single_point(self, **kwargs_single_point) -> dict: pass + + def cached_compute_value_from_kwargs_single_point(self, **kwargs_single_point) -> dict: + return self.compute_value_from_kwargs_single_point(**kwargs_single_point) + + def add_title(self, ax, grid_column_value, grid_row_value): + title_str = self.grid_row_item.display_name_from_value(grid_row_value) + title_str += ' ' if len(title_str) > 0 else '' + title_str += self.grid_column_item.display_name_from_value(grid_column_value) + ax.set_title(title_str)