diff --git a/experiment/meteo_france_SCM_study/visualization/study_visualization/main_study_visualizer.py b/experiment/meteo_france_SCM_study/visualization/study_visualization/main_study_visualizer.py
index be948c6cf9c2bee4a638244bd57e8424519006b9..f6f208fcc883fbbe85f1bcc4d98baa4da6adb04e 100644
--- a/experiment/meteo_france_SCM_study/visualization/study_visualization/main_study_visualizer.py
+++ b/experiment/meteo_france_SCM_study/visualization/study_visualization/main_study_visualizer.py
@@ -77,7 +77,7 @@ def normal_visualization():
             study_visualizer = StudyVisualizer(study, save_to_file=save_to_file)
             # study_visualizer.visualize_independent_margin_fits(threshold=[None, 20, 40, 60][0])
             # study_visualizer.visualize_annual_mean_values()
-            study_visualizer.visualize_linear_margin_fit(only_first_max_stable=True)
+            study_visualizer.visualize_linear_margin_fit(only_first_max_stable=None)
 
 
 def complete_analysis(only_first_one=False):
diff --git a/experiment/meteo_france_SCM_study/visualization/study_visualization/study_visualizer.py b/experiment/meteo_france_SCM_study/visualization/study_visualization/study_visualizer.py
index 46a37e326865596dce051753a13fcb3c48c1ac2f..d138d7b17991ed8e59895098ff0cc27ce220a6f5 100644
--- a/experiment/meteo_france_SCM_study/visualization/study_visualization/study_visualizer.py
+++ b/experiment/meteo_france_SCM_study/visualization/study_visualization/study_visualizer.py
@@ -9,10 +9,13 @@ import pandas as pd
 import seaborn as sns
 
 from experiment.meteo_france_SCM_study.abstract_study import AbstractStudy
+from experiment.meteo_france_SCM_study.visualization.utils import create_adjusted_axes
 from experiment.utils import average_smoothing_with_sliding_window
 from extreme_estimator.estimator.full_estimator.abstract_full_estimator import \
     FullEstimatorInASingleStepWithSmoothMargin
 from extreme_estimator.estimator.margin_estimator.abstract_margin_estimator import SmoothMarginEstimator
+from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \
+    AbstractMarginFunction
 from extreme_estimator.extreme_models.margin_model.param_function.param_function import AbstractParamFunction
 from extreme_estimator.extreme_models.margin_model.linear_margin_model import LinearAllParametersAllDimsMarginModel
 from extreme_estimator.extreme_models.max_stable_model.abstract_max_stable_model import CovarianceFunction
@@ -29,6 +32,7 @@ from utils import get_display_name_from_object_type, VERSION_TIME, float_to_str_
 BLOCK_MAXIMA_DISPLAY_NAME = 'block maxima '
 
 
