From e23e220ac6b6056c7e420d9cd1723ae97cdf6d48 Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Tue, 27 Nov 2018 15:02:11 +0100
Subject: [PATCH] [MARGIN FUNCTION] add linear coef object - refactor
 accordingly

---
 extreme_estimator/estimator/full_estimator.py | 23 ++++-
 .../abstract_margin_function.py               |  3 +-
 .../independent_margin_function.py            | 86 +------------------
 .../margin_function/linear_margin_function.py | 74 ++++++++++++++++
 .../margin_function/param_function.py         | 49 -----------
 .../margin_function/plot_margin_functions.py  |  2 -
 .../margin_model/param_function/__init__.py   |  0
 .../param_function/linear_coef.py             | 54 ++++++++++++
 .../param_function/param_function.py          | 52 +++++++++++
 .../margin_model/smooth_margin_model.py       | 60 ++++++++-----
 .../max_stable_model/max_stable_fit.R         |  6 +-
 11 files changed, 246 insertions(+), 163 deletions(-)
 create mode 100644 extreme_estimator/extreme_models/margin_model/margin_function/linear_margin_function.py
 delete mode 100644 extreme_estimator/extreme_models/margin_model/margin_function/param_function.py
 delete mode 100644 extreme_estimator/extreme_models/margin_model/margin_function/plot_margin_functions.py
 create mode 100644 extreme_estimator/extreme_models/margin_model/param_function/__init__.py
 create mode 100644 extreme_estimator/extreme_models/margin_model/param_function/linear_coef.py
 create mode 100644 extreme_estimator/extreme_models/margin_model/param_function/param_function.py

diff --git a/extreme_estimator/estimator/full_estimator.py b/extreme_estimator/estimator/full_estimator.py
index 33b3e5da..1fbad398 100644
--- a/extreme_estimator/estimator/full_estimator.py
+++ b/extreme_estimator/estimator/full_estimator.py
@@ -1,4 +1,6 @@
 from extreme_estimator.extreme_models.margin_model.abstract_margin_model import AbstractMarginModel
+from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \
+    AbstractMarginFunction
 from extreme_estimator.extreme_models.margin_model.smooth_margin_model import LinearMarginModel
 from extreme_estimator.extreme_models.max_stable_model.abstract_max_stable_model import AbstractMaxStableModel
 from extreme_estimator.estimator.abstract_estimator import AbstractEstimator
@@ -8,7 +10,21 @@ from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
 
 
 class AbstractFullEstimator(AbstractEstimator):
-    pass
+
+    def __init__(self, dataset: AbstractDataset):
+        super().__init__(dataset)
+        self._margin_function_fitted = None
+        self._max_stable_model_fitted = None
+
+    @property
+    def margin_function_fitted(self) -> AbstractMarginFunction:
+        assert self._margin_function_fitted is not None, 'Error: estimator has not been fitted'
+        return self._margin_function_fitted
+
+    # @property
+    # def max_stable_fitted(self) -> AbstractMarginFunction:
+    #     assert self._margin_function_fitted is not None, 'Error: estimator has not been fitted'
+    #     return self._margin_function_fitted
 
 
 class SmoothMarginalsThenUnitaryMsp(AbstractFullEstimator):
@@ -52,10 +68,11 @@ class FullEstimatorInASingleStepWithSmoothMargin(AbstractFullEstimator):
             maxima_frech=self.dataset.maxima_frech,
             df_coordinates=self.dataset.df_coordinates,
             fit_marge=True,
-            fit_marge_form_dict=self.smooth_margin_function_to_fit.fit_marge_form_dict,
-            margin_start_dict=self.smooth_margin_function_to_fit.margin_start_dict
+            fit_marge_form_dict=self.smooth_margin_function_to_fit.form_dict,
+            margin_start_dict=self.smooth_margin_function_to_fit.coef_dict
         )
         # Initialize
+        # self._margin_function_fitted =
 
 
 class PointwiseAndThenUnitaryMsp(AbstractFullEstimator):
diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
index 32bf0301..090f8f33 100644
--- a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
+++ b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
@@ -9,9 +9,8 @@ from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoo
 class AbstractMarginFunction(object):
     """ Class of function mapping points from a space S (could be 1D, 2D,...) to R^3 (the 3 parameters of the GEV)"""
 
