diff --git a/extreme_estimator/R_model/gev/gev_parameters.py b/extreme_estimator/R_model/gev/gev_parameters.py
index 387f8ab7b639430676f9a81a4ea4df6e437aa71b..0f7c7cdd85cceb8a3d9b5af0c8a655942018abdc 100644
--- a/extreme_estimator/R_model/gev/gev_parameters.py
+++ b/extreme_estimator/R_model/gev/gev_parameters.py
@@ -1,8 +1,11 @@
+import numpy as np
+
 
 class GevParams(object):
     GEV_SCALE = 'scale'
     GEV_LOC = 'loc'
     GEV_SHAPE = 'shape'
+    GEV_PARAM_NAMES = [GEV_LOC, GEV_SCALE, GEV_SHAPE]
 
     def __init__(self, loc: float, scale: float, shape: float):
         self.location = loc
@@ -20,10 +23,6 @@ class GevParams(object):
             self.GEV_SHAPE: self.shape,
         }
 
-    def rgev(self, nb_obs):
-        gev_params = {
-            self.GEV_LOC: loc,
-            self.GEV_SCALE: scale,
-            self.GEV_SHAPE: shape,
-        }
-        return self.r.rgev(nb_obs, **gev_params)
+    def to_array(self) -> np.ndarray:
+        gev_param_name_to_value = self.to_dict()
+        return np.array([gev_param_name_to_value[gev_param_name] for gev_param_name in self.GEV_PARAM_NAMES])
diff --git a/extreme_estimator/R_model/margin_function/__init__.py b/extreme_estimator/R_model/margin_function/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/extreme_estimator/R_model/margin_function/abstract_margin_function.py b/extreme_estimator/R_model/margin_function/abstract_margin_function.py
new file mode 100644
index 0000000000000000000000000000000000000000..5392b28799ae261b2bfc470c518f98374e5556c7
--- /dev/null
+++ b/extreme_estimator/R_model/margin_function/abstract_margin_function.py
@@ -0,0 +1,35 @@
+from typing import List, Dict
+
+import matplotlib.cm as cm
+import matplotlib.pyplot as plt
+import numpy as np
+
+from extreme_estimator.R_model.gev.gev_parameters import GevParams
+from spatio_temporal_dataset.spatial_coordinates.abstract_spatial_coordinates import AbstractSpatialCoordinates
+
+
+class AbstractMarginFunction(object):
+    """
+    It represents any function mapping points from a space S (could be 2D, 3D,...) to R^3 (the 3 parameters of the GEV)
+    """
+
+    def __init__(self, spatial_coordinates: AbstractSpatialCoordinates, default_params: GevParams):
+        self.spatial_coordinates = spatial_coordinates
+        self.default_params = default_params.to_dict()
+
+    def get_gev_params(self, coordinate: np.ndarray) -> GevParams:
+        pass
+
+    def visualize_2D(self, gev_param_name=GevParams.GEV_LOC, show=False):
+        x = self.spatial_coordinates.x_coordinates
+        y = self.spatial_coordinates.y_coordinates
+        resolution = 100
+        grid = np.zeros([resolution, resolution, 3])
+        for i, xi in enumerate(np.linspace(x.min(), x.max(), resolution)):
+            for j, yj in enumerate(np.linspace(y.min(), y.max(), resolution)):
+                grid[i, j] = self.get_gev_params(np.array([xi, yj])).to_array()
+        gev_param_idx = GevParams.GEV_PARAM_NAMES.index(gev_param_name)
+        plt.imshow(grid[..., gev_param_idx], extent=(x.min(), x.max(), y.min(), y.max()),
+                   interpolation='nearest', cmap=cm.gist_rainbow)
+        if show:
+            plt.show()
diff --git a/extreme_estimator/R_model/margin_function/independent_margin_function.py b/extreme_estimator/R_model/margin_function/independent_margin_function.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce255e4b1f9a33494d62682163fc86b47faa9e56
--- /dev/null
+++ b/extreme_estimator/R_model/margin_function/independent_margin_function.py
@@ -0,0 +1,78 @@
+from typing import Dict, List
+
+import numpy as np
+
+from extreme_estimator.R_model.gev.gev_parameters import GevParams
+from extreme_estimator.R_model.margin_function.abstract_margin_function import AbstractMarginFunction
+from spatio_temporal_dataset.spatial_coordinates.abstract_spatial_coordinates import AbstractSpatialCoordinates
+
+
+class ParamFunction(object):
+
+    def get_gev_param_value(self, coordinate: np.ndarray) -> float:
+        pass
+
+
+class IndependentMarginFunction(AbstractMarginFunction):
+
+    def __init__(self, spatial_coordinates: AbstractSpatialCoordinates, default_params: GevParams):
+        super().__init__(spatial_coordinates, default_params)
+        self.gev_param_name_to_param_function = None  # type: Dict[str, ParamFunction]
+
+    def get_gev_params(self, coordinate: np.ndarray) -> GevParams:
+        assert self.gev_param_name_to_param_function is not None
+        assert len(self.gev_param_name_to_param_function) == 3
+        gev_params = {}
+        for gev_param_name in GevParams.GEV_PARAM_NAMES:
+            param_function = self.gev_param_name_to_param_function[gev_param_name]
+            gev_value = param_function.get_gev_param_value(coordinate)
+            gev_params[gev_param_name] = gev_value
+        return GevParams.from_dict(gev_params)
+
+
+class ConstantParamFunction(ParamFunction):
+
+    def __init__(self, constant):
+        self.constant = constant
+
+    def get_gev_param_value(self, coordinate: np.ndarray) -> float:
+        return self.constant
+
+
+class LinearOneAxisParamFunction(ParamFunction):
+
+    def __init__(self, linear_axis, coordinates_axis, start, end=0.0):
+        self.linear_axis = linear_axis
+        self.t_min = coordinates_axis.min()
+        self.t_max = coordinates_axis.max()
+        self.start = start
+        self.end = end
+
+    def get_gev_param_value(self, coordinate: np.ndarray) -> float:
+        t = coordinate[self.linear_axis]
+        t_between_zero_and_one = (t - self.t_min) / self.t_max
+        return self.start + t_between_zero_and_one * (self.end - self.start)
+
+
+class LinearMarginFunction(IndependentMarginFunction):
+    """
+    On the minimal point along all the dimension, the GevParms will equal default params
+    Otherwise, it will augment linearly along a single linear axis
+    """
+
+    def __init__(self, spatial_coordinates: AbstractSpatialCoordinates, default_params: GevParams,
+                 gev_param_name_to_linear_axis: Dict[str, int]):
+        super().__init__(spatial_coordinates, default_params)
+        self.param_to_linear_dims = gev_param_name_to_linear_axis
+        assert all([axis < np.ndim(spatial_coordinates.coordinates) for axis in gev_param_name_to_linear_axis.values()])
+        # Initialize gev_parameter_to_param_function
+        self.gev_param_name_to_param_function = {}
+        for gev_param_name in GevParams.GEV_PARAM_NAMES:
+            if gev_param_name not in gev_param_name_to_linear_axis.keys():
+                param_function = ConstantParamFunction(constant=self.default_params[gev_param_name])
+            else:
+                linear_axis = gev_param_name_to_linear_axis.get(gev_param_name, None)
+                coordinates_axis = self.spatial_coordinates.coordinates[:, linear_axis]
+                param_function = LinearOneAxisParamFunction(linear_axis=linear_axis, coordinates_axis=coordinates_axis,
+                                                            start=self.default_params[gev_param_name])
+            self.gev_param_name_to_param_function[gev_param_name] = param_function
diff --git a/extreme_estimator/R_model/margin_function/plot_margin_functions.py b/extreme_estimator/R_model/margin_function/plot_margin_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..139597f9cb07c5d48bed18984ec4747f4b4f3438
--- /dev/null
+++ b/extreme_estimator/R_model/margin_function/plot_margin_functions.py
@@ -0,0 +1,2 @@
+
+
diff --git a/extreme_estimator/R_model/margin_model/abstract_margin_function.py b/extreme_estimator/R_model/margin_model/abstract_margin_function.py
deleted file mode 100644
index d1eb05dd026b55972528f5563b4a9953e5ac5167..0000000000000000000000000000000000000000
--- a/extreme_estimator/R_model/margin_model/abstract_margin_function.py
+++ /dev/null
@@ -1,44 +0,0 @@
-from typing import List
-
-import numpy as np
-
-from extreme_estimator.R_model.gev.gev_parameters import GevParams
-from spatio_temporal_dataset.spatial_coordinates.abstract_spatial_coordinates import AbstractSpatialCoordinates
-
-
-class AbstractMarginFunction(object):
-    """
-    It represents any function mapping points from a space S (could be 2D, 3D,...) to R^3 (the 3 parameters of the GEV)
-    """
-
-    def __init__(self, spatial_coordinates: AbstractSpatialCoordinates, default_params: GevParams):
-        self.spatial_coordinates = spatial_coordinates
-        self.default_params = default_params
-
-    def get_gev_params(self, coordinate: np.ndarray) -> GevParams:
-        pass
-
-
-class ConstantMarginFunction(AbstractMarginFunction):
-
-    def get_gev_params(self, coordinate: np.ndarray) -> GevParams:
-        return self.default_params
-
-
-class LinearMarginFunction(AbstractMarginFunction):
-
-    def __init__(self, spatial_coordinates: AbstractSpatialCoordinates, default_params: GevParams,
-                 linear_dims: List[int]):
-        super().__init__(spatial_coordinates, default_params)
-        self.linear_dims = linear_dims
-
-# class LinearShapeMarginFunction(AbstractMarginFunction):
-#     """Linear function """
-#
-#     def __init__(self, coordinates, dimension_index_for_linearity=0):
-#         super().__init__(coordinates)
-#         self.dimension_index_for_linearity = dimension_index_for_linearity
-#         assert dimension_index_for_linearity < np.ndim(self.coordinates)
-#         # Compute
-#
-#     def get_gev_params(self, coordinate):
diff --git a/extreme_estimator/R_model/margin_model/abstract_margin_model.py b/extreme_estimator/R_model/margin_model/abstract_margin_model.py
index 747ac64fbc92dfdd17b65fe2f9465af760649509..384179f16667c6503e74dbded64963a6340b64c4 100644
--- a/extreme_estimator/R_model/margin_model/abstract_margin_model.py
+++ b/extreme_estimator/R_model/margin_model/abstract_margin_model.py
@@ -1,22 +1,18 @@
 import numpy as np
 
 from extreme_estimator.R_model.abstract_model import AbstractModel
