diff --git a/extreme_estimator/estimator/abstract_estimator.py b/extreme_estimator/estimator/abstract_estimator.py
index 0dfb164342809954da131b5b4923ba375ac9605d..857fef72b9d33d0f060412f036aba0f80a6ff04f 100644
--- a/extreme_estimator/estimator/abstract_estimator.py
+++ b/extreme_estimator/estimator/abstract_estimator.py
@@ -1,5 +1,6 @@
 import time
 
+from extreme_estimator.extreme_models.margin_model.linear_margin_model import LinearMarginModel
 from extreme_estimator.extreme_models.margin_model.margin_function.parametric_margin_function import \
     ParametricMarginFunction
 from extreme_estimator.extreme_models.result_from_fit import ResultFromFit
@@ -51,10 +52,11 @@ class AbstractEstimator(object):
     def extract_function_fitted(self) -> AbstractMarginFunction:
         raise NotImplementedError
 
-    def extract_function_fitted_from_function_to_fit(self, margin_function_to_fit: ParametricMarginFunction):
+    def extract_function_fitted_from_the_model_shape(self, margin_model: LinearMarginModel):
         return LinearMarginFunction.from_coef_dict(coordinates=self.dataset.coordinates,
-                                                   gev_param_name_to_dims=margin_function_to_fit.gev_param_name_to_dims,
-                                                   coef_dict=self.result_from_fit.margin_coef_dict)
+                                                   gev_param_name_to_dims=margin_model.margin_function_start_fit.gev_param_name_to_dims,
+                                                   coef_dict=self.result_from_fit.margin_coef_dict,
+                                                   starting_point=margin_model.starting_point)
 
     # @property
     # def max_stable_fitted(self) -> AbstractMarginFunction:
diff --git a/extreme_estimator/estimator/full_estimator/abstract_full_estimator.py b/extreme_estimator/estimator/full_estimator/abstract_full_estimator.py
index 547b54f6781aa08a57fc45b8d74df67cb68de6ef..b01ba71292b344905aae61e92211a6fee2943e35 100644
--- a/extreme_estimator/estimator/full_estimator/abstract_full_estimator.py
+++ b/extreme_estimator/estimator/full_estimator/abstract_full_estimator.py
@@ -63,12 +63,12 @@ class FullEstimatorInASingleStepWithSmoothMargin(AbstractFullEstimator):
             df_coordinates_spat=self.dataset.coordinates.df_spatial_coordinates(self.train_split),
             df_coordinates_temp=self.dataset.coordinates.df_temporal_coordinates(self.train_split),
             fit_marge=True,
-            fit_marge_form_dict=self.margin_function_start_fit.form_dict,
-            margin_start_dict=self.margin_function_start_fit.coef_dict
+            fit_marge_form_dict=self.linear_margin_model.margin_function_start_fit.form_dict,
+            margin_start_dict=self.linear_margin_model.margin_function_start_fit.coef_dict
         )
 
     def extract_function_fitted(self):
-        return self.extract_function_fitted_from_function_to_fit(self.margin_function_start_fit)
+        return self.extract_function_fitted_from_the_model_shape(self.linear_margin_model)
 
 
 class PointwiseAndThenUnitaryMsp(AbstractFullEstimator):
diff --git a/extreme_estimator/estimator/margin_estimator/abstract_margin_estimator.py b/extreme_estimator/estimator/margin_estimator/abstract_margin_estimator.py
index 7acdf0ce039a834935534e4d8250645b03b8d1b1..38f8e5c36f6f3a2ca39fc09a753aefaedd528aa3 100644
--- a/extreme_estimator/estimator/margin_estimator/abstract_margin_estimator.py
+++ b/extreme_estimator/estimator/margin_estimator/abstract_margin_estimator.py
@@ -5,6 +5,7 @@ from extreme_estimator.extreme_models.margin_model.margin_function.abstract_marg
     AbstractMarginFunction
 from extreme_estimator.extreme_models.margin_model.linear_margin_model import LinearMarginModel, \
     LinearAllParametersAllDimsMarginModel
+from extreme_estimator.extreme_models.margin_model.margin_function.linear_margin_function import LinearMarginFunction
 from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
 
 
@@ -35,6 +36,10 @@ class LinearMarginEstimator(AbstractMarginEstimator):
                                                                             df_coordinates_spat=df_coordinates_spat,
                                                                             df_coordinates_temp=df_coordinates_temp)
 
-    def extract_function_fitted(self):
-        return self.extract_function_fitted_from_function_to_fit(self.margin_model.margin_function_start_fit)
+    @property
+    def margin_function_fitted(self) -> LinearMarginFunction:
+        return super().margin_function_fitted
+
+    def extract_function_fitted(self) -> LinearMarginFunction:
+        return self.extract_function_fitted_from_the_model_shape(self.margin_model)
 
diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/independent_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/independent_margin_function.py
index c4d34c8d382cb58b3c71813dfad406a515466749..25350f19758727cef00a0844c8e92fa346319e65 100644
--- a/extreme_estimator/extreme_models/margin_model/margin_function/independent_margin_function.py
+++ b/extreme_estimator/extreme_models/margin_model/margin_function/independent_margin_function.py
@@ -1,4 +1,4 @@
-from typing import Dict
+from typing import Dict, Union
 
 import numpy as np
 
@@ -17,7 +17,7 @@ class IndependentMarginFunction(AbstractMarginFunction):
     def __init__(self, coordinates: AbstractCoordinates):
         """Attribute 'gev_param_name_to_param_function' maps each GEV parameter to its corresponding function"""
         super().__init__(coordinates)
-        self.gev_param_name_to_param_function = None  # type: Dict[str, AbstractParamFunction]
+        self.gev_param_name_to_param_function = None  # type: Union[None, Dict[str, AbstractParamFunction]]
 
     def get_gev_params(self, coordinate: np.ndarray) -> GevParams:
         """Each GEV parameter is computed independently through its corresponding param_function"""
diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/linear_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/linear_margin_function.py
index db54afeb5459e346e8610de74e4c8bbff4926182..3341bf2efc732c8845cd63d1638ba23a0fc9dcea 100644
--- a/extreme_estimator/extreme_models/margin_model/margin_function/linear_margin_function.py
+++ b/extreme_estimator/extreme_models/margin_model/margin_function/linear_margin_function.py
@@ -1,13 +1,11 @@
-from typing import Dict, List
-
-import numpy as np
+from typing import Dict, List, Union
 
 from extreme_estimator.extreme_models.margin_model.margin_function.parametric_margin_function import \
     ParametricMarginFunction
 from extreme_estimator.extreme_models.margin_model.param_function.abstract_coef import AbstractCoef
 from extreme_estimator.extreme_models.margin_model.param_function.linear_coef import LinearCoef
-from extreme_estimator.extreme_models.margin_model.param_function.param_function import ConstantParamFunction, \
-    AbstractParamFunction, LinearParamFunction
+from extreme_estimator.extreme_models.margin_model.param_function.param_function import AbstractParamFunction, \
+    LinearParamFunction
 from extreme_estimator.margin_fits.gev.gev_params import GevParams
 from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
 
@@ -29,32 +27,15 @@ class LinearMarginFunction(ParametricMarginFunction):
     COEF_CLASS = LinearCoef
 
     def __init__(self, coordinates: AbstractCoordinates, gev_param_name_to_dims: Dict[str, List[int]],
-                 gev_param_name_to_coef: Dict[str, AbstractCoef],
-                 starting_point=None):
-        # Starting point for the trend is the same for all the parameters
-        self.starting_point = starting_point
-        self.gev_param_name_to_coef = None  # type: Dict[str, LinearCoef]
-        super().__init__(coordinates, gev_param_name_to_dims, gev_param_name_to_coef)
-
-    # @classmethod
-    # def from_coef_dict(cls, coordinates: AbstractCoordinates, gev_param_name_to_dims: Dict[str, List[int]],
-    #                    coef_dict: Dict[str, float]):
-    #     return super().from_coef_dict(coordinates, gev_param_name_to_dims, coef_dict)
+                 gev_param_name_to_coef: Dict[str, AbstractCoef], starting_point: Union[None, int] = None):
+        self.gev_param_name_to_coef = None  # type: Union[None, Dict[str, LinearCoef]]
+        super().__init__(coordinates, gev_param_name_to_dims, gev_param_name_to_coef, starting_point)
 
     def load_specific_param_function(self, gev_param_name) -> AbstractParamFunction:
         return LinearParamFunction(dims=self.gev_param_name_to_dims[gev_param_name],
                                    coordinates=self.coordinates.coordinates_values(),
                                    linear_coef=self.gev_param_name_to_coef[gev_param_name])
 
-    def get_gev_params(self, coordinate: np.ndarray) -> GevParams:
-        if self.starting_point is not None:
-            # Shift temporal coordinate to enable to model temporal trend with starting point
-            assert self.coordinates.has_temporal_coordinates
-            assert 0 <= self.coordinates.idx_temporal_coordinates < len(coordinate)
-            if coordinate[self.coordinates.idx_temporal_coordinates] < self.starting_point:
-                coordinate[self.coordinates.idx_temporal_coordinates] = self.starting_point
-        return super().get_gev_params(coordinate)
-
     @classmethod
     def idx_to_coefficient_name(cls, coordinates: AbstractCoordinates) -> Dict[int, str]:
         # Intercept correspond to the dimension 0
