diff --git a/experiment/meteo_france_SCM_study/visualization/study_visualization/study_visualizer.py b/experiment/meteo_france_SCM_study/visualization/study_visualization/study_visualizer.py
index 2abb47e70a174d4fb3ecfe5560c6b169b1844a51..46a37e326865596dce051753a13fcb3c48c1ac2f 100644
--- a/experiment/meteo_france_SCM_study/visualization/study_visualization/study_visualizer.py
+++ b/experiment/meteo_france_SCM_study/visualization/study_visualization/study_visualizer.py
@@ -13,7 +13,7 @@ from experiment.utils import average_smoothing_with_sliding_window
 from extreme_estimator.estimator.full_estimator.abstract_full_estimator import \
     FullEstimatorInASingleStepWithSmoothMargin
 from extreme_estimator.estimator.margin_estimator.abstract_margin_estimator import SmoothMarginEstimator
-from extreme_estimator.extreme_models.margin_model.param_function.param_function import ParamFunction
+from extreme_estimator.extreme_models.margin_model.param_function.param_function import AbstractParamFunction
 from extreme_estimator.extreme_models.margin_model.linear_margin_model import LinearAllParametersAllDimsMarginModel
 from extreme_estimator.extreme_models.max_stable_model.abstract_max_stable_model import CovarianceFunction
 from extreme_estimator.extreme_models.max_stable_model.max_stable_models import BrownResnick
@@ -57,7 +57,7 @@ class StudyVisualizer(object):
         self.coef_zoom_map = 1
 
         # Remove some assert
-        ParamFunction.OUT_OF_BOUNDS_ASSERT = False
+        AbstractParamFunction.OUT_OF_BOUNDS_ASSERT = False
 
     @property
     def observations(self):
diff --git a/extreme_estimator/estimator/abstract_estimator.py b/extreme_estimator/estimator/abstract_estimator.py
index d21f3063aceb481e6006271ca139ee67f752826f..a5da6543fed542878695e2790cbcd220877e9dab 100644
--- a/extreme_estimator/estimator/abstract_estimator.py
+++ b/extreme_estimator/estimator/abstract_estimator.py
@@ -1,5 +1,7 @@
 import time
 
+from extreme_estimator.extreme_models.margin_model.margin_function.parametric_margin_function import \
+    ParametricMarginFunction
 from extreme_estimator.extreme_models.result_from_fit import ResultFromFit
 from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \
     AbstractMarginFunction
@@ -50,10 +52,10 @@ class AbstractEstimator(object):
         assert self._margin_function_fitted is not None, 'No margin function has been fitted'
         return self._margin_function_fitted
 
-    def extract_fitted_models_from_fitted_params(self, margin_function_to_fit, full_params_fitted):
+    def extract_fitted_models_from_fitted_params(self, margin_function_to_fit: ParametricMarginFunction, full_params_fitted):
         coef_dict = {k: v for k, v in full_params_fitted.items() if LinearCoef.COEFF_STR in k}
         self._margin_function_fitted = LinearMarginFunction.from_coef_dict(coordinates=self.dataset.coordinates,
-                                                                           gev_param_name_to_linear_dims=margin_function_to_fit.gev_param_name_to_linear_dims,
+                                                                           gev_param_name_to_dims=margin_function_to_fit.gev_param_name_to_dims,
                                                                            coef_dict=coef_dict)
 
     @property
diff --git a/extreme_estimator/extreme_models/margin_model/abstract_margin_model.py b/extreme_estimator/extreme_models/margin_model/abstract_margin_model.py
index c052dc184085846fef40f06d503bdeb5ff6e4e0d..7e44d0dc3dce6b50589658b369a3120cd6eb22d4 100644
--- a/extreme_estimator/extreme_models/margin_model/abstract_margin_model.py
+++ b/extreme_estimator/extreme_models/margin_model/abstract_margin_model.py
@@ -13,6 +13,11 @@ from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoo
 
 
 class AbstractMarginModel(AbstractModel, ABC):
+    """
+    An AbstractMarginModel has two main AbstractMarginFunction attributes:
+        -margin_function_sample for sampling
+        -margin_function_start_fit for starting to fit
+    """
 
     def __init__(self, coordinates: AbstractCoordinates, use_start_value=False,
                  params_start_fit=None, params_sample=None):
diff --git a/extreme_estimator/extreme_models/margin_model/linear_margin_model.py b/extreme_estimator/extreme_models/margin_model/linear_margin_model.py
index 25eeba991a0211c7ff43640ea9fe31d1b63d2ec9..27cc3c6fa9314744415b91b40712b8dcc20cfbfa 100644
--- a/extreme_estimator/extreme_models/margin_model/linear_margin_model.py
+++ b/extreme_estimator/extreme_models/margin_model/linear_margin_model.py
@@ -1,102 +1,103 @@
-from extreme_estimator.extreme_models.margin_model.abstract_margin_model import AbstractMarginModel
 from extreme_estimator.extreme_models.margin_model.margin_function.linear_margin_function import LinearMarginFunction
 from extreme_estimator.extreme_models.margin_model.param_function.linear_coef import LinearCoef
 from extreme_estimator.extreme_models.margin_model.parametric_margin_model import ParametricMarginModel
 from extreme_estimator.margin_fits.gev.gev_params import GevParams
+from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
 
 
 class LinearMarginModel(ParametricMarginModel):
 
-    def load_margin_functions(self, gev_param_name_to_linear_dims=None):
-        assert gev_param_name_to_linear_dims is not None, 'LinearMarginModel cannot be used for sampling/fitting \n' \
+    @classmethod
+    def from_coef_list(cls, coordinates, gev_param_name_to_coef_list):
+        params = {}
+        for gev_param_name in GevParams.PARAM_NAMES:
+            for idx, coef in enumerate(gev_param_name_to_coef_list[gev_param_name], -1):
+                params[(gev_param_name, idx)] = coef
+        return cls(coordinates, params_sample=params, params_start_fit=params)
+
+    def load_margin_functions(self, gev_param_name_to_dims=None):
+        assert gev_param_name_to_dims is not None, 'LinearMarginModel cannot be used for sampling/fitting \n' \
                                                           'load_margin_functions needs to be implemented in child class'
+        # Load default params (with a dictionary format to enable quick replacement)
+        # IMPORTANT: Using a dictionary format enable using the default/user params methodology
+        self.default_params_sample = self.default_param_name_and_dim_to_coef
+        self.default_params_start_fit = self.default_param_name_and_dim_to_coef
+
         # Load sample coef
-        self.default_params_sample = self.default_param_name_and_dim_to_coef()
-        linear_coef_sample = self.gev_param_name_to_linear_coef(param_name_and_dim_to_coef=self.params_sample)
+        coef_sample = self.gev_param_name_to_linear_coef(param_name_and_dim_to_coef=self.params_sample)
         self.margin_function_sample = LinearMarginFunction(coordinates=self.coordinates,
-                                                           gev_param_name_to_linear_coef=linear_coef_sample,
-                                                           gev_param_name_to_linear_dims=gev_param_name_to_linear_dims)
+                                                           gev_param_name_to_coef=coef_sample,
+                                                           gev_param_name_to_dims=gev_param_name_to_dims)
 
         # Load start fit coef
