From 3914f93e8608c603e1857320a48e3cab2d51f5a9 Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Sun, 19 Apr 2020 19:13:27 +0200
Subject: [PATCH] [refactor] modify margin function

---
 .../parametric_margin_function.py             |  4 +
 extreme_fit/model/abstract_model.py           | 14 ++--
 .../margin_model/abstract_margin_model.py     | 20 +++--
 .../abstract_temporal_linear_margin_model.py  |  3 +-
 .../linear_margin_model.py                    | 78 +++++++++----------
 .../temporal_linear_margin_exp_models.py      |  4 +-
 .../temporal_linear_margin_models.py          | 26 +++----
 .../margin_model/parametric_margin_model.py   | 11 ++-
 .../model/margin_model/spline_margin_model.py | 37 ++++-----
 9 files changed, 101 insertions(+), 96 deletions(-)

diff --git a/extreme_fit/function/margin_function/parametric_margin_function.py b/extreme_fit/function/margin_function/parametric_margin_function.py
index e47e4054..ccf2a83d 100644
--- a/extreme_fit/function/margin_function/parametric_margin_function.py
+++ b/extreme_fit/function/margin_function/parametric_margin_function.py
@@ -91,3 +91,7 @@ class ParametricMarginFunction(IndependentMarginFunction):
     @property
     def form_dict(self) -> Dict[str, str]:
         raise NotImplementedError
+
+    @property
+    def coef_dict(self) -> Dict[str, str]:
+        raise NotImplementedError
diff --git a/extreme_fit/model/abstract_model.py b/extreme_fit/model/abstract_model.py
index cd6e7355..640fe738 100644
--- a/extreme_fit/model/abstract_model.py
+++ b/extreme_fit/model/abstract_model.py
@@ -1,7 +1,7 @@
 class AbstractModel(object):
 
     def __init__(self, params_user=None):
-        self.user_params_sample = params_user
+        self.params_user = params_user
 
     @property
     def default_params(self):
@@ -9,14 +9,14 @@ class AbstractModel(object):
 
     @property
     def params_sample(self) -> dict:
-        return self.merge_params(default_params=self.default_params, input_params=self.user_params_sample)
+        return self.merge_params(default_params=self.default_params, params_user=self.params_user)
 
     @staticmethod
-    def merge_params(default_params, input_params):
+    def merge_params(default_params, params_user):
         assert default_params is not None, 'some default_params need to be specified'
         merged_params = default_params.copy()
-        if input_params is not None:
-            assert isinstance(default_params, dict) and isinstance(input_params, dict)
-            assert set(input_params.keys()).issubset(set(default_params.keys()))
-            merged_params.update(input_params)
+        if params_user is not None:
+            assert isinstance(default_params, dict) and isinstance(params_user, dict)
+            assert set(params_user.keys()).issubset(set(default_params.keys()))
+            merged_params.update(params_user)
         return merged_params
diff --git a/extreme_fit/model/margin_model/abstract_margin_model.py b/extreme_fit/model/margin_model/abstract_margin_model.py
index d87146e6..33b22ddc 100644
--- a/extreme_fit/model/margin_model/abstract_margin_model.py
+++ b/extreme_fit/model/margin_model/abstract_margin_model.py
@@ -2,6 +2,7 @@ from abc import ABC
 
 import numpy as np
 import pandas as pd
+from cached_property import cached_property
 
 from extreme_fit.model.abstract_model import AbstractModel
 from extreme_fit.function.margin_function.abstract_margin_function \