-from extreme_estimator.R_model.margin_model.abstract_margin_function import AbstractMarginFunction
+from extreme_estimator.R_model.margin_function.abstract_margin_function import AbstractMarginFunction
 from extreme_estimator.R_model.gev.gev_parameters import GevParams
 from spatio_temporal_dataset.spatial_coordinates.abstract_spatial_coordinates import AbstractSpatialCoordinates
 
 
 class AbstractMarginModel(AbstractModel):
-    GEV_SCALE = GevParams.GEV_SCALE
-    GEV_LOC = GevParams.GEV_LOC
-    GEV_SHAPE = GevParams.GEV_SHAPE
-    GEV_PARAMETERS = [GEV_LOC, GEV_SCALE, GEV_SHAPE]
 
     def __init__(self, spatial_coordinates: AbstractSpatialCoordinates, params_start_fit=None, params_sample=None):
         super().__init__(params_start_fit, params_sample)
         self.spatial_coordinates = spatial_coordinates
         self.margin_function_sample = None  # type: AbstractMarginFunction
-        self.margin_function_start_fit = None # type: AbstractMarginFunction
+        self.margin_function_start_fit = None  # type: AbstractMarginFunction
         self.load_margin_functions()
 
     def load_margin_functions(self, margin_function_class: type = None):
