diff --git a/extreme_estimator/R_model/abstract_model.py b/extreme_estimator/R_model/abstract_model.py
index 304ca70880064a47f74d24220c3e1e407bd240e1..ea8b2f20af9f8c94fca80cbf70a03ef288321c74 100644
--- a/extreme_estimator/R_model/abstract_model.py
+++ b/extreme_estimator/R_model/abstract_model.py
@@ -2,11 +2,28 @@ from extreme_estimator.R_model.utils import get_loaded_r
 
 
 class AbstractModel(object):
-
     r = get_loaded_r()
 
     def __init__(self, params_start_fit=None, params_sample=None):
         self.default_params_start_fit = None
         self.default_params_sample = None
         self.user_params_start_fit = params_start_fit
-        self.user_params_sample = params_sample
\ No newline at end of file
+        self.user_params_sample = params_sample
+
+    @property
+    def params_start_fit(self) -> dict:
+        return self.merge_params(default_params=self.default_params_start_fit, input_params=self.user_params_start_fit)
+
+    @property
+    def params_sample(self) -> dict:
+        return self.merge_params(default_params=self.default_params_sample, input_params=self.user_params_sample)
+
+    @staticmethod
+    def merge_params(default_params, input_params):
+        assert default_params is not None, 'some default_params need to be specified'
+        merged_params = default_params.copy()
+        if input_params is not None:
+            assert isinstance(default_params, dict) and isinstance(input_params, dict)
+            assert set(input_params.keys()).issubset(set(default_params.keys()))
+            merged_params.update(input_params)
+        return merged_params
diff --git a/extreme_estimator/R_model/gev/__init__.py b/extreme_estimator/R_model/gev/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/extreme_estimator/R_model/margin_model/gev_mle_fit.R b/extreme_estimator/R_model/gev/gev_mle_fit.R
similarity index 100%
rename from extreme_estimator/R_model/margin_model/gev_mle_fit.R
rename to extreme_estimator/R_model/gev/gev_mle_fit.R
diff --git a/extreme_estimator/R_model/margin_model/gev_mle_fit.py b/extreme_estimator/R_model/gev/gev_mle_fit.py
similarity index 64%
rename from extreme_estimator/R_model/margin_model/gev_mle_fit.py
rename to extreme_estimator/R_model/gev/gev_mle_fit.py
index 8778f99053b7fbad12f344abaa3e77036896ed81..f879f2bfd54224fab3affe2684428fccf60cd9b5 100644
--- a/extreme_estimator/R_model/margin_model/gev_mle_fit.py
+++ b/extreme_estimator/R_model/gev/gev_mle_fit.py
@@ -4,10 +4,11 @@ import rpy2.robjects.numpy2ri as rpyn
 import os.path as op
 
 # Defining some constants
+from extreme_estimator.R_model.gev.gev_parameters import GevParams
 from extreme_estimator.R_model.utils import get_associated_r_file
 
 
-def mle_gev(x_gev: np.ndarray, start_loc=0, start_scale=1, start_shape=0):
+def gev_mle_fit(x_gev: np.ndarray, start_loc=0, start_scale=1, start_shape=0):
     assert np.ndim(x_gev) == 1
     assert start_scale > 0
     r = ro.r
@@ -21,13 +22,10 @@ def mle_gev(x_gev: np.ndarray, start_loc=0, start_scale=1, start_shape=0):
 
 
 class GevMleFit(object):
-    GEV_SCALE = 'scale'
-    GEV_LOCATION = 'loc'
-    GEV_SHAPE = 'shape'
 
     def __init__(self, x_gev: np.ndarray, start_loc=0, start_scale=1, start_shape=0):
         self.x_gev = x_gev