@@ -18,23 +19,20 @@ class AbstractMarginModel(AbstractModel, ABC):
         -margin_function
     """
 
-    def __init__(self, coordinates: AbstractCoordinates, params_user=None,
-                 params_class=GevParams):
+    def __init__(self, coordinates: AbstractCoordinates, params_user=None, params_class=GevParams):
         super().__init__(params_user)
         assert isinstance(coordinates, AbstractCoordinates), type(coordinates)
         self.coordinates = coordinates
-        self.margin_function = None  # type: AbstractMarginFunction
-        self.margin_function = None  # type: AbstractMarginFunction
         self.params_class = params_class
-        self.load_margin_functions()
 
-    def load_margin_functions(self):
-        raise NotImplementedError
+    @cached_property
+    def margin_function(self) -> AbstractMarginFunction:
+        margin_function = self.load_margin_function()
+        assert margin_function is not None
+        return margin_function
 
-    def default_load_margin_functions(self, margin_function_class):
-        # todo: check it i could remove these attributes
-        self.margin_function = margin_function_class(coordinates=self.coordinates,
-                                                     default_params=self.params_class.from_dict(self.params_sample))
+    def load_margin_function(self):
+        raise NotImplementedError
 
     # Conversion class methods
 
diff --git a/extreme_fit/model/margin_model/linear_margin_model/abstract_temporal_linear_margin_model.py b/extreme_fit/model/margin_model/linear_margin_model/abstract_temporal_linear_margin_model.py
index 18c2afd5..c0901173 100644
--- a/extreme_fit/model/margin_model/linear_margin_model/abstract_temporal_linear_margin_model.py
+++ b/extreme_fit/model/margin_model/linear_margin_model/abstract_temporal_linear_margin_model.py
@@ -26,8 +26,7 @@ class AbstractTemporalLinearMarginModel(LinearMarginModel):
                  params_initial_fit_bayesian=None,
                  type_for_MLE="GEV",
                  params_class=GevParams):
-        super().__init__(coordinates, params_user, starting_point,
-                         params_class)
+        super().__init__(coordinates, params_user, starting_point, params_class)
         self.type_for_mle = type_for_MLE
         self.params_initial_fit_bayesian = params_initial_fit_bayesian
         self.nb_iterations_for_bayesian_fit = nb_iterations_for_bayesian_fit
diff --git a/extreme_fit/model/margin_model/linear_margin_model/linear_margin_model.py b/extreme_fit/model/margin_model/linear_margin_model/linear_margin_model.py
index e3c86ebf..07d6e8cf 100644
--- a/extreme_fit/model/margin_model/linear_margin_model/linear_margin_model.py
+++ b/extreme_fit/model/margin_model/linear_margin_model/linear_margin_model.py
@@ -14,16 +14,16 @@ class LinearMarginModel(ParametricMarginModel):
                 params[(param_name, idx)] = coef
         return cls(coordinates, params_user=params, params_class=params_class, **kwargs)
 
-    def load_margin_functions(self, param_name_to_dims=None):
+    def load_margin_function(self, param_name_to_dims=None):
         assert param_name_to_dims is not None, 'LinearMarginModel cannot be used for sampling/fitting \n' \
-                                                   'load_margin_functions needs to be implemented in child class'
+                                               'load_margin_functions needs to be implemented in child class'
         # Load sample coef
         coef_sample = self.param_name_to_linear_coef(param_name_and_dim_to_coef=self.params_sample)
-        self.margin_function = LinearMarginFunction(coordinates=self.coordinates,
-                                                    param_name_to_coef=coef_sample,
-                                                    param_name_to_dims=param_name_to_dims,
-                                                    starting_point=self.starting_point,
-                                                    params_class=self.params_class)
+        return LinearMarginFunction(coordinates=self.coordinates,
+                                    param_name_to_coef=coef_sample,
+                                    param_name_to_dims=param_name_to_dims,
+                                    starting_point=self.starting_point,
+                                    params_class=self.params_class)
 
     @property
     def default_params(self) -> dict:
@@ -49,75 +49,75 @@ class LinearMarginModel(ParametricMarginModel):
 
 class ConstantMarginModel(LinearMarginModel):
 
-    def load_margin_functions(self, param_name_to_dims=None):
-        super().load_margin_functions({})
+    def load_margin_function(self, param_name_to_dims=None):
+        return super().load_margin_function({})
 
 
 class LinearShapeDim0MarginModel(LinearMarginModel):
 
-    def load_margin_functions(self, margin_function_class: type = None, param_name_to_dims=None):
-        super().load_margin_functions({GevParams.SHAPE: [0]})
+    def load_margin_function(self, margin_function_class: type = None, param_name_to_dims=None):
+        return super().load_margin_function({GevParams.SHAPE: [0]})
 
 
 class LinearScaleDim0MarginModel(LinearMarginModel):
 
-    def load_margin_functions(self, margin_function_class: type = None, param_name_to_dims=None):
-        super().load_margin_functions({GevParams.SCALE: [0]})
+    def load_margin_function(self, margin_function_class: type = None, param_name_to_dims=None):
+        return super().load_margin_function({GevParams.SCALE: [0]})
 
 
 class LinearShapeDim0and1MarginModel(LinearMarginModel):
 
-    def load_margin_functions(self, margin_function_class: type = None, param_name_to_dims=None):
-        super().load_margin_functions({GevParams.SHAPE: [0, 1]})
+    def load_margin_function(self, margin_function_class: type = None, param_name_to_dims=None):
+        return super().load_margin_function({GevParams.SHAPE: [0, 1]})
 
 
 class LinearAllParametersDim0MarginModel(LinearMarginModel):
 
-    def load_margin_functions(self, margin_function_class: type = None, param_name_to_dims=None):
-        super().load_margin_functions({GevParams.SHAPE: [0],
-                                       GevParams.LOC: [0],
-                                       GevParams.SCALE: [0]})
+    def load_margin_function(self, margin_function_class: type = None, param_name_to_dims=None):
+        return super().load_margin_function({GevParams.SHAPE: [0],
+                                             GevParams.LOC: [0],
+                                             GevParams.SCALE: [0]})
 
 
 class LinearMarginModelExample(LinearMarginModel):
 
-    def load_margin_functions(self, margin_function_class: type = None, param_name_to_dims=None):
-        super().load_margin_functions({GevParams.SHAPE: [0],
-                                       GevParams.LOC: [1],
-                                       GevParams.SCALE: [0]})
+    def load_margin_function(self, margin_function_class: type = None, param_name_to_dims=None):
+        return super().load_margin_function({GevParams.SHAPE: [0],
+                                             GevParams.LOC: [1],
+                                             GevParams.SCALE: [0]})
 
 
 class LinearLocationAllDimsMarginModel(LinearMarginModel):
 
-    def load_margin_functions(self, margin_function_class: type = None, param_name_to_dims=None):
-        super().load_margin_functions({GevParams.LOC: self.coordinates.coordinates_dims})
+    def load_margin_function(self, margin_function_class: type = None, param_name_to_dims=None):
+        return super().load_margin_function({GevParams.LOC: self.coordinates.coordinates_dims})
 
 
 class LinearShapeAllDimsMarginModel(LinearMarginModel):
 
-    def load_margin_functions(self, margin_function_class: type = None, param_name_to_dims=None):
-        super().load_margin_functions({GevParams.SHAPE: self.coordinates.coordinates_dims})
+    def load_margin_function(self, margin_function_class: type = None, param_name_to_dims=None):
+        return super().load_margin_function({GevParams.SHAPE: self.coordinates.coordinates_dims})
 
 
 class LinearAllParametersAllDimsMarginModel(LinearMarginModel):
 
-    def load_margin_functions(self, margin_function_class: type = None, param_name_to_dims=None):
-        super().load_margin_functions({GevParams.SHAPE: self.coordinates.coordinates_dims,
-                                       GevParams.LOC: self.coordinates.coordinates_dims,
-                                       GevParams.SCALE: self.coordinates.coordinates_dims})
+    def load_margin_function(self, margin_function_class: type = None, param_name_to_dims=None):
+        return super().load_margin_function({GevParams.SHAPE: self.coordinates.coordinates_dims,
+                                             GevParams.LOC: self.coordinates.coordinates_dims,
+                                             GevParams.SCALE: self.coordinates.coordinates_dims})
 
 
 class LinearStationaryMarginModel(LinearMarginModel):
 
-    def load_margin_functions(self, margin_function_class: type = None, param_name_to_dims=None):
-        super().load_margin_functions({GevParams.SHAPE: self.coordinates.spatial_coordinates_dims,
-                                       GevParams.LOC: self.coordinates.spatial_coordinates_dims,
-                                       GevParams.SCALE: self.coordinates.spatial_coordinates_dims})
+    def load_margin_function(self, margin_function_class: type = None, param_name_to_dims=None):
+        return super().load_margin_function({GevParams.SHAPE: self.coordinates.spatial_coordinates_dims,
+                                             GevParams.LOC: self.coordinates.spatial_coordinates_dims,
+                                             GevParams.SCALE: self.coordinates.spatial_coordinates_dims})
 
 
 class LinearNonStationaryLocationMarginModel(LinearMarginModel):
 
-    def load_margin_functions(self, margin_function_class: type = None, param_name_to_dims=None):
-        super().load_margin_functions({GevParams.SHAPE: self.coordinates.spatial_coordinates_dims,
-                                       GevParams.LOC: self.coordinates.coordinates_dims,
-                                       GevParams.SCALE: self.coordinates.spatial_coordinates_dims})
+    def load_margin_function(self, margin_function_class: type = None, param_name_to_dims=None):
+        return super().load_margin_function({GevParams.SHAPE: self.coordinates.spatial_coordinates_dims,
+                                             GevParams.LOC: self.coordinates.coordinates_dims,
+                                             GevParams.SCALE: self.coordinates.spatial_coordinates_dims})
diff --git a/extreme_fit/model/margin_model/linear_margin_model/temporal_linear_margin_exp_models.py b/extreme_fit/model/margin_model/linear_margin_model/temporal_linear_margin_exp_models.py
index 0204a860..2bf516b4 100644
--- a/extreme_fit/model/margin_model/linear_margin_model/temporal_linear_margin_exp_models.py
+++ b/extreme_fit/model/margin_model/linear_margin_model/temporal_linear_margin_exp_models.py
@@ -11,5 +11,5 @@ class NonStationaryRateTemporalModel(AbstractTemporalLinearMarginModel, Abstract
         super().__init__(*arg, **kwargs)
         self.drop_duplicates = False
 
-    def load_margin_functions(self, param_name_to_dims=None):
-        super().load_margin_functions({ExpParams.RATE: [self.coordinates.idx_temporal_coordinates]})
+    def load_margin_function(self, param_name_to_dims=None):
+        return super().load_margin_function({ExpParams.RATE: [self.coordinates.idx_temporal_coordinates]})
diff --git a/extreme_fit/model/margin_model/linear_margin_model/temporal_linear_margin_models.py b/extreme_fit/model/margin_model/linear_margin_model/temporal_linear_margin_models.py
index b84dcd7f..989db471 100644
--- a/extreme_fit/model/margin_model/linear_margin_model/temporal_linear_margin_models.py
+++ b/extreme_fit/model/margin_model/linear_margin_model/temporal_linear_margin_models.py
@@ -9,14 +9,14 @@ from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoo
 
 class StationaryTemporalModel(AbstractTemporalLinearMarginModel):
 
-    def load_margin_functions(self, param_name_to_dims=None):
-        super().load_margin_functions({})
+    def load_margin_function(self, param_name_to_dims=None):
+        return super().load_margin_function({})
 
 
 class NonStationaryLocationTemporalModel(AbstractTemporalLinearMarginModel):
 
-    def load_margin_functions(self, param_name_to_dims=None):
-        super().load_margin_functions({GevParams.LOC: [self.coordinates.idx_temporal_coordinates]})
+    def load_margin_function(self, param_name_to_dims=None):
+        return super().load_margin_function({GevParams.LOC: [self.coordinates.idx_temporal_coordinates]})
 
     @property
     def mul(self):
@@ -25,8 +25,8 @@ class NonStationaryLocationTemporalModel(AbstractTemporalLinearMarginModel):
 
 class NonStationaryScaleTemporalModel(AbstractTemporalLinearMarginModel):
 
-    def load_margin_functions(self, param_name_to_dims=None):
-        super().load_margin_functions({GevParams.SCALE: [self.coordinates.idx_temporal_coordinates]})
+    def load_margin_function(self, param_name_to_dims=None):
+        return super().load_margin_function({GevParams.SCALE: [self.coordinates.idx_temporal_coordinates]})
 
     @property
     def sigl(self):
@@ -35,8 +35,8 @@ class NonStationaryScaleTemporalModel(AbstractTemporalLinearMarginModel):
 
 class NonStationaryLogScaleTemporalModel(NonStationaryScaleTemporalModel):
 
-    def load_margin_functions(self, param_name_to_dims=None):
-        super().load_margin_functions({GevParams.SCALE: [self.coordinates.idx_temporal_coordinates]})
+    def load_margin_function(self, param_name_to_dims=None):
+        return super().load_margin_function({GevParams.SCALE: [self.coordinates.idx_temporal_coordinates]})
 
     @property
     def siglink(self):
@@ -45,8 +45,8 @@ class NonStationaryLogScaleTemporalModel(NonStationaryScaleTemporalModel):
 
 class NonStationaryShapeTemporalModel(AbstractTemporalLinearMarginModel):
 
-    def load_margin_functions(self, param_name_to_dims=None):
-        super().load_margin_functions({GevParams.SHAPE: [self.coordinates.idx_temporal_coordinates]})
+    def load_margin_function(self, param_name_to_dims=None):
+        return super().load_margin_function({GevParams.SHAPE: [self.coordinates.idx_temporal_coordinates]})
 
     @property
     def shl(self):
@@ -55,9 +55,9 @@ class NonStationaryShapeTemporalModel(AbstractTemporalLinearMarginModel):
 
 class NonStationaryLocationAndScaleTemporalModel(AbstractTemporalLinearMarginModel):
 
-    def load_margin_functions(self, param_name_to_dims=None):
-        super().load_margin_functions({GevParams.LOC: [self.coordinates.idx_temporal_coordinates],
-                                       GevParams.SCALE: [self.coordinates.idx_temporal_coordinates]})
+    def load_margin_function(self, param_name_to_dims=None):
+        return super().load_margin_function({GevParams.LOC: [self.coordinates.idx_temporal_coordinates],
+                                      GevParams.SCALE: [self.coordinates.idx_temporal_coordinates]})
 
     @property
     def mul(self):
diff --git a/extreme_fit/model/margin_model/parametric_margin_model.py b/extreme_fit/model/margin_model/parametric_margin_model.py
index 6aba0d05..04aa53e3 100644
--- a/extreme_fit/model/margin_model/parametric_margin_model.py
+++ b/extreme_fit/model/margin_model/parametric_margin_model.py
@@ -2,8 +2,10 @@ from abc import ABC
 
 import numpy as np
 import pandas as pd
+from cached_property import cached_property
 
 from extreme_fit.distribution.gev.gev_params import GevParams
+from extreme_fit.function.margin_function.abstract_margin_function import AbstractMarginFunction
 from extreme_fit.function.margin_function.parametric_margin_function import \
     ParametricMarginFunction
 from extreme_fit.model.margin_model.abstract_margin_model import AbstractMarginModel
@@ -22,11 +24,16 @@ class ParametricMarginModel(AbstractMarginModel, ABC):
         """
         :param starting_point: starting coordinate for the temporal trend
         """
+        super().__init__(coordinates, params_user, params_class)
         self.fit_method = fit_method
         self.starting_point = starting_point
-        self.margin_function = None  # type: ParametricMarginFunction
         self.drop_duplicates = True
-        super().__init__(coordinates, params_user, params_class)
+
+    @cached_property
+    def margin_function(self) -> ParametricMarginFunction:
+        margin_function = super().margin_function
+        assert isinstance(margin_function, ParametricMarginFunction)
+        return margin_function
 
     def fitmargin_from_maxima_gev(self, data: np.ndarray, df_coordinates_spat: pd.DataFrame,
                                   df_coordinates_temp: pd.DataFrame) -> ResultFromSpatialExtreme:
diff --git a/extreme_fit/model/margin_model/spline_margin_model.py b/extreme_fit/model/margin_model/spline_margin_model.py
index 83addd3c..b5d97d22 100644
--- a/extreme_fit/model/margin_model/spline_margin_model.py
+++ b/extreme_fit/model/margin_model/spline_margin_model.py
@@ -15,10 +15,10 @@ class SplineMarginModel(ParametricMarginModel):
                  params_user=None):
         super().__init__(coordinates, params_user)
 
-    def load_margin_functions(self, param_name_to_dims: Dict[str, List[int]] = None,
-                              param_name_to_coef: Dict[str, AbstractCoef] = None,
-                              param_name_to_nb_knots: Dict[str, int] = None,
-                              degree=3):
+    def load_margin_function(self, param_name_to_dims: Dict[str, List[int]] = None,
+                             param_name_to_coef: Dict[str, AbstractCoef] = None,
+                             param_name_to_nb_knots: Dict[str, int] = None,
+                             degree=3):
         # Default parameters
         # todo: for the default parameters: take inspiration from the linear_margin_model
         # also implement the class method thing
@@ -36,27 +36,24 @@ class SplineMarginModel(ParametricMarginModel):
         if param_name_to_nb_knots is None:
             param_name_to_nb_knots = {param_name: 2 for param_name in GevParams.PARAM_NAMES}
 
-        # Load sample coef
-        self.margin_function = SplineMarginFunction(coordinates=self.coordinates,
-                                                    param_name_to_dims=param_name_to_dims,
-                                                    param_name_to_coef=param_name_to_coef,
-                                                    param_name_to_nb_knots=param_name_to_nb_knots,
-                                                    degree=degree)
+        return SplineMarginFunction(coordinates=self.coordinates,
+                                    param_name_to_dims=param_name_to_dims,
+                                    param_name_to_coef=param_name_to_coef,
+                                    param_name_to_nb_knots=param_name_to_nb_knots,
+                                    degree=degree)
 
 
 class ConstantSplineMarginModel(SplineMarginModel):
 
-    def load_margin_functions(self, param_name_to_dims: Dict[str, List[int]] = None,
-                              param_name_to_coef: Dict[str, AbstractCoef] = None,
-                              param_name_to_nb_knots: Dict[str, int] = None, degree=3):
-        super().load_margin_functions({}, param_name_to_coef, param_name_to_nb_knots,
-                                      degree)
+    def load_margin_function(self, param_name_to_dims: Dict[str, List[int]] = None,
+                             param_name_to_coef: Dict[str, AbstractCoef] = None,
+                             param_name_to_nb_knots: Dict[str, int] = None, degree=3):
+        return super().load_margin_function({}, param_name_to_coef, param_name_to_nb_knots, degree)
 
 
 class Degree1SplineMarginModel(SplineMarginModel):
 
-    def load_margin_functions(self, param_name_to_dims: Dict[str, List[int]] = None,
-                              param_name_to_coef: Dict[str, AbstractCoef] = None,
-                              param_name_to_nb_knots: Dict[str, int] = None, degree=3):
-        super().load_margin_functions(param_name_to_dims, param_name_to_coef, param_name_to_nb_knots,
-                                      degree=1)
+    def load_margin_function(self, param_name_to_dims: Dict[str, List[int]] = None,
+                             param_name_to_coef: Dict[str, AbstractCoef] = None,
+                             param_name_to_nb_knots: Dict[str, int] = None, degree=3):
+        return super().load_margin_function(param_name_to_dims, param_name_to_coef, param_name_to_nb_knots, degree=1)
-- 
GitLab