diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
index a88c4beb20ec5bd06ea82e5bd26f76a436308775..3ae470f8a597fc6868d57d890e027c61edd42d5e 100644
--- a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
+++ b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
@@ -9,13 +9,15 @@ from extreme_estimator.margin_fits.gev.gev_params import GevParams
 from extreme_estimator.margin_fits.plot.create_shifted_cmap import plot_extreme_param, imshow_shifted
 from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
 from spatio_temporal_dataset.slicer.split import Split
+from utils import cached_property
 
 
 class AbstractMarginFunction(object):
     """ Class of function mapping points from a space S (could be 1D, 2D,...) to R^3 (the 3 parameters of the GEV)"""
 
-    def __init__(self, coordinates: AbstractCoordinates):
+    def __init__(self, coordinates: AbstractCoordinates, resolution=100):
         self.coordinates = coordinates
+        self.resolution = resolution
 
         # Visualization parameters
         self.visualization_axes = None
@@ -29,6 +31,14 @@ class AbstractMarginFunction(object):
         self._grid_2D = None
         self._grid_1D = None
 
+    @property
+    def x(self):
+        return self.coordinates.x_coordinates
+
+    @property
+    def y(self):
+        return self.coordinates.y_coordinates
+
     def get_gev_params(self, coordinate: np.ndarray) -> GevParams:
         """Main method that maps each coordinate to its GEV parameters"""
         pass
@@ -131,14 +141,11 @@ class AbstractMarginFunction(object):
     # Visualization 2D
 
     def visualize_2D(self, gev_param_name=GevParams.LOC, ax=None, show=True):
-        x = self.coordinates.x_coordinates
-        y = self.coordinates.y_coordinates
-        grid = self.grid_2D(x, y)
         if ax is None:
             ax = plt.gca()
 
         # Special display
-        imshow_shifted(ax, gev_param_name, grid[gev_param_name], x, y)
+        imshow_shifted(ax, gev_param_name, self.grid_2D[gev_param_name], self.x, self.y)
 
         # X axis
         ax.set_xlabel('coordinate X')
@@ -152,13 +159,15 @@ class AbstractMarginFunction(object):
         if show:
             plt.show()
 
-    def grid_2D(self, x, y):
-        resolution = 100
+    @cached_property
+    def grid_2D(self):
+        x = self.x
+        y = self.y
         grid = []
-        for i, xi in enumerate(np.linspace(x.min(), x.max(), resolution)):
-            for j, yj in enumerate(np.linspace(y.min(), y.max(), resolution)):
+        for i, xi in enumerate(np.linspace(x.min(), x.max(), self.resolution)):
+            for j, yj in enumerate(np.linspace(y.min(), y.max(), self.resolution)):
                 grid.append(self.get_gev_params(np.array([xi, yj])).summary_dict)