-        self.mle_params = mle_gev(x_gev, start_loc, start_scale, start_shape)
-        self.shape = self.mle_params[self.GEV_SHAPE]
-        self.scale = self.mle_params[self.GEV_SCALE]
-        self.location = self.mle_params[self.GEV_LOCATION]
+        self.mle_params = gev_mle_fit(x_gev, start_loc, start_scale, start_shape)
+        self.shape = self.mle_params[GevParams.GEV_SHAPE]
+        self.scale = self.mle_params[GevParams.GEV_SCALE]
+        self.loc = self.mle_params[GevParams.GEV_LOC]
diff --git a/extreme_estimator/R_model/gev/gev_parameters.py b/extreme_estimator/R_model/gev/gev_parameters.py
new file mode 100644
index 0000000000000000000000000000000000000000..387f8ab7b639430676f9a81a4ea4df6e437aa71b
--- /dev/null
+++ b/extreme_estimator/R_model/gev/gev_parameters.py
@@ -0,0 +1,29 @@
+
+class GevParams(object):
+    GEV_SCALE = 'scale'
+    GEV_LOC = 'loc'
+    GEV_SHAPE = 'shape'
+
+    def __init__(self, loc: float, scale: float, shape: float):
+        self.location = loc
+        self.scale = scale
+        self.shape = shape
+
+    @classmethod
+    def from_dict(cls, params: dict):
+        return cls(**params)
+
+    def to_dict(self) -> dict:
+        return {
+            self.GEV_LOC: self.location,
+            self.GEV_SCALE: self.scale,
+            self.GEV_SHAPE: self.shape,
+        }
+
+    def rgev(self, nb_obs):
+        gev_params = {
+            self.GEV_LOC: loc,
+            self.GEV_SCALE: scale,
+            self.GEV_SHAPE: shape,
+        }
+        return self.r.rgev(nb_obs, **gev_params)
diff --git a/extreme_estimator/R_model/margin_model/abstract_margin_function.py b/extreme_estimator/R_model/margin_model/abstract_margin_function.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1eb05dd026b55972528f5563b4a9953e5ac5167
--- /dev/null
+++ b/extreme_estimator/R_model/margin_model/abstract_margin_function.py
@@ -0,0 +1,44 @@
+from typing import List
+
+import numpy as np
+
+from extreme_estimator.R_model.gev.gev_parameters import GevParams
+from spatio_temporal_dataset.spatial_coordinates.abstract_spatial_coordinates import AbstractSpatialCoordinates
+
+
+class AbstractMarginFunction(object):
+    """
+    It represents any function mapping points from a space S (could be 2D, 3D,...) to R^3 (the 3 parameters of the GEV)
+    """
+
+    def __init__(self, spatial_coordinates: AbstractSpatialCoordinates, default_params: GevParams):
+        self.spatial_coordinates = spatial_coordinates
+        self.default_params = default_params
+
+    def get_gev_params(self, coordinate: np.ndarray) -> GevParams:
+        pass
+
+
+class ConstantMarginFunction(AbstractMarginFunction):
+
+    def get_gev_params(self, coordinate: np.ndarray) -> GevParams:
+        return self.default_params
+
+
+class LinearMarginFunction(AbstractMarginFunction):
+
+    def __init__(self, spatial_coordinates: AbstractSpatialCoordinates, default_params: GevParams,
+                 linear_dims: List[int]):
+        super().__init__(spatial_coordinates, default_params)
+        self.linear_dims = linear_dims
+
+# class LinearShapeMarginFunction(AbstractMarginFunction):
+#     """Linear function """
+#
+#     def __init__(self, coordinates, dimension_index_for_linearity=0):
+#         super().__init__(coordinates)
+#         self.dimension_index_for_linearity = dimension_index_for_linearity
+#         assert dimension_index_for_linearity < np.ndim(self.coordinates)
+#         # Compute
+#
+#     def get_gev_params(self, coordinate):
diff --git a/extreme_estimator/R_model/margin_model/abstract_margin_model.py b/extreme_estimator/R_model/margin_model/abstract_margin_model.py
index 25e59524214fe7f7c7a1b284f19e573cc95a29ce..747ac64fbc92dfdd17b65fe2f9465af760649509 100644
--- a/extreme_estimator/R_model/margin_model/abstract_margin_model.py
+++ b/extreme_estimator/R_model/margin_model/abstract_margin_model.py
@@ -1,82 +1,72 @@
 import numpy as np
-import pandas as pd
 
 from extreme_estimator.R_model.abstract_model import AbstractModel
-from extreme_estimator.R_model.margin_model.gev_mle_fit import GevMleFit, mle_gev
+from extreme_estimator.R_model.margin_model.abstract_margin_function import AbstractMarginFunction
+from extreme_estimator.R_model.gev.gev_parameters import GevParams
+from spatio_temporal_dataset.spatial_coordinates.abstract_spatial_coordinates import AbstractSpatialCoordinates
 
 
 class AbstractMarginModel(AbstractModel):
