diff --git a/extreme_estimator/estimator/full_estimator/abstract_full_estimator.py b/extreme_estimator/estimator/full_estimator/abstract_full_estimator.py
index b01ba71292b344905aae61e92211a6fee2943e35..c3933c2c471f4959fa6dcf2dda717c3a2ed84bcf 100644
--- a/extreme_estimator/estimator/full_estimator/abstract_full_estimator.py
+++ b/extreme_estimator/estimator/full_estimator/abstract_full_estimator.py
@@ -70,6 +70,10 @@ class FullEstimatorInASingleStepWithSmoothMargin(AbstractFullEstimator):
     def extract_function_fitted(self):
         return self.extract_function_fitted_from_the_model_shape(self.linear_margin_model)
 
+    @property
+    def margin_function_fitted(self) -> LinearMarginFunction:
+        return super().margin_function_fitted
+
 
 class PointwiseAndThenUnitaryMsp(AbstractFullEstimator):
     pass
diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/linear_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/linear_margin_function.py
index dc564da1cdd09d16f947ec998a80bd08d9e0dc89..e65281cc3b291d794e8992530a2a664d39f70428 100644
--- a/extreme_estimator/extreme_models/margin_model/margin_function/linear_margin_function.py
+++ b/extreme_estimator/extreme_models/margin_model/margin_function/linear_margin_function.py
@@ -63,6 +63,10 @@ class LinearMarginFunction(ParametricMarginFunction):
     def mu1_temporal_trend(self):
         return self.coef_dict[LinearCoef.coef_template_str(ExtremeParams.LOC, AbstractCoordinates.COORDINATE_T).format(1)]
 
+    @property
+    def mu0(self):
+        return self.coef_dict[LinearCoef.coef_template_str(ExtremeParams.LOC, LinearCoef.INTERCEPT_NAME).format(1)]
+
     @property
     def form_dict(self) -> Dict[str, str]:
         form_dict = {}
diff --git a/spatio_temporal_dataset/coordinates/abstract_coordinates.py b/spatio_temporal_dataset/coordinates/abstract_coordinates.py
index 91c85d25fa8b665dbcd2463b725fef1f913fd922..b3bf403983c30f8fbaf54e6d3e9bb30b117d4a25 100644
--- a/spatio_temporal_dataset/coordinates/abstract_coordinates.py
+++ b/spatio_temporal_dataset/coordinates/abstract_coordinates.py
@@ -204,6 +204,10 @@ class AbstractCoordinates(object):
         else:
             return self.df_coordinates(split).loc[:, self.coordinates_temporal_names].drop_duplicates()
 
+    @property
+    def nb_steps(self, split: Split = Split.all):
+        return len(self.df_temporal_coordinates(split))
+
     def df_temporal_range(self, split: Split = Split.all) -> Tuple[int, int]:
         df_temporal_coordinates = self.df_temporal_coordinates(split)
         return int(df_temporal_coordinates.min()), int(df_temporal_coordinates.max()),
diff --git a/spatio_temporal_dataset/coordinates/spatio_temporal_coordinates/abstract_spatio_temporal_coordinates.py b/spatio_temporal_dataset/coordinates/spatio_temporal_coordinates/abstract_spatio_temporal_coordinates.py
index 53e4af8af4832ce85bf131937c127c4bee618011..84b7feece54e3f123d2e07793bcb7626d70668c6 100644
--- a/spatio_temporal_dataset/coordinates/spatio_temporal_coordinates/abstract_spatio_temporal_coordinates.py
+++ b/spatio_temporal_dataset/coordinates/spatio_temporal_coordinates/abstract_spatio_temporal_coordinates.py
@@ -1,12 +1,19 @@
 import pandas as pd
 
 from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
+from spatio_temporal_dataset.coordinates.spatial_coordinates.abstract_spatial_coordinates import \
+    AbstractSpatialCoordinates
 from spatio_temporal_dataset.coordinates.utils import get_index_with_spatio_temporal_index_suffix
 from spatio_temporal_dataset.slicer.spatio_temporal_slicer import SpatioTemporalSlicer
 
 
 class AbstractSpatioTemporalCoordinates(AbstractCoordinates):
 