-    def __init__(self, coordinates: AbstractCoordinates, default_params: GevParams):
+    def __init__(self, coordinates: AbstractCoordinates):
         self.coordinates = coordinates
-        self.default_params = default_params.to_dict()
 
     def get_gev_params(self, coordinate: np.ndarray) -> GevParams:
         """Main method that maps each coordinate to its GEV parameters"""
diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/independent_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/independent_margin_function.py
index 25de6fe2..0f87191e 100644
--- a/extreme_estimator/extreme_models/margin_model/margin_function/independent_margin_function.py
+++ b/extreme_estimator/extreme_models/margin_model/margin_function/independent_margin_function.py
@@ -1,9 +1,8 @@
-from typing import Dict, List, Tuple
+from typing import Dict
 
 import numpy as np
 
-from extreme_estimator.extreme_models.margin_model.margin_function.param_function import ConstantParamFunction, \
-    LinearOneAxisParamFunction, ParamFunction, LinearParamFunction
+from extreme_estimator.extreme_models.margin_model.param_function.param_function import ParamFunction
 from extreme_estimator.gev_params import GevParams
 from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \
     AbstractMarginFunction
@@ -13,9 +12,9 @@ from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoo
 class IndependentMarginFunction(AbstractMarginFunction):
     """Margin Function where each parameter of the GEV are modeled independently"""
 
-    def __init__(self, coordinates: AbstractCoordinates, default_params: GevParams):
+    def __init__(self, coordinates: AbstractCoordinates):
         """Attribute 'gev_param_name_to_param_function' maps each GEV parameter to its corresponding function"""
-        super().__init__(coordinates, default_params)
+        super().__init__(coordinates)
         self.gev_param_name_to_param_function = None  # type: Dict[str, ParamFunction]
 
     def get_gev_params(self, coordinate: np.ndarray) -> GevParams:
@@ -29,80 +28,3 @@ class IndependentMarginFunction(AbstractMarginFunction):
         return GevParams.from_dict(gev_params)
 
 
