diff --git a/test/test_extreme_estimator/test_extreme_models/test_margin_function.py b/test/test_extreme_estimator/test_extreme_models/test_margin_function.py
new file mode 100644
index 0000000000000000000000000000000000000000..567221095c85c444f3c0bbeb1a106dcfd2881767
--- /dev/null
+++ b/test/test_extreme_estimator/test_extreme_models/test_margin_function.py
@@ -0,0 +1,37 @@
+import unittest
+
+import numpy as np
+
+from extreme_estimator.extreme_models.margin_model.linear_margin_model import LinearAllParametersAllDimsMarginModel
+from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \
+    AbstractMarginFunction
+from extreme_estimator.extreme_models.margin_model.margin_function.linear_margin_function import LinearMarginFunction
+from spatio_temporal_dataset.coordinates.spatial_coordinates.coordinates_2D import LinSpaceSpatial2DCoordinates
+from test.test_utils import load_test_spatiotemporal_coordinates
+
+
+class MarginFunction(unittest.TestCase):
+    nb_points = 2
+    margin_function_class = LinearMarginFunction
+    margin_model_class = LinearAllParametersAllDimsMarginModel
+
+    def test_grid_2D_orientation(self):
+        # Assert that the grid correspond to what we expect in a simple case
+        margin_model = self.margin_model_class(LinSpaceSpatial2DCoordinates.from_nb_points(nb_points=self.nb_points))
+        AbstractMarginFunction.VISUALIZATION_RESOLUTION = 2
+        grid = margin_model.margin_function_sample.grid_2D()['loc']
+        true_grid = np.array([[0.98, 1.0], [1.0, 1.02]])
+        self.assertTrue((grid == true_grid).all(), msg="\nexpected:\n{}, \nfound:\n{}".format(true_grid, grid))
+
+    def test_coef_dict(self):
+        coordinates = load_test_spatiotemporal_coordinates(self.nb_points, self.nb_points)[0]
+        margin_model = self.margin_model_class(coordinates)
+        # Test to check loading of margin function from coef dict
+        coef_dict = {'locCoeff1': 0, 'locCoeff2': 1, 'scaleCoeff1': 0,
+                     'scaleCoeff2': 1, 'shapeCoeff1': 0,
+                     'shapeCoeff2': 1,
+                     'tempCoeffLoc1': 1, 'tempCoeffScale1': 1,
+                     'tempCoeffShape1': 1}
+        self.margin_function_class.from_coef_dict(coordinates,
+                                                  margin_model.margin_function_sample.gev_param_name_to_dims,
+                                                  coef_dict)
diff --git a/test/test_extreme_estimator/test_extreme_models/test_margin_model.py b/test/test_extreme_estimator/test_extreme_models/test_margin_model.py
index a7e613cdc7792c4dae2358db99984d99400e6009..f317c22ab602c229f96f4c3ca9d0d43121a90896 100644
--- a/test/test_extreme_estimator/test_extreme_models/test_margin_model.py
+++ b/test/test_extreme_estimator/test_extreme_models/test_margin_model.py
@@ -29,26 +29,11 @@ class TestVisualizationLinearMarginModel(unittest.TestCase):
     def test_example_visualization_2D_spatial(self):
         spatial_coordinates = LinSpaceSpatial2DCoordinates.from_nb_points(nb_points=self.nb_points)
         self.margin_model = self.margin_model_class(coordinates=spatial_coordinates)
-        # Assert that the grid correspond to what we expect in a simple case
-        AbstractMarginFunction.VISUALIZATION_RESOLUTION = 2
-        grid = self.margin_model.margin_function_sample.grid_2D()['loc']
-        true_grid = np.array([[0.98, 1.0], [1.0, 1.02]])
-        self.assertTrue((grid == true_grid).all(), msg="\nexpected:\n{}, \nfound:\n{}".format(true_grid, grid))
 
     def test_example_visualization_2D_spatio_temporal(self):
         self.nb_steps = 2
         coordinates = load_test_spatiotemporal_coordinates(nb_steps=self.nb_steps, nb_points=self.nb_points)[1]
         self.margin_model = self.margin_model_class(coordinates)
-        # Test to check loading of margin function from coef dict
-        # coef_dict = {'locCoeff1': 0, 'locCoeff2': 1, 'scaleCoeff1': 0,
-        #              'scaleCoeff2': 1, 'shapeCoeff1': 0,
-        #              'shapeCoeff2': 1,
-        #              'tempCoeffLoc1': 1, 'tempCoeffScale1': 1,
-        #              'tempCoeffShape1': 1}
-        # margin_function = LinearMarginFunction.from_coef_dict(coordinates,
-        #                                                       self.margin_model.margin_function_sample.gev_param_name_to_linear_dims,
-        #                                                       coef_dict)
-        # self.margin_model.margin_function_sample = margin_function
 
 
 class TestVisualizationSplineMarginModel(unittest.TestCase):