@@ -68,5 +64,3 @@ class AbstractMarginModel(AbstractModel):
         pass
 
     # Define the method to sample/fit a single gev
-
-
diff --git a/extreme_estimator/R_model/margin_model/smooth_margin_model.py b/extreme_estimator/R_model/margin_model/smooth_margin_model.py
index c1210d6cc672ada49b2ef577ed96579c6106e1bf..bf29a650a0100facc3d14c8d3fd0537b7dad35f6 100644
--- a/extreme_estimator/R_model/margin_model/smooth_margin_model.py
+++ b/extreme_estimator/R_model/margin_model/smooth_margin_model.py
@@ -1,29 +1,40 @@
 import numpy as np
 
-from extreme_estimator.R_model.margin_model.abstract_margin_function import ConstantMarginFunction, \
-    AbstractMarginFunction
+from extreme_estimator.R_model.margin_function.abstract_margin_function import AbstractMarginFunction
+from extreme_estimator.R_model.margin_function.independent_margin_function import LinearMarginFunction
 from extreme_estimator.R_model.margin_model.abstract_margin_model import AbstractMarginModel
 from extreme_estimator.R_model.gev.gev_parameters import GevParams
-from spatio_temporal_dataset.spatial_coordinates.abstract_spatial_coordinates import AbstractSpatialCoordinates
 
 
-class SmoothMarginModel(AbstractMarginModel):
-    pass
-
-
-class ConstantMarginModel(SmoothMarginModel):
+class LinearMarginModel(AbstractMarginModel):
 