-class LinearMarginFunction(IndependentMarginFunction):
-    """ Margin Function, where each parameter can augment linearly as follows:
-        On the extremal point along all the dimension, the GEV parameters will equal the default_param value
-        Then, it will augment linearly along a single linear axis"""
-
-    def __init__(self, coordinates: AbstractCoordinates,
-                 default_params: GevParams,
-                 gev_param_name_to_linear_axes: Dict[str, List[int]],
-                 gev_param_name_and_axis_to_start_fit: Dict[Tuple[str, int], float] = None):
-        """
-        -Attribute 'gev_param_name_to_linear_axis'        maps each GEV parameter to its corresponding function
-        -Attribute 'gev_param_name_and_axis_to_start_fit' maps each tuple (GEV parameter, axis) to its start value for
-            fitting (by default equal to 1). Also start value for the intercept is equal to 0 by default."""
-        super().__init__(coordinates, default_params)
-        self.gev_param_name_and_axis_to_start_fit = gev_param_name_and_axis_to_start_fit
-        self.gev_param_name_to_linear_axes = gev_param_name_to_linear_axes
-
-        # Check the axes are well-defined with respect to the coordinates
-        for axes in self.gev_param_name_to_linear_axes.values():
-            for axis in axes:
-                assert axis < coordinates.nb_columns, "axis={}, nb_columns={}".format(axis, coordinates.nb_columns)
-
-        # Build gev_parameter_to_param_function dictionary
-        self.gev_param_name_to_param_function = {}  # type: Dict[str, ParamFunction]
-        # Map each gev_param_name to its corresponding param_function
-        for gev_param_name in GevParams.GEV_PARAM_NAMES:
-            # By default, if gev_param_name linear_axis is not specified, a constantParamFunction is chosen
-            if gev_param_name not in self.gev_param_name_to_linear_axes.keys():
-                param_function = ConstantParamFunction(constant=self.default_params[gev_param_name])
-            # Otherwise, we fit a LinearParamFunction
-            else:
-                param_function = LinearParamFunction(linear_axes=self.gev_param_name_to_linear_axes[gev_param_name],
-                                                     coordinates=self.coordinates.coordinates_values,
-                                                     start=self.default_params[gev_param_name])
-                # Some check on the Linear param function
-                if gev_param_name == GevParams.GEV_SCALE:
-                    assert param_function.end > 0 and param_function.start > 0, \
-                        'Impossible start/end value for Scale parameter'
-
-            # Add the param_function to the dictionary
-            self.gev_param_name_to_param_function[gev_param_name] = param_function
-
-    @property
-    def fit_marge_form_dict(self) -> dict:
-        """
-        Example of formula that could be specified:
-        loc.form = loc ~ coord_x
-        scale.form = scale ~ coord_y
-        shape.form = shape ~ coord_x+coord_y
-        :return:
-        """
-        fit_marge_form_dict = {}
-        axis_to_name = {i: name for i, name in enumerate(AbstractCoordinates.COORDINATE_NAMES)}
-        for gev_param_name in GevParams.GEV_PARAM_NAMES:
-            axes = self.gev_param_name_to_linear_axes.get(gev_param_name, None)
-            formula_str = '1' if axes is None else '+'.join([axis_to_name[axis] for axis in axes])
-            fit_marge_form_dict[gev_param_name + '.form'] = gev_param_name + ' ~ ' + formula_str
-        return fit_marge_form_dict
-
-    @property
-    def margin_start_dict(self) -> dict:
-        # Define default values
-        default_start_fit_coef = 1.0
-        default_start_fit_intercept = 0.0
-        # Build the dictionary containing all the parameters
-        margin_start_dict = {}
-        for gev_param_name in GevParams.GEV_PARAM_NAMES:
-            coef_template_str = gev_param_name + 'Coeff{}'
-            # Constant param must be specified for all the parameters
-            margin_start_dict[coef_template_str.format(1)] = default_start_fit_intercept
-            for j, axis in enumerate(self.gev_param_name_to_linear_axes.get(gev_param_name, []), 2):
-                if self.gev_param_name_and_axis_to_start_fit is None:
-                    coef = default_start_fit_coef
-                else:
-                    coef = self.gev_param_name_and_axis_to_start_fit.get((gev_param_name, axis), default_start_fit_coef)
-                margin_start_dict[coef_template_str.format(j)] = coef
-        return margin_start_dict
diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/linear_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/linear_margin_function.py
new file mode 100644
index 00000000..62377cea
--- /dev/null
+++ b/extreme_estimator/extreme_models/margin_model/margin_function/linear_margin_function.py
@@ -0,0 +1,74 @@
+from typing import Dict, List, Tuple
+
+from extreme_estimator.extreme_models.margin_model.margin_function.independent_margin_function import \
+    IndependentMarginFunction
+from extreme_estimator.extreme_models.margin_model.param_function.linear_coef import LinearCoef
+from extreme_estimator.extreme_models.margin_model.param_function.param_function import ConstantParamFunction, \
+    ParamFunction, LinearParamFunction
+from extreme_estimator.gev_params import GevParams
+from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
+
+
+class LinearMarginFunction(IndependentMarginFunction):
+    """ Margin Function, where each parameter can augment linearly along any dimension.
+
+        dim = 0 correspond to the intercept
+        dim = 1 correspond to the coordinate X
+        dim = 2 correspond to the coordinate Y
+        dim = 3 correspond to the coordinate Z
+
+        gev_param_name_to_linear_dims             maps each parameter of the GEV distribution to its linear dimensions
+
+        gev_param_name_to_linear_coef             maps each parameter of the GEV distribution to its linear coefficients
+
+        gev_param_name_to_start_fit_linear_coef   maps each parameter of the GEV distribution to its starting fitting
+                                                   value for the linear coefficients
+    """
+
+    def __init__(self, coordinates: AbstractCoordinates,
+                 gev_param_name_to_linear_dims: Dict[str, List[int]],
+                 gev_param_name_to_linear_coef: Dict[str, LinearCoef]):
+        super().__init__(coordinates)
+        self.gev_param_name_to_linear_coef = gev_param_name_to_linear_coef
+        self.gev_param_name_to_linear_dims = gev_param_name_to_linear_dims
+        # Build gev_parameter_to_param_function dictionary
+        self.gev_param_name_to_param_function = {}  # type: Dict[str, ParamFunction]
+
+        # Check the linear_dim are well-defined with respect to the coordinates
+        for linear_dims in self.gev_param_name_to_linear_dims.values():
+            for dim in linear_dims:
+                assert 0 < dim <= coordinates.nb_columns, "dim={}, nb_columns={}".format(dim, coordinates.nb_columns)
+
+        # Map each gev_param_name to its corresponding param_function
+        for gev_param_name in GevParams.GEV_PARAM_NAMES:
+            linear_coef = self.gev_param_name_to_linear_coef[gev_param_name]
+            # By default, if linear_dims are not specified, a constantParamFunction is chosen
+            if gev_param_name not in self.gev_param_name_to_linear_dims.keys():
+                param_function = ConstantParamFunction(constant=linear_coef.get_coef(dim=0))
+            # Otherwise, we fit a LinearParamFunction
+            else:
+                param_function = LinearParamFunction(linear_dims=self.gev_param_name_to_linear_dims[gev_param_name],
+                                                     coordinates=self.coordinates.coordinates_values,
+                                                     linear_coef=linear_coef)
+            # Add the param_function to the dictionary
+            self.gev_param_name_to_param_function[gev_param_name] = param_function
+
+    @classmethod
+    def from_margin_coeff_dict(cls, coordinates, margin_coeff_dict):
+        pass
+
+    @property
+    def form_dict(self) -> dict:
+        form_dict = {}
+        for gev_param_name in GevParams.GEV_PARAM_NAMES:
+            linear_dims = self.gev_param_name_to_linear_dims.get(gev_param_name, [])
+            form_dict.update(self.gev_param_name_to_linear_coef[gev_param_name].form_dict(linear_dims=linear_dims))
+        return form_dict
+
+    @property
+    def coef_dict(self) -> dict:
+        coef_dict = {}
+        for gev_param_name in GevParams.GEV_PARAM_NAMES:
+            linear_dims = self.gev_param_name_to_linear_dims.get(gev_param_name, [])
+            coef_dict.update(self.gev_param_name_to_linear_coef[gev_param_name].coef_dict(linear_dims=linear_dims))
+        return coef_dict
diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/param_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/param_function.py
deleted file mode 100644
index 24a81791..00000000
--- a/extreme_estimator/extreme_models/margin_model/margin_function/param_function.py
+++ /dev/null
@@ -1,49 +0,0 @@
-from typing import List
-
-import numpy as np
-
-
-class ParamFunction(object):
-
-    def get_gev_param_value(self, coordinate: np.ndarray) -> float:
-        pass
-
-
-class ConstantParamFunction(ParamFunction):
-
-    def __init__(self, constant):
-        self.constant = constant
-
-    def get_gev_param_value(self, coordinate: np.ndarray) -> float:
-        return self.constant
-
-
-class LinearOneAxisParamFunction(ParamFunction):
-
-    def __init__(self, linear_axis: int, coordinates_axis: np.ndarray, start: float, end: float = 0.01):
-        self.linear_axis = linear_axis
-        self.t_min = coordinates_axis.min()
-        self.t_max = coordinates_axis.max()
-        self.start = start
-        self.end = end
-
-    def get_gev_param_value(self, coordinate: np.ndarray) -> float:
-        t = coordinate[self.linear_axis]
-        t_between_zero_and_one = (t - self.t_min) / (self.t_max - self.t_min)
-        assert 0 <= t_between_zero_and_one <= 1, 'Out of bounds'
-        return self.start + t_between_zero_and_one * (self.end - self.start)
-
-
-class LinearParamFunction(ParamFunction):
-
-    def __init__(self, linear_axes: List[int], coordinates: np.ndarray, start: float, end: float = 0.01):
-        self.linear_one_axis_param_functions = []  # type: List[LinearOneAxisParamFunction]
-        self.start = start
-        self.end = end
-        for linear_axis in linear_axes:
-            param_function = LinearOneAxisParamFunction(linear_axis, coordinates[:, linear_axis], start, end)
-            self.linear_one_axis_param_functions.append(param_function)
-
-    def get_gev_param_value(self, coordinate: np.ndarray) -> float:
-        values = [param_funct.get_gev_param_value(coordinate) for param_funct in self.linear_one_axis_param_functions]
-        return float(np.mean(values))
diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/plot_margin_functions.py b/extreme_estimator/extreme_models/margin_model/margin_function/plot_margin_functions.py
deleted file mode 100644
index 139597f9..00000000
--- a/extreme_estimator/extreme_models/margin_model/margin_function/plot_margin_functions.py
+++ /dev/null
@@ -1,2 +0,0 @@
-
-
diff --git a/extreme_estimator/extreme_models/margin_model/param_function/__init__.py b/extreme_estimator/extreme_models/margin_model/param_function/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/extreme_estimator/extreme_models/margin_model/param_function/linear_coef.py b/extreme_estimator/extreme_models/margin_model/param_function/linear_coef.py
new file mode 100644
index 00000000..8374737d
--- /dev/null
+++ b/extreme_estimator/extreme_models/margin_model/param_function/linear_coef.py
@@ -0,0 +1,54 @@
+from typing import Dict
+
+from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
+
+
+class LinearCoef(object):
+    """
+    Object that maps each dimension to its corresponding coefficient.
+        dim = 0 correspond to the intercept
+        dim = 1 correspond to the coordinate X
+        dim = 2 correspond to the coordinate Y
+        dim = 3 correspond to the coordinate Z
+    """
+
+    def __init__(self, gev_param_name: str, default_value: float = 0.0, dim_to_coef: Dict[int, float] = None):
+        self.gev_param_name = gev_param_name
+        self.dim_to_coef = dim_to_coef
+        self.default_value = default_value
+
+    def get_coef(self, dim: int) -> float:
+        if self.dim_to_coef is None:
+            return self.default_value
+        else:
+            return self.dim_to_coef.get(dim, self.default_value)
+
+    @property
+    def intercept(self):
+        return self.get_coef(dim=0)
+
+    @classmethod
+    def from_dict(cls, coef_dict: Dict[int, float], gev_param_name: str, default_value: float = 0.0):
+        pass
+
+    def coef_dict(self, linear_dims):
+        coef_dict = {}
+        coef_template_str = self.gev_param_name + 'Coeff{}'
+        # Constant param must be specified for all the parameters
+        coef_dict[coef_template_str.format(1)] = self.intercept
+        # Specify only the param that belongs to dim_to_coef
+        for j, dim in enumerate(linear_dims, 2):
+            coef_dict[coef_template_str.format(j)] = self.dim_to_coef[dim]
+        return coef_dict
+
+    def form_dict(self, linear_dims):
+        """
+        Example of formula that could be specified:
+        loc.form = loc ~ coord_x
+        scale.form = scale ~ coord_y
+        shape.form = shape ~ coord_x+coord_y
+        :return:
+        """
+        dim_to_name = {i: name for i, name in enumerate(AbstractCoordinates.COORDINATE_NAMES, 1)}
+        formula_str = '1' if not linear_dims else '+'.join([dim_to_name[dim] for dim in linear_dims])
+        return {self.gev_param_name + '.form': self.gev_param_name + ' ~ ' + formula_str}
\ No newline at end of file
diff --git a/extreme_estimator/extreme_models/margin_model/param_function/param_function.py b/extreme_estimator/extreme_models/margin_model/param_function/param_function.py
new file mode 100644
index 00000000..df9f485a
--- /dev/null
+++ b/extreme_estimator/extreme_models/margin_model/param_function/param_function.py
@@ -0,0 +1,52 @@
+from typing import List
+import numpy as np
+from extreme_estimator.extreme_models.margin_model.param_function.linear_coef import LinearCoef
+
+
+class ParamFunction(object):
+
+    def get_gev_param_value(self, coordinate: np.ndarray) -> float:
+        pass
+
+
+class ConstantParamFunction(ParamFunction):
+
+    def __init__(self, constant):
+        self.constant = constant
+
+    def get_gev_param_value(self, coordinate: np.ndarray) -> float:
+        return self.constant
+
+
+class LinearOneAxisParamFunction(ParamFunction):
+
+    def __init__(self, linear_axis: int, coordinates: np.ndarray, coef: float = 0.01):
+        self.linear_axis = linear_axis
+        self.t_min = coordinates[:, linear_axis].min()
+        self.t_max = coordinates[:, linear_axis].max()
+        self.coef = coef
+
+    def get_gev_param_value(self, coordinate: np.ndarray) -> float:
+        t = coordinate[self.linear_axis]
+        t_between_zero_and_one = t / (self.t_max - self.t_min)
+        assert -1 <= t_between_zero_and_one <= 1, 'Out of bounds'
+        return t_between_zero_and_one * self.coef
+
+
+class LinearParamFunction(ParamFunction):
+
+    def __init__(self, linear_dims: List[int], coordinates: np.ndarray, linear_coef: LinearCoef = None):
+        self.linear_coef = linear_coef
+        # Load each one axis linear function
+        self.linear_one_axis_param_functions = []  # type: List[LinearOneAxisParamFunction]
+        for linear_dim in linear_dims:
+            param_function = LinearOneAxisParamFunction(linear_axis=linear_dim-1, coordinates=coordinates,
+                                                        coef=self.linear_coef.get_coef(dim=linear_dim))
+            self.linear_one_axis_param_functions.append(param_function)
+
+    def get_gev_param_value(self, coordinate: np.ndarray) -> float:
+        # Add the intercept and the value with respect to each axis
+        gev_param_value = self.linear_coef.intercept
+        for linear_one_axis_param_function in self.linear_one_axis_param_functions:
+            gev_param_value += linear_one_axis_param_function.get_gev_param_value(coordinate)
+        return gev_param_value
diff --git a/extreme_estimator/extreme_models/margin_model/smooth_margin_model.py b/extreme_estimator/extreme_models/margin_model/smooth_margin_model.py
index 88951e44..83e5f6d0 100644
--- a/extreme_estimator/extreme_models/margin_model/smooth_margin_model.py
+++ b/extreme_estimator/extreme_models/margin_model/smooth_margin_model.py
@@ -2,23 +2,39 @@ import numpy as np
 
 from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \
     AbstractMarginFunction