-        self.default_params_start_fit = self.default_param_name_and_dim_to_coef()
-        linear_coef_start_fit = self.gev_param_name_to_linear_coef(param_name_and_dim_to_coef=self.params_start_fit)
+        coef_start_fit = self.gev_param_name_to_linear_coef(param_name_and_dim_to_coef=self.params_start_fit)
         self.margin_function_start_fit = LinearMarginFunction(coordinates=self.coordinates,
-                                                              gev_param_name_to_linear_coef=linear_coef_start_fit,
-                                                              gev_param_name_to_linear_dims=gev_param_name_to_linear_dims)
+                                                              gev_param_name_to_coef=coef_start_fit,
+                                                              gev_param_name_to_dims=gev_param_name_to_dims)
 
-    @staticmethod
-    def default_param_name_and_dim_to_coef() -> dict:
+    @property
+    def default_param_name_and_dim_to_coef(self) -> dict:
         default_intercept = 1
         default_slope = 0.01
         gev_param_name_and_dim_to_coef = {}
         for gev_param_name in GevParams.PARAM_NAMES:
-            gev_param_name_and_dim_to_coef[(gev_param_name, 0)] = default_intercept
-            for dim in [1, 2, 3]:
+            gev_param_name_and_dim_to_coef[(gev_param_name, -1)] = default_intercept
+            for dim in self.coordinates.coordinates_dims:
                 gev_param_name_and_dim_to_coef[(gev_param_name, dim)] = default_slope
         return gev_param_name_and_dim_to_coef
 
-    @staticmethod
-    def gev_param_name_to_linear_coef(param_name_and_dim_to_coef):
+    def gev_param_name_to_linear_coef(self, param_name_and_dim_to_coef):
         gev_param_name_to_linear_coef = {}
         for gev_param_name in GevParams.PARAM_NAMES:
-            dim_to_coef = {dim: param_name_and_dim_to_coef[(gev_param_name, dim)] for dim in [0, 1, 2, 3]}
-            linear_coef = LinearCoef(gev_param_name=gev_param_name, dim_to_coef=dim_to_coef)
+            idx_to_coef = {idx: param_name_and_dim_to_coef[(gev_param_name, idx)] for idx in [-1] + self.coordinates.coordinates_dims}
+            linear_coef = LinearCoef(gev_param_name=gev_param_name, idx_to_coef=idx_to_coef)
             gev_param_name_to_linear_coef[gev_param_name] = linear_coef
         return gev_param_name_to_linear_coef
 
-    @classmethod
-    def from_coef_list(cls, coordinates, gev_param_name_to_coef_list):
-        params = {}
-        for gev_param_name in GevParams.PARAM_NAMES:
-            for dim, coef in enumerate(gev_param_name_to_coef_list[gev_param_name]):
-                params[(gev_param_name, dim)] = coef
-        return cls(coordinates, params_sample=params, params_start_fit=params)
-
 
 class ConstantMarginModel(LinearMarginModel):
 
-    def load_margin_functions(self, gev_param_name_to_linear_dims=None):
+    def load_margin_functions(self, gev_param_name_to_dims=None):
         super().load_margin_functions({})
 
 
 class LinearShapeDim1MarginModel(LinearMarginModel):
 
-    def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_linear_dims=None):
-        super().load_margin_functions({GevParams.SHAPE: [1]})
+    def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_dims=None):
+        super().load_margin_functions({GevParams.SHAPE: [0]})
 
 
 class LinearScaleDim1MarginModel(LinearMarginModel):
 
-    def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_linear_dims=None):
-        super().load_margin_functions({GevParams.SCALE: [1]})
+    def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_dims=None):
+        super().load_margin_functions({GevParams.SCALE: [0]})
 
 
 class LinearShapeDim1and2MarginModel(LinearMarginModel):
 
-    def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_linear_dims=None):
-        super().load_margin_functions({GevParams.SHAPE: [1, 2]})
+    def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_dims=None):
+        super().load_margin_functions({GevParams.SHAPE: [0, 1]})
 
 
 class LinearAllParametersDim1MarginModel(LinearMarginModel):
 
-    def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_linear_dims=None):
-        super().load_margin_functions({GevParams.SHAPE: [1],
-                                       GevParams.LOC: [1],
-                                       GevParams.SCALE: [1]})
+    def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_dims=None):
+        super().load_margin_functions({GevParams.SHAPE: [0],
+                                       GevParams.LOC: [0],
+                                       GevParams.SCALE: [0]})
 
 
 class LinearMarginModelExample(LinearMarginModel):
 
-    def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_linear_dims=None):
-        super().load_margin_functions({GevParams.SHAPE: [1],
-                                       GevParams.LOC: [2],
-                                       GevParams.SCALE: [1]})
+    def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_dims=None):
+        super().load_margin_functions({GevParams.SHAPE: [0],
+                                       GevParams.LOC: [1],
+                                       GevParams.SCALE: [0]})
 
 
 class LinearAllParametersAllDimsMarginModel(LinearMarginModel):
 
-    def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_linear_dims=None):
-        all_dims = list(range(1, self.coordinates.nb_coordinates + 1))
-        super().load_margin_functions({GevParams.SHAPE: all_dims.copy(),
-                                       GevParams.LOC: all_dims.copy(),
-                                       GevParams.SCALE: all_dims.copy()})
+    def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_dims=None):
+        super().load_margin_functions({GevParams.SHAPE: self.coordinates.coordinates_dims,
+                                       GevParams.LOC: self.coordinates.coordinates_dims,
+                                       GevParams.SCALE: self.coordinates.coordinates_dims})
diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
index 44f704459e1b13374a0b144b6ea3b95531b89c05..81e42ecaf7b3883d0cd7d88094795cbf281351a8 100644
--- a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
+++ b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
@@ -13,7 +13,9 @@ from utils import cached_property
 
 
 class AbstractMarginFunction(object):
-    """ Class of function mapping points from a space S (could be 1D, 2D,...) to R^3 (the 3 parameters of the GEV)"""
+    """
+    AbstractMarginFunction maps points from a space S (could be 1D, 2D,...) to R^3 (the 3 parameters of the GEV)
+    """
     VISUALIZATION_RESOLUTION = 100
 
     def __init__(self, coordinates: AbstractCoordinates):
@@ -138,7 +140,8 @@ class AbstractMarginFunction(object):
         grid = []
         for i, xi in enumerate(linspace):
             gev_param = self.get_gev_params(np.array([xi]))
-            assert not gev_param.has_undefined_parameters, 'This case needs to be handled during display'
+            assert not gev_param.has_undefined_parameters, 'This case needs to be handled during display,' \
+                                                           'gev_parameter for xi={} is undefined'.format(xi)
             grid.append(gev_param.summary_dict)
         grid = {gev_param: [g[gev_param] for g in grid] for gev_param in GevParams.SUMMARY_NAMES}
         return grid, linspace
diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/independent_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/independent_margin_function.py
index bc97b77b85ef197e361255b3ccd6bffc151bcaf1..c4d34c8d382cb58b3c71813dfad406a515466749 100644
--- a/extreme_estimator/extreme_models/margin_model/margin_function/independent_margin_function.py
+++ b/extreme_estimator/extreme_models/margin_model/margin_function/independent_margin_function.py
@@ -2,7 +2,7 @@ from typing import Dict
 
 import numpy as np
 
-from extreme_estimator.extreme_models.margin_model.param_function.param_function import ParamFunction
+from extreme_estimator.extreme_models.margin_model.param_function.param_function import AbstractParamFunction
 from extreme_estimator.margin_fits.gev.gev_params import GevParams
 from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \
     AbstractMarginFunction
@@ -10,12 +10,14 @@ from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoo
 
 
 class IndependentMarginFunction(AbstractMarginFunction):
-    """Margin Function where each parameter of the GEV are modeled independently"""
+    """
+        IndependentMarginFunction: each parameter of the GEV are modeled independently
+    """
 
     def __init__(self, coordinates: AbstractCoordinates):
         """Attribute 'gev_param_name_to_param_function' maps each GEV parameter to its corresponding function"""
         super().__init__(coordinates)
-        self.gev_param_name_to_param_function = None  # type: Dict[str, ParamFunction]
+        self.gev_param_name_to_param_function = None  # type: Dict[str, AbstractParamFunction]
 
     def get_gev_params(self, coordinate: np.ndarray) -> GevParams:
         """Each GEV parameter is computed independently through its corresponding param_function"""
diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/linear_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/linear_margin_function.py
index 34134e31cedbf5fe1d22fbb3f969276334051733..427d0d38ee175a1a85638a3bc30a34b79dd8ac5c 100644
--- a/extreme_estimator/extreme_models/margin_model/margin_function/linear_margin_function.py
+++ b/extreme_estimator/extreme_models/margin_model/margin_function/linear_margin_function.py
@@ -2,9 +2,10 @@ from typing import Dict, List
 
 from extreme_estimator.extreme_models.margin_model.margin_function.parametric_margin_function import \
     ParametricMarginFunction
+from extreme_estimator.extreme_models.margin_model.param_function.abstract_coef import AbstractCoef
 from extreme_estimator.extreme_models.margin_model.param_function.linear_coef import LinearCoef
 from extreme_estimator.extreme_models.margin_model.param_function.param_function import ConstantParamFunction, \
-    ParamFunction, LinearParamFunction
+    AbstractParamFunction, LinearParamFunction
 from extreme_estimator.margin_fits.gev.gev_params import GevParams
 from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
 
@@ -21,90 +22,50 @@ class LinearMarginFunction(ParametricMarginFunction):
 
         gev_param_name_to_linear_coef             maps each parameter of the GEV distribution to its linear coefficients
 
-        gev_param_name_to_start_fit_linear_coef   maps each parameter of the GEV distribution to its starting fitting
-                                                   value for the linear coefficients
     """
 
-    def __init__(self, coordinates: AbstractCoordinates,
-                 gev_param_name_to_linear_dims: Dict[str, List[int]],
-                 gev_param_name_to_linear_coef: Dict[str, LinearCoef]):
-        super().__init__(coordinates)
-        self.gev_param_name_to_linear_coef = gev_param_name_to_linear_coef  # type: Dict[str, LinearCoef]
-        self.gev_param_name_to_linear_dims = gev_param_name_to_linear_dims  # type: Dict[str, List[int]]
-        # Build gev_parameter_to_param_function dictionary
-        self.gev_param_name_to_param_function = {}  # type: Dict[str, ParamFunction]
+    COEF_CLASS = LinearCoef
 
-        # Check the linear_dim are well-defined with respect to the coordinates
-        for linear_dims in self.gev_param_name_to_linear_dims.values():
-            for dim in linear_dims:
-                assert 0 < dim <= coordinates.nb_coordinates, "dim={}, nb_columns={}".format(dim,
-                                                                                             coordinates.nb_coordinates)
+    def __init__(self, coordinates: AbstractCoordinates, gev_param_name_to_dims: Dict[str, List[int]],
+                 gev_param_name_to_coef: Dict[str, AbstractCoef]):
+        self.gev_param_name_to_coef = None  # type: Dict[str, LinearCoef]
+        super().__init__(coordinates, gev_param_name_to_dims, gev_param_name_to_coef)
 
-        # Map each gev_param_name to its corresponding param_function
-        for gev_param_name in GevParams.PARAM_NAMES:
-            linear_coef = self.gev_param_name_to_linear_coef[gev_param_name]
-            # By default, if linear_dims are not specified, a constantParamFunction is chosen
-            if gev_param_name not in self.gev_param_name_to_linear_dims.keys():
-                param_function = ConstantParamFunction(constant=linear_coef.get_coef(dim=0))
-            # Otherwise, we fit a LinearParamFunction
-            else:
-                param_function = LinearParamFunction(linear_dims=self.gev_param_name_to_linear_dims[gev_param_name],
-                                                     coordinates=self.coordinates.coordinates_values(),
-                                                     linear_coef=linear_coef)
-            # Add the param_function to the dictionary
-            self.gev_param_name_to_param_function[gev_param_name] = param_function
-
-    @classmethod
-    def from_coef_dict(cls, coordinates: AbstractCoordinates, gev_param_name_to_linear_dims: Dict[str, List[int]],
-                       coef_dict: Dict[str, float]):
-        gev_param_name_to_linear_coef = {}
-        for gev_param_name in GevParams.PARAM_NAMES:
-            linear_dims = gev_param_name_to_linear_dims.get(gev_param_name, [])
-            linear_coef = LinearCoef.from_coef_dict(coef_dict=coef_dict, gev_param_name=gev_param_name,
-                                                    linear_dims=linear_dims,
-                                                    dim_to_coefficient_name=cls.dim_to_coefficient_name(coordinates))
-            gev_param_name_to_linear_coef[gev_param_name] = linear_coef
-        return cls(coordinates, gev_param_name_to_linear_dims, gev_param_name_to_linear_coef)
+    def load_specific_param_function(self, gev_param_name) -> AbstractParamFunction:
+        return LinearParamFunction(dims=self.gev_param_name_to_dims[gev_param_name],
+                                   coordinates=self.coordinates.coordinates_values(),
+                                   linear_coef=self.gev_param_name_to_coef[gev_param_name])
 
     @classmethod
-    def dim_to_coefficient_name(cls, coordinates: AbstractCoordinates) -> Dict[int, str]:
+    def idx_to_coefficient_name(cls, coordinates: AbstractCoordinates) -> Dict[int, str]:
         # Intercept correspond to the dimension 0
-        dim_to_coefficient_name = {0: LinearCoef.INTERCEPT_NAME}
+        idx_to_coefficient_name = {-1: LinearCoef.INTERCEPT_NAME}
         # Coordinates correspond to the dimension starting from 1
-        for i, coordinate_name in enumerate(coordinates.coordinates_names, 1):
-            dim_to_coefficient_name[i] = coordinate_name
-        return dim_to_coefficient_name
+        for idx, coordinate_name in enumerate(coordinates.coordinates_names):
+            idx_to_coefficient_name[idx] = coordinate_name
+        return idx_to_coefficient_name
 
     @classmethod
     def coefficient_name_to_dim(cls, coordinates: AbstractCoordinates) -> Dict[int, str]:
-        return {v: k for k, v in cls.dim_to_coefficient_name(coordinates).items()}
+        return {v: k for k, v in cls.idx_to_coefficient_name(coordinates).items()}
 
     @property
     def form_dict(self) -> Dict[str, str]:
         form_dict = {}
         for gev_param_name in GevParams.PARAM_NAMES:
-            linear_dims = self.gev_param_name_to_linear_dims.get(gev_param_name, [])
+            linear_dims = self.gev_param_name_to_dims.get(gev_param_name, [])
             # Load spatial form_dict (only if we have some spatial coordinates)
             if self.coordinates.coordinates_spatial_names:
                 spatial_names = [name for name in self.coordinates.coordinates_spatial_names
                                  if self.coefficient_name_to_dim(self.coordinates)[name] in linear_dims]
-                spatial_form = self.gev_param_name_to_linear_coef[gev_param_name].spatial_form_dict(spatial_names)
+                spatial_form = self.gev_param_name_to_coef[gev_param_name].spatial_form_dict(spatial_names)
                 form_dict.update(spatial_form)
             # Load temporal form dict (only if we have some temporal coordinates)
             if self.coordinates.coordinates_temporal_names:
                 temporal_names = [name for name in self.coordinates.coordinates_temporal_names
                                   if self.coefficient_name_to_dim(self.coordinates)[name] in linear_dims]
-                temporal_form = self.gev_param_name_to_linear_coef[gev_param_name].temporal_form_dict(temporal_names)
+                temporal_form = self.gev_param_name_to_coef[gev_param_name].temporal_form_dict(temporal_names)
                 # Specifying a formula '~ 1' creates a bug in fitspatgev of SpatialExtreme R package
                 assert not any(['1' in formula for formula in temporal_form.values()])
                 form_dict.update(temporal_form)
         return form_dict
-
-    @property
-    def coef_dict(self) -> Dict[str, float]:
-        coef_dict = {}
-        for gev_param_name in GevParams.PARAM_NAMES:
-            linear_dims = self.gev_param_name_to_linear_dims.get(gev_param_name, [])
-            linear_coef = self.gev_param_name_to_linear_coef[gev_param_name]
-            coef_dict.update(linear_coef.coef_dict(linear_dims, self.dim_to_coefficient_name(self.coordinates)))
-        return coef_dict
diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/parametric_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/parametric_margin_function.py
index 0305bcfbfad0f7e69ab458fa508ebe830c59d654..8a0e51ca55ea2e0123db14359a45815fa3e8800a 100644
--- a/extreme_estimator/extreme_models/margin_model/margin_function/parametric_margin_function.py
+++ b/extreme_estimator/extreme_models/margin_model/margin_function/parametric_margin_function.py
@@ -1,15 +1,83 @@
-from typing import Dict
+from typing import Dict, List
 
 from extreme_estimator.extreme_models.margin_model.margin_function.independent_margin_function import \
     IndependentMarginFunction
+from extreme_estimator.extreme_models.margin_model.param_function.abstract_coef import AbstractCoef
+from extreme_estimator.extreme_models.margin_model.param_function.param_function import AbstractParamFunction, \
+    ConstantParamFunction
+from extreme_estimator.margin_fits.gev.gev_params import GevParams
+from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
 
 
 class ParametricMarginFunction(IndependentMarginFunction):
+    """
+    ParametricMarginFunction each parameter of the GEV will:
 
