diff --git a/experiment/fit_diagnosis/main_split.py b/experiment/fit_diagnosis/main_split.py deleted file mode 100644 index c6c326198163b33665738e62b0b5704f5ff45bef..0000000000000000000000000000000000000000 --- a/experiment/fit_diagnosis/main_split.py +++ /dev/null @@ -1,51 +0,0 @@ -import random - -from experiment.fit_diagnosis.split_curve import SplitCurve, LocFunction -from extreme_estimator.estimator.full_estimator import FullEstimatorInASingleStepWithSmoothMargin -from extreme_estimator.estimator.margin_estimator import SmoothMarginEstimator -from extreme_estimator.extreme_models.margin_model.smooth_margin_model import ConstantMarginModel, \ - LinearAllParametersAllDimsMarginModel -from extreme_estimator.extreme_models.max_stable_model.max_stable_models import Smith -from extreme_estimator.gev_params import GevParams -from spatio_temporal_dataset.coordinates.unidimensional_coordinates.coordinates_1D import LinSpaceCoordinates - -from spatio_temporal_dataset.dataset.simulation_dataset import FullSimulatedDataset -from spatio_temporal_dataset.slicer.spatial_slicer import SpatialSlicer -from spatio_temporal_dataset.slicer.spatio_temporal_slicer import SpatioTemporalSlicer - -random.seed(42) - - -def load_dataset(): - nb_points = 50 - nb_obs = 60 - coordinates = LinSpaceCoordinates.from_nb_points(nb_points=nb_points, train_split_ratio=0.8) - - # MarginModel Linear with respect to the shape (from 0.01 to 0.02) - params_sample = { - (GevParams.GEV_LOC, 0): 10, - (GevParams.GEV_SHAPE, 0): 1.0, - (GevParams.GEV_SCALE, 0): 1.0, - } - margin_model = ConstantMarginModel(coordinates=coordinates, params_sample=params_sample) - max_stable_model = Smith() - - return FullSimulatedDataset.from_double_sampling(nb_obs=nb_obs, margin_model=margin_model, - coordinates=coordinates, - max_stable_model=max_stable_model) - - -def full_estimator(dataset): - max_stable_model = Smith() - margin_model_for_estimator = LinearAllParametersAllDimsMarginModel(dataset.coordinates) - # full_estimator = FullEstimatorInASingleStepWithSmoothMargin(dataset, margin_model_for_estimator, max_stable_model) - fast_estimator = SmoothMarginEstimator(dataset, margin_model_for_estimator) - return fast_estimator - - -if __name__ == '__main__': - dataset = load_dataset() - dataset.slicer.summary() - full_estimator = full_estimator(dataset) - curve = SplitCurve(dataset, full_estimator, margin_functions=[LocFunction()]) - curve.visualize() diff --git a/experiment/fit_diagnosis/split_curve.py b/experiment/fit_diagnosis/split_curve.py index bc408c5ea02254aae864730ca6faf807833e10c2..a7511791734f5ea4503025f602a2608b0342e377 100644 --- a/experiment/fit_diagnosis/split_curve.py +++ b/experiment/fit_diagnosis/split_curve.py @@ -1,4 +1,5 @@ import numpy as np +import matplotlib.cm as cm import matplotlib.pyplot as plt import seaborn as sns @@ -7,42 +8,53 @@ from typing import Union, List from extreme_estimator.estimator.full_estimator import AbstractFullEstimator from extreme_estimator.estimator.margin_estimator import AbstractMarginEstimator +from extreme_estimator.extreme_models.margin_model.margin_function.combined_margin_function import \ + CombinedMarginFunction from extreme_estimator.extreme_models.margin_model.margin_function.utils import error_dict_between_margin_functions from extreme_estimator.gev_params import GevParams from spatio_temporal_dataset.dataset.simulation_dataset import FullSimulatedDataset from spatio_temporal_dataset.slicer.split import Split, ALL_SPLITS_EXCEPT_ALL -class MarginFunction(object): +class SplitCurve(object): - def margin_function(self, gev_param: GevParams) -> float: - pass + def __init__(self, nb_fit: int = 1): + self.nb_fit = nb_fit + self.margin_function_fitted_all = None + def fit(self, show=True): + self.margin_function_fitted_all = [] -class LocFunction(MarginFunction): + for i in range(self.nb_fit): + # A new dataset with the same margin, but just the observations are resampled + self.dataset = self.load_dataset() + # Both split must be defined + assert not self.dataset.slicer.some_required_ind_are_not_defined + self.margin_function_sample = self.dataset.margin_model.margin_function_sample - def margin_function(self, gev_param: GevParams) -> float: - return gev_param.location + print('Fitting {}/{}...'.format(i + 1, self.nb_fit)) + self.estimator = self.load_estimator(self.dataset) + # Fit the estimator and get the margin_function + self.estimator.fit() + self.margin_function_fitted_all.append(self.estimator.margin_function_fitted) + # Individual error dict + self.error_dict_all = [error_dict_between_margin_functions(self.margin_function_sample, m) + for m in self.margin_function_fitted_all] -class SplitCurve(object): + # Mean margin + self.mean_margin_function_fitted = CombinedMarginFunction.from_margin_functions(self.margin_function_fitted_all) + self.mean_error_dict = error_dict_between_margin_functions(self.margin_function_sample, + self.mean_margin_function_fitted) - def __init__(self, dataset: FullSimulatedDataset, estimator: Union[AbstractFullEstimator, AbstractMarginEstimator], - margin_functions: List[MarginFunction]): - # Dataset is already loaded and will not be modified - self.dataset = dataset - # Both split must be defined - assert not self.dataset.slicer.some_required_ind_are_not_defined - self.margin_function_sample = self.dataset.margin_model.margin_function_sample + if show: + self.visualize() - self.estimator = estimator - # Fit the estimator and get the margin_function - self.estimator.fit() - # todo: potentially I will do the fit several times, and retrieve the mean error - # there is a big variablility so it would be really interesting to average, to make real - self.margin_function_fitted = estimator.margin_function_fitted + def load_dataset(self): + pass - self.error_dict = error_dict_between_margin_functions(self.margin_function_sample, self.margin_function_fitted) + def load_estimator(self, dataset): + pass @property def main_title(self): @@ -58,17 +70,39 @@ class SplitCurve(object): plt.show() def margin_graph(self, ax, gev_value_name): - # Display the fitted curve - self.margin_function_fitted.visualize_single_param(gev_value_name, ax, show=False) + # Create bins of data, each with an associated color corresponding to its error + + data = self.mean_error_dict[gev_value_name].values + nb_bins = 10 + limits = np.linspace(data.min(), data.max(), num=nb_bins + 1) + limits[-1] += 0.01 + colors = cm.binary(limits) + # Display train/test points for split, marker in [(self.dataset.train_split, 'o'), (self.dataset.test_split, 'x')]: - self.margin_function_sample.set_datapoint_display_parameters(split, datapoint_marker=marker) - self.margin_function_sample.visualize_single_param(gev_value_name, ax, show=False) + for left_limit, right_limit, color in zip(limits[:-1], limits[1:], colors): + # Find for the split the index + data_ind = self.mean_error_dict[gev_value_name].loc[ + self.dataset.coordinates.coordinates_index(split)].values + data_filter = np.logical_and(left_limit <= data_ind, data_ind < right_limit) + + self.margin_function_sample.set_datapoint_display_parameters(split, datapoint_marker=marker, + filter=data_filter, color=color) + self.margin_function_sample.visualize_single_param(gev_value_name, ax, show=False) + + # Display the individual fitted curve + self.mean_margin_function_fitted.color = 'lightskyblue' + for m in self.margin_function_fitted_all: + m.visualize_single_param(gev_value_name, ax, show=False) + # Display the mean fitted curve + self.mean_margin_function_fitted.color = 'blue' + self.mean_margin_function_fitted.visualize_single_param(gev_value_name, ax, show=False) def score_graph(self, ax, gev_value_name): # todo: for the moment only the train/test are interresting (the spatio temporal isn"t working yet) + sns.set_style('whitegrid') - s = self.error_dict[gev_value_name] + s = self.mean_error_dict[gev_value_name] for split in self.dataset.splits: ind = self.dataset.coordinates_index(split) data = s.loc[ind].values diff --git a/experiment/fit_diagnosis/split_curve_example.py b/experiment/fit_diagnosis/split_curve_example.py new file mode 100644 index 0000000000000000000000000000000000000000..790ca10e730cc88b85d9fc4f18c6881f2159234a --- /dev/null +++ b/experiment/fit_diagnosis/split_curve_example.py @@ -0,0 +1,55 @@ +from typing import Union + +from experiment.fit_diagnosis.split_curve import SplitCurve +from extreme_estimator.estimator.full_estimator import AbstractFullEstimator +from extreme_estimator.estimator.margin_estimator import AbstractMarginEstimator, ConstantMarginEstimator +from spatio_temporal_dataset.dataset.simulation_dataset import FullSimulatedDataset + +import random + +from experiment.fit_diagnosis.split_curve import SplitCurve +from extreme_estimator.estimator.full_estimator import FullEstimatorInASingleStepWithSmoothMargin +from extreme_estimator.estimator.margin_estimator import SmoothMarginEstimator +from extreme_estimator.extreme_models.margin_model.smooth_margin_model import ConstantMarginModel, \ + LinearAllParametersAllDimsMarginModel +from extreme_estimator.extreme_models.max_stable_model.max_stable_models import Smith +from extreme_estimator.gev_params import GevParams +from spatio_temporal_dataset.coordinates.unidimensional_coordinates.coordinates_1D import LinSpaceCoordinates +from spatio_temporal_dataset.dataset.simulation_dataset import FullSimulatedDataset + + +class SplitCurveExample(SplitCurve): + + def __init__(self, nb_fit: int = 1): + super().__init__(nb_fit) + self.nb_points = 50 + self.nb_obs = 60 + self.coordinates = LinSpaceCoordinates.from_nb_points(nb_points=self.nb_points, train_split_ratio=0.8) + # MarginModel Linear with respect to the shape (from 0.01 to 0.02) + params_sample = { + (GevParams.GEV_LOC, 0): 10, + (GevParams.GEV_SHAPE, 0): 1.0, + (GevParams.GEV_SCALE, 0): 1.0, + } + self.margin_model = ConstantMarginModel(coordinates=self.coordinates, params_sample=params_sample) + self.max_stable_model = Smith() + + def load_dataset(self): + return FullSimulatedDataset.from_double_sampling(nb_obs=self.nb_obs, margin_model=self.margin_model, + coordinates=self.coordinates, + max_stable_model=self.max_stable_model) + + def load_estimator(self, dataset): + max_stable_model = Smith() + margin_model_for_estimator = LinearAllParametersAllDimsMarginModel(dataset.coordinates) + estimator = FullEstimatorInASingleStepWithSmoothMargin(dataset, margin_model_for_estimator, max_stable_model) + # estimator = SmoothMarginEstimator(dataset, margin_model_for_estimator) + return estimator + + + + + +if __name__ == '__main__': + curve = SplitCurveExample(nb_fit=2) + curve.fit() diff --git a/experiment/return_level_plot/__init__.py b/experiment/return_level_plot/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/experiment/return_level_plot/spatial_2D_plot.py b/experiment/return_level_plot/spatial_2D_plot.py deleted file mode 100644 index 8347b0e90cf475097612577e6dc9296a8d698542..0000000000000000000000000000000000000000 --- a/experiment/return_level_plot/spatial_2D_plot.py +++ /dev/null @@ -1,30 +0,0 @@ -from itertools import product -from typing import List, Dict - -import matplotlib.pyplot as plt - -from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \ - AbstractMarginFunction -from extreme_estimator.gev_params import GevParams - -plt.style.use('seaborn-white') - - -class Spatial2DPlot(object): - - def __init__(self, name_to_margin_function: Dict[str, AbstractMarginFunction]): - self.name_to_margin_function = name_to_margin_function # type: Dict[str, AbstractMarginFunction] - self.grid_columns = GevParams.GEV_PARAM_NAMES - - def plot(self): - nb_grid_rows, nb_grid_columns = len(self.name_to_margin_function), len(self.grid_columns) - fig, axes = plt.subplots(nb_grid_rows, nb_grid_columns, sharex='col', sharey='row') - fig.subplots_adjust(hspace=0.4, wspace=0.4) - margin_function: AbstractMarginFunction - for i, (name, margin_function) in enumerate(self.name_to_margin_function.items()): - for j, param_name in enumerate(self.grid_columns): - ax = axes[i, j] if nb_grid_rows > 1 else axes[j] - margin_function.visualize_2D(gev_param_name=param_name, ax=ax) - ax.set_title("{} for {}".format(param_name, name)) - fig.suptitle('Spatial2DPlot') - plt.show() diff --git a/extreme_estimator/estimator/margin_estimator.py b/extreme_estimator/estimator/margin_estimator.py index e4268ddede6f64a3fd22dbcb3bad14dc50ab4e0f..829cd65c7305c51cfa8421ba8392a98eb3b89250 100644 --- a/extreme_estimator/estimator/margin_estimator.py +++ b/extreme_estimator/estimator/margin_estimator.py @@ -22,6 +22,17 @@ class PointWiseMarginEstimator(AbstractMarginEstimator): pass +class ConstantMarginEstimator(AbstractMarginEstimator): + + def __init__(self, dataset: AbstractDataset, margin_model: LinearMarginModel): + super().__init__(dataset) + assert isinstance(margin_model, LinearMarginModel) + self.margin_model = margin_model + + def _fit(self): + self._margin_function_fitted = self.margin_model.margin_function_start_fit + + class SmoothMarginEstimator(AbstractMarginEstimator): """# with different type of marginals: cosntant, linear....""" diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py index 0d65ba0e0436e619da477d35665a40d676fdf937..1072d4a5d4749f32efffa3b6639fce674221ad13 100644 --- a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py +++ b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py @@ -21,6 +21,8 @@ class AbstractMarginFunction(object): self.datapoint_display = False self.spatio_temporal_split = Split.all self.datapoint_marker = 'o' + self.color = 'skyblue' + self.filter = None def get_gev_params(self, coordinate: np.ndarray) -> GevParams: """Main method that maps each coordinate to its GEV parameters""" @@ -40,21 +42,23 @@ class AbstractMarginFunction(object): # Visualization function - def set_datapoint_display_parameters(self, spatio_temporal_split, datapoint_marker): + def set_datapoint_display_parameters(self, spatio_temporal_split, datapoint_marker, filter=None, color=None): self.datapoint_display = True self.spatio_temporal_split = spatio_temporal_split self.datapoint_marker = datapoint_marker + self.filter = filter + self.color = color def visualize(self, axes=None, show=True, dot_display=False): self.datapoint_display = dot_display if axes is None: - fig, axes = plt.subplots(3, 1, sharex='col', sharey='row') - fig.subplots_adjust(hspace=0.4, wspace=0.4, ) + fig, axes = plt.subplots(1, len(GevParams.GEV_VALUE_NAMES)) + fig.subplots_adjust(hspace=1.0, wspace=1.0) self.visualization_axes = axes - for i, gev_param_name in enumerate(GevParams.GEV_PARAM_NAMES): + for i, gev_value_name in enumerate(GevParams.GEV_VALUE_NAMES): ax = axes[i] - self.visualize_single_param(gev_param_name, ax, show=False) - title_str = gev_param_name + self.visualize_single_param(gev_value_name, ax, show=False) + title_str = gev_value_name ax.set_title(title_str) if show: plt.show() @@ -68,45 +72,32 @@ class AbstractMarginFunction(object): else: raise NotImplementedError('3D Margin visualization not yet implemented') + # Visualization 1D + def visualize_1D(self, gev_value_name=GevParams.GEV_LOC, ax=None, show=True): x = self.coordinates.x_coordinates grid, linspace = self.get_grid_values_1D(x) if ax is None: ax = plt.gca() if self.datapoint_display: - ax.plot(linspace, grid[gev_value_name], self.datapoint_marker) + ax.plot(linspace, grid[gev_value_name], self.datapoint_marker, color=self.color) else: - ax.plot(linspace, grid[gev_value_name]) + ax.plot(linspace, grid[gev_value_name], color=self.color) # X axis - ax.set_xlabel('coordinate') + ax.set_xlabel('coordinate X') plt.setp(ax.get_xticklabels(), visible=True) ax.xaxis.set_tick_params(labelbottom=True) - # Y axis - ax.set_ylabel(gev_value_name) - plt.setp(ax.get_yticklabels(), visible=True) - ax.yaxis.set_tick_params(labelbottom=True) if show: plt.show() - def visualize_2D(self, gev_param_name=GevParams.GEV_LOC, ax=None, show=True): - x = self.coordinates.x_coordinates - y = self.coordinates.y_coordinates - grid = self.get_grid_2D(x, y) - gev_param_idx = GevParams.GEV_PARAM_NAMES.index(gev_param_name) - if ax is None: - ax = plt.gca() - imshow_method = ax.imshow - imshow_method(grid[..., gev_param_idx], extent=(x.min(), x.max(), y.min(), y.max()), - interpolation='nearest', cmap=cm.gist_rainbow) - # todo: add dot display in 2D - if show: - plt.show() - def get_grid_values_1D(self, x): # TODO: to avoid getting the value several times, I could cache the results if self.datapoint_display: + # todo: keep only the index of interest here linspace = self.coordinates.coordinates_values(self.spatio_temporal_split)[:, 0] + if self.filter is not None: + linspace = linspace[self.filter] resolution = len(linspace) else: resolution = 100 @@ -119,10 +110,35 @@ class AbstractMarginFunction(object): grid = {gev_param: [g[gev_param] for g in grid] for gev_param in GevParams.GEV_VALUE_NAMES} return grid, linspace + # Visualization 2D + + def visualize_2D(self, gev_value_name=GevParams.GEV_LOC, ax=None, show=True): + x = self.coordinates.x_coordinates + y = self.coordinates.y_coordinates + grid = self.get_grid_2D(x, y) + if ax is None: + ax = plt.gca() + imshow_method = ax.imshow + imshow_method(grid[gev_value_name], extent=(x.min(), x.max(), y.min(), y.max()), + interpolation='nearest', cmap=cm.gist_rainbow) + # X axis + ax.set_xlabel('coordinate X') + plt.setp(ax.get_xticklabels(), visible=True) + ax.xaxis.set_tick_params(labelbottom=True) + # Y axis + ax.set_ylabel('coordinate Y') + plt.setp(ax.get_yticklabels(), visible=True) + ax.yaxis.set_tick_params(labelbottom=True) + # todo: add dot display in 2D + if show: + plt.show() + def get_grid_2D(self, x, y): resolution = 100 - grid = np.zeros([resolution, resolution, 3]) + grid = [] for i, xi in enumerate(np.linspace(x.min(), x.max(), resolution)): for j, yj in enumerate(np.linspace(y.min(), y.max(), resolution)): - grid[i, j] = self.get_gev_params(np.array([xi, yj])).to_array() + grid.append(self.get_gev_params(np.array([xi, yj])).value_dict) + grid = {value_name: np.array([g[value_name] for g in grid]).reshape([resolution, resolution]) + for value_name in GevParams.GEV_VALUE_NAMES} return grid diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/combined_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/combined_margin_function.py new file mode 100644 index 0000000000000000000000000000000000000000..7d149cb78fea9822f2c0826a8a75f06409dc180d --- /dev/null +++ b/extreme_estimator/extreme_models/margin_model/margin_function/combined_margin_function.py @@ -0,0 +1,30 @@ +from typing import List + +import numpy as np + +from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \ + AbstractMarginFunction +from extreme_estimator.gev_params import GevParams +from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates + + +class CombinedMarginFunction(AbstractMarginFunction): + + def __init__(self, coordinates: AbstractCoordinates, margin_functions: List[AbstractMarginFunction]): + super().__init__(coordinates) + self.margin_functions = margin_functions # type: List[AbstractMarginFunction] + + def get_gev_params(self, coordinate: np.ndarray) -> GevParams: + gev_params_list = [margin_function.get_gev_params(coordinate) for margin_function in self.margin_functions] + mean_gev_params = np.mean(np.array([gev_param.to_array() for gev_param in gev_params_list]), axis=0) + gev_param = GevParams(*mean_gev_params) + return gev_param + + @classmethod + def from_margin_functions(cls, margin_functions: List[AbstractMarginFunction]): + assert len(margin_functions) > 0 + assert all([isinstance(margin_function, AbstractMarginFunction) for margin_function in margin_functions]) + common_coordinates = set([margin_function.coordinates for margin_function in margin_functions]) + assert len(common_coordinates) == 1 + coordinates = common_coordinates.pop() + return cls(coordinates, margin_functions) diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/utils.py b/extreme_estimator/extreme_models/margin_model/margin_function/utils.py index 2ea29cc1f5368de65e1deb493ed6ba7162a378d0..4f044b987f4cb0778dfaa87d86e05f95949332c4 100644 --- a/extreme_estimator/extreme_models/margin_model/margin_function/utils.py +++ b/extreme_estimator/extreme_models/margin_model/margin_function/utils.py @@ -20,10 +20,8 @@ def error_dict_between_margin_functions(reference: AbstractMarginFunction, fitte assert reference.coordinates == fitted.coordinates reference_values = reference.gev_value_name_to_serie fitted_values = fitted.gev_value_name_to_serie - gev_param_name_to_error_serie = {} - for value_name in GevParams.GEV_VALUE_NAMES: - print(value_name) - error = relative_abs_error(reference_values[value_name], fitted_values[value_name]) - gev_param_name_to_error_serie[value_name] = error + for gev_value_name in GevParams.GEV_VALUE_NAMES: + error = relative_abs_error(reference_values[gev_value_name], fitted_values[gev_value_name]) + gev_param_name_to_error_serie[gev_value_name] = error return gev_param_name_to_error_serie diff --git a/test/test_experiment/test_split_curve.py b/test/test_experiment/test_split_curve.py new file mode 100644 index 0000000000000000000000000000000000000000..4b9b1d69093ec2ff3269f51008e50bb868ccd8ac --- /dev/null +++ b/test/test_experiment/test_split_curve.py @@ -0,0 +1,56 @@ +import unittest +from typing import Union + +from experiment.fit_diagnosis.split_curve import SplitCurve +from extreme_estimator.estimator.full_estimator import AbstractFullEstimator +from extreme_estimator.estimator.margin_estimator import AbstractMarginEstimator, ConstantMarginEstimator +from spatio_temporal_dataset.dataset.simulation_dataset import FullSimulatedDataset + +import random + +from experiment.fit_diagnosis.split_curve import SplitCurve +from extreme_estimator.estimator.full_estimator import FullEstimatorInASingleStepWithSmoothMargin +from extreme_estimator.estimator.margin_estimator import SmoothMarginEstimator +from extreme_estimator.extreme_models.margin_model.smooth_margin_model import ConstantMarginModel, \ + LinearAllParametersAllDimsMarginModel +from extreme_estimator.extreme_models.max_stable_model.max_stable_models import Smith +from extreme_estimator.gev_params import GevParams +from spatio_temporal_dataset.coordinates.unidimensional_coordinates.coordinates_1D import LinSpaceCoordinates +from spatio_temporal_dataset.dataset.simulation_dataset import FullSimulatedDataset + + +class TestSplitCurve(unittest.TestCase): + DISPLAY = False + + class SplitCurveFastForTest(SplitCurve): + + def __init__(self, nb_fit: int = 1): + super().__init__(nb_fit) + self.nb_points = 50 + self.nb_obs = 60 + self.coordinates = LinSpaceCoordinates.from_nb_points(nb_points=self.nb_points, train_split_ratio=0.8) + # MarginModel Linear with respect to the shape (from 0.01 to 0.02) + params_sample = { + (GevParams.GEV_LOC, 0): 10, + (GevParams.GEV_SHAPE, 0): 1.0, + (GevParams.GEV_SCALE, 0): 1.0, + } + self.margin_model = ConstantMarginModel(coordinates=self.coordinates, params_sample=params_sample) + self.max_stable_model = Smith() + + def load_dataset(self): + return FullSimulatedDataset.from_double_sampling(nb_obs=self.nb_obs, margin_model=self.margin_model, + coordinates=self.coordinates, + max_stable_model=self.max_stable_model) + + def load_estimator(self, dataset): + # todo: create a test from that example + return ConstantMarginEstimator(dataset, LinearAllParametersAllDimsMarginModel(dataset.coordinates)) + + def test_split_curve(self): + s = self.SplitCurveFastForTest(nb_fit=2) + s.fit(show=self.DISPLAY) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_extreme_estimator/test_estimator/test_margin_estimators.py b/test/test_extreme_estimator/test_estimator/test_margin_estimators.py index cd7b0a9e1d3c91da72e25e4f54d20abd28d6a60d..f9a51a737613b4fe6bd4855cf9ba475b9c3b9b2c 100644 --- a/test/test_extreme_estimator/test_estimator/test_margin_estimators.py +++ b/test/test_extreme_estimator/test_estimator/test_margin_estimators.py @@ -1,7 +1,6 @@ import unittest from extreme_estimator.estimator.margin_estimator import SmoothMarginEstimator -from experiment.return_level_plot.spatial_2D_plot import Spatial2DPlot from spatio_temporal_dataset.dataset.simulation_dataset import MarginDataset from test.test_utils import load_smooth_margin_models, load_test_1D_and_2D_coordinates @@ -17,21 +16,15 @@ class TestSmoothMarginEstimator(unittest.TestCase): def test_dependency_estimators(self): for coordinates in self.coordinates: smooth_margin_models = load_smooth_margin_models(coordinates=coordinates) - for margin_model in smooth_margin_models: + for margin_model in smooth_margin_models[1:]: dataset = MarginDataset.from_sampling(nb_obs=10, margin_model=margin_model, coordinates=coordinates) # Fit estimator estimator = SmoothMarginEstimator(dataset=dataset, margin_model=margin_model) estimator.fit() - # Map name to their margin functions - name_to_margin_function = { - 'Ground truth margin function': dataset.margin_model.margin_function_sample, - 'Estimated margin function': estimator.margin_function_fitted, - } - # Spatial Plot - if self.DISPLAY: - Spatial2DPlot(name_to_margin_function=name_to_margin_function).plot() + # Plot + margin_model.margin_function_sample.visualize(show=self.DISPLAY) self.assertTrue(True)