-from extreme_estimator.extreme_models.margin_model.margin_function.independent_margin_function import \
-    LinearMarginFunction
 from extreme_estimator.extreme_models.margin_model.abstract_margin_model import AbstractMarginModel
+from extreme_estimator.extreme_models.margin_model.margin_function.linear_margin_function import LinearMarginFunction
+from extreme_estimator.extreme_models.margin_model.param_function.linear_coef import LinearCoef
 from extreme_estimator.gev_params import GevParams
 
 
 class LinearMarginModel(AbstractMarginModel):
 
-    def load_margin_functions(self, gev_param_name_to_linear_axis=None):
+    def load_margin_functions(self, gev_param_name_to_linear_dims=None):
+        # Load sample coef
         self.default_params_sample = GevParams(1.0, 1.0, 1.0).to_dict()
-        self.default_params_start_fit = GevParams(1.0, 1.0, 1.0).to_dict()
+        linear_coef_sample = self.get_standard_linear_coef(gev_param_name_to_intercept=self.params_sample)
         self.margin_function_sample = LinearMarginFunction(coordinates=self.coordinates,
-                                                           default_params=GevParams.from_dict(self.params_sample),
-                                                           gev_param_name_to_linear_axes=gev_param_name_to_linear_axis)
+                                                           gev_param_name_to_linear_coef=linear_coef_sample,
+                                                           gev_param_name_to_linear_dims=gev_param_name_to_linear_dims)
+
+        # Load start fit coef
+        self.default_params_start_fit = GevParams(1.0, 1.0, 1.0).to_dict()
+        linear_coef_start_fit = self.get_standard_linear_coef(gev_param_name_to_intercept=self.params_start_fit)
         self.margin_function_start_fit = LinearMarginFunction(coordinates=self.coordinates,
-                                                              default_params=GevParams.from_dict(self.params_start_fit),
-                                                              gev_param_name_to_linear_axes=gev_param_name_to_linear_axis)
+                                                              gev_param_name_to_linear_coef=linear_coef_start_fit,
+                                                              gev_param_name_to_linear_dims=gev_param_name_to_linear_dims)
+
+    @staticmethod
+    def get_standard_linear_coef(gev_param_name_to_intercept, slope=0.1):
+        gev_param_name_to_linear_coef = {}
+        for gev_param_name in GevParams.GEV_PARAM_NAMES:
+            dim_to_coef = {dim: slope for dim in range(1, 4)}
+            dim_to_coef[0] = gev_param_name_to_intercept[gev_param_name]
+            linear_coef = LinearCoef(gev_param_name=gev_param_name, dim_to_coef=dim_to_coef)
+            gev_param_name_to_linear_coef[gev_param_name] = linear_coef
+        return gev_param_name_to_linear_coef
+
 
     def fitmargin_from_maxima_gev(self, maxima_gev: np.ndarray, coordinates_values: np.ndarray) -> AbstractMarginFunction:
         return self.margin_function_start_fit
@@ -26,34 +42,34 @@ class LinearMarginModel(AbstractMarginModel):
 
 class ConstantMarginModel(LinearMarginModel):
 
-    def load_margin_functions(self, gev_param_name_to_linear_axis=None):
+    def load_margin_functions(self, gev_param_name_to_linear_dims=None):
         super().load_margin_functions({})
 
 
 class LinearShapeAxis0MarginModel(LinearMarginModel):
 
-    def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_linear_axis=None):
-        super().load_margin_functions({GevParams.GEV_SHAPE: [0]})
+    def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_linear_dims=None):
+        super().load_margin_functions({GevParams.GEV_SHAPE: [1]})
 
 
 class LinearShapeAxis0and1MarginModel(LinearMarginModel):
 
-    def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_linear_axis=None):
-        super().load_margin_functions({GevParams.GEV_SHAPE: [0, 1]})
+    def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_linear_dims=None):
+        super().load_margin_functions({GevParams.GEV_SHAPE: [1, 2]})
 
 
 class LinearAllParametersAxis0MarginModel(LinearMarginModel):
 
-    def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_linear_axis=None):
-        super().load_margin_functions({GevParams.GEV_SHAPE: [0],
-                                       GevParams.GEV_LOC: [0],
-                                       GevParams.GEV_SCALE: [0]})
+    def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_linear_dims=None):
+        super().load_margin_functions({GevParams.GEV_SHAPE: [1],
+                                       GevParams.GEV_LOC: [1],
+                                       GevParams.GEV_SCALE: [1]})
 
 
 class LinearAllParametersAllAxisMarginModel(LinearMarginModel):
 