-    @property
-    def form_dict(self) -> Dict[str, str]:
+        -depend on some integer dimensions (dimension 1 or/and dimension 2 for instance).
+        Coordinate name corresponding to the dimension depends on the order of the columns of self.coordinates
+        gev_param_name_to_dims maps each GEV parameter to its corresponding dimensions
+
+        -have a set of all potential coefficient that could be used to define a function
+        gev_param_name_to_coef maps each GEV parameter to an AbstractCoef object. This object contains
+
+        -combining the integer dimensions & the set of all potential coefficient
+        to keep only the relevant coefficient, and build the corresponding function from that
+        gev_param_name_to_param_function maps each GEV parameter to a AbstractParamFunction object.
+
+    """
+
+    COEF_CLASS = None
+
+    def __init__(self, coordinates: AbstractCoordinates, gev_param_name_to_dims: Dict[str, List[int]],
+                 gev_param_name_to_coef: Dict[str, AbstractCoef]):
+        super().__init__(coordinates)
+        self.gev_param_name_to_dims = gev_param_name_to_dims  # type: Dict[str, List[int]]
+
+        # Check the dimension are well-defined with respect to the coordinates
+        for dims in self.gev_param_name_to_dims.values():
+            for dim in dims:
+                assert 0 <= dim < coordinates.nb_coordinates, \
+                    "dim={}, nb_columns={}".format(dim, coordinates.nb_coordinates)
+
+        self.gev_param_name_to_coef = gev_param_name_to_coef  # type: Dict[str, AbstractCoef]
+
+        # Build gev_parameter_to_param_function dictionary
+        self.gev_param_name_to_param_function = {}  # type: Dict[str, AbstractParamFunction]
+        # Map each gev_param_name to its corresponding param_function
+        for gev_param_name in GevParams.PARAM_NAMES:
+            # By default, if dims are not specified, a constantParamFunction is chosen
+            if self.gev_param_name_to_dims.get(gev_param_name) is None:
+                param_function = ConstantParamFunction(constant=self.gev_param_name_to_coef[gev_param_name].intercept)
+            # Otherwise, we load a specific param function
+            else:
+                param_function = self.load_specific_param_function(gev_param_name)
+            # In both cases, we add the param_function to the dictionary
+            self.gev_param_name_to_param_function[gev_param_name] = param_function
+
+    def load_specific_param_function(self, gev_param_name) -> AbstractParamFunction:
         raise NotImplementedError
 
+    @classmethod
+    def from_coef_dict(cls, coordinates: AbstractCoordinates, gev_param_name_to_dims: Dict[str, List[int]],
+                       coef_dict: Dict[str, float]):
+        assert cls.COEF_CLASS is not None, 'a COEF_CLASS class attributes needs to be defined'
+        gev_param_name_to_coef = {}
+        for gev_param_name in GevParams.PARAM_NAMES:
+            dims = gev_param_name_to_dims.get(gev_param_name, [])
+            coef = cls.COEF_CLASS.from_coef_dict(coef_dict=coef_dict, gev_param_name=gev_param_name, dims=dims,
+                                                 coordinates=coordinates)
+            gev_param_name_to_coef[gev_param_name] = coef
+        return cls(coordinates, gev_param_name_to_dims, gev_param_name_to_coef)
+
     @property
     def coef_dict(self) -> Dict[str, float]:
-        raise NotImplementedError
\ No newline at end of file
+        coef_dict = {}
+        for gev_param_name in GevParams.PARAM_NAMES:
+            dims = self.gev_param_name_to_dims.get(gev_param_name, [])
+            coef = self.gev_param_name_to_coef[gev_param_name]
+            coef_dict.update(coef.coef_dict(dims, self.idx_to_coefficient_name(self.coordinates)))
+        return coef_dict
+
+    @property
+    def form_dict(self) -> Dict[str, str]:
+        raise NotImplementedError
diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/spline_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/spline_margin_function.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7560ea664cae179fe0753bc531b8d6e448b725f
--- /dev/null
+++ b/extreme_estimator/extreme_models/margin_model/margin_function/spline_margin_function.py
@@ -0,0 +1,65 @@
+from typing import Dict, List
+
+import numpy as np
+
+from extreme_estimator.extreme_models.margin_model.margin_function.parametric_margin_function import \
+    ParametricMarginFunction
+from extreme_estimator.extreme_models.margin_model.param_function.abstract_coef import AbstractCoef
+from extreme_estimator.extreme_models.margin_model.param_function.param_function import AbstractParamFunction, \
+    SplineParamFunction
+from extreme_estimator.extreme_models.margin_model.param_function.spline_coef import SplineCoef
+from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
+
+
+class SplineMarginFunction(ParametricMarginFunction):
+    """
+    -gev_param_name_to_dims maps each GEV parameters to its correspond knot dimensions.
+        For instance, dims = [1,2] means the knot will be realized with 2D knots
+        dims = [1] means the knot will lie only on the first axis
+
+    """
+
+    COEF_CLASS = SplineCoef
+
+    def __init__(self, coordinates: AbstractCoordinates, gev_param_name_to_dims: Dict[str, List[int]],
+                 gev_param_name_to_coef: Dict[str, AbstractCoef],
+                 gev_param_name_to_nb_knots: Dict[str, int],
+                 degree=3):
+        self.gev_param_name_to_coef = None  # type: Dict[str, SplineCoef]
+        # Attributes specific for SplineMarginFunction
+        self.gev_param_name_to_nb_knots = gev_param_name_to_nb_knots
+        assert degree % 2 == 1
+        self.degree = degree
+        super().__init__(coordinates, gev_param_name_to_dims, gev_param_name_to_coef)
+
+
+    def compute_knots(self, dims, nb_knots) -> np.ndarray:
+        """Define the knots as the quantiles"""
+        return np.quantile(a=self.coordinates.df_all_coordinates.iloc[:, dims], q=np.linspace(0, 1, nb_knots+2)[1:-1])
+
+    @property
+    def form_dict(self) -> Dict[str, str]:
+        """
+        3 examples of potential form dict:
+            loc.form <- y ~ rb(locations[,1], knots = knots, degree = 3, penalty = .5)
+            scale.form <- y ~ rb(locations[,2], knots = knots2, degree = 3, penalty = .5)
+            shape.form <- y ~ rb(locations, knots = knots_tot, degree = 3, penalty = .5)
+        """
+        pass
+
+    def load_specific_param_function(self, gev_param_name) -> AbstractParamFunction:
+        dims = self.gev_param_name_to_dims[gev_param_name]
+        coef = self.gev_param_name_to_coef[gev_param_name]
+        nb_knots = self.gev_param_name_to_nb_knots[gev_param_name]
+        knots = self.compute_knots(dims, nb_knots)
+        return SplineParamFunction(dims=dims, degree=self.degree, spline_coef=coef, knots=knots)
+
+
+
+
+
+
+
+
+
+
diff --git a/extreme_estimator/extreme_models/margin_model/param_function/abstract_coef.py b/extreme_estimator/extreme_models/margin_model/param_function/abstract_coef.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ba91ecc3855e908447ec134f7675619970c1946
--- /dev/null
+++ b/extreme_estimator/extreme_models/margin_model/param_function/abstract_coef.py
@@ -0,0 +1,39 @@
+from typing import Dict, List
+
+from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
+
+
+class AbstractCoef(object):
+
+    def __init__(self, gev_param_name: str, default_value: float = 0.0, idx_to_coef=None):
+        self.gev_param_name = gev_param_name
+        self.default_value = default_value
+        self.idx_to_coef = idx_to_coef
+
+    def get_coef(self, idx) -> float:
+        if self.idx_to_coef is None:
+            return self.default_value
+        else:
+            return self.idx_to_coef.get(idx, self.compute_default_value(idx))
+
+    def compute_default_value(self, idx):
+        return self.default_value
+
+    @property
+    def intercept(self) -> float:
+        return self.default_value
+
+    """ Coef dict """
+
+    def coef_dict(self, dims: List[int], coordinates: AbstractCoordinates) -> Dict[str, float]:
+        raise NotImplementedError
+
+    @classmethod
+    def from_coef(cls, coef_dict: Dict[str, float], gev_param_name: str, dims: List[int], coordinates: AbstractCoordinates):
+        raise NotImplementedError
+
+    """ Form dict """
+
+    def form_dict(self, names: List[str]) -> Dict[str, str]:
+        formula_str = '1' if not names else '+'.join(names)
+        return {self.gev_param_name + '.form': self.gev_param_name + ' ~ ' + formula_str}
\ No newline at end of file
diff --git a/extreme_estimator/extreme_models/margin_model/param_function/linear_coef.py b/extreme_estimator/extreme_models/margin_model/param_function/linear_coef.py
index 1e67c94c17d2ab5a35cd6f946bfe222297c44d67..6270d0cf6bf78cf8cfa445b9cb1f427c9da608a4 100644
--- a/extreme_estimator/extreme_models/margin_model/param_function/linear_coef.py
+++ b/extreme_estimator/extreme_models/margin_model/param_function/linear_coef.py
@@ -1,9 +1,10 @@
 from typing import Dict, List
 
+from extreme_estimator.extreme_models.margin_model.param_function.abstract_coef import AbstractCoef
 from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
 
 
-class LinearCoef(object):
+class LinearCoef(AbstractCoef):
     """
     Object that maps each dimension to its corresponding coefficient.
         dim = 0 correspond to the intercept
@@ -14,20 +15,9 @@ class LinearCoef(object):
     INTERCEPT_NAME = 'intercept'
     COEFF_STR = 'Coeff'
 
-    def __init__(self, gev_param_name: str, dim_to_coef: Dict[int, float] = None, default_value: float = 0.0):
-        self.gev_param_name = gev_param_name
-        self.dim_to_coef = dim_to_coef
-        self.default_value = default_value
-
-    def get_coef(self, dim: int) -> float:
-        if self.dim_to_coef is None:
-            return self.default_value
-        else:
-            return self.dim_to_coef.get(dim, self.default_value)
-
     @property
     def intercept(self) -> float:
-        return self.get_coef(dim=0)
+        return self.get_coef(idx=-1)
 
     @classmethod
     def coef_template_str(cls, gev_param_name: str, coefficient_name: str) -> str:
@@ -53,41 +43,38 @@ class LinearCoef(object):
                     for coefficient_name in dim_to_coefficient_name.values()])
 
     @classmethod
-    def add_intercept_dim(cls, dims):
-        return [0] + dims
+    def add_intercept_idx(cls, dims):
+        return [-1] + dims
+
+    """ Coef dict """
 
     @classmethod
-    def from_coef_dict(cls, coef_dict: Dict[str, float], gev_param_name: str, linear_dims: List[int],
-                       dim_to_coefficient_name: Dict[int, str]):
-        dims = cls.add_intercept_dim(linear_dims)
-        dim_to_coef = {}
-        j = 1
+    def from_coef_dict(cls, coef_dict: Dict[str, float], gev_param_name: str, dims: List[int],
+                       coordinates: AbstractCoordinates):
+        idx_to_coef = {-1: coef_dict[cls.coef_template_str(gev_param_name, coefficient_name=cls.INTERCEPT_NAME).format(1)]}
+        j = 2
         for dim in dims:
