From 7a0f83b577e760fde82daff211b168eb5864a838 Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Wed, 16 Oct 2019 14:25:53 +0200
Subject: [PATCH] [EXTREME ESTIMATOR][ESTIMATOR] refactor estimator

---
 .../unidimensional_robustness.py              |  2 +-
 .../estimator/abstract_estimator.py           | 32 ++++---------------
 .../full_estimator/abstract_full_estimator.py |  8 +++--
 .../abstract_margin_estimator.py              | 12 +++----
 .../abstract_max_stable_estimator.py          |  4 +--
 5 files changed, 20 insertions(+), 38 deletions(-)

diff --git a/experiment/robustness_plot/estimation_robustness/unidimensional_robustness/unidimensional_robustness.py b/experiment/robustness_plot/estimation_robustness/unidimensional_robustness/unidimensional_robustness.py
index 068faa64..937a65c9 100644
--- a/experiment/robustness_plot/estimation_robustness/unidimensional_robustness/unidimensional_robustness.py
+++ b/experiment/robustness_plot/estimation_robustness/unidimensional_robustness/unidimensional_robustness.py
@@ -45,7 +45,7 @@ def multiple_unidimensional_robustness():
 
     # Put only the parameter that will vary
     spatial_robustness.robustness_grid_plot(**{
-        SinglePlot.OrdinateItem.name: [AbstractEstimator.MAE_ERROR, AbstractEstimator.DURATION],
+        # SinglePlot.OrdinateItem.name: [AbstractEstimator.MAE_ERROR, AbstractEstimator.DURATION],
         MaxStableProcessPlot.NbStationItem.name: nb_stations,
         MaxStableProcessPlot.NbObservationItem.name: nb_observation,
         MaxStableProcessPlot.MaxStableModelItem.name: msp_models,
diff --git a/extreme_estimator/estimator/abstract_estimator.py b/extreme_estimator/estimator/abstract_estimator.py
index ae736464..2becbbfb 100644
--- a/extreme_estimator/estimator/abstract_estimator.py
+++ b/extreme_estimator/estimator/abstract_estimator.py
@@ -1,5 +1,7 @@
 import time
 
+from cached_property import cached_property
+
 from extreme_estimator.extreme_models.margin_model.linear_margin_model.linear_margin_model import LinearMarginModel
 from extreme_estimator.extreme_models.result_from_fit import ResultFromFit
 from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \
@@ -9,42 +11,26 @@ from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
 
 
 class AbstractEstimator(object):
-    DURATION = 'Average duration'
-    MAE_ERROR = 'Mean Average Error'
-
-    # For each estimator, we shall define:
-    # - The loss
-    # - The optimization method for each part of the process
 
     def __init__(self, dataset: AbstractDataset):
         self.dataset = dataset  # type: AbstractDataset
-        self.additional_information = dict()
         self._result_from_fit = None  # type: ResultFromFit
-        self._margin_function_fitted = None
-        self._max_stable_model_fitted = None
 
     @classmethod
     def from_dataset(cls, dataset: AbstractDataset):
-        # raise NotImplementedError('from_dataset class constructor has not been implemented for this class')
-        pass
+        raise NotImplementedError
 
     def fit(self):
-        ts = time.time()
-        self._fit()
-        te = time.time()
-        self.additional_information[self.DURATION] = int((te - ts) * 1000)
+        raise NotImplementedError
 
     @property
     def result_from_fit(self) -> ResultFromFit:
         assert self._result_from_fit is not None, 'Fit has not be done'
         return self._result_from_fit
 
-    @property
+    @cached_property
     def margin_function_fitted(self) -> AbstractMarginFunction:
-        if self._margin_function_fitted is None:
-            self._margin_function_fitted = self.extract_function_fitted()
-        assert self._margin_function_fitted is not None, 'No margin function has been fitted'
-        return self._margin_function_fitted
+        return self.extract_function_fitted()
 
     def extract_function_fitted(self) -> AbstractMarginFunction:
         raise NotImplementedError
@@ -55,11 +41,7 @@ class AbstractEstimator(object):
                                                    coef_dict=self.result_from_fit.margin_coef_dict,
                                                    starting_point=margin_model.starting_point)
 
+    #
     @property
     def train_split(self):
         return self.dataset.train_split
-
-    # Methods to override in the child class
-
-    def _fit(self):
-        raise NotImplementedError
diff --git a/extreme_estimator/estimator/full_estimator/abstract_full_estimator.py b/extreme_estimator/estimator/full_estimator/abstract_full_estimator.py
index cdce4f7b..85fcb135 100644
--- a/extreme_estimator/estimator/full_estimator/abstract_full_estimator.py
+++ b/extreme_estimator/estimator/full_estimator/abstract_full_estimator.py
@@ -1,3 +1,5 @@
+from cached_property import cached_property
+
 from extreme_estimator.estimator.abstract_estimator import AbstractEstimator
 from extreme_estimator.estimator.margin_estimator.abstract_margin_estimator import LinearMarginEstimator
 from extreme_estimator.estimator.max_stable_estimator.abstract_max_stable_estimator import MaxStableEstimator
@@ -23,7 +25,7 @@ class SmoothMarginalsThenUnitaryMsp(AbstractFullEstimator):
         self.margin_estimator = LinearMarginEstimator(dataset=dataset, margin_model=margin_model)
         self.max_stable_estimator = MaxStableEstimator(dataset=dataset, max_stable_model=max_stable_model)
 
-    def _fit(self):
+    def fit(self):
         # Estimate the margin parameters
         self.margin_estimator.fit()
         # Compute the maxima_frech
@@ -65,7 +67,7 @@ class FullEstimatorInASingleStepWithSmoothMargin(AbstractFullEstimator):
         return self.dataset.coordinates.df_temporal_coordinates_for_fit(split=self.train_split,
                                                                         starting_point=self.linear_margin_model.starting_point)
 
-    def _fit(self):
+    def fit(self):
         # Estimate both the margin and the max-stable structure
         self._result_from_fit = self.max_stable_model.fitmaxstab(
             data_gev=self.dataset.maxima_gev_for_spatial_extremes_package(self.train_split),
@@ -79,7 +81,7 @@ class FullEstimatorInASingleStepWithSmoothMargin(AbstractFullEstimator):
     def extract_function_fitted(self):
         return self.extract_function_fitted_from_the_model_shape(self.linear_margin_model)
 
-    @property
+    @cached_property
     def margin_function_fitted(self) -> LinearMarginFunction:
         return super().margin_function_fitted
 
diff --git a/extreme_estimator/estimator/margin_estimator/abstract_margin_estimator.py b/extreme_estimator/estimator/margin_estimator/abstract_margin_estimator.py
index 03aa7d4b..afdf32c8 100644
--- a/extreme_estimator/estimator/margin_estimator/abstract_margin_estimator.py
+++ b/extreme_estimator/estimator/margin_estimator/abstract_margin_estimator.py
@@ -1,5 +1,7 @@
 from abc import ABC
 
+from cached_property import cached_property
+
 from extreme_estimator.estimator.abstract_estimator import AbstractEstimator
 from extreme_estimator.extreme_models.margin_model.linear_margin_model.linear_margin_model import LinearMarginModel
 from extreme_estimator.extreme_models.margin_model.margin_function.linear_margin_function import LinearMarginFunction
@@ -11,8 +13,6 @@ class AbstractMarginEstimator(AbstractEstimator, ABC):
     def __init__(self, dataset: AbstractDataset):
         super().__init__(dataset)
         assert self.dataset.maxima_gev() is not None
-        self._margin_function_fitted = None
-
 
 class LinearMarginEstimator(AbstractMarginEstimator):
     """# with different type of marginals: cosntant, linear...."""
@@ -25,7 +25,7 @@ class LinearMarginEstimator(AbstractMarginEstimator):
         assert isinstance(margin_model, LinearMarginModel)
         self.margin_model = margin_model
 
-    def _fit(self):
+    def fit(self):
         maxima_gev_specialized = self.dataset.maxima_gev_for_spatial_extremes_package(self.train_split)
         df_coordinates_spat = self.dataset.coordinates.df_spatial_coordinates(self.train_split)
         df_coordinates_temp = self.dataset.coordinates.df_temporal_coordinates_for_fit(split=self.train_split,
@@ -34,11 +34,9 @@ class LinearMarginEstimator(AbstractMarginEstimator):
                                                                             df_coordinates_spat=df_coordinates_spat,
                                                                             df_coordinates_temp=df_coordinates_temp)
 
-    @property
+    @cached_property
     def margin_function_fitted(self) -> LinearMarginFunction:
-        margin_function_fitted = super().margin_function_fitted
-        assert isinstance(margin_function_fitted, LinearMarginFunction)
-        return margin_function_fitted
+        return super().margin_function_fitted
 
     def extract_function_fitted(self) -> LinearMarginFunction:
         return self.extract_function_fitted_from_the_model_shape(self.margin_model)
diff --git a/extreme_estimator/estimator/max_stable_estimator/abstract_max_stable_estimator.py b/extreme_estimator/estimator/max_stable_estimator/abstract_max_stable_estimator.py
index 868c7c07..8a9bed24 100644
--- a/extreme_estimator/estimator/max_stable_estimator/abstract_max_stable_estimator.py
+++ b/extreme_estimator/estimator/max_stable_estimator/abstract_max_stable_estimator.py
@@ -16,7 +16,7 @@ class AbstractMaxStableEstimator(AbstractEstimator):
 
 class MaxStableEstimator(AbstractMaxStableEstimator):
 
-    def _fit(self):
+    def fit(self):
         assert self.dataset.maxima_frech(split=self.train_split) is not None
         self._result_from_fit = self.max_stable_model.fitmaxstab(
             data_frech=self.dataset.maxima_frech_for_spatial_extremes_package(split=self.train_split),
@@ -25,7 +25,7 @@ class MaxStableEstimator(AbstractMaxStableEstimator):
 
     def scalars(self, true_max_stable_params: dict):
         error = self._error(true_max_stable_params)
-        return {**error, **self.additional_information}
+        return {**error}
 
     def _error(self, true_max_stable_params: dict):
         absolute_errors = {param_name: np.abs(param_true_value - self.max_stable_params_fitted[param_name])
-- 
GitLab