Commit 7a0f83b5 authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[EXTREME ESTIMATOR][ESTIMATOR] refactor estimator

parent 486ab59d
No related merge requests found
Showing with 20 additions and 38 deletions
+20 -38
...@@ -45,7 +45,7 @@ def multiple_unidimensional_robustness(): ...@@ -45,7 +45,7 @@ def multiple_unidimensional_robustness():
# Put only the parameter that will vary # Put only the parameter that will vary
spatial_robustness.robustness_grid_plot(**{ spatial_robustness.robustness_grid_plot(**{
SinglePlot.OrdinateItem.name: [AbstractEstimator.MAE_ERROR, AbstractEstimator.DURATION], # SinglePlot.OrdinateItem.name: [AbstractEstimator.MAE_ERROR, AbstractEstimator.DURATION],
MaxStableProcessPlot.NbStationItem.name: nb_stations, MaxStableProcessPlot.NbStationItem.name: nb_stations,
MaxStableProcessPlot.NbObservationItem.name: nb_observation, MaxStableProcessPlot.NbObservationItem.name: nb_observation,
MaxStableProcessPlot.MaxStableModelItem.name: msp_models, MaxStableProcessPlot.MaxStableModelItem.name: msp_models,
......
import time import time
from cached_property import cached_property
from extreme_estimator.extreme_models.margin_model.linear_margin_model.linear_margin_model import LinearMarginModel from extreme_estimator.extreme_models.margin_model.linear_margin_model.linear_margin_model import LinearMarginModel
from extreme_estimator.extreme_models.result_from_fit import ResultFromFit from extreme_estimator.extreme_models.result_from_fit import ResultFromFit
from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \ from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \
...@@ -9,42 +11,26 @@ from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset ...@@ -9,42 +11,26 @@ from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
class AbstractEstimator(object): class AbstractEstimator(object):
DURATION = 'Average duration'
MAE_ERROR = 'Mean Average Error'
# For each estimator, we shall define:
# - The loss
# - The optimization method for each part of the process
def __init__(self, dataset: AbstractDataset): def __init__(self, dataset: AbstractDataset):
self.dataset = dataset # type: AbstractDataset self.dataset = dataset # type: AbstractDataset
self.additional_information = dict()
self._result_from_fit = None # type: ResultFromFit self._result_from_fit = None # type: ResultFromFit
self._margin_function_fitted = None
self._max_stable_model_fitted = None
@classmethod @classmethod
def from_dataset(cls, dataset: AbstractDataset): def from_dataset(cls, dataset: AbstractDataset):
# raise NotImplementedError('from_dataset class constructor has not been implemented for this class') raise NotImplementedError
pass
def fit(self): def fit(self):
ts = time.time() raise NotImplementedError
self._fit()
te = time.time()
self.additional_information[self.DURATION] = int((te - ts) * 1000)
@property @property
def result_from_fit(self) -> ResultFromFit: def result_from_fit(self) -> ResultFromFit:
assert self._result_from_fit is not None, 'Fit has not be done' assert self._result_from_fit is not None, 'Fit has not be done'
return self._result_from_fit return self._result_from_fit
@property @cached_property
def margin_function_fitted(self) -> AbstractMarginFunction: def margin_function_fitted(self) -> AbstractMarginFunction:
if self._margin_function_fitted is None: return self.extract_function_fitted()
self._margin_function_fitted = self.extract_function_fitted()
assert self._margin_function_fitted is not None, 'No margin function has been fitted'
return self._margin_function_fitted
def extract_function_fitted(self) -> AbstractMarginFunction: def extract_function_fitted(self) -> AbstractMarginFunction:
raise NotImplementedError raise NotImplementedError
...@@ -55,11 +41,7 @@ class AbstractEstimator(object): ...@@ -55,11 +41,7 @@ class AbstractEstimator(object):
coef_dict=self.result_from_fit.margin_coef_dict, coef_dict=self.result_from_fit.margin_coef_dict,
starting_point=margin_model.starting_point) starting_point=margin_model.starting_point)
#
@property @property
def train_split(self): def train_split(self):
return self.dataset.train_split return self.dataset.train_split
# Methods to override in the child class
def _fit(self):
raise NotImplementedError
from cached_property import cached_property
from extreme_estimator.estimator.abstract_estimator import AbstractEstimator from extreme_estimator.estimator.abstract_estimator import AbstractEstimator
from extreme_estimator.estimator.margin_estimator.abstract_margin_estimator import LinearMarginEstimator from extreme_estimator.estimator.margin_estimator.abstract_margin_estimator import LinearMarginEstimator
from extreme_estimator.estimator.max_stable_estimator.abstract_max_stable_estimator import MaxStableEstimator from extreme_estimator.estimator.max_stable_estimator.abstract_max_stable_estimator import MaxStableEstimator
...@@ -23,7 +25,7 @@ class SmoothMarginalsThenUnitaryMsp(AbstractFullEstimator): ...@@ -23,7 +25,7 @@ class SmoothMarginalsThenUnitaryMsp(AbstractFullEstimator):
self.margin_estimator = LinearMarginEstimator(dataset=dataset, margin_model=margin_model) self.margin_estimator = LinearMarginEstimator(dataset=dataset, margin_model=margin_model)
self.max_stable_estimator = MaxStableEstimator(dataset=dataset, max_stable_model=max_stable_model) self.max_stable_estimator = MaxStableEstimator(dataset=dataset, max_stable_model=max_stable_model)
def _fit(self): def fit(self):
# Estimate the margin parameters # Estimate the margin parameters
self.margin_estimator.fit() self.margin_estimator.fit()
# Compute the maxima_frech # Compute the maxima_frech
...@@ -65,7 +67,7 @@ class FullEstimatorInASingleStepWithSmoothMargin(AbstractFullEstimator): ...@@ -65,7 +67,7 @@ class FullEstimatorInASingleStepWithSmoothMargin(AbstractFullEstimator):
return self.dataset.coordinates.df_temporal_coordinates_for_fit(split=self.train_split, return self.dataset.coordinates.df_temporal_coordinates_for_fit(split=self.train_split,
starting_point=self.linear_margin_model.starting_point) starting_point=self.linear_margin_model.starting_point)
def _fit(self): def fit(self):
# Estimate both the margin and the max-stable structure # Estimate both the margin and the max-stable structure
self._result_from_fit = self.max_stable_model.fitmaxstab( self._result_from_fit = self.max_stable_model.fitmaxstab(
data_gev=self.dataset.maxima_gev_for_spatial_extremes_package(self.train_split), data_gev=self.dataset.maxima_gev_for_spatial_extremes_package(self.train_split),
...@@ -79,7 +81,7 @@ class FullEstimatorInASingleStepWithSmoothMargin(AbstractFullEstimator): ...@@ -79,7 +81,7 @@ class FullEstimatorInASingleStepWithSmoothMargin(AbstractFullEstimator):
def extract_function_fitted(self): def extract_function_fitted(self):
return self.extract_function_fitted_from_the_model_shape(self.linear_margin_model) return self.extract_function_fitted_from_the_model_shape(self.linear_margin_model)
@property @cached_property
def margin_function_fitted(self) -> LinearMarginFunction: def margin_function_fitted(self) -> LinearMarginFunction:
return super().margin_function_fitted return super().margin_function_fitted
......
from abc import ABC from abc import ABC
from cached_property import cached_property
from extreme_estimator.estimator.abstract_estimator import AbstractEstimator from extreme_estimator.estimator.abstract_estimator import AbstractEstimator
from extreme_estimator.extreme_models.margin_model.linear_margin_model.linear_margin_model import LinearMarginModel from extreme_estimator.extreme_models.margin_model.linear_margin_model.linear_margin_model import LinearMarginModel
from extreme_estimator.extreme_models.margin_model.margin_function.linear_margin_function import LinearMarginFunction from extreme_estimator.extreme_models.margin_model.margin_function.linear_margin_function import LinearMarginFunction
...@@ -11,8 +13,6 @@ class AbstractMarginEstimator(AbstractEstimator, ABC): ...@@ -11,8 +13,6 @@ class AbstractMarginEstimator(AbstractEstimator, ABC):
def __init__(self, dataset: AbstractDataset): def __init__(self, dataset: AbstractDataset):
super().__init__(dataset) super().__init__(dataset)
assert self.dataset.maxima_gev() is not None assert self.dataset.maxima_gev() is not None
self._margin_function_fitted = None
class LinearMarginEstimator(AbstractMarginEstimator): class LinearMarginEstimator(AbstractMarginEstimator):
"""# with different type of marginals: cosntant, linear....""" """# with different type of marginals: cosntant, linear...."""
...@@ -25,7 +25,7 @@ class LinearMarginEstimator(AbstractMarginEstimator): ...@@ -25,7 +25,7 @@ class LinearMarginEstimator(AbstractMarginEstimator):
assert isinstance(margin_model, LinearMarginModel) assert isinstance(margin_model, LinearMarginModel)
self.margin_model = margin_model self.margin_model = margin_model
def _fit(self): def fit(self):
maxima_gev_specialized = self.dataset.maxima_gev_for_spatial_extremes_package(self.train_split) maxima_gev_specialized = self.dataset.maxima_gev_for_spatial_extremes_package(self.train_split)
df_coordinates_spat = self.dataset.coordinates.df_spatial_coordinates(self.train_split) df_coordinates_spat = self.dataset.coordinates.df_spatial_coordinates(self.train_split)
df_coordinates_temp = self.dataset.coordinates.df_temporal_coordinates_for_fit(split=self.train_split, df_coordinates_temp = self.dataset.coordinates.df_temporal_coordinates_for_fit(split=self.train_split,
...@@ -34,11 +34,9 @@ class LinearMarginEstimator(AbstractMarginEstimator): ...@@ -34,11 +34,9 @@ class LinearMarginEstimator(AbstractMarginEstimator):
df_coordinates_spat=df_coordinates_spat, df_coordinates_spat=df_coordinates_spat,
df_coordinates_temp=df_coordinates_temp) df_coordinates_temp=df_coordinates_temp)
@property @cached_property
def margin_function_fitted(self) -> LinearMarginFunction: def margin_function_fitted(self) -> LinearMarginFunction:
margin_function_fitted = super().margin_function_fitted return super().margin_function_fitted
assert isinstance(margin_function_fitted, LinearMarginFunction)
return margin_function_fitted
def extract_function_fitted(self) -> LinearMarginFunction: def extract_function_fitted(self) -> LinearMarginFunction:
return self.extract_function_fitted_from_the_model_shape(self.margin_model) return self.extract_function_fitted_from_the_model_shape(self.margin_model)
...@@ -16,7 +16,7 @@ class AbstractMaxStableEstimator(AbstractEstimator): ...@@ -16,7 +16,7 @@ class AbstractMaxStableEstimator(AbstractEstimator):
class MaxStableEstimator(AbstractMaxStableEstimator): class MaxStableEstimator(AbstractMaxStableEstimator):
def _fit(self): def fit(self):
assert self.dataset.maxima_frech(split=self.train_split) is not None assert self.dataset.maxima_frech(split=self.train_split) is not None
self._result_from_fit = self.max_stable_model.fitmaxstab( self._result_from_fit = self.max_stable_model.fitmaxstab(
data_frech=self.dataset.maxima_frech_for_spatial_extremes_package(split=self.train_split), data_frech=self.dataset.maxima_frech_for_spatial_extremes_package(split=self.train_split),
...@@ -25,7 +25,7 @@ class MaxStableEstimator(AbstractMaxStableEstimator): ...@@ -25,7 +25,7 @@ class MaxStableEstimator(AbstractMaxStableEstimator):
def scalars(self, true_max_stable_params: dict): def scalars(self, true_max_stable_params: dict):
error = self._error(true_max_stable_params) error = self._error(true_max_stable_params)
return {**error, **self.additional_information} return {**error}
def _error(self, true_max_stable_params: dict): def _error(self, true_max_stable_params: dict):
absolute_errors = {param_name: np.abs(param_true_value - self.max_stable_params_fitted[param_name]) absolute_errors = {param_name: np.abs(param_true_value - self.max_stable_params_fitted[param_name])
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment