diff --git a/experiment/split/main_split.py b/experiment/split/main_split.py new file mode 100644 index 0000000000000000000000000000000000000000..c56b4420ed348ad51c62c1fcefc0677eacfb9dbe --- /dev/null +++ b/experiment/split/main_split.py @@ -0,0 +1,47 @@ +from experiment.split.split_curve import SplitCurve, LocFunction +from extreme_estimator.estimator.full_estimator import FullEstimatorInASingleStepWithSmoothMargin +from extreme_estimator.extreme_models.margin_model.smooth_margin_model import ConstantMarginModel +from extreme_estimator.extreme_models.max_stable_model.max_stable_models import Smith +from extreme_estimator.gev_params import GevParams +from spatio_temporal_dataset.coordinates.unidimensional_coordinates.coordinates_1D import LinSpaceCoordinates + +from spatio_temporal_dataset.dataset.simulation_dataset import FullSimulatedDataset +from spatio_temporal_dataset.slicer.spatial_slicer import SpatialSlicer +from spatio_temporal_dataset.slicer.spatio_temporal_slicer import SpatioTemporalSlicer + + +def load_dataset(): + nb_points = 50 + nb_obs = 60 + coordinates = LinSpaceCoordinates.from_nb_points(nb_points=nb_points, train_split_ratio=0.8) + + # MarginModel Linear with respect to the shape (from 0.01 to 0.02) + params_sample = { + # (GevParams.GEV_SHAPE, 0): 0.2, + (GevParams.GEV_LOC, 0): 10, + (GevParams.GEV_SHAPE, 0): 1.0, + (GevParams.GEV_SCALE, 0): 1.0, + } + margin_model = ConstantMarginModel(coordinates=coordinates, params_sample=params_sample) + max_stable_model = Smith() + + return FullSimulatedDataset.from_double_sampling(nb_obs=nb_obs, margin_model=margin_model, + coordinates=coordinates, + max_stable_model=max_stable_model, + train_split_ratio=0.8, + slicer_class=SpatioTemporalSlicer) + + +def full_estimator(dataset): + max_stable_model = Smith() + margin_model_for_estimator = ConstantMarginModel(dataset.coordinates) + full_estimator = FullEstimatorInASingleStepWithSmoothMargin(dataset, margin_model_for_estimator, max_stable_model) + return full_estimator + + +if __name__ == '__main__': + dataset = load_dataset() + dataset.slicer.summary() + full_estimator = full_estimator(dataset) + curve = SplitCurve(dataset, full_estimator, margin_functions=[LocFunction()]) + curve.visualize() diff --git a/experiment/split/split_curve.py b/experiment/split/split_curve.py new file mode 100644 index 0000000000000000000000000000000000000000..a9bb6f70ff29a1a2090e7f477e67acc881d4cb6f --- /dev/null +++ b/experiment/split/split_curve.py @@ -0,0 +1,71 @@ +import numpy as np + +import matplotlib.pyplot as plt +import seaborn as sns + +from typing import Union, List + +from extreme_estimator.estimator.full_estimator import AbstractFullEstimator +from extreme_estimator.estimator.margin_estimator import AbstractMarginEstimator +from extreme_estimator.extreme_models.margin_model.margin_function.utils import error_dict_between_margin_functions +from extreme_estimator.gev_params import GevParams +from spatio_temporal_dataset.dataset.simulation_dataset import FullSimulatedDataset +from spatio_temporal_dataset.slicer.split import Split, ALL_SPLITS_EXCEPT_ALL + + +class MarginFunction(object): + + def margin_function(self, gev_param: GevParams) -> float: + pass + + +class LocFunction(MarginFunction): + + def margin_function(self, gev_param: GevParams) -> float: + return gev_param.location + + +class SplitCurve(object): + + def __init__(self, dataset: FullSimulatedDataset, estimator: Union[AbstractFullEstimator, AbstractMarginEstimator], + margin_functions: List[MarginFunction]): + # Dataset is already loaded and will not be modified + self.dataset = dataset + # Both split must be defined + assert not self.dataset.slicer.some_required_ind_are_not_defined + self.margin_function_sample = self.dataset.margin_model.margin_function_sample + + self.estimator = estimator + # Fit the estimator and get the margin_function + self.estimator.fit() + # todo: potentially I will do the fit several times, and retrieve the mean error + self.margin_function_fitted = estimator.margin_function_fitted + + self.error_dict = error_dict_between_margin_functions(self.margin_function_sample, self.margin_function_fitted) + # todo: add quantile curve, additionally to the marginal curve + + def visualize(self): + fig, axes = plt.subplots(3, 2, sharex='col', sharey='row') + fig.subplots_adjust(hspace=0.4, wspace=0.4, ) + for i, gev_param_name in enumerate(GevParams.GEV_PARAM_NAMES): + self.margin_graph(axes[i, 0], gev_param_name) + self.score_graph(axes[i, 1], gev_param_name) + plt.show() + + def margin_graph(self, ax, gev_param_name): + # Display the fitted curve + self.margin_function_fitted.visualize_single_param(gev_param_name, ax, show=False) + # Display train/test points + for split, marker in [(self.dataset.train_split, 'o'), (self.dataset.test_split, 'x')]: + self.margin_function_sample.set_datapoint_display_parameters(split, datapoint_marker=marker) + self.margin_function_sample.visualize_single_param(gev_param_name, ax, show=False) + title_str = gev_param_name + ax.set_title(title_str) + + def score_graph(self, ax, gev_param_name): + for split in self.dataset.splits: + s = self.error_dict[gev_param_name] + data = [1.5] * 7 + [2.5] * 2 + [3.5] * 8 + [4.5] * 3 + [5.5] * 1 + [6.5] * 8 + sns.set_style('whitegrid') + sns.kdeplot(np.array(data), bw=0.5, ax=ax) + print()