-    GEV_SCALE = GevMleFit.GEV_SCALE
-    GEV_LOCATION = GevMleFit.GEV_LOCATION
-    GEV_SHAPE = GevMleFit.GEV_SHAPE
-    GEV_PARAMETERS = [GEV_LOCATION, GEV_SCALE, GEV_SHAPE]
+    GEV_SCALE = GevParams.GEV_SCALE
+    GEV_LOC = GevParams.GEV_LOC
+    GEV_SHAPE = GevParams.GEV_SHAPE
+    GEV_PARAMETERS = [GEV_LOC, GEV_SCALE, GEV_SHAPE]
 
-    def __init__(self, params_start_fit=None, params_sample=None):
+    def __init__(self, spatial_coordinates: AbstractSpatialCoordinates, params_start_fit=None, params_sample=None):
         super().__init__(params_start_fit, params_sample)
+        self.spatial_coordinates = spatial_coordinates
+        self.margin_function_sample = None  # type: AbstractMarginFunction
+        self.margin_function_start_fit = None # type: AbstractMarginFunction
+        self.load_margin_functions()
 
-    # Define the method to sample/fit a single gev
+    def load_margin_functions(self, margin_function_class: type = None):
+        assert margin_function_class is not None
+        self.margin_function_sample = margin_function_class(spatial_coordinates=self.spatial_coordinates,
+                                                            default_params=GevParams.from_dict(self.params_sample))
+        self.margin_function_start_fit = margin_function_class(spatial_coordinates=self.spatial_coordinates,
+                                                               default_params=GevParams.from_dict(
+                                                                   self.params_start_fit))
 
-    def rgev(self, nb_obs, loc, scale, shape):
-        gev_params = {
-            self.GEV_LOCATION: loc,
-            self.GEV_SCALE: scale,
-            self.GEV_SHAPE: shape,
-        }
-        return self.r.rgev(nb_obs, **gev_params)
+    # Conversion class methods
 
-    def fitgev(self, x_gev, estimator=GevMleFit):
-        mle_params = mle_gev(x_gev=x_gev)
+    @classmethod
+    def convert_maxima(cls, convertion_r_function, maxima: np.ndarray, coordinates: np.ndarray,
+                       margin_function: AbstractMarginFunction):
+        assert len(maxima) == len(coordinates)
+        converted_maxima = []
+        for x, coordinate in zip(maxima, coordinates):
+            gev_params = margin_function.get_gev_params(coordinate)
+            x_gev = convertion_r_function(x, **gev_params.to_dict())
+            converted_maxima.append(x_gev)
+        return np.array(converted_maxima)
 
-    def gev_params_sample(self, coordinate) -> dict:
-        pass
+    @classmethod
+    def gev2frech(cls, maxima_gev: np.ndarray, coordinates: np.ndarray, margin_function: AbstractMarginFunction):
+        return cls.convert_maxima(cls.r.gev2frech, maxima_gev, coordinates, margin_function)
 
-    # Define the method to sample/fit all marginals globally in the child classes
+    @classmethod
+    def frech2gev(cls, maxima_frech: np.ndarray, coordinates: np.ndarray, margin_function: AbstractMarginFunction):
+        return cls.convert_maxima(cls.r.frech2gev, maxima_frech, coordinates, margin_function)
 
-    def fitmargin(self, maxima, coord):
-        df_fit_gev_params = None
-        return df_fit_gev_params
+    # Sampling methods
 
-    def rmargin(self, nb_obs, coord):
-        maxima_gev = None
+    def rmargin_from_maxima_frech(self, maxima_frech: np.ndarray, coordinates: np.ndarray):
+        maxima_gev = self.frech2gev(maxima_frech, coordinates, self.margin_function_sample)
         return maxima_gev
 
-    def frech2gev(self, maxima_frech: np.ndarray, coordinates: np.ndarray):
-        assert len(maxima_frech) == len(coordinates)
+    def rmargin_from_nb_obs(self, nb_obs, coordinates):
         maxima_gev = []
-        for x_frech, coordinate in zip(maxima_frech, coordinates):
-            gev_params = self.gev_params_sample(coordinate)
-            x_gev = self.r.frech2gev(x_frech, **gev_params)
+        for coordinate in coordinates:
+            gev_params = self.margin_function_sample.get_gev_params(coordinate)
+            x_gev = self.r.rgev(nb_obs, **gev_params.to_dict())
             maxima_gev.append(x_gev)
         return np.array(maxima_gev)
 
-    @classmethod
-    def gev2frech(cls, maxima_gev: np.ndarray, df_gev_params: pd.DataFrame):
-        assert len(maxima_gev) == len(df_gev_params)
-        maxima_frech = []
-        for x_gev, (_, s_gev_params) in zip(maxima_gev, df_gev_params.iterrows()):
-            gev_params = dict(s_gev_params)
-            gev2frech_param = {'emp': False}
-            x_frech = cls.r.gev2frech(x_gev, **gev_params, **gev2frech_param)
-            maxima_frech.append(x_frech)
-        return np.array(maxima_frech)
-
-
-class SmoothMarginModel(AbstractMarginModel):
-    pass
-
-
-class ConstantMarginModel(SmoothMarginModel):
-    def __init__(self, params_start_fit=None, params_sample=None):
-        super().__init__(params_start_fit, params_sample)
-        self.default_params_sample = {gev_param: 1.0 for gev_param in self.GEV_PARAMETERS}
-        self.default_params_start_fit = {gev_param: 1.0 for gev_param in self.GEV_PARAMETERS}
-
-    def gev_params_sample(self, coordinate):
-        return self.default_params_sample
-
-    def fitmargin(self, maxima, coord):
-        return pd.DataFrame([pd.Series(self.default_params_start_fit) for _ in maxima])
+    # Fitting methods
 
+    def fitmargin_from_maxima_gev(self, maxima_gev: np.ndarray, coordinates: np.ndarray) -> AbstractMarginFunction:
+        pass
 
+    # Define the method to sample/fit a single gev
 
 
diff --git a/extreme_estimator/R_model/margin_model/smooth_margin_model.py b/extreme_estimator/R_model/margin_model/smooth_margin_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1210d6cc672ada49b2ef577ed96579c6106e1bf
--- /dev/null
+++ b/extreme_estimator/R_model/margin_model/smooth_margin_model.py
@@ -0,0 +1,30 @@
+import numpy as np
+
+from extreme_estimator.R_model.margin_model.abstract_margin_function import ConstantMarginFunction, \
+    AbstractMarginFunction
+from extreme_estimator.R_model.margin_model.abstract_margin_model import AbstractMarginModel
+from extreme_estimator.R_model.gev.gev_parameters import GevParams
+from spatio_temporal_dataset.spatial_coordinates.abstract_spatial_coordinates import AbstractSpatialCoordinates
+
+
+class SmoothMarginModel(AbstractMarginModel):
+    pass
+
+
+class ConstantMarginModel(SmoothMarginModel):
+
+    def load_margin_functions(self, margin_function_class: type = None):
+        self.default_params_sample = GevParams(1.0, 1.0, 1.0).to_dict()
+        self.default_params_start_fit = GevParams(1.0, 1.0, 1.0).to_dict()
+        super().load_margin_functions(margin_function_class=ConstantMarginFunction)
+
+    def fitmargin_from_maxima_gev(self, maxima_gev: np.ndarray, coordinates: np.ndarray) -> AbstractMarginFunction:
+        return self.margin_function_start_fit
+
+
+class LinearShapeMarginModel(SmoothMarginModel):
+    pass
+
+
+if __name__ == '__main__':
+    pass
diff --git a/extreme_estimator/R_model/max_stable_model/abstract_max_stable_model.py b/extreme_estimator/R_model/max_stable_model/abstract_max_stable_model.py
index ab0c717c5ed2b0abd04154f56258f4d9cce1b3c3..1075d383d824a0bab1a3c427c2031b0f8cbf5c3a 100644
--- a/extreme_estimator/R_model/max_stable_model/abstract_max_stable_model.py
+++ b/extreme_estimator/R_model/max_stable_model/abstract_max_stable_model.py
@@ -13,6 +13,10 @@ class AbstractMaxStableModel(AbstractModel):
         super().__init__(params_start_fit, params_sample)
         self.cov_mod = None
 
+    @property
+    def cov_mod_param(self):
+        return {'cov.mod': self.cov_mod}
+
     def fitmaxstab(self, maxima_frech: np.ndarray, coord: np.ndarray, fit_marge=False):
         assert all([isinstance(arr, np.ndarray) for arr in [maxima_frech, coord]])
         #  Specify the fit params
@@ -33,36 +37,14 @@ class AbstractMaxStableModel(AbstractModel):
         fitted_values = {key: fitted_values.rx2(key)[0] for key in fitted_values.names}
         return fitted_values
 
-    def rmaxstab(self, nb_obs: int, coord: np.ndarray) -> np.ndarray:
+    def rmaxstab(self, nb_obs: int, coordinates: np.ndarray) -> np.ndarray:
         """
         Return an numpy of maxima. With rows being the stations and columns being the years of maxima
         """
         maxima_frech = np.array(
-            self.r.rmaxstab(nb_obs, coord, *list(self.cov_mod_param.values()), **self.params_sample))
+            self.r.rmaxstab(nb_obs, coordinates, *list(self.cov_mod_param.values()), **self.params_sample))
         return np.transpose(maxima_frech)
 
-    @property
-    def cov_mod_param(self):
-        return {'cov.mod': self.cov_mod}
-
-    @property
-    def params_start_fit(self):
-        return self.merge_params(default_params=self.default_params_start_fit, input_params=self.user_params_start_fit)
-
-    @property
-    def params_sample(self):
-        return self.merge_params(default_params=self.default_params_sample, input_params=self.user_params_sample)
-
-    @staticmethod
-    def merge_params(default_params, input_params):
-        assert default_params is not None, 'some default_params need to be specified'
-        merged_params = default_params.copy()
-        if input_params is not None:
-            assert isinstance(default_params, dict) and isinstance(input_params, dict)
-            assert set(input_params.keys()).issubset(set(default_params.keys()))
-            merged_params.update(input_params)
-        return merged_params
-
 
 class CovarianceFunction(Enum):
     whitmat = 0
diff --git a/extreme_estimator/estimator/full_estimator.py b/extreme_estimator/estimator/full_estimator.py
index 3d31d97f58e8b9aa98d627790ae40b772be5a74e..f674fb3fbb74ec3343d3b79ea7319f73657cfbc3 100644
--- a/extreme_estimator/estimator/full_estimator.py
+++ b/extreme_estimator/estimator/full_estimator.py
@@ -24,7 +24,8 @@ class SmoothMarginalsThenUnitaryMsp(AbstractFullEstimator):
         self.margin_estimator.fit()
         # Compute the maxima_frech
         maxima_frech = AbstractMarginModel.gev2frech(maxima_gev=self.dataset.maxima_gev,
-                                                     df_gev_params=self.margin_estimator.df_gev_params)
+                                                     coordinates=self.dataset.coordinates,
+                                                     margin_function=self.margin_estimator.margin_function_fitted)
         # Update maxima frech field through the dataset object
         self.dataset.maxima_frech = maxima_frech
         # Estimate the max stable parameters
diff --git a/extreme_estimator/estimator/margin_estimator.py b/extreme_estimator/estimator/margin_estimator.py
index 0bcf900652207361dd01a46e120d1b85114bb1a6..fd9c3e95d566581e8416d4f23d7e8d72a0e1b47c 100644
--- a/extreme_estimator/estimator/margin_estimator.py
+++ b/extreme_estimator/estimator/margin_estimator.py
@@ -8,6 +8,12 @@ class AbstractMarginEstimator(AbstractEstimator):
     def __init__(self, dataset: AbstractDataset):
         super().__init__(dataset)
         assert self.dataset.maxima_gev is not None
+        self._margin_function_fitted = None
+
+    @property
+    def margin_function_fitted(self):
+        assert self._margin_function_fitted is not None, 'Error: estimator has not been not fitted yet'
+        return self._margin_function_fitted
 
 
 class PointWiseMarginEstimator(AbstractMarginEstimator):
@@ -19,9 +25,9 @@ class SmoothMarginEstimator(AbstractMarginEstimator):
 
     def __init__(self, dataset: AbstractDataset, margin_model: AbstractMarginModel):
         super().__init__(dataset)
+        assert isinstance(margin_model, AbstractMarginModel)
         self.margin_model = margin_model
-        self.df_gev_params = None
 
     def _fit(self):
-        self.df_gev_params = self.margin_model.fitmargin(maxima=self.dataset.maxima_gev,
-                                                         coord=self.dataset.coordinates)
+        self._margin_function_fitted = self.margin_model.fitmargin_from_maxima_gev(maxima_gev=self.dataset.maxima_gev,
+                                                                                   coordinates=self.dataset.coordinates)
diff --git a/spatio_temporal_dataset/temporal_observations/annual_maxima_observations.py b/spatio_temporal_dataset/temporal_observations/annual_maxima_observations.py
index 9c4f138706eb2037aafb45ba9179958b902aff58..a24f6cdf8be57a77f47e7f3f956bef6ee8276b52 100644
--- a/spatio_temporal_dataset/temporal_observations/annual_maxima_observations.py
+++ b/spatio_temporal_dataset/temporal_observations/annual_maxima_observations.py
@@ -19,7 +19,7 @@ class MarginAnnualMaxima(AnnualMaxima):
     @classmethod
     def from_sampling(cls, nb_obs: int, spatial_coordinates: AbstractSpatialCoordinates,
                       margin_model: AbstractMarginModel):
-        maxima_gev = margin_model.rmargin(nb_obs=nb_obs, coord=spatial_coordinates.coordinates)
+        maxima_gev = margin_model.rmargin_from_nb_obs(nb_obs=nb_obs, coordinates=spatial_coordinates.coordinates)
         df_maxima_gev = pd.DataFrame(data=maxima_gev, index=spatial_coordinates.index)
         return cls(df_maxima_gev=df_maxima_gev)
 
@@ -29,7 +29,7 @@ class MaxStableAnnualMaxima(AbstractTemporalObservations):
     @classmethod
     def from_sampling(cls, nb_obs: int, max_stable_model: AbstractMaxStableModel,
                       spatial_coordinates: AbstractSpatialCoordinates):
-        maxima_frech = max_stable_model.rmaxstab(nb_obs=nb_obs, coord=spatial_coordinates.coordinates)
+        maxima_frech = max_stable_model.rmaxstab(nb_obs=nb_obs, coordinates=spatial_coordinates.coordinates)
         df_maxima_frech = pd.DataFrame(data=maxima_frech, index=spatial_coordinates.index)
         return cls(df_maxima_frech=df_maxima_frech)
 
@@ -42,6 +42,7 @@ class FullAnnualMaxima(MaxStableAnnualMaxima):
                              margin_model: AbstractMarginModel):
         max_stable_annual_maxima = super().from_sampling(nb_obs, max_stable_model, spatial_coordinates)
         #  Compute df_maxima_gev from df_maxima_frech
-        maxima_gev = margin_model.frech2gev(max_stable_annual_maxima.maxima_frech, spatial_coordinates.coordinates)
+        maxima_gev = margin_model.rmargin_from_maxima_frech(maxima_frech=max_stable_annual_maxima.maxima_frech,
+                                                            coordinates=spatial_coordinates.coordinates)
         max_stable_annual_maxima.df_maxima_gev = pd.DataFrame(data=maxima_gev, index=spatial_coordinates.index)
         return max_stable_annual_maxima
diff --git a/test/test_extreme_estimator/test_R_model/__init__.py b/test/test_extreme_estimator/test_R_model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/test/test_extreme_estimator/test_R_model/test_gev_mle_fit.py b/test/test_extreme_estimator/test_R_model/test_gev_mle_fit.py
new file mode 100644
index 0000000000000000000000000000000000000000..60da9accb22df9d30c2336167bca949f34046d44
--- /dev/null
+++ b/test/test_extreme_estimator/test_R_model/test_gev_mle_fit.py
@@ -0,0 +1,33 @@
+import unittest
+
+import numpy as np
+
+from extreme_estimator.R_model.gev.gev_mle_fit import GevMleFit
+from extreme_estimator.R_model.utils import get_loaded_r
+
+
+class TestGevMleFit(unittest.TestCase):
+
+    def test_unitary_gev_mle_fit(self):
+        r = get_loaded_r()
+        r("""
+        set.seed(42)
+        N <- 50
+        loc = 0; scale = 1; shape <- 1
+        x_gev <- rgev(N, loc = loc, scale = scale, shape = shape)
+        start_loc = 0; start_scale = 1; start_shape = 1
+        """)
+        # Get the MLE estimator
+        estimator = GevMleFit(x_gev=np.array(r['x_gev']),
+                              start_loc=np.float(r['start_loc'][0]),
+                              start_scale=np.float(r['start_scale'][0]),
+                              start_shape=np.float(r['start_shape'][0]))
+        # Compare the MLE estimated parameters to the reference
+        mle_params_estimated = estimator.mle_params
+        mle_params_ref = {'loc': 0.0219, 'scale': 1.0347, 'shape': 0.8290}
+        for key in mle_params_ref.keys():
+            self.assertAlmostEqual(mle_params_ref[key], mle_params_estimated[key], places=3)
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/test/test_extreme_estimator/test_R_model/test_margin_model.py b/test/test_extreme_estimator/test_R_model/test_margin_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef3ccb84bf014c0f587d2074d7cad9bf2771ac73
--- /dev/null
+++ b/test/test_extreme_estimator/test_R_model/test_margin_model.py
@@ -0,0 +1,21 @@
+import unittest
+
+from extreme_estimator.R_model.margin_model.smooth_margin_model import ConstantMarginModel
+from spatio_temporal_dataset.spatial_coordinates.generated_coordinates import CircleCoordinatesRadius1
+
+
+class TestMarginModel(unittest.TestCase):
+    DISPLAY = True
+    MARGIN_TYPES = [ConstantMarginModel]
+
+    # def test_visualization(self):
+    #     coord_2D = CircleCoordinatesRadius1.from_nb_points(nb_points=50)
+    #     if self.DISPLAY:
+    #         coord_2D.visualization_2D()
+    #     for margin_class in self.MARGIN_TYPES:
+    #         margin_model = margin_class()
+    #         margin_model.visualize(coordinates=coord_2D.coordinates)
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/test/test_extreme_estimator/test_estimator/__init__.py b/test/test_extreme_estimator/test_estimator/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/test/test_extreme_estimator/test_full_estimators.py b/test/test_extreme_estimator/test_estimator/test_full_estimators.py
similarity index 79%
rename from test/test_extreme_estimator/test_full_estimators.py
rename to test/test_extreme_estimator/test_estimator/test_full_estimators.py
index 2a9aa6b32f020567f7e360ae9adb3b452a3c09b9..5e03f8fd0d6a070d239cba3705fab66c00d271ba 100644
--- a/test/test_extreme_estimator/test_full_estimators.py
+++ b/test/test_extreme_estimator/test_estimator/test_full_estimators.py
@@ -4,8 +4,8 @@ from itertools import product
 from extreme_estimator.estimator.full_estimator import SmoothMarginalsThenUnitaryMsp
 from spatio_temporal_dataset.dataset.simulation_dataset import FullSimulatedDataset
 from spatio_temporal_dataset.spatial_coordinates.generated_coordinates import CircleCoordinatesRadius1
-from test.test_extreme_estimator.test_margin_estimators import TestMarginEstimators
-from test.test_extreme_estimator.test_max_stable_estimators import TestMaxStableEstimators
+from test.test_extreme_estimator.test_estimator.test_margin_estimators import TestMarginEstimators
+from test.test_extreme_estimator.test_estimator.test_max_stable_estimators import TestMaxStableEstimators
 
 
 class TestFullEstimators(unittest.TestCase):
@@ -14,14 +14,14 @@ class TestFullEstimators(unittest.TestCase):
 
     def setUp(self):
         super().setUp()
-        self.spatial_coord = CircleCoordinatesRadius1.from_nb_points(nb_points=5, max_radius=1)
+        self.spatial_coordinates = CircleCoordinatesRadius1.from_nb_points(nb_points=5, max_radius=1)
         self.max_stable_models = TestMaxStableEstimators.load_max_stable_models()
-        self.margin_models = TestMarginEstimators.load_margin_models()
+        self.margin_models = TestMarginEstimators.load_margin_models(spatial_coordinates=self.spatial_coordinates)
 
     def test_full_estimators(self):
         for margin_model, max_stable_model in product(self.margin_models, self.max_stable_models):
             dataset = FullSimulatedDataset.from_double_sampling(nb_obs=10, margin_model=margin_model,
-                                                                spatial_coordinates=self.spatial_coord,
+                                                                spatial_coordinates=self.spatial_coordinates,
                                                                 max_stable_model=max_stable_model)
 
             for estimator_class in self.FULL_ESTIMATORS:
diff --git a/test/test_extreme_estimator/test_estimator/test_margin_estimators.py b/test/test_extreme_estimator/test_estimator/test_margin_estimators.py
new file mode 100644
index 0000000000000000000000000000000000000000..322fd3094fa690574690d029098bbdc5d54df9fb
--- /dev/null
+++ b/test/test_extreme_estimator/test_estimator/test_margin_estimators.py
@@ -0,0 +1,40 @@
+import unittest
+
+from extreme_estimator.R_model.margin_model.abstract_margin_model import AbstractMarginModel
+from extreme_estimator.R_model.margin_model.smooth_margin_model import ConstantMarginModel
+from extreme_estimator.estimator.margin_estimator import SmoothMarginEstimator
+from spatio_temporal_dataset.dataset.simulation_dataset import MarginDataset
+from spatio_temporal_dataset.spatial_coordinates.generated_coordinates import CircleCoordinatesRadius1
+
+
+class TestMarginEstimators(unittest.TestCase):
+    DISPLAY = False
+    MARGIN_TYPES = [ConstantMarginModel]
+    MARGIN_ESTIMATORS = [SmoothMarginEstimator]
+
+    def setUp(self):
+        super().setUp()
+        self.spatial_coordinates = CircleCoordinatesRadius1.from_nb_points(nb_points=5, max_radius=1)
+        self.margin_models = self.load_margin_models(spatial_coordinates=self.spatial_coordinates)
+
+    @classmethod
+    def load_margin_models(cls, spatial_coordinates):
+        return [margin_class(spatial_coordinates=spatial_coordinates) for margin_class in cls.MARGIN_TYPES]
+
+    def test_dependency_estimators(self):
+        for margin_model in self.margin_models:
+            dataset = MarginDataset.from_sampling(nb_obs=10, margin_model=margin_model,
+                                                  spatial_coordinates=self.spatial_coordinates)
+
+            for estimator_class in self.MARGIN_ESTIMATORS:
+                estimator = estimator_class(dataset=dataset, margin_model=margin_model)
+                estimator.fit()
+                if self.DISPLAY:
+                    print(type(margin_model))
+                    print(dataset.df_dataset.head())
+                    print(estimator.additional_information)
+            self.assertTrue(True)
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/test/test_extreme_estimator/test_max_stable_estimators.py b/test/test_extreme_estimator/test_estimator/test_max_stable_estimators.py
similarity index 100%
rename from test/test_extreme_estimator/test_max_stable_estimators.py
rename to test/test_extreme_estimator/test_estimator/test_max_stable_estimators.py
diff --git a/test/test_extreme_estimator/test_margin_estimators.py b/test/test_extreme_estimator/test_margin_estimators.py
deleted file mode 100644
index 2b896d91ff933401d27f1a50c861b77305eadc8d..0000000000000000000000000000000000000000
--- a/test/test_extreme_estimator/test_margin_estimators.py
+++ /dev/null
@@ -1,63 +0,0 @@
-import unittest
-
-import numpy as np
-
-from extreme_estimator.R_model.margin_model.abstract_margin_model import ConstantMarginModel
-from extreme_estimator.R_model.margin_model.gev_mle_fit import GevMleFit
-from extreme_estimator.R_model.utils import get_loaded_r
-from extreme_estimator.estimator.margin_estimator import SmoothMarginEstimator
-from spatio_temporal_dataset.dataset.simulation_dataset import MarginDataset
-from spatio_temporal_dataset.spatial_coordinates.generated_coordinates import CircleCoordinatesRadius1
-
-
-class TestMarginEstimators(unittest.TestCase):
-    DISPLAY = False
-    MARGIN_TYPES = [ConstantMarginModel]
-    MARGIN_ESTIMATORS = [SmoothMarginEstimator]
-
-    def test_unitary_mle_gev_fit(self):
-        r = get_loaded_r()
-        r("""
-        set.seed(42)
-        N <- 50
-        loc = 0; scale = 1; shape <- 1
-        x_gev <- rgev(N, loc = loc, scale = scale, shape = shape)
-        start_loc = 0; start_scale = 1; start_shape = 1
-        """)
-        # Get the MLE estimator
-        estimator = GevMleFit(x_gev=np.array(r['x_gev']),
-                              start_loc=np.float(r['start_loc'][0]),
-                              start_scale=np.float(r['start_scale'][0]),
-                              start_shape=np.float(r['start_shape'][0]))
-        # Compare the MLE estimated parameters to the reference
-        mle_params_estimated = estimator.mle_params
-        mle_params_ref = {'loc': 0.0219, 'scale': 1.0347, 'shape': 0.8290}
-        for key in mle_params_ref.keys():
-            self.assertAlmostEqual(mle_params_ref[key], mle_params_estimated[key], places=3)
-
-    def setUp(self):
-        super().setUp()
-        self.spatial_coord = CircleCoordinatesRadius1.from_nb_points(nb_points=5, max_radius=1)
-        self.margin_models = self.load_margin_models()
-
-    @classmethod
-    def load_margin_models(cls):
-        return [margin_class() for margin_class in cls.MARGIN_TYPES]
-
-    def test_dependency_estimators(self):
-        for margin_model in self.margin_models:
-            dataset = MarginDataset.from_sampling(nb_obs=10, margin_model=margin_model,
-                                                  spatial_coordinates=self.spatial_coord)
-
-            for estimator_class in self.MARGIN_ESTIMATORS:
-                estimator = estimator_class(dataset=dataset, margin_model=margin_model)
-                estimator.fit()
-                if self.DISPLAY:
-                    print(type(margin_model))
-                    print(dataset.df_dataset.head())
-                    print(estimator.additional_information)
-            self.assertTrue(True)
-
-
-if __name__ == '__main__':
-    unittest.main()