-            coefficient_name = dim_to_coefficient_name[dim]
+            coefficient_name = coordinates.coordinates_names[dim]
             if coefficient_name == AbstractCoordinates.COORDINATE_T:
                 j = 1
             coef = coef_dict[cls.coef_template_str(gev_param_name, coefficient_name).format(j)]
-            dim_to_coef[dim] = coef
+            idx_to_coef[dim] = coef
             j += 1
-        return cls(gev_param_name, dim_to_coef)
+        return cls(gev_param_name=gev_param_name, idx_to_coef=idx_to_coef)
 
-    def coef_dict(self, linear_dims, dim_to_coefficient_name: Dict[int, str]) -> Dict[str, float]:
-        dims = self.add_intercept_dim(linear_dims)
+    def coef_dict(self, dims, dim_to_coefficient_name: Dict[int, str]) -> Dict[str, float]:
+        dims = self.add_intercept_idx(dims)
         coef_dict = {}
         j = 1
         for dim in dims:
             coefficient_name = dim_to_coefficient_name[dim]
             if coefficient_name == AbstractCoordinates.COORDINATE_T:
                 j = 1
-            coef = self.dim_to_coef[dim]
+            coef = self.idx_to_coef[dim]
             coef_dict[self.coef_template_str(self.gev_param_name, coefficient_name).format(j)] = coef
             j += 1
         return coef_dict
 
-    def form_dict(self, names: List[str]) -> Dict[str, str]:
-        formula_str = '1' if not names else '+'.join(names)
-        return {self.gev_param_name + '.form': self.gev_param_name + ' ~ ' + formula_str}
-
     def spatial_form_dict(self, coordinate_spatial_names: List[str]) -> Dict[str, str]:
         """
         Example of formula that could be specified:
@@ -105,7 +92,7 @@ class LinearCoef(object):
         Example of formula that could be specified:
         temp.form.loc = loc ~ coord_t
         Example of formula that could not be specified
-        temp.loc.form = shape ~ 1
+        temp.loc.form = loc ~ 1
         :return:
         """
         assert all([name in [AbstractCoordinates.COORDINATE_T] for name in coordinate_temporal_names])
diff --git a/extreme_estimator/extreme_models/margin_model/param_function/param_function.py b/extreme_estimator/extreme_models/margin_model/param_function/param_function.py
index f8b882d65ac74f99b70200e1cf70d3d245f2fc40..a4803125e31a9408d056e127eb711d88799981d0 100644
--- a/extreme_estimator/extreme_models/margin_model/param_function/param_function.py
+++ b/extreme_estimator/extreme_models/margin_model/param_function/param_function.py
@@ -1,16 +1,17 @@
 from typing import List
 import numpy as np
 from extreme_estimator.extreme_models.margin_model.param_function.linear_coef import LinearCoef
+from extreme_estimator.extreme_models.margin_model.param_function.spline_coef import SplineCoef
 
 
-class ParamFunction(object):
+class AbstractParamFunction(object):
     OUT_OF_BOUNDS_ASSERT = True
 
     def get_gev_param_value(self, coordinate: np.ndarray) -> float:
         pass
 
 
-class ConstantParamFunction(ParamFunction):
+class ConstantParamFunction(AbstractParamFunction):
 
     def __init__(self, constant):
         self.constant = constant
@@ -19,33 +20,30 @@ class ConstantParamFunction(ParamFunction):
         return self.constant
 
 
-class LinearOneAxisParamFunction(ParamFunction):
+class LinearOneAxisParamFunction(AbstractParamFunction):
 
-    def __init__(self, linear_axis: int, coordinates: np.ndarray, coef: float = 0.01):
-        self.linear_axis = linear_axis
-        self.t_min = coordinates[:, linear_axis].min()
-        self.t_max = coordinates[:, linear_axis].max()
+    def __init__(self, dim: int, coordinates: np.ndarray, coef: float = 0.01):
+        self.dim = dim
+        self.t_min = coordinates[:, dim].min()
+        self.t_max = coordinates[:, dim].max()
         self.coef = coef
 
-    def get_gev_param_value_normalized(self, coordinate: np.ndarray) -> float:
-        return self.get_gev_param_value(coordinate) / (self.t_max - self.t_min)
-
     def get_gev_param_value(self, coordinate: np.ndarray) -> float:
-        t = coordinate[self.linear_axis]
+        t = coordinate[self.dim]
         if self.OUT_OF_BOUNDS_ASSERT:
             assert self.t_min <= t <= self.t_max, 'Out of bounds'
         return t * self.coef
 
 
-class LinearParamFunction(ParamFunction):
+class LinearParamFunction(AbstractParamFunction):
 
-    def __init__(self, linear_dims: List[int], coordinates: np.ndarray, linear_coef: LinearCoef = None):
+    def __init__(self, dims: List[int], coordinates: np.ndarray, linear_coef: LinearCoef = None):
         self.linear_coef = linear_coef
         # Load each one axis linear function
         self.linear_one_axis_param_functions = []  # type: List[LinearOneAxisParamFunction]
-        for linear_dim in linear_dims:
-            param_function = LinearOneAxisParamFunction(linear_axis=linear_dim - 1, coordinates=coordinates,
-                                                        coef=self.linear_coef.get_coef(dim=linear_dim))
+        for dim in dims:
+            param_function = LinearOneAxisParamFunction(dim=dim, coordinates=coordinates,
+                                                        coef=self.linear_coef.get_coef(idx=dim))
             self.linear_one_axis_param_functions.append(param_function)
 
     def get_gev_param_value(self, coordinate: np.ndarray) -> float:
@@ -54,3 +52,29 @@ class LinearParamFunction(ParamFunction):
         for linear_one_axis_param_function in self.linear_one_axis_param_functions:
             gev_param_value += linear_one_axis_param_function.get_gev_param_value(coordinate)
         return gev_param_value