-    def load_margin_functions(self, margin_function_class: type = None):
+    def load_margin_functions(self, gev_param_name_to_linear_axis=None):
         self.default_params_sample = GevParams(1.0, 1.0, 1.0).to_dict()
         self.default_params_start_fit = GevParams(1.0, 1.0, 1.0).to_dict()
-        super().load_margin_functions(margin_function_class=ConstantMarginFunction)
+        self.margin_function_sample = LinearMarginFunction(spatial_coordinates=self.spatial_coordinates,
+                                                           default_params=GevParams.from_dict(self.params_sample),
+                                                           gev_param_name_to_linear_axis=gev_param_name_to_linear_axis)
+        self.margin_function_start_fit = LinearMarginFunction(spatial_coordinates=self.spatial_coordinates,
+                                                              default_params=GevParams.from_dict(self.params_start_fit),
+                                                              gev_param_name_to_linear_axis=gev_param_name_to_linear_axis)
+
+
+class ConstantMarginModel(LinearMarginModel):
+
+    def load_margin_functions(self, gev_param_name_to_linear_axis=None):
+        super().load_margin_functions({})
 
     def fitmargin_from_maxima_gev(self, maxima_gev: np.ndarray, coordinates: np.ndarray) -> AbstractMarginFunction:
         return self.margin_function_start_fit
 
 
-class LinearShapeMarginModel(SmoothMarginModel):
-    pass
+class LinearShapeAxis0MarginModel(LinearMarginModel):
+
+    def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_linear_axis=None):
+        super().load_margin_functions({GevParams.GEV_SHAPE: 0})
+
+    # def fitmargin_from_maxima_gev(self, maxima_gev: np.ndarray, coordinates: np.ndarray) -> AbstractMarginFunction:
+    #     return self.margin_function_start_fit
 
 
 if __name__ == '__main__':
