From 7a0f83b577e760fde82daff211b168eb5864a838 Mon Sep 17 00:00:00 2001 From: Le Roux Erwan <erwan.le-roux@irstea.fr> Date: Wed, 16 Oct 2019 14:25:53 +0200 Subject: [PATCH] [EXTREME ESTIMATOR][ESTIMATOR] refactor estimator --- .../unidimensional_robustness.py | 2 +- .../estimator/abstract_estimator.py | 32 ++++--------------- .../full_estimator/abstract_full_estimator.py | 8 +++-- .../abstract_margin_estimator.py | 12 +++---- .../abstract_max_stable_estimator.py | 4 +-- 5 files changed, 20 insertions(+), 38 deletions(-) diff --git a/experiment/robustness_plot/estimation_robustness/unidimensional_robustness/unidimensional_robustness.py b/experiment/robustness_plot/estimation_robustness/unidimensional_robustness/unidimensional_robustness.py index 068faa64..937a65c9 100644 --- a/experiment/robustness_plot/estimation_robustness/unidimensional_robustness/unidimensional_robustness.py +++ b/experiment/robustness_plot/estimation_robustness/unidimensional_robustness/unidimensional_robustness.py @@ -45,7 +45,7 @@ def multiple_unidimensional_robustness(): # Put only the parameter that will vary spatial_robustness.robustness_grid_plot(**{ - SinglePlot.OrdinateItem.name: [AbstractEstimator.MAE_ERROR, AbstractEstimator.DURATION], + # SinglePlot.OrdinateItem.name: [AbstractEstimator.MAE_ERROR, AbstractEstimator.DURATION], MaxStableProcessPlot.NbStationItem.name: nb_stations, MaxStableProcessPlot.NbObservationItem.name: nb_observation, MaxStableProcessPlot.MaxStableModelItem.name: msp_models, diff --git a/extreme_estimator/estimator/abstract_estimator.py b/extreme_estimator/estimator/abstract_estimator.py index ae736464..2becbbfb 100644 --- a/extreme_estimator/estimator/abstract_estimator.py +++ b/extreme_estimator/estimator/abstract_estimator.py @@ -1,5 +1,7 @@ import time +from cached_property import cached_property + from extreme_estimator.extreme_models.margin_model.linear_margin_model.linear_margin_model import LinearMarginModel from extreme_estimator.extreme_models.result_from_fit import ResultFromFit from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \ @@ -9,42 +11,26 @@ from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset class AbstractEstimator(object): - DURATION = 'Average duration' - MAE_ERROR = 'Mean Average Error' - - # For each estimator, we shall define: - # - The loss - # - The optimization method for each part of the process def __init__(self, dataset: AbstractDataset): self.dataset = dataset # type: AbstractDataset - self.additional_information = dict() self._result_from_fit = None # type: ResultFromFit - self._margin_function_fitted = None - self._max_stable_model_fitted = None @classmethod def from_dataset(cls, dataset: AbstractDataset): - # raise NotImplementedError('from_dataset class constructor has not been implemented for this class') - pass + raise NotImplementedError def fit(self): - ts = time.time() - self._fit() - te = time.time() - self.additional_information[self.DURATION] = int((te - ts) * 1000) + raise NotImplementedError @property def result_from_fit(self) -> ResultFromFit: assert self._result_from_fit is not None, 'Fit has not be done' return self._result_from_fit - @property + @cached_property def margin_function_fitted(self) -> AbstractMarginFunction: - if self._margin_function_fitted is None: - self._margin_function_fitted = self.extract_function_fitted() - assert self._margin_function_fitted is not None, 'No margin function has been fitted' - return self._margin_function_fitted + return self.extract_function_fitted() def extract_function_fitted(self) -> AbstractMarginFunction: raise NotImplementedError @@ -55,11 +41,7 @@ class AbstractEstimator(object): coef_dict=self.result_from_fit.margin_coef_dict, starting_point=margin_model.starting_point) + # @property def train_split(self): return self.dataset.train_split - - # Methods to override in the child class - - def _fit(self): - raise NotImplementedError diff --git a/extreme_estimator/estimator/full_estimator/abstract_full_estimator.py b/extreme_estimator/estimator/full_estimator/abstract_full_estimator.py index cdce4f7b..85fcb135 100644 --- a/extreme_estimator/estimator/full_estimator/abstract_full_estimator.py +++ b/extreme_estimator/estimator/full_estimator/abstract_full_estimator.py @@ -1,3 +1,5 @@ +from cached_property import cached_property + from extreme_estimator.estimator.abstract_estimator import AbstractEstimator from extreme_estimator.estimator.margin_estimator.abstract_margin_estimator import LinearMarginEstimator from extreme_estimator.estimator.max_stable_estimator.abstract_max_stable_estimator import MaxStableEstimator @@ -23,7 +25,7 @@ class SmoothMarginalsThenUnitaryMsp(AbstractFullEstimator): self.margin_estimator = LinearMarginEstimator(dataset=dataset, margin_model=margin_model) self.max_stable_estimator = MaxStableEstimator(dataset=dataset, max_stable_model=max_stable_model) - def _fit(self): + def fit(self): # Estimate the margin parameters self.margin_estimator.fit() # Compute the maxima_frech @@ -65,7 +67,7 @@ class FullEstimatorInASingleStepWithSmoothMargin(AbstractFullEstimator): return self.dataset.coordinates.df_temporal_coordinates_for_fit(split=self.train_split, starting_point=self.linear_margin_model.starting_point) - def _fit(self): + def fit(self): # Estimate both the margin and the max-stable structure self._result_from_fit = self.max_stable_model.fitmaxstab( data_gev=self.dataset.maxima_gev_for_spatial_extremes_package(self.train_split), @@ -79,7 +81,7 @@ class FullEstimatorInASingleStepWithSmoothMargin(AbstractFullEstimator): def extract_function_fitted(self): return self.extract_function_fitted_from_the_model_shape(self.linear_margin_model) - @property + @cached_property def margin_function_fitted(self) -> LinearMarginFunction: return super().margin_function_fitted diff --git a/extreme_estimator/estimator/margin_estimator/abstract_margin_estimator.py b/extreme_estimator/estimator/margin_estimator/abstract_margin_estimator.py index 03aa7d4b..afdf32c8 100644 --- a/extreme_estimator/estimator/margin_estimator/abstract_margin_estimator.py +++ b/extreme_estimator/estimator/margin_estimator/abstract_margin_estimator.py @@ -1,5 +1,7 @@ from abc import ABC +from cached_property import cached_property + from extreme_estimator.estimator.abstract_estimator import AbstractEstimator from extreme_estimator.extreme_models.margin_model.linear_margin_model.linear_margin_model import LinearMarginModel from extreme_estimator.extreme_models.margin_model.margin_function.linear_margin_function import LinearMarginFunction @@ -11,8 +13,6 @@ class AbstractMarginEstimator(AbstractEstimator, ABC): def __init__(self, dataset: AbstractDataset): super().__init__(dataset) assert self.dataset.maxima_gev() is not None - self._margin_function_fitted = None - class LinearMarginEstimator(AbstractMarginEstimator): """# with different type of marginals: cosntant, linear....""" @@ -25,7 +25,7 @@ class LinearMarginEstimator(AbstractMarginEstimator): assert isinstance(margin_model, LinearMarginModel) self.margin_model = margin_model - def _fit(self): + def fit(self): maxima_gev_specialized = self.dataset.maxima_gev_for_spatial_extremes_package(self.train_split) df_coordinates_spat = self.dataset.coordinates.df_spatial_coordinates(self.train_split) df_coordinates_temp = self.dataset.coordinates.df_temporal_coordinates_for_fit(split=self.train_split, @@ -34,11 +34,9 @@ class LinearMarginEstimator(AbstractMarginEstimator): df_coordinates_spat=df_coordinates_spat, df_coordinates_temp=df_coordinates_temp) - @property + @cached_property def margin_function_fitted(self) -> LinearMarginFunction: - margin_function_fitted = super().margin_function_fitted - assert isinstance(margin_function_fitted, LinearMarginFunction) - return margin_function_fitted + return super().margin_function_fitted def extract_function_fitted(self) -> LinearMarginFunction: return self.extract_function_fitted_from_the_model_shape(self.margin_model) diff --git a/extreme_estimator/estimator/max_stable_estimator/abstract_max_stable_estimator.py b/extreme_estimator/estimator/max_stable_estimator/abstract_max_stable_estimator.py index 868c7c07..8a9bed24 100644 --- a/extreme_estimator/estimator/max_stable_estimator/abstract_max_stable_estimator.py +++ b/extreme_estimator/estimator/max_stable_estimator/abstract_max_stable_estimator.py @@ -16,7 +16,7 @@ class AbstractMaxStableEstimator(AbstractEstimator): class MaxStableEstimator(AbstractMaxStableEstimator): - def _fit(self): + def fit(self): assert self.dataset.maxima_frech(split=self.train_split) is not None self._result_from_fit = self.max_stable_model.fitmaxstab( data_frech=self.dataset.maxima_frech_for_spatial_extremes_package(split=self.train_split), @@ -25,7 +25,7 @@ class MaxStableEstimator(AbstractMaxStableEstimator): def scalars(self, true_max_stable_params: dict): error = self._error(true_max_stable_params) - return {**error, **self.additional_information} + return {**error} def _error(self, true_max_stable_params: dict): absolute_errors = {param_name: np.abs(param_true_value - self.max_stable_params_fitted[param_name]) -- GitLab