+
+
+class SplineParamFunction(AbstractParamFunction):
+
+    def __init__(self, dims, degree, spline_coef: SplineCoef, knots: np.ndarray) -> None:
+        self.spline_coef = spline_coef
+        self.degree = degree
+        self.dims = dims
+        self.knots = knots
+
+    @property
+    def m(self) -> int:
+        return int((self.degree + 1) / 2)
+
+    def get_gev_param_value(self, coordinate: np.ndarray) -> float:
+        gev_param_value = self.spline_coef.intercept
+        # Polynomial part
+        for dim in self.dims:
+            polynomial_coef = self.spline_coef.dim_to_polynomial_coef[dim]
+            for degree in range(1, self.m):
+                gev_param_value += polynomial_coef.get_coef(degree) * coordinate[dim]
+        # Knot part
+        for idx, knot in enumerate(self.knots):
+            distance = np.power(np.linalg.norm(coordinate - knot), self.degree)
+            gev_param_value += self.spline_coef.knot_coef.get_coef(idx) * distance
+        return gev_param_value
diff --git a/extreme_estimator/extreme_models/margin_model/param_function/spline_coef.py b/extreme_estimator/extreme_models/margin_model/param_function/spline_coef.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e3f86e5a55afec9b69eda81756e4185848d5d3d
--- /dev/null
+++ b/extreme_estimator/extreme_models/margin_model/param_function/spline_coef.py
@@ -0,0 +1,32 @@
+from typing import Dict
+
+from extreme_estimator.extreme_models.margin_model.param_function.abstract_coef import AbstractCoef
+
+
+class PolynomialCoef(AbstractCoef):
+    """
+    Object that maps each degree to its corresponding coefficient.
+        degree = 1 correspond to the coefficient of the first order polynomial
+        degree = 2 correspond to the the coefficient of the first order polynomial
+        degree = 3 correspond to the the coefficient of the first order polynomial
+    """
+
+    def __init__(self, gev_param_name: str, default_value: float = 1.0, degree_to_coef=None):
+        super().__init__(gev_param_name, default_value, idx_to_coef=degree_to_coef)
+
+    def compute_default_value(self, idx):
+        return self.default_value / idx
+
+
+class KnotCoef(AbstractCoef):
+
+    def __init__(self, gev_param_name: str, default_value: float = 1.0, idx_to_coef=None):
+        super().__init__(gev_param_name, default_value, idx_to_coef)
+
+
+class SplineCoef(AbstractCoef):
+
+    def __init__(self, gev_param_name: str, knot_coef: KnotCoef, dim_to_polynomial_coef: Dict[int, PolynomialCoef]):
+        super().__init__(gev_param_name, 1.0, None)
+        self.knot_coef = knot_coef
+        self.dim_to_polynomial_coef = dim_to_polynomial_coef
diff --git a/extreme_estimator/extreme_models/margin_model/spline_margin_model.py b/extreme_estimator/extreme_models/margin_model/spline_margin_model.py
index 269e4db6d3ad76ce7498779c8f82af8c6dff02ae..d0f3b3f001e4a9ab98b44a100c310544bb594176 100644
--- a/extreme_estimator/extreme_models/margin_model/spline_margin_model.py
+++ b/extreme_estimator/extreme_models/margin_model/spline_margin_model.py
@@ -1,12 +1,76 @@
+import numpy as np
+from typing import Dict, List
 
-"""
-Potentially, we could implement a spline model for the margin, to check results from Gaume
+import numpy as np
 
-rbpspline
+from extreme_estimator.extreme_models.margin_model.margin_function.parametric_margin_function import \
+    ParametricMarginFunction
+from extreme_estimator.extreme_models.margin_model.param_function.abstract_coef import AbstractCoef
+from extreme_estimator.extreme_models.margin_model.param_function.param_function import AbstractParamFunction, \
+    SplineParamFunction
+from extreme_estimator.extreme_models.margin_model.param_function.spline_coef import SplineCoef, KnotCoef, \
+    PolynomialCoef
+from extreme_estimator.margin_fits.gev.gev_params import GevParams
+from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
 
-Fits a penalized spline with radial basis functions to data
+from extreme_estimator.extreme_models.margin_model.abstract_margin_model import AbstractMarginModel
+from extreme_estimator.extreme_models.margin_model.margin_function.spline_margin_function import SplineMarginFunction
+from extreme_estimator.extreme_models.margin_model.parametric_margin_model import ParametricMarginModel
+from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
 
-Examplesn <- 200x <- runif(n)fun <- function(x) sin(3 * pi * x)y <- fun(x) + rnorm(n, 0, sqrt(0.4))
-knots <- quantile(x, prob = 1:(n/4) / (n/4 + 1))fitted <- rbpspline(y, x, knots = knots, degree = 3)fittedplot(x, y)lines(fitted, col = 2)
 
-"""
\ No newline at end of file
+class SplineMarginModel(ParametricMarginModel):
+
+    def __init__(self, coordinates: AbstractCoordinates, use_start_value=False, params_start_fit=None,
+                 params_sample=None):
+        super().__init__(coordinates, use_start_value, params_start_fit, params_sample)
+
+    def load_margin_functions(self, gev_param_name_to_dims: Dict[str, List[int]] = None,
+                              gev_param_name_to_coef: Dict[str, AbstractCoef] = None,
+                              gev_param_name_to_nb_knots: Dict[str, int] = None,
+                              degree=3):
+        # Default parameters
+        if gev_param_name_to_dims is None:
+            gev_param_name_to_dims = {gev_param_name: self.coordinates.coordinates_dims
+                                      for gev_param_name in GevParams.PARAM_NAMES}
+        if gev_param_name_to_coef is None:
+            gev_param_name_to_coef = {}
+            for gev_param_name in GevParams.PARAM_NAMES:
+                knot_coef = KnotCoef(gev_param_name)
+                polynomial_coef = PolynomialCoef(gev_param_name)
+                dim_to_polynomial_coef = {dim: polynomial_coef for dim in self.coordinates.coordinates_dims}
+                spline_coef = SplineCoef(gev_param_name, knot_coef, dim_to_polynomial_coef)
+                gev_param_name_to_coef[gev_param_name] = spline_coef
+        if gev_param_name_to_nb_knots is None:
+            gev_param_name_to_nb_knots = {gev_param_name: 2 for gev_param_name in GevParams.PARAM_NAMES}
+
+        # Load sample coef
+        self.margin_function_sample = SplineMarginFunction(coordinates=self.coordinates,
+                                                           gev_param_name_to_dims=gev_param_name_to_dims,
+                                                           gev_param_name_to_coef=gev_param_name_to_coef,
+                                                           gev_param_name_to_nb_knots=gev_param_name_to_nb_knots,
+                                                           degree=degree)
+        # Load start fit coef
+        self.margin_function_start_fit = SplineMarginFunction(coordinates=self.coordinates,
+                                                              gev_param_name_to_dims=gev_param_name_to_dims,
+                                                              gev_param_name_to_coef=gev_param_name_to_coef,
+                                                              gev_param_name_to_nb_knots=gev_param_name_to_nb_knots,
+                                                              degree=degree)
+
+
+class ConstantSplineMarginModel(SplineMarginModel):
+
+    def load_margin_functions(self, gev_param_name_to_dims: Dict[str, List[int]] = None,
+                              gev_param_name_to_coef: Dict[str, AbstractCoef] = None,
+                              gev_param_name_to_nb_knots: Dict[str, int] = None, degree=3):
+        super().load_margin_functions({}, gev_param_name_to_coef, gev_param_name_to_nb_knots,
+                                      degree)
+
+
+class Degree1SplineMarginModel(SplineMarginModel):
+
+    def load_margin_functions(self, gev_param_name_to_dims: Dict[str, List[int]] = None,
+                              gev_param_name_to_coef: Dict[str, AbstractCoef] = None,
+                              gev_param_name_to_nb_knots: Dict[str, int] = None, degree=3):
+        super().load_margin_functions(gev_param_name_to_dims, gev_param_name_to_coef, gev_param_name_to_nb_knots,
+                                      degree=1)
diff --git a/spatio_temporal_dataset/coordinates/abstract_coordinates.py b/spatio_temporal_dataset/coordinates/abstract_coordinates.py
index 5adcffcffe8bb37d249e78cfe37172d7cb89ad70..4eafbd1c013519e2f0b5b18091ccddfb177df4af 100644
--- a/spatio_temporal_dataset/coordinates/abstract_coordinates.py
+++ b/spatio_temporal_dataset/coordinates/abstract_coordinates.py
@@ -148,6 +148,10 @@ class AbstractCoordinates(object):
     def nb_coordinates(self) -> int:
         return len(self.coordinates_names)
 