-    def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_linear_axis=None):
-        all_axis = list(range(self.coordinates.nb_columns))
-        super().load_margin_functions({GevParams.GEV_SHAPE: all_axis.copy(),
-                                       GevParams.GEV_LOC: all_axis.copy(),
-                                       GevParams.GEV_SCALE: all_axis.copy()})
+    def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_linear_dims=None):
+        all_dims = list(range(1, self.coordinates.nb_columns + 1))
+        super().load_margin_functions({GevParams.GEV_SHAPE: all_dims.copy(),
+                                       GevParams.GEV_LOC: all_dims.copy(),
+                                       GevParams.GEV_SCALE: all_dims.copy()})
diff --git a/extreme_estimator/extreme_models/max_stable_model/max_stable_fit.R b/extreme_estimator/extreme_models/max_stable_model/max_stable_fit.R
index b7cf2f13..0de2193b 100644
--- a/extreme_estimator/extreme_models/max_stable_model/max_stable_fit.R
+++ b/extreme_estimator/extreme_models/max_stable_model/max_stable_fit.R
@@ -20,11 +20,11 @@ rmaxstab2D <- function (n.obs){
     print(class(coord))
     print(colnames(coord))
 
-    loc.form = loc ~ 1
+    loc.form = loc ~ N
     scale.form = scale ~ 1
     shape.form = shape ~ 1
 
-    namedlist = list(cov11 = 1.0, cov12 = 1.2, cov22 = 2.2, locCoeff1=1.0, scaleCoeff1=1.0, shapeCoeff1=1.0)
+    namedlist = list(cov11 = 1.0, cov12 = 1.2, cov22 = 2.2, locCoeff1=1.0, locCoeff2=1.0, scaleCoeff1=1.0, shapeCoeff1=1.0)
     res = fitmaxstab(data=data, coord=coord, cov.mod="gauss", start=namedlist, fit.marge=TRUE, loc.form=loc.form, scale.form=scale.form,shape.form=shape.form)
     print(res['fitted.values'])
 }
@@ -63,7 +63,7 @@ if (call_main) {
     set.seed(42)
     n.obs = 500
     rmaxstab2D(n.obs)
-    rmaxstab1D(n.obs)
+    # rmaxstab1D(n.obs)
 
     # namedlist = list(cov11 = 1.0, cov12 = 1.2, cov22 = 2.2)
     # res = fitmaxstab(data=data, coord=coord, cov.mod="gauss", start=namedlist)
-- 
GitLab