@@ -68,6 +49,15 @@ class LinearMarginFunction(ParametricMarginFunction):
     def coefficient_name_to_dim(cls, coordinates: AbstractCoordinates) -> Dict[int, str]:
         return {v: k for k, v in cls.idx_to_coefficient_name(coordinates).items()}
 
+    @property
+    def coef_dict(self) -> Dict[str, float]:
+        coef_dict = {}
+        for gev_param_name in GevParams.PARAM_NAMES:
+            dims = self.gev_param_name_to_dims.get(gev_param_name, [])
+            coef = self.gev_param_name_to_coef[gev_param_name]
+            coef_dict.update(coef.coef_dict(dims, self.idx_to_coefficient_name(self.coordinates)))
+        return coef_dict
+
     @property
     def form_dict(self) -> Dict[str, str]:
         form_dict = {}
diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/parametric_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/parametric_margin_function.py
index 8a0e51ca55ea2e0123db14359a45815fa3e8800a..7476aa13eccb9107ea95e84dcd080165de1cedba 100644
--- a/extreme_estimator/extreme_models/margin_model/margin_function/parametric_margin_function.py
+++ b/extreme_estimator/extreme_models/margin_model/margin_function/parametric_margin_function.py
@@ -1,4 +1,6 @@
-from typing import Dict, List
+from typing import Dict, List, Union
+
+import numpy as np
 
 from extreme_estimator.extreme_models.margin_model.margin_function.independent_margin_function import \
     IndependentMarginFunction
@@ -29,7 +31,9 @@ class ParametricMarginFunction(IndependentMarginFunction):
     COEF_CLASS = None
 
     def __init__(self, coordinates: AbstractCoordinates, gev_param_name_to_dims: Dict[str, List[int]],
-                 gev_param_name_to_coef: Dict[str, AbstractCoef]):
+                 gev_param_name_to_coef: Dict[str, AbstractCoef], starting_point: Union[None, int] = None):
+        # Starting point for the trend is the same for all the parameters
+        self.starting_point = starting_point
         super().__init__(coordinates)
         self.gev_param_name_to_dims = gev_param_name_to_dims  # type: Dict[str, List[int]]
 
@@ -57,9 +61,19 @@ class ParametricMarginFunction(IndependentMarginFunction):
     def load_specific_param_function(self, gev_param_name) -> AbstractParamFunction:
         raise NotImplementedError
 
+    def get_gev_params(self, coordinate: np.ndarray) -> GevParams:
+        print('here get gev', self.starting_point)
+        if self.starting_point is not None:
+            # Shift temporal coordinate to enable to model temporal trend with starting point
+            assert self.coordinates.has_temporal_coordinates
+            assert 0 <= self.coordinates.idx_temporal_coordinates < len(coordinate)
+            if coordinate[self.coordinates.idx_temporal_coordinates] < self.starting_point:
+                coordinate[self.coordinates.idx_temporal_coordinates] = self.starting_point
+        return super().get_gev_params(coordinate)
+
     @classmethod
     def from_coef_dict(cls, coordinates: AbstractCoordinates, gev_param_name_to_dims: Dict[str, List[int]],
-                       coef_dict: Dict[str, float]):
+                       coef_dict: Dict[str, float], starting_point: Union[None, int] = None):
         assert cls.COEF_CLASS is not None, 'a COEF_CLASS class attributes needs to be defined'
         gev_param_name_to_coef = {}
         for gev_param_name in GevParams.PARAM_NAMES:
@@ -67,16 +81,7 @@ class ParametricMarginFunction(IndependentMarginFunction):
             coef = cls.COEF_CLASS.from_coef_dict(coef_dict=coef_dict, gev_param_name=gev_param_name, dims=dims,
                                                  coordinates=coordinates)
             gev_param_name_to_coef[gev_param_name] = coef
-        return cls(coordinates, gev_param_name_to_dims, gev_param_name_to_coef)
-
-    @property
-    def coef_dict(self) -> Dict[str, float]:
-        coef_dict = {}
-        for gev_param_name in GevParams.PARAM_NAMES:
-            dims = self.gev_param_name_to_dims.get(gev_param_name, [])
-            coef = self.gev_param_name_to_coef[gev_param_name]
-            coef_dict.update(coef.coef_dict(dims, self.idx_to_coefficient_name(self.coordinates)))
-        return coef_dict
+        return cls(coordinates, gev_param_name_to_dims, gev_param_name_to_coef, starting_point)
 
     @property
     def form_dict(self) -> Dict[str, str]:
diff --git a/test/test_extreme_estimator/test_margin_fits/test_gev/test_gev_temporal_margin.py b/test/test_extreme_estimator/test_margin_fits/test_gev/test_gev_temporal_margin.py
new file mode 100644
index 0000000000000000000000000000000000000000..176d0a037d582dc5e317248605cc137909966a76
--- /dev/null
+++ b/test/test_extreme_estimator/test_margin_fits/test_gev/test_gev_temporal_margin.py
@@ -0,0 +1,77 @@
+import unittest
+
+import numpy as np
+import pandas as pd
+
+from extreme_estimator.estimator.margin_estimator.abstract_margin_estimator import LinearMarginEstimator
+from extreme_estimator.extreme_models.margin_model.temporal_linear_margin_model import StationaryStationModel, \
+    NonStationaryStationModel
+from extreme_estimator.extreme_models.utils import r, set_seed_r
+from extreme_estimator.margin_fits.gev.gevmle_fit import GevMleFit
+from extreme_estimator.margin_fits.gev.ismev_gev_fit import IsmevGevFit
+from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
+from spatio_temporal_dataset.coordinates.temporal_coordinates.abstract_temporal_coordinates import \
+    AbstractTemporalCoordinates
+from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
+from spatio_temporal_dataset.spatio_temporal_observations.abstract_spatio_temporal_observations import \
+    AbstractSpatioTemporalObservations
+
+
+class TestGevTemporalMargin(unittest.TestCase):
+
+    def setUp(self) -> None:
+        set_seed_r()
+        r("""
+        N <- 50
+        loc = 0; scale = 1; shape <- 1
+        x_gev <- rgev(N, loc = loc, scale = scale, shape = shape)
+        start_loc = 0; start_scale = 1; start_shape = 1
+        """)
+        # Compute the stationary temporal margin with isMev
+        df = pd.DataFrame({AbstractCoordinates.COORDINATE_T: range(50)})
+        self.coordinates = AbstractTemporalCoordinates.from_df(df)
+        df2 = pd.DataFrame(data=np.array(r['x_gev']), index=df.index)
+        observations = AbstractSpatioTemporalObservations(df_maxima_gev=df2)
+        self.dataset = AbstractDataset(observations=observations, coordinates=self.coordinates)
+
+    def test_gev_temporal_margin_fit_stationary(self):
+        # Create estimator
+        margin_model = StationaryStationModel(self.coordinates)
+        estimator = LinearMarginEstimator(self.dataset, margin_model)
+        estimator.fit()
+        ref = {'loc': 0.0219, 'scale': 1.0347, 'shape': 0.8295}
+        for year in range(1, 3):
+            mle_params_estimated = estimator.margin_function_fitted.get_gev_params(np.array([year])).to_dict()
+            for key in ref.keys():
+                self.assertAlmostEqual(ref[key], mle_params_estimated[key], places=3)
+
+    def test_gev_temporal_margin_fit_nonstationary(self):
+        # Create estimator
+        margin_model = NonStationaryStationModel(self.coordinates)
+        estimator = LinearMarginEstimator(self.dataset, margin_model)
+        estimator.fit()
+        self.assertNotEqual(estimator.result_from_fit.margin_coef_dict['tempCoeffLoc1'], 0.0)
+        # Checks that parameters returned are indeed different
+        mle_params_estimated_year1 = estimator.margin_function_fitted.get_gev_params(np.array([1])).to_dict()
+        mle_params_estimated_year3 = estimator.margin_function_fitted.get_gev_params(np.array([3])).to_dict()
+        self.assertNotEqual(mle_params_estimated_year1, mle_params_estimated_year3)
+
+    def test_gev_temporal_margin_fit_nonstationary_with_start_point(self):
+        # Create estimator
+        margin_model = NonStationaryStationModel(self.coordinates, starting_point=3)
+        estimator = LinearMarginEstimator(self.dataset, margin_model)
+        estimator.fit()
+        self.assertNotEqual(estimator.result_from_fit.margin_coef_dict['tempCoeffLoc1'], 0.0)
+        # Checks starting point parameter are well passed
+        self.assertEqual(3, estimator.margin_function_fitted.starting_point)
+        # Checks that parameters returned are indeed different
+        mle_params_estimated_year1 = estimator.margin_function_fitted.get_gev_params(np.array([1])).to_dict()
+        mle_params_estimated_year3 = estimator.margin_function_fitted.get_gev_params(np.array([3])).to_dict()
+        self.assertEqual(mle_params_estimated_year1, mle_params_estimated_year3)
+        mle_params_estimated_year5 = estimator.margin_function_fitted.get_gev_params(np.array([5])).to_dict()
+        self.assertNotEqual(mle_params_estimated_year5, mle_params_estimated_year3)
+    # todo: create same test with a starting value, to check if the starting value is taken into account in the margin_function_fitted
+
+
+if __name__ == '__main__':
+    unittest.main()