diff --git a/spatio_temporal_dataset/spatial_coordinates/abstract_spatial_coordinates.py b/spatio_temporal_dataset/spatial_coordinates/abstract_spatial_coordinates.py
index 2def2b96ece46ca9a6a071fa0882dd02f7023774..fe54ce003134a8c374d404c587d9beb70d85d6b7 100644
--- a/spatio_temporal_dataset/spatial_coordinates/abstract_spatial_coordinates.py
+++ b/spatio_temporal_dataset/spatial_coordinates/abstract_spatial_coordinates.py
@@ -78,6 +78,14 @@ class AbstractSpatialCoordinates(object):
     def coordinates(self) -> np.ndarray:
         return self.coordinates_values(df_coordinates=self.df_coordinates)
 
+    @property
+    def x_coordinates(self) -> np.ndarray:
+        return self.df_coordinates.loc[:, self.COORD_X].values.copy()
+
+    @property
+    def y_coordinates(self) -> np.ndarray:
+        return self.df_coordinates.loc[:, self.COORD_Y].values.copy()
+
     @property
     def coordinates_train(self) -> np.ndarray:
         return self.coordinates_values(df_coordinates=self.df_coordinates_split(self.TRAIN_SPLIT_STR))
diff --git a/test/test_extreme_estimator/test_R_model/test_margin_function.py b/test/test_extreme_estimator/test_R_model/test_margin_function.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9a75c26427c6bab4bb005bc90b85af158ea452c
--- /dev/null
+++ b/test/test_extreme_estimator/test_R_model/test_margin_function.py
@@ -0,0 +1,22 @@
+import unittest
+
+from extreme_estimator.R_model.gev.gev_parameters import GevParams
+from extreme_estimator.R_model.margin_function.independent_margin_function import LinearMarginFunction
+from extreme_estimator.R_model.margin_model.smooth_margin_model import ConstantMarginModel, LinearShapeAxis0MarginModel
+from spatio_temporal_dataset.spatial_coordinates.generated_coordinates import CircleCoordinatesRadius1
+
+
+class TestLinearMarginModel(unittest.TestCase):
+    DISPLAY = True
+
+    def test_visualization_2D(self):
+        spatial_coordinates = CircleCoordinatesRadius1.from_nb_points(nb_points=50)
+        margin_model = LinearShapeAxis0MarginModel(spatial_coordinates=spatial_coordinates)
+        for gev_param_name in GevParams.GEV_PARAM_NAMES:
+            margin_model.margin_function_sample.visualize_2D(gev_param_name=gev_param_name, show=self.DISPLAY)
+        # maxima_gev = margin_model.rmargin_from_nb_obs(nb_obs=10, coordinates=coordinates)
+        # fitted_margin_function = margin_model.fitmargin_from_maxima_gev(maxima_gev=maxima_gev, coordinates=coordinates)
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/test/test_extreme_estimator/test_R_model/test_margin_model.py b/test/test_extreme_estimator/test_R_model/test_margin_model.py
deleted file mode 100644
index ef3ccb84bf014c0f587d2074d7cad9bf2771ac73..0000000000000000000000000000000000000000
--- a/test/test_extreme_estimator/test_R_model/test_margin_model.py
+++ /dev/null
@@ -1,21 +0,0 @@
-import unittest
-
-from extreme_estimator.R_model.margin_model.smooth_margin_model import ConstantMarginModel
-from spatio_temporal_dataset.spatial_coordinates.generated_coordinates import CircleCoordinatesRadius1
-
-
-class TestMarginModel(unittest.TestCase):
-    DISPLAY = True
-    MARGIN_TYPES = [ConstantMarginModel]
-
-    # def test_visualization(self):
-    #     coord_2D = CircleCoordinatesRadius1.from_nb_points(nb_points=50)
-    #     if self.DISPLAY:
-    #         coord_2D.visualization_2D()
-    #     for margin_class in self.MARGIN_TYPES:
-    #         margin_model = margin_class()
-    #         margin_model.visualize(coordinates=coord_2D.coordinates)
-
-
-if __name__ == '__main__':
-    unittest.main()