-        grid = {value_name: np.array([g[value_name] for g in grid]).reshape([resolution, resolution])
+        grid = {value_name: np.array([g[value_name] for g in grid]).reshape([self.resolution, self.resolution])
                 for value_name in GevParams.SUMMARY_NAMES}
         return grid
 
diff --git a/extreme_estimator/margin_fits/plot/create_shifted_cmap.py b/extreme_estimator/margin_fits/plot/create_shifted_cmap.py
index f176a58587b9fb78a56e4a66bb0e310fb5294449..128954fa47485c779a0a4c730fb28ddee43836ff 100644
--- a/extreme_estimator/margin_fits/plot/create_shifted_cmap.py
+++ b/extreme_estimator/margin_fits/plot/create_shifted_cmap.py
@@ -53,5 +53,6 @@ def imshow_shifted(ax, gev_param_name, values, x, y):
         value = np.min(values)
         # The right blue corner will be blue (but most of the time, another display will be on top)
         masked_array[-1, -1] = value - epsilon
-    ax.imshow(masked_array, extent=(x.min(), x.max(), y.min(), y.max()), cmap=shifted_cmap)
+    # IMPORTANT: Origin for all the plots is at the bottom left corner
+    ax.imshow(masked_array, extent=(x.min(), x.max(), y.min(), y.max()), cmap=shifted_cmap, origin='lower')
 
diff --git a/spatio_temporal_dataset/coordinates/spatial_coordinates/coordinates_2D.py b/spatio_temporal_dataset/coordinates/spatial_coordinates/coordinates_2D.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c845c9e5e656ee56b4779be95a7172c154480c4
--- /dev/null
+++ b/spatio_temporal_dataset/coordinates/spatial_coordinates/coordinates_2D.py
@@ -0,0 +1,21 @@
+import numpy as np
+import pandas as pd
+from rpy2.robjects import r
+
+from spatio_temporal_dataset.coordinates.spatial_coordinates.abstract_spatial_coordinates import \
+    AbstractSpatialCoordinates
+from spatio_temporal_dataset.coordinates.spatial_coordinates.coordinates_1D import LinSpaceSpatialCoordinates
+
+
+class AbstractBiDimensionalSpatialCoordinates(AbstractSpatialCoordinates):
+    pass
+
+
+class LinSpaceSpatial2DCoordinates(AbstractBiDimensionalSpatialCoordinates):
+
+    @classmethod
+    def from_nb_points(cls, nb_points, train_split_ratio: float = None, start=-1.0, end=1.0):
+        axis_coordinates = np.linspace(start, end, nb_points)
+        df = pd.DataFrame.from_dict({cls.COORDINATE_X: axis_coordinates,
+                                     cls.COORDINATE_Y: axis_coordinates})
+        return cls.from_df(df, train_split_ratio)
diff --git a/spatio_temporal_dataset/coordinates/spatial_coordinates/generated_spatial_coordinates.py b/spatio_temporal_dataset/coordinates/spatial_coordinates/generated_spatial_coordinates.py
index f2716fbb8107a6ea48a2eca62625c79235238a7c..cb48486ab6d9d98b990e2460e19912abed9c5115 100644
--- a/spatio_temporal_dataset/coordinates/spatial_coordinates/generated_spatial_coordinates.py
+++ b/spatio_temporal_dataset/coordinates/spatial_coordinates/generated_spatial_coordinates.py
@@ -12,17 +12,17 @@ from spatio_temporal_dataset.coordinates.spatial_coordinates.abstract_spatial_co
 class CircleSpatialCoordinates(AbstractSpatialCoordinates):
 
     @classmethod
-    def df_spatial(cls, nb_points, max_radius=1.0):
+    def df_spatial(cls, nb_points, max_radius=1.0, random=True):
         # Sample uniformly inside the circle
-        angles = np.array(r.runif(nb_points, max=2 * math.pi))
-        radius = np.sqrt(np.array(r.runif(nb_points, max=max_radius)))
+        angles = np.array(r.runif(nb_points, max=2 * math.pi)) if random else np.linspace(0.0, 2 * math.pi, nb_points+1)[:-1]
+        radius = np.sqrt(np.array(r.runif(nb_points, max=max_radius))) if random else np.ones(nb_points) * max_radius
         df = pd.DataFrame.from_dict({cls.COORDINATE_X: radius * np.cos(angles),
                                      cls.COORDINATE_Y: radius * np.sin(angles)})
         return df
 
     @classmethod
-    def from_nb_points(cls, nb_points, train_split_ratio: float = None, max_radius=1.0):
-        return cls.from_df(cls.df_spatial(nb_points, max_radius), train_split_ratio)
+    def from_nb_points(cls, nb_points, train_split_ratio: float = None, max_radius=1.0, random=True):
+        return cls.from_df(cls.df_spatial(nb_points, max_radius, random), train_split_ratio)
 
     def visualization_2D(self):
         radius = 1.0
@@ -36,5 +36,5 @@ class CircleSpatialCoordinates(AbstractSpatialCoordinates):
 class CircleSpatialCoordinatesRadius2(CircleSpatialCoordinates):
 
     @classmethod
-    def from_nb_points(cls, nb_points, train_split_ratio: float = None, max_radius=1.0):
-        return 2 * super().from_nb_points(nb_points, train_split_ratio, max_radius)
+    def from_nb_points(cls, nb_points, train_split_ratio: float = None, max_radius=1.0, random=True):
+        return 2 * super().from_nb_points(nb_points, train_split_ratio, max_radius, random)
diff --git a/test/test_extreme_estimator/test_extreme_models/test_margin_model.py b/test/test_extreme_estimator/test_extreme_models/test_margin_model.py
index 1be2e00e012a4c7dc3c6ccd854648c79eb3a3546..5f60a476fc9573134678682d3928774f1355cc39 100644
--- a/test/test_extreme_estimator/test_extreme_models/test_margin_model.py
+++ b/test/test_extreme_estimator/test_extreme_models/test_margin_model.py
@@ -1,16 +1,18 @@
+import numpy as np
 import unittest
 
 from extreme_estimator.extreme_models.margin_model.margin_function.linear_margin_function import LinearMarginFunction
 from extreme_estimator.margin_fits.gev.gev_params import GevParams
 from extreme_estimator.extreme_models.margin_model.smooth_margin_model import LinearShapeDim1MarginModel, \
     LinearAllParametersAllDimsMarginModel
+from spatio_temporal_dataset.coordinates.spatial_coordinates.coordinates_2D import LinSpaceSpatial2DCoordinates
 from spatio_temporal_dataset.coordinates.spatial_coordinates.generated_spatial_coordinates import \
     CircleSpatialCoordinates
 from spatio_temporal_dataset.coordinates.spatial_coordinates.coordinates_1D import LinSpaceSpatialCoordinates
 from test.test_utils import load_test_spatiotemporal_coordinates
 
 
-class VisualizationMarginModel(unittest.TestCase):
+class TestVisualizationMarginModel(unittest.TestCase):
     DISPLAY = False
     nb_points = 2
     margin_model_class = [LinearShapeDim1MarginModel, LinearAllParametersAllDimsMarginModel][-1]
@@ -24,8 +26,13 @@ class VisualizationMarginModel(unittest.TestCase):
         self.margin_model = self.margin_model_class(coordinates=coordinates, params_sample={(GevParams.SHAPE, 1): 0.02})
 
     def test_example_visualization_2D_spatial(self):
-        spatial_coordinates = CircleSpatialCoordinates.from_nb_points(nb_points=self.nb_points)
+        spatial_coordinates = LinSpaceSpatial2DCoordinates.from_nb_points(nb_points=self.nb_points)
         self.margin_model = self.margin_model_class(coordinates=spatial_coordinates)
+        # Assert that the grid correspond to what we expect in a simple case
+        self.margin_model.margin_function_sample.resolution = 2
+        grid = self.margin_model.margin_function_sample.grid_2D['loc']
+        true_grid = np.array([[0.98, 1.0], [1.0, 1.02]])
+        self.assertTrue((grid == true_grid).all(), msg="\nexpected:\n{}, \nfound:\n{}".format(true_grid, grid))
 
     # def test_example_visualization_2D_spatio_temporal(self):
     #     self.nb_steps = 2
@@ -49,5 +56,6 @@ class VisualizationMarginModel(unittest.TestCase):
 
 if __name__ == '__main__':
     unittest.main()
-    # v = VisualizationMarginModel()
-    # v.test_example_visualization_2D_spatio_temporal()
+    # v = TestVisualizationMarginModel()
+    # v.test_example_visualization_2D_spatial()
+    # v.tearDown()