+    @property
+    def coordinates_dims(self) -> List[int]:
+        return list(range(self.nb_coordinates))
+
     # Spatial attributes
 
     @property
diff --git a/test/test_extreme_estimator/test_extreme_models/test_margin_model.py b/test/test_extreme_estimator/test_extreme_models/test_margin_model.py
index 5980064c76c6fa338443e8aec27de97b8f006e5b..b7cd717275f3489a0a960ccf64f8813cc50f705d 100644
--- a/test/test_extreme_estimator/test_extreme_models/test_margin_model.py
+++ b/test/test_extreme_estimator/test_extreme_models/test_margin_model.py
@@ -1,18 +1,15 @@
-import numpy as np
 import unittest
 
+import numpy as np
+
+from extreme_estimator.extreme_models.margin_model.linear_margin_model import LinearAllParametersAllDimsMarginModel
 from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \
     AbstractMarginFunction
-from extreme_estimator.extreme_models.margin_model.margin_function.linear_margin_function import LinearMarginFunction
-from extreme_estimator.extreme_models.margin_model.spline_margin_model import SplineMarginModel
+from extreme_estimator.extreme_models.margin_model.spline_margin_model import ConstantSplineMarginModel, \
+    SplineMarginModel, Degree1SplineMarginModel
 from extreme_estimator.margin_fits.gev.gev_params import GevParams
-from extreme_estimator.extreme_models.margin_model.linear_margin_model import LinearShapeDim1MarginModel, \
-    LinearAllParametersAllDimsMarginModel
-from spatio_temporal_dataset.coordinates.spatial_coordinates.coordinates_2D import LinSpaceSpatial2DCoordinates
-from spatio_temporal_dataset.coordinates.spatial_coordinates.generated_spatial_coordinates import \
-    CircleSpatialCoordinates
 from spatio_temporal_dataset.coordinates.spatial_coordinates.coordinates_1D import LinSpaceSpatialCoordinates
-from test.test_utils import load_test_spatiotemporal_coordinates
+from spatio_temporal_dataset.coordinates.spatial_coordinates.coordinates_2D import LinSpaceSpatial2DCoordinates
 
 
 class TestVisualizationLinearMarginModel(unittest.TestCase):
@@ -24,11 +21,9 @@ class TestVisualizationLinearMarginModel(unittest.TestCase):
         self.margin_model.margin_function_sample.visualize_function(show=self.DISPLAY)
         self.assertTrue(True)
 
-
-
     def test_example_visualization_1D(self):
         coordinates = LinSpaceSpatialCoordinates.from_nb_points(nb_points=self.nb_points)
-        self.margin_model = self.margin_model_class(coordinates=coordinates, params_sample={(GevParams.SHAPE, 1): 0.02})
+        self.margin_model = self.margin_model_class(coordinates=coordinates, params_sample={(GevParams.SHAPE, 0): 0.02})
 
     def test_example_visualization_2D_spatial(self):
         spatial_coordinates = LinSpaceSpatial2DCoordinates.from_nb_points(nb_points=self.nb_points)
@@ -59,25 +54,28 @@ class TestVisualizationLinearMarginModel(unittest.TestCase):
     #     # Load
 
 
-# class TestVisualizationSplineMarginModel(TestVisualizationMarginModel):
-#     margin_model_class = SplineMarginModel
-#
-#     def tearDown(self) -> None:
-#         self.margin_model.margin_function_sample.visualize_function(show=self.DISPLAY)
-#         self.assertTrue(True)
+class TestVisualizationSplineMarginModel(unittest.TestCase):
+    DISPLAY = False
+    nb_points = 2
+    margin_model_class = Degree1SplineMarginModel
+
+    def tearDown(self) -> None:
+        self.margin_model.margin_function_sample.visualize_function(show=self.DISPLAY)
+        self.assertTrue(True)
 
-    # def test_example_visualization_1D_spline(self):
-    #     coordinates = LinSpaceSpatialCoordinates.from_nb_points(nb_points=self.nb_points)
-    #     self.margin_model = self.margin_model_class(coordinates=coordinates, params_sample={(GevParams.SHAPE, 1): 0.02})
+    def test_example_visualization_1D_spline(self):
+        coordinates = LinSpaceSpatialCoordinates.from_nb_points(nb_points=self.nb_points, start=0.0)
+        self.margin_model = self.margin_model_class(coordinates=coordinates, params_sample={(GevParams.SHAPE, 1): 0.02})
 
-    # def test_example_visualization_2D_spatial_spline(self):
-    #     spatial_coordinates = LinSpaceSpatial2DCoordinates.from_nb_points(nb_points=self.nb_points)
-    #     self.margin_model = self.margin_model_class(coordinates=spatial_coordinates)
-    #     # Assert that the grid correspond to what we expect in a simple case
-    #     AbstractMarginFunction.VISUALIZATION_RESOLUTION = 2
-    #     grid = self.margin_model.margin_function_sample.grid_2D['loc']
-    #     true_grid = np.array([[0.98, 1.0], [1.0, 1.02]])
-    #     self.assertTrue((grid == true_grid).all(), msg="\nexpected:\n{}, \nfound:\n{}".format(true_grid, grid))
+    def test_example_visualization_2D_spatial_spline(self):
+        spatial_coordinates = LinSpaceSpatial2DCoordinates.from_nb_points(nb_points=self.nb_points)
+        self.margin_model = self.margin_model_class(coordinates=spatial_coordinates)
+        # TODO: add a similar test than in the linear case
+        # # Assert that the grid correspond to what we expect in a simple case
+        # AbstractMarginFunction.VISUALIZATION_RESOLUTION = 2
+        # grid = self.margin_model.margin_function_sample.grid_2D['loc']
+        # true_grid = np.array([[0.98, 1.0], [1.0, 1.02]])
+        # self.assertTrue((grid == true_grid).all(), msg="\nexpected:\n{}, \nfound:\n{}".format(true_grid, grid))
 
 
 if __name__ == '__main__':