+    def __init__(self, df: pd.DataFrame, slicer_class: type, s_split_spatial: pd.Series = None,
+                 s_split_temporal: pd.Series = None):
+        super().__init__(df, slicer_class, s_split_spatial, s_split_temporal)
+        self.spatial_coordinates = AbstractSpatialCoordinates.from_df(df=self.df_spatial_coordinates())
+
     @classmethod
     def from_df(cls, df: pd.DataFrame, train_split_ratio: float = None):
         assert cls.COORDINATE_T in df.columns
diff --git a/spatio_temporal_dataset/dataset/simulation_dataset.py b/spatio_temporal_dataset/dataset/simulation_dataset.py
index a2cbea15fe1ebf691d34cff0c0a83d96f23e75ad..c3e90638dc21c03214c83ab6af640c7895d79bf0 100644
--- a/spatio_temporal_dataset/dataset/simulation_dataset.py
+++ b/spatio_temporal_dataset/dataset/simulation_dataset.py
@@ -1,11 +1,13 @@
 from extreme_estimator.extreme_models.margin_model.abstract_margin_model import AbstractMarginModel
 from extreme_estimator.extreme_models.max_stable_model.abstract_max_stable_model import AbstractMaxStableModel
 from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
+from spatio_temporal_dataset.coordinates.spatio_temporal_coordinates.abstract_spatio_temporal_coordinates import \
+    AbstractSpatioTemporalCoordinates
 from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
 from spatio_temporal_dataset.spatio_temporal_observations.abstract_spatio_temporal_observations import \
     AbstractSpatioTemporalObservations
 from spatio_temporal_dataset.spatio_temporal_observations.annual_maxima_observations import \
-    MaxStableAnnualMaxima, MarginAnnualMaxima, FullAnnualMaxima
+    MaxStableAnnualMaxima, MarginAnnualMaxima, FullAnnualMaxima, FullSpatioTemporalAnnualMaxima
 
 
 class SimulatedDataset(AbstractDataset):
@@ -46,8 +48,16 @@ class FullSimulatedDataset(SimulatedDataset):
     def from_double_sampling(cls, nb_obs: int, max_stable_model: AbstractMaxStableModel,
                              coordinates: AbstractCoordinates,
                              margin_model: AbstractMarginModel):
-        assert coordinates.nb_coordinates <= 2, 'rmaxstable available only for 2D coordinates'
-        observations = FullAnnualMaxima.from_double_sampling(nb_obs, max_stable_model,
-                                                             coordinates, margin_model)
+        assert coordinates.nb_coordinates <= 2 or \
+               coordinates.has_spatio_temporal_coordinates and coordinates.nb_coordinates == 3, \
+            'rmaxstable available only for 2D coordinates'
+        if coordinates.nb_coordinates <= 2:
+            observations = FullAnnualMaxima.from_double_sampling(nb_obs, max_stable_model,
+                                                                 coordinates, margin_model)
+        else:
+            assert isinstance(coordinates, AbstractSpatioTemporalCoordinates)
+            observations = FullSpatioTemporalAnnualMaxima.from_double_sampling(nb_obs, max_stable_model,
+                                                                               coordinates, margin_model)
+
         return cls(observations=observations, coordinates=coordinates,
                    max_stable_model=max_stable_model, margin_model=margin_model)
diff --git a/spatio_temporal_dataset/spatio_temporal_observations/annual_maxima_observations.py b/spatio_temporal_dataset/spatio_temporal_observations/annual_maxima_observations.py
index ff12071c6c6df547124a2520feeedf3c28f1e71a..32817b32c910505b56148375f3521c6d81d931e8 100644
--- a/spatio_temporal_dataset/spatio_temporal_observations/annual_maxima_observations.py
+++ b/spatio_temporal_dataset/spatio_temporal_observations/annual_maxima_observations.py
@@ -3,6 +3,10 @@ import pandas as pd
 from extreme_estimator.extreme_models.margin_model.abstract_margin_model import AbstractMarginModel
 from extreme_estimator.extreme_models.max_stable_model.abstract_max_stable_model import AbstractMaxStableModel
 from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
+from spatio_temporal_dataset.coordinates.spatial_coordinates.abstract_spatial_coordinates import \
+    AbstractSpatialCoordinates
+from spatio_temporal_dataset.coordinates.spatio_temporal_coordinates.abstract_spatio_temporal_coordinates import \
+    AbstractSpatioTemporalCoordinates
 from spatio_temporal_dataset.spatio_temporal_observations.abstract_spatio_temporal_observations \
     import AbstractSpatioTemporalObservations
 
@@ -46,3 +50,21 @@ class FullAnnualMaxima(MaxStableAnnualMaxima):
                                                             coordinates_values=coordinates.coordinates_values())
         max_stable_annual_maxima.df_maxima_gev = pd.DataFrame(data=maxima_gev, index=coordinates.index)
         return max_stable_annual_maxima
+
+
+class FullSpatioTemporalAnnualMaxima(MaxStableAnnualMaxima):
+
+    @classmethod
+    def from_double_sampling(cls, nb_obs: int, max_stable_model: AbstractMaxStableModel,
+                             coordinates: AbstractSpatioTemporalCoordinates, margin_model: AbstractMarginModel):
+        # Sample with the max stable spatially
+        spatial_coordinate = coordinates.spatial_coordinates
+        nb_total_obs = nb_obs * coordinates.nb_steps
+        max_stable_annual_maxima = super().from_sampling(nb_total_obs, max_stable_model, spatial_coordinate)
+        # Convert observation to a spatio temporal index
+        max_stable_annual_maxima.convert_to_spatio_temporal_index(coordinates)
+        #  Compute df_maxima_gev from df_maxima_frech
+        maxima_gev = margin_model.rmargin_from_maxima_frech(maxima_frech=max_stable_annual_maxima.maxima_frech(),
+                                                            coordinates_values=coordinates.coordinates_values())
+        max_stable_annual_maxima.df_maxima_gev = pd.DataFrame(data=maxima_gev, index=coordinates.index)
+        return max_stable_annual_maxima
diff --git a/test/test_extreme_estimator/test_extreme_models/test_max_stable_temporal.py b/test/test_extreme_estimator/test_extreme_models/test_max_stable_temporal.py
new file mode 100644
index 0000000000000000000000000000000000000000..27ceac3ecec5aeff4b3f5faad213618e5d2a3d45
--- /dev/null
+++ b/test/test_extreme_estimator/test_extreme_models/test_max_stable_temporal.py
@@ -0,0 +1,116 @@
+import random
+import unittest
+
+import numpy as np
+import pandas as pd
+
+from extreme_estimator.estimator.full_estimator.abstract_full_estimator import \
+    FullEstimatorInASingleStepWithSmoothMargin
+from extreme_estimator.estimator.margin_estimator.abstract_margin_estimator import LinearMarginEstimator
+from extreme_estimator.extreme_models.margin_model.linear_margin_model import LinearNonStationaryLocationMarginModel, \
+    LinearStationaryMarginModel
+from extreme_estimator.extreme_models.margin_model.temporal_linear_margin_model import StationaryStationModel, \
+    NonStationaryStationModel
+from extreme_estimator.extreme_models.utils import r, set_seed_r, set_seed_for_test
+from extreme_estimator.margin_fits.gev.gevmle_fit import GevMleFit
+from extreme_estimator.margin_fits.gev.ismev_gev_fit import IsmevGevFit
+from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
+from spatio_temporal_dataset.coordinates.temporal_coordinates.abstract_temporal_coordinates import \
+    AbstractTemporalCoordinates
+from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
+from spatio_temporal_dataset.dataset.simulation_dataset import MarginDataset, FullSimulatedDataset
+from spatio_temporal_dataset.spatio_temporal_observations.abstract_spatio_temporal_observations import \
+    AbstractSpatioTemporalObservations
+from test.test_utils import load_test_spatiotemporal_coordinates, load_smooth_margin_models, load_test_max_stable_models
+
+
+class TestMaxStableTemporal(unittest.TestCase):
+
+    def setUp(self) -> None:
+        set_seed_for_test(seed=42)
+        self.nb_points = 2
+        self.nb_steps = 50
+        self.nb_obs = 1
+        # Load some 2D spatial coordinates
+        self.coordinates = load_test_spatiotemporal_coordinates(nb_steps=self.nb_steps, nb_points=self.nb_points)[1]
+        self.smooth_margin_model = LinearNonStationaryLocationMarginModel(coordinates=self.coordinates,
+                                                                          starting_point=2)
+        self.max_stable_model = load_test_max_stable_models()[0]
+        self.dataset = FullSimulatedDataset.from_double_sampling(nb_obs=self.nb_obs,
+                                                                 margin_model=self.smooth_margin_model,
+                                                                 coordinates=self.coordinates,
+                                                                 max_stable_model=self.max_stable_model)
+
+    def test_margin_fit_stationary(self):
+        # Create estimator
+        margin_model = LinearStationaryMarginModel(self.coordinates)
+        estimator = FullEstimatorInASingleStepWithSmoothMargin(self.dataset, margin_model,
+                                                               self.max_stable_model)
+        estimator.fit()
+        ref = {'loc': 1.2091156634312243, 'scale': 1.1210085591373455, 'shape': 0.9831957705294134}
+        for year in range(1, 3):
+            coordinate = np.array([0.0, 0.0, year])
+            mle_params_estimated = estimator.margin_function_fitted.get_gev_params(coordinate).to_dict()
+            for key in ref.keys():
+                self.assertAlmostEqual(ref[key], mle_params_estimated[key], places=3)
+
+    def test_margin_fit_nonstationary(self):
+        # Create estimator
+        margin_model = LinearNonStationaryLocationMarginModel(self.coordinates)
+        estimator = FullEstimatorInASingleStepWithSmoothMargin(self.dataset, margin_model,
+                                                               self.max_stable_model)
+        estimator.fit()
+        self.assertNotEqual(estimator.margin_function_fitted.mu1_temporal_trend, 0.0)
+        # Checks that parameters returned are indeed different
+        coordinate1 = np.array([0.0, 0.0, 1])
+        mle_params_estimated_year1 = estimator.margin_function_fitted.get_gev_params(coordinate1).to_dict()
+        coordinate3 = np.array([0.0, 0.0, 3])
+        mle_params_estimated_year3 = estimator.margin_function_fitted.get_gev_params(coordinate3).to_dict()
+        self.assertNotEqual(mle_params_estimated_year1, mle_params_estimated_year3)
+
+    def test_margin_fit_nonstationary_with_start_point(self):
+        # Create estimator
+        estimator = self.fit_non_stationary_estimator(starting_point=2)
+        # By default, estimator find the good margin
+        self.assertNotEqual(estimator.margin_function_fitted.mu1_temporal_trend, 0.0)
+        self.assertAlmostEqual(estimator.margin_function_fitted.mu1_temporal_trend,
+                               self.smooth_margin_model.margin_function_sample.mu1_temporal_trend,
+                               places=2)
+        # Checks starting point parameter are well passed
+        self.assertEqual(2, estimator.margin_function_fitted.starting_point)
+        # Checks that parameters returned are indeed different
+        coordinate1 = np.array([0.0, 0.0, 1])
+        mle_params_estimated_year1 = estimator.margin_function_fitted.get_gev_params(coordinate1).to_dict()
+        coordinate2 = np.array([0.0, 0.0, 2])
+        mle_params_estimated_year2 = estimator.margin_function_fitted.get_gev_params(coordinate2).to_dict()
+        self.assertEqual(mle_params_estimated_year1, mle_params_estimated_year2)
+        coordinate5 = np.array([0.0, 0.0, 5])
+        mle_params_estimated_year5 = estimator.margin_function_fitted.get_gev_params(coordinate5).to_dict()
+        self.assertNotEqual(mle_params_estimated_year5, mle_params_estimated_year2)
+
+    def fit_non_stationary_estimator(self, starting_point):
+        margin_model = LinearNonStationaryLocationMarginModel(self.coordinates, starting_point=starting_point)
+        estimator = FullEstimatorInASingleStepWithSmoothMargin(self.dataset, margin_model,
+                                                               self.max_stable_model)
+        estimator.fit()
+        return estimator
+
+    def test_two_different_starting_points(self):
+        # Create two different estimators
+        estimator1 = self.fit_non_stationary_estimator(starting_point=3)
+        estimator2 = self.fit_non_stationary_estimator(starting_point=20)
+
+        for starting_point in range(3, 20):
+            estimator = self.fit_non_stationary_estimator(starting_point=starting_point)
+            print(estimator.margin_function_fitted.starting_point)
+            print(estimator.margin_function_fitted.coef_dict)
+            print(estimator.margin_function_fitted.mu0)
+            print(estimator.margin_function_fitted.mu1_temporal_trend)
+
+        mu1_estimator1 = estimator1.margin_function_fitted.mu1_temporal_trend
+        mu1_estimator2 = estimator2.margin_function_fitted.mu1_temporal_trend
+        self.assertNotEqual(mu1_estimator1, mu1_estimator2)
+
+
+if __name__ == '__main__':
+    unittest.main()