+
 class StudyVisualizer(object):
 
     def __init__(self, study: AbstractStudy, show=True, save_to_file=False, only_one_graph=False, only_first_row=False,
@@ -202,7 +206,8 @@ class StudyVisualizer(object):
         # Counting the sum of 3-consecutive days of snowfall does not have any physical meaning,
         # as we are counting twice some days
         color_mean = 'g'
-        tuples_x_y = [(year, np.mean(data[:, massif_id])) for year, data in self.study.year_to_daily_time_serie_array.items()]
+        tuples_x_y = [(year, np.mean(data[:, massif_id])) for year, data in
+                      self.study.year_to_daily_time_serie_array.items()]
         x, y = list(zip(*tuples_x_y))
         x, y = average_smoothing_with_sliding_window(x, y, window_size_for_smoothing=self.window_size_for_smoothing)
         ax.plot(x, y, color=color_mean)
@@ -213,13 +218,14 @@ class StudyVisualizer(object):
     def visualize_brown_resnick_fit(self):
         pass
 
-
     def visualize_linear_margin_fit(self, only_first_max_stable=False):
-        default_covariance_function = CovarianceFunction.cauchy
+        default_covariance_function = CovarianceFunction.powexp
         plot_name = 'Full Likelihood with Linear marginals and max stable dependency structure'
         plot_name += '\n(with {} covariance structure when a covariance is needed)'.format(
             str(default_covariance_function).split('.')[-1])
         self.plot_name = plot_name
+
+        # Load max stable models
         max_stable_models = load_test_max_stable_models(default_covariance_function=default_covariance_function)
         if only_first_max_stable:
             # Keep only the BrownResnick model
@@ -227,9 +233,19 @@ class StudyVisualizer(object):
         if only_first_max_stable is None:
             max_stable_models = []
 
-
-        fig, axes = plt.subplots(len(max_stable_models) + 2, len(GevParams.SUMMARY_NAMES), figsize=self.figsize)
-        fig.subplots_adjust(hspace=self.subplot_space, wspace=self.subplot_space)
+        # Load axes (either a 2D or 3D array depending on self.coordinates)
+        nb_max_stable_models = len(max_stable_models) + 2
+        nb_summary_names = GevParams.NB_SUMMARY_NAMES
+        if self.coordinates.has_temporal_coordinates:
+            nb_times_steps = AbstractMarginFunction.VISUALIZATION_TEMPORAL_STEPS
+            # Create one plot for each max stable models
+            axes = []
+            for _ in range(nb_max_stable_models):
+                axes.append(create_adjusted_axes(nb_rows=nb_summary_names, nb_columns=nb_times_steps,
+                                                 figsize=self.figsize, subplot_space=self.subplot_space))
+        else:
+            axes = create_adjusted_axes(nb_rows=nb_max_stable_models, nb_columns=nb_summary_names,
+                                        figsize=self.figsize, subplot_space=self.subplot_space)
         margin_class = LinearAllParametersAllDimsMarginModel
 
         # Plot the margin fit independently
diff --git a/experiment/meteo_france_SCM_study/visualization/utils.py b/experiment/meteo_france_SCM_study/visualization/utils.py
index 970c2aa8059f5b303bc6e32532e71af8fe7d4d82..19101351bf06e4e4e9f35329e46fc898ec4451bf 100644
--- a/experiment/meteo_france_SCM_study/visualization/utils.py
+++ b/experiment/meteo_france_SCM_study/visualization/utils.py
@@ -35,5 +35,11 @@ def get_km_formatter():
     return tkr.FuncFormatter(numfmt)  # create your custom formatter function
 
 
+def create_adjusted_axes(nb_rows, nb_columns, figsize=(16,10), subplot_space=0.5):
+    fig, axes = plt.subplots(nb_rows, nb_columns, figsize=figsize)
+    fig.subplots_adjust(hspace=subplot_space, wspace=subplot_space)
+    return axes
+
+
 if __name__ == '__main__':
     example_plot_df()
diff --git a/experiment/simulation/abstract_simulation.py b/experiment/simulation/abstract_simulation.py
index 49973ae56631d93a1335d438dd9a574c6ed20220..22f6ee32c0bd8fbec2320441d06c5cd64529b856 100644
--- a/experiment/simulation/abstract_simulation.py
+++ b/experiment/simulation/abstract_simulation.py
@@ -127,7 +127,7 @@ class AbstractSimulation(object):
 
     @staticmethod
     def load_fig_and_axes():
-        fig, axes = plt.subplots(len(GevParams.SUMMARY_NAMES), 2)
+        fig, axes = plt.subplots(GevParams.NB_SUMMARY_NAMES, 2)
         fig.subplots_adjust(hspace=0.4, wspace=0.4)
         return fig, axes
 
diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
index 81e42ecaf7b3883d0cd7d88094795cbf281351a8..2428f5fdb63189df3c7d4adce5ba09bfc47012d4 100644
--- a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
+++ b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
@@ -5,6 +5,7 @@ import matplotlib.pyplot as plt
 import numpy as np
 import pandas as pd
 
+from experiment.meteo_france_SCM_study.visualization.utils import create_adjusted_axes
 from extreme_estimator.margin_fits.gev.gev_params import GevParams
 from extreme_estimator.margin_fits.plot.create_shifted_cmap import plot_extreme_param, imshow_shifted
 from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
@@ -17,6 +18,7 @@ class AbstractMarginFunction(object):
     AbstractMarginFunction maps points from a space S (could be 1D, 2D,...) to R^3 (the 3 parameters of the GEV)
     """
     VISUALIZATION_RESOLUTION = 100
+    VISUALIZATION_TEMPORAL_STEPS = 2
 
     def __init__(self, coordinates: AbstractCoordinates):
         self.coordinates = coordinates
@@ -30,9 +32,11 @@ class AbstractMarginFunction(object):
         self.color = 'skyblue'
         self.filter = None
         self.linewidth = 1
+        self.subplot_space = 1.0
 
-        self._grid_2D = None
+        self.temporal_step_to_grid_2D = {}
         self._grid_1D = None
+        self.title = None
 
         # Visualization limits
         self._visualization_x_limits = None
@@ -75,28 +79,37 @@ class AbstractMarginFunction(object):
         self.color = color
 
     def visualize_function(self, axes=None, show=True, dot_display=False, title=None):
+        self.title = title
         self.datapoint_display = dot_display
         if axes is None:
-            fig, axes = plt.subplots(1, len(GevParams.SUMMARY_NAMES))
-            fig.subplots_adjust(hspace=1.0, wspace=1.0)
+            if self.coordinates.has_temporal_coordinates:
+                axes = create_adjusted_axes(GevParams.NB_SUMMARY_NAMES, self.VISUALIZATION_TEMPORAL_STEPS)
+            else:
+                axes = create_adjusted_axes(1, GevParams.NB_SUMMARY_NAMES, subplot_space=self.subplot_space)
         self.visualization_axes = axes
-        for i, gev_value_name in enumerate(GevParams.SUMMARY_NAMES):
-            ax = axes[i]
+        assert len(axes) == GevParams.NB_SUMMARY_NAMES
+        for ax, gev_value_name in zip(axes, GevParams.SUMMARY_NAMES):
             self.visualize_single_param(gev_value_name, ax, show=False)
-            title_str = gev_value_name if title is None else title
-            ax.set_title(title_str)
+            self.set_title(ax, gev_value_name)
         if show:
             plt.show()
         return axes
 
+    def set_title(self, ax, gev_value_name):
+        if hasattr(ax, 'set_title'):
+            title_str = gev_value_name if self.title is None else self.title
+            ax.set_title(title_str)
+
     def visualize_single_param(self, gev_value_name=GevParams.LOC, ax=None, show=True):
         assert gev_value_name in GevParams.SUMMARY_NAMES
-        if self.coordinates.nb_coordinates_spatial == 1:
+        nb_coordinates_spatial = self.coordinates.nb_coordinates_spatial
+        has_temporal_coordinates = self.coordinates.has_temporal_coordinates
+        if nb_coordinates_spatial == 1 and not has_temporal_coordinates:
             self.visualize_1D(gev_value_name, ax, show)
-        elif self.coordinates.nb_coordinates_spatial == 2:
+        elif nb_coordinates_spatial == 2 and not has_temporal_coordinates:
             self.visualize_2D(gev_value_name, ax, show)
-        elif self.coordinates.nb_coordinates_spatial == 3:
-            self.visualize_3D(gev_value_name, ax, show)
+        elif nb_coordinates_spatial == 2 and has_temporal_coordinates:
+            self.visualize_2D_spatial_1D_temporal(gev_value_name, ax, show)
         else:
             raise NotImplementedError('Other visualization not yet implemented')
 
@@ -148,12 +161,12 @@ class AbstractMarginFunction(object):
 
     # Visualization 2D
 
-    def visualize_2D(self, gev_param_name=GevParams.LOC, ax=None, show=True):
+    def visualize_2D(self, gev_param_name=GevParams.LOC, ax=None, show=True, temporal_step=None):
         if ax is None:
             ax = plt.gca()
 
         # Special display
-        imshow_shifted(ax, gev_param_name, self.grid_2D[gev_param_name], self.visualization_extend, self.mask_2D)
+        imshow_shifted(ax, gev_param_name, self.grid_2D(temporal_step)[gev_param_name], self.visualization_extend, self.mask_2D)
 
         # X axis
         ax.set_xlabel('coordinate X')
@@ -185,19 +198,44 @@ class AbstractMarginFunction(object):
     def visualization_extend(self):
         return self.visualization_x_limits + self.visualization_y_limits
 
-    @cached_property
-    def grid_2D(self):
+    def grid_2D(self, temporal_step=None):
+        # Cache the results
+        if temporal_step not in self.temporal_step_to_grid_2D:
+            self.temporal_step_to_grid_2D[temporal_step] = self._grid_2D(temporal_step)
+        return self.temporal_step_to_grid_2D[temporal_step]
+
+    def _grid_2D(self, temporal_step=None):
         grid = []
         for xi in np.linspace(*self.visualization_x_limits, self.VISUALIZATION_RESOLUTION):
             for yj in np.linspace(*self.visualization_y_limits, self.VISUALIZATION_RESOLUTION):
-                grid.append(self.get_gev_params(np.array([xi, yj])).summary_dict)
-        grid = {value_name: np.array([g[value_name] for g in grid]).reshape([self.VISUALIZATION_RESOLUTION, self.VISUALIZATION_RESOLUTION])
+                # Build spatio temporal coordinate
+                coordinate = [xi, yj]
+                if temporal_step is not None:
+                    coordinate.append(temporal_step)
+                grid.append(self.get_gev_params(np.array(coordinate)).summary_dict)
+        grid = {value_name: np.array([g[value_name] for g in grid]).reshape(
+            [self.VISUALIZATION_RESOLUTION, self.VISUALIZATION_RESOLUTION])
                 for value_name in GevParams.SUMMARY_NAMES}
         return grid
 
     # Visualization 3D
 
-    def visualize_3D(self, gev_param_name=GevParams.LOC, ax=None, show=True):
-        # Make the first/the last time step 2D visualization side by side
-        # self.visualize_2D(gev_param_name=gev_param_name, ax=ax, show=show)
-        pass
+    def visualize_2D_spatial_1D_temporal(self, gev_param_name=GevParams.LOC, axes=None, show=True,
+                                         add_future_temporal_steps=False):
+        if axes is None:
+            axes = create_adjusted_axes(self.VISUALIZATION_TEMPORAL_STEPS, 1)
+        assert len(axes) == self.VISUALIZATION_TEMPORAL_STEPS
+
+        # Build temporal_steps a list of time steps
+        future_temporal_steps = [10, 100] if add_future_temporal_steps else []
+        nb_past_temporal_step = self.VISUALIZATION_TEMPORAL_STEPS - len(future_temporal_steps)
+        start, stop = self.coordinates.df_temporal_range()
+        temporal_steps = list(np.linspace(start, stop, num=nb_past_temporal_step)) + future_temporal_steps
+        assert len(temporal_steps) == self.VISUALIZATION_TEMPORAL_STEPS
+
+        for ax, temporal_step in zip(axes, temporal_steps):
+            self.visualize_2D(gev_param_name, ax, show=False, temporal_step=temporal_step)
+            self.set_title(ax, gev_param_name)
+
+        if show:
+            plt.show()
diff --git a/extreme_estimator/margin_fits/gev/gev_params.py b/extreme_estimator/margin_fits/gev/gev_params.py
index 17081517fd2aae9ebef8a28e11cce8613d8d689e..56a044471381c2aa0dc1e982c93004f894cf935d 100644
--- a/extreme_estimator/margin_fits/gev/gev_params.py
+++ b/extreme_estimator/margin_fits/gev/gev_params.py
@@ -9,6 +9,7 @@ class GevParams(ExtremeParams):
     PARAM_NAMES = [ExtremeParams.LOC, ExtremeParams.SCALE, ExtremeParams.SHAPE]
     # Summary
     SUMMARY_NAMES = PARAM_NAMES + ExtremeParams.QUANTILE_NAMES
+    NB_SUMMARY_NAMES = len(SUMMARY_NAMES)
 
     def __init__(self, loc: float, scale: float, shape: float, block_size: int = None):
         super().__init__(loc, scale, shape)
diff --git a/spatio_temporal_dataset/coordinates/abstract_coordinates.py b/spatio_temporal_dataset/coordinates/abstract_coordinates.py
index 4eafbd1c013519e2f0b5b18091ccddfb177df4af..9c43045068e599e24a03af4a4f25818d93d95aee 100644
--- a/spatio_temporal_dataset/coordinates/abstract_coordinates.py
+++ b/spatio_temporal_dataset/coordinates/abstract_coordinates.py
@@ -1,5 +1,5 @@
 import os.path as op
-from typing import List
+from typing import List, Tuple
 
 import matplotlib.pyplot as plt
 import numpy as np
@@ -178,12 +178,20 @@ class AbstractCoordinates(object):
     def nb_coordinates_temporal(self) -> int:
         return len(self.coordinates_temporal_names)
 
+    @property
+    def has_temporal_coordinates(self):
+        return self.nb_coordinates_temporal > 0
+
     def df_temporal_coordinates(self, split: Split = Split.all) -> pd.DataFrame:
         if self.nb_coordinates_temporal == 0:
             return pd.DataFrame()
         else:
             return self.df_coordinates(split).loc[:, self.coordinates_temporal_names].drop_duplicates()
 
+    def df_temporal_range(self, split: Split = Split.all) -> Tuple[float, float]:
+        df_temporal_coordinates = self.df_temporal_coordinates(split)
+        return float(df_temporal_coordinates.min()), float(df_temporal_coordinates.max()),
+
     #  Visualization
 
     @property
diff --git a/test/test_extreme_estimator/test_extreme_models/test_margin_model.py b/test/test_extreme_estimator/test_extreme_models/test_margin_model.py
index b7cd717275f3489a0a960ccf64f8813cc50f705d..a7e613cdc7792c4dae2358db99984d99400e6009 100644
--- a/test/test_extreme_estimator/test_extreme_models/test_margin_model.py
+++ b/test/test_extreme_estimator/test_extreme_models/test_margin_model.py
@@ -10,6 +10,7 @@ from extreme_estimator.extreme_models.margin_model.spline_margin_model import Co
 from extreme_estimator.margin_fits.gev.gev_params import GevParams
 from spatio_temporal_dataset.coordinates.spatial_coordinates.coordinates_1D import LinSpaceSpatialCoordinates
 from spatio_temporal_dataset.coordinates.spatial_coordinates.coordinates_2D import LinSpaceSpatial2DCoordinates
+from test.test_utils import load_test_spatiotemporal_coordinates
 
 
 class TestVisualizationLinearMarginModel(unittest.TestCase):
@@ -30,28 +31,24 @@ class TestVisualizationLinearMarginModel(unittest.TestCase):
         self.margin_model = self.margin_model_class(coordinates=spatial_coordinates)
         # Assert that the grid correspond to what we expect in a simple case
         AbstractMarginFunction.VISUALIZATION_RESOLUTION = 2
-        grid = self.margin_model.margin_function_sample.grid_2D['loc']
+        grid = self.margin_model.margin_function_sample.grid_2D()['loc']
         true_grid = np.array([[0.98, 1.0], [1.0, 1.02]])
         self.assertTrue((grid == true_grid).all(), msg="\nexpected:\n{}, \nfound:\n{}".format(true_grid, grid))
 
-    # def test_example_visualization_2D_spatio_temporal(self):
-    #     self.nb_steps = 2
-    #     coordinates = load_test_spatiotemporal_coordinates(nb_steps=self.nb_steps, nb_points=self.nb_points)[0]
-    #     self.margin_model = self.margin_model_class(coordinates)
-    #
-    #     # Load margin function from coef dict
-    #     coef_dict = {'locCoeff1': 0, 'locCoeff2': 1, 'scaleCoeff1': 0,
-    #                  'scaleCoeff2': 1, 'shapeCoeff1': 0,
-    #                  'shapeCoeff2': 1,
-    #                  'tempCoeffLoc1': 1, 'tempCoeffScale1': 1,
-    #                  'tempCoeffShape1': 1}
-    #     margin_function = LinearMarginFunction.from_coef_dict(coordinates,
-    #                                                           self.margin_model.margin_function_sample.gev_param_name_to_linear_dims,
-    #                                                           coef_dict)
-    #     self.margin_model.margin_function_sample = margin_function
-    #     self.margin_model.margin_function_sample.visualize_2D(show=True)
-    #
-    #     # Load
+    def test_example_visualization_2D_spatio_temporal(self):
+        self.nb_steps = 2
+        coordinates = load_test_spatiotemporal_coordinates(nb_steps=self.nb_steps, nb_points=self.nb_points)[1]
+        self.margin_model = self.margin_model_class(coordinates)
+        # Test to check loading of margin function from coef dict
+        # coef_dict = {'locCoeff1': 0, 'locCoeff2': 1, 'scaleCoeff1': 0,
+        #              'scaleCoeff2': 1, 'shapeCoeff1': 0,
+        #              'shapeCoeff2': 1,
+        #              'tempCoeffLoc1': 1, 'tempCoeffScale1': 1,
+        #              'tempCoeffShape1': 1}
+        # margin_function = LinearMarginFunction.from_coef_dict(coordinates,
+        #                                                       self.margin_model.margin_function_sample.gev_param_name_to_linear_dims,
+        #                                                       coef_dict)
+        # self.margin_model.margin_function_sample = margin_function
 
 
 class TestVisualizationSplineMarginModel(unittest.TestCase):