From d130ce719df998819b4b465ffedb393998c5c969 Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Mon, 26 Nov 2018 16:01:02 +0100
Subject: [PATCH] [MAX STABLE MODEL] add fitmaxstab for 1D coordinates (by
 building 2D coordinates from the 1D coordinates)

---
 .../margin_model/abstract_margin_model.py     |  1 +
 .../independent_margin_function.py            |  3 +-
 .../margin_model/smooth_margin_model.py       |  9 ++--
 .../abstract_max_stable_model.py              | 27 +++++++---
 .../max_stable_model/max_stable_fit.R         | 49 ++++++++++++++++---
 .../max_stable_model/max_stable_models.py     | 13 +++--
 .../coordinates/abstract_coordinates.py       |  4 ++
 .../test_estimator/test_full_estimators.py    | 42 +++++++---------
 .../test_estimator/test_margin_estimators.py  | 45 +++++++++--------
 .../test_max_stable_estimators.py             | 33 +++++++------
 .../test_dataset.py                           |  7 +--
 test/test_utils.py                            | 35 ++++++++++---
 12 files changed, 176 insertions(+), 92 deletions(-)

diff --git a/extreme_estimator/extreme_models/margin_model/abstract_margin_model.py b/extreme_estimator/extreme_models/margin_model/abstract_margin_model.py
index 5ca3d45b..f6cff25c 100644
--- a/extreme_estimator/extreme_models/margin_model/abstract_margin_model.py
+++ b/extreme_estimator/extreme_models/margin_model/abstract_margin_model.py
@@ -10,6 +10,7 @@ class AbstractMarginModel(AbstractModel):
 
     def __init__(self, coordinates: AbstractCoordinates, params_start_fit=None, params_sample=None):
         super().__init__(params_start_fit, params_sample)
+        assert isinstance(coordinates, AbstractCoordinates), type(coordinates)
         self.coordinates = coordinates
         self.margin_function_sample = None  # type: AbstractMarginFunction
         self.margin_function_start_fit = None  # type: AbstractMarginFunction
diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/independent_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/independent_margin_function.py
index 65adfcf3..24780e79 100644
--- a/extreme_estimator/extreme_models/margin_model/margin_function/independent_margin_function.py
+++ b/extreme_estimator/extreme_models/margin_model/margin_function/independent_margin_function.py
@@ -48,7 +48,8 @@ class LinearMarginFunction(IndependentMarginFunction):
 
         # Check the axes are well-defined with respect to the coordinates
         for axes in self.gev_param_name_to_linear_axes.values():
-            assert all([axis < np.ndim(coordinates.coordinates_values) for axis in axes])
+            for axis in axes:
+                assert axis < coordinates.nb_columns, "axis={}, nb_columns={}".format(axis, coordinates.nb_columns)
 
         # Build gev_parameter_to_param_function dictionary
         self.gev_param_name_to_param_function = {}  # type: Dict[str, ParamFunction]
diff --git a/extreme_estimator/extreme_models/margin_model/smooth_margin_model.py b/extreme_estimator/extreme_models/margin_model/smooth_margin_model.py
index 15f16c56..8d22b109 100644
--- a/extreme_estimator/extreme_models/margin_model/smooth_margin_model.py
+++ b/extreme_estimator/extreme_models/margin_model/smooth_margin_model.py
@@ -50,9 +50,10 @@ class LinearAllParametersAxis0MarginModel(LinearMarginModel):
                                        GevParams.GEV_SCALE: [0]})
 
 
-class LinearAllParametersAxis0And1MarginModel(LinearMarginModel):
+class LinearAllParametersAllAxisMarginModel(LinearMarginModel):
 
     def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_linear_axis=None):
-        super().load_margin_functions({GevParams.GEV_SHAPE: [0, 1],
-                                       GevParams.GEV_LOC: [0, 1],
-                                       GevParams.GEV_SCALE: [0, 1]})
+        all_axis = list(range(self.coordinates.nb_columns))
+        super().load_margin_functions({GevParams.GEV_SHAPE: all_axis.copy(),
+                                       GevParams.GEV_LOC: all_axis.copy(),
+                                       GevParams.GEV_SCALE: all_axis.copy()})
diff --git a/extreme_estimator/extreme_models/max_stable_model/abstract_max_stable_model.py b/extreme_estimator/extreme_models/max_stable_model/abstract_max_stable_model.py
index 367b9b25..5ad1c115 100644
--- a/extreme_estimator/extreme_models/max_stable_model/abstract_max_stable_model.py
+++ b/extreme_estimator/extreme_models/max_stable_model/abstract_max_stable_model.py
@@ -3,9 +3,11 @@ from enum import Enum
 
 import numpy as np
 import rpy2
+from rpy2.rinterface import RRuntimeError
 import rpy2.robjects as robjects
 
 from extreme_estimator.extreme_models.abstract_model import AbstractModel
+from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
 
 
 class AbstractMaxStableModel(AbstractModel):
@@ -26,27 +28,37 @@ class AbstractMaxStableModel(AbstractModel):
             assert fit_marge_form_dict is not None
             assert margin_start_dict is not None
 
-        # Prepare the data and the coord objects
+        # Prepare the data
         data = np.transpose(maxima_frech)
+
+        # Prepare the coord
+        # In the one dimensional case, fitmaxstab isn't working
+        # therefore, we treat our 1D coordinate as 2D coordinate on the line y=x, and enforce iso=TRUE
+        fitmaxstab_with_one_dimensional_data = len(df_coordinates.columns) == 1
+        if fitmaxstab_with_one_dimensional_data:
+            assert AbstractCoordinates.COORDINATE_X in df_coordinates.columns
+            df_coordinates[AbstractCoordinates.COORDINATE_Y] = df_coordinates[AbstractCoordinates.COORDINATE_X]
+        # Give names to columns to enable a specification of the shape of each marginal parameter
         coord = robjects.vectors.Matrix(df_coordinates.values)
         coord.colnames = robjects.StrVector(list(df_coordinates.columns))
 
-        #  Prepare the fit params
+        #  Prepare the fit_params (a dictionary containing all additional parameters)
         fit_params = self.cov_mod_param.copy()
         start_dict = self.params_start_fit
-        # Remove the 'var' parameter from the start_dict in the 2D case, otherwise fitmaxstab crashes
-        if len(df_coordinates.columns) == 2 and 'var' in start_dict.keys():
-                start_dict.pop('var')
+        # Remove some parameters that should only be used either in 1D or 2D case, otherwise fitmaxstab crashes
+        start_dict = self.remove_unused_parameters(start_dict, fitmaxstab_with_one_dimensional_data)
         if fit_marge:
             start_dict.update(margin_start_dict)
             fit_params.update({k: robjects.Formula(v) for k, v in fit_marge_form_dict.items()})
+        if fitmaxstab_with_one_dimensional_data:
+            fit_params['iso'] = True
         fit_params['start'] = self.r.list(**start_dict)
         fit_params['fit.marge'] = fit_marge
 
         # Run the fitmaxstab in R
         try:
             res = self.r.fitmaxstab(data=data, coord=coord, **fit_params)  # type: robjects.ListVector
-        except rpy2.rinterface.RRuntimeError as error:
+        except RRuntimeError as error:
             raise Exception('Some R exception have been launched at RunTime: \n {}'.format(error.__repr__()))
         # todo: maybe if the convergence was not successful I could try other starting point several times
         # Retrieve the resulting fitted values
@@ -62,6 +74,9 @@ class AbstractMaxStableModel(AbstractModel):
             self.r.rmaxstab(nb_obs, coordinates, *list(self.cov_mod_param.values()), **self.params_sample))
         return np.transpose(maxima_frech)
 
+    def remove_unused_parameters(self, start_dict, coordinate_dim):
+        return start_dict
+
 
 class CovarianceFunction(Enum):
     whitmat = 0
diff --git a/extreme_estimator/extreme_models/max_stable_model/max_stable_fit.R b/extreme_estimator/extreme_models/max_stable_model/max_stable_fit.R
index 24f36d60..b7cf2f13 100644
--- a/extreme_estimator/extreme_models/max_stable_model/max_stable_fit.R
+++ b/extreme_estimator/extreme_models/max_stable_model/max_stable_fit.R
@@ -1,12 +1,9 @@
 library(SpatialExtremes)
 
 
-# Boolean for python call
-call_main = !exists("python_wrapping")
 
-if (call_main) {
-    set.seed(42)
-    n.obs = 50
+# rmaxstab with 2D data
+rmaxstab2D <- function (n.obs){
     n.site = 2
     coord <- matrix(rnorm(2*n.site, sd = sqrt(.2)), ncol = 2)
     colnames(coord) = c("E", "N")
@@ -27,9 +24,46 @@ if (call_main) {
     scale.form = scale ~ 1
     shape.form = shape ~ 1
 
-
     namedlist = list(cov11 = 1.0, cov12 = 1.2, cov22 = 2.2, locCoeff1=1.0, scaleCoeff1=1.0, shapeCoeff1=1.0)
     res = fitmaxstab(data=data, coord=coord, cov.mod="gauss", start=namedlist, fit.marge=TRUE, loc.form=loc.form, scale.form=scale.form,shape.form=shape.form)
+    print(res['fitted.values'])
+}
+
+# rmaxstab with 1D data
+rmaxstab1D <- function (n.obs){
+
+    # In one dimensional, we duplicate the coordinate
+    vec = rnorm(3, sd = sqrt(.2))
+    coord = cbind(vec, vec)
+    var = 1.0
+    data <- rmaxstab(n.obs, coord, "gauss", cov11 = var, cov12 = 0, cov22 = var)
+
+    print(class(coord))
+    print(colnames(coord))
+
+    loc.form = loc ~ 1
+    scale.form = scale ~ 1
+    shape.form = shape ~ 1
+
+    # GAUSS
+    namedlist = list(cov=1.0, locCoeff1=1.0, scaleCoeff1=1.0, shapeCoeff1=1.0)
+    res = fitmaxstab(data=data, coord=coord, cov.mod="gauss", start=namedlist, fit.marge=TRUE, loc.form=loc.form, scale.form=scale.form,shape.form=shape.form, iso=TRUE)
+
+    # BROWN
+    # namedlist = list(range = 3, smooth = 0.5, locCoeff1=1.0, scaleCoeff1=1.0, shapeCoeff1=1.0)
+    # res = fitmaxstab(data=data, coord=coord, cov.mod="brown", start=namedlist, fit.marge=TRUE, loc.form=loc.form, scale.form=scale.form,shape.form=shape.form, iso=TRUE)
+
+
+    print(res['fitted.values'])
+}
+
+# Boolean for python call
+call_main = !exists("python_wrapping")
+if (call_main) {
+    set.seed(42)
+    n.obs = 500
+    rmaxstab2D(n.obs)
+    rmaxstab1D(n.obs)
 
     # namedlist = list(cov11 = 1.0, cov12 = 1.2, cov22 = 2.2)
     # res = fitmaxstab(data=data, coord=coord, cov.mod="gauss", start=namedlist)
@@ -38,7 +72,8 @@ if (call_main) {
     #     print(name)
     #     print(res[name])
     # }
-    print(res['fitted.values'])
+
     # print(res['convergence'])
 
 }
+
diff --git a/extreme_estimator/extreme_models/max_stable_model/max_stable_models.py b/extreme_estimator/extreme_models/max_stable_model/max_stable_models.py
index 812a7a94..347f5ba3 100644
--- a/extreme_estimator/extreme_models/max_stable_model/max_stable_models.py
+++ b/extreme_estimator/extreme_models/max_stable_model/max_stable_models.py
@@ -17,6 +17,13 @@ class Smith(AbstractMaxStableModel):
         }
         self.default_params_sample = self.default_params_start_fit.copy()
 
+    def remove_unused_parameters(self, start_dict, fitmaxstab_with_one_dimensional_data):
+        if fitmaxstab_with_one_dimensional_data:
+            start_dict = {'cov': start_dict['var']}
+        else:
+            start_dict.pop('var')
+        return start_dict
+
 
 class BrownResnick(AbstractMaxStableModel):
 
@@ -47,7 +54,7 @@ class Geometric(AbstractMaxStableModelWithCovarianceFunction):
     def __init__(self, params_start_fit=None, params_sample=None, covariance_function: CovarianceFunction = None):
         super().__init__(params_start_fit, params_sample, covariance_function)
         self.cov_mod = 'g' + self.covariance_function.name
-        self.default_params_sample .update({'sigma2': 0.5})
+        self.default_params_sample.update({'sigma2': 0.5})
         self.default_params_start_fit = self.default_params_sample.copy()
 
 
@@ -56,7 +63,7 @@ class ExtremalT(AbstractMaxStableModelWithCovarianceFunction):
     def __init__(self, params_start_fit=None, params_sample=None, covariance_function: CovarianceFunction = None):
         super().__init__(params_start_fit, params_sample, covariance_function)
         self.cov_mod = 't' + self.covariance_function.name
-        self.default_params_sample .update({'DoF': 2})
+        self.default_params_sample.update({'DoF': 2})
         self.default_params_start_fit = self.default_params_sample.copy()
 
 
@@ -65,5 +72,5 @@ class ISchlather(AbstractMaxStableModelWithCovarianceFunction):
     def __init__(self, params_start_fit=None, params_sample=None, covariance_function: CovarianceFunction = None):
         super().__init__(params_start_fit, params_sample, covariance_function)
         self.cov_mod = 'i' + self.covariance_function.name
-        self.default_params_sample .update({'alpha': 0.5})
+        self.default_params_sample.update({'alpha': 0.5})
         self.default_params_start_fit = self.default_params_sample.copy()
diff --git a/spatio_temporal_dataset/coordinates/abstract_coordinates.py b/spatio_temporal_dataset/coordinates/abstract_coordinates.py
index 10e5d9f9..68f997d6 100644
--- a/spatio_temporal_dataset/coordinates/abstract_coordinates.py
+++ b/spatio_temporal_dataset/coordinates/abstract_coordinates.py
@@ -43,6 +43,10 @@ class AbstractCoordinates(object):
     def columns(self):
         return self.coordinates_columns(df_coord=self.df_coordinates)
 
+    @property
+    def nb_columns(self):
+        return len(self.columns)
+
     @property
     def df(self) -> pd.DataFrame:
         # Merged DataFrame of df_coord and s_split
diff --git a/test/test_extreme_estimator/test_estimator/test_full_estimators.py b/test/test_extreme_estimator/test_estimator/test_full_estimators.py
index 8c520aca..0c8aacbf 100644
--- a/test/test_extreme_estimator/test_estimator/test_full_estimators.py
+++ b/test/test_extreme_estimator/test_estimator/test_full_estimators.py
@@ -1,40 +1,36 @@
 import unittest
 from itertools import product
 
-from extreme_estimator.estimator.full_estimator import SmoothMarginalsThenUnitaryMsp, \
-    FullEstimatorInASingleStepWithSmoothMargin
 from spatio_temporal_dataset.dataset.simulation_dataset import FullSimulatedDataset
-from spatio_temporal_dataset.coordinates.spatial_coordinates.generated_spatial_coordinates import CircleCoordinates
-from test.test_extreme_estimator.test_estimator.test_margin_estimators import TestSmoothMarginEstimator
-from test.test_extreme_estimator.test_estimator.test_max_stable_estimators import TestMaxStableEstimators
-from test.test_utils import load_test_max_stable_models, load_smooth_margin_models
+from test.test_utils import load_test_max_stable_models, load_smooth_margin_models, load_test_1D_and_2D_coordinates, \
+    load_test_full_estimators
 
 
 class TestFullEstimators(unittest.TestCase):
     DISPLAY = False
-    FULL_ESTIMATORS = [SmoothMarginalsThenUnitaryMsp, FullEstimatorInASingleStepWithSmoothMargin][:]
+    nb_obs = 10
+    nb_points = 5
 
     def setUp(self):
         super().setUp()
-        self.spatial_coordinates = CircleCoordinates.from_nb_points(nb_points=5, max_radius=1)
+        self.spatial_coordinates = load_test_1D_and_2D_coordinates(nb_points=self.nb_points)
         self.max_stable_models = load_test_max_stable_models()
-        self.smooth_margin_models = load_smooth_margin_models(coordinates=self.spatial_coordinates)
 
     def test_full_estimators(self):
-        for margin_model, max_stable_model in product(self.smooth_margin_models, self.max_stable_models):
-            dataset = FullSimulatedDataset.from_double_sampling(nb_obs=10, margin_model=margin_model,
-                                                                coordinates=self.spatial_coordinates,
-                                                                max_stable_model=max_stable_model)
-
-            for estimator_class in self.FULL_ESTIMATORS:
-                estimator = estimator_class(dataset=dataset, margin_model=margin_model,
-                                            max_stable_model=max_stable_model)
-                estimator.fit()
-                if self.DISPLAY:
-                    print(type(margin_model))
-                    print(dataset.df_dataset.head())
-                    print(estimator.additional_information)
-            self.assertTrue(True)
+        for coordinates in self.spatial_coordinates:
+            smooth_margin_models = load_smooth_margin_models(coordinates=coordinates)
+            for margin_model, max_stable_model in product(smooth_margin_models, self.max_stable_models):
+                dataset = FullSimulatedDataset.from_double_sampling(nb_obs=self.nb_obs, margin_model=margin_model,
+                                                                    coordinates=coordinates,
+                                                                    max_stable_model=max_stable_model)
+
+                for full_estimator in load_test_full_estimators(dataset, margin_model, max_stable_model):
+                    full_estimator.fit()
+                    if self.DISPLAY:
+                        print(type(margin_model))
+                        print(dataset.df_dataset.head())
+                        print(full_estimator.additional_information)
+        self.assertTrue(True)
 
 
 if __name__ == '__main__':
diff --git a/test/test_extreme_estimator/test_estimator/test_margin_estimators.py b/test/test_extreme_estimator/test_estimator/test_margin_estimators.py
index 2e8969bb..473b88fa 100644
--- a/test/test_extreme_estimator/test_estimator/test_margin_estimators.py
+++ b/test/test_extreme_estimator/test_estimator/test_margin_estimators.py
@@ -1,40 +1,39 @@
 import unittest
 
-from extreme_estimator.extreme_models.margin_model.smooth_margin_model import ConstantMarginModel, \
-    LinearShapeAxis0MarginModel, LinearShapeAxis0and1MarginModel, LinearAllParametersAxis0MarginModel, \
-    LinearAllParametersAxis0And1MarginModel
 from extreme_estimator.estimator.margin_estimator import SmoothMarginEstimator
 from extreme_estimator.return_level_plot.spatial_2D_plot import Spatial2DPlot
-from spatio_temporal_dataset.dataset.simulation_dataset import MarginDataset
 from spatio_temporal_dataset.coordinates.spatial_coordinates.generated_spatial_coordinates import CircleCoordinates
-from test.test_utils import load_smooth_margin_models
+from spatio_temporal_dataset.dataset.simulation_dataset import MarginDataset
+from test.test_utils import load_smooth_margin_models, load_test_1D_and_2D_coordinates
 
 
 class TestSmoothMarginEstimator(unittest.TestCase):
     DISPLAY = False
-    SMOOTH_MARGIN_ESTIMATORS = [SmoothMarginEstimator]
+    nb_points = 5
 
     def setUp(self):
         super().setUp()
-        self.spatial_coordinates = CircleCoordinates.from_nb_points(nb_points=5, max_radius=1)
-        self.smooth_margin_models = load_smooth_margin_models(coordinates=self.spatial_coordinates)
+        self.coordinates = load_test_1D_and_2D_coordinates(nb_points=self.nb_points)
 
     def test_dependency_estimators(self):
-        for margin_model in self.smooth_margin_models:
-            dataset = MarginDataset.from_sampling(nb_obs=10, margin_model=margin_model,
-                                                  coordinates=self.spatial_coordinates)
-            # Fit estimator
-            estimator = SmoothMarginEstimator(dataset=dataset, margin_model=margin_model)
-            estimator.fit()
-            # Map name to their margin functions
-            name_to_margin_function = {
-                'Ground truth margin function': dataset.margin_model.margin_function_sample,
-                'Estimated margin function': estimator.margin_function_fitted,
-            }
-            # Spatial Plot
-            if self.DISPLAY:
-                Spatial2DPlot(name_to_margin_function=name_to_margin_function).plot()
-            self.assertTrue(True)
+        for coordinates in self.coordinates:
+            smooth_margin_models = load_smooth_margin_models(coordinates=coordinates)
+            for margin_model in smooth_margin_models:
+                dataset = MarginDataset.from_sampling(nb_obs=10,
+                                                      margin_model=margin_model,
+                                                      coordinates=coordinates)
+                # Fit estimator
+                estimator = SmoothMarginEstimator(dataset=dataset, margin_model=margin_model)
+                estimator.fit()
+                # Map name to their margin functions
+                name_to_margin_function = {
+                    'Ground truth margin function': dataset.margin_model.margin_function_sample,
+                    'Estimated margin function': estimator.margin_function_fitted,
+                }
+                # Spatial Plot
+                if self.DISPLAY:
+                    Spatial2DPlot(name_to_margin_function=name_to_margin_function).plot()
+        self.assertTrue(True)
 
 
 if __name__ == '__main__':
diff --git a/test/test_extreme_estimator/test_estimator/test_max_stable_estimators.py b/test/test_extreme_estimator/test_estimator/test_max_stable_estimators.py
index a4f1b8c7..aaea5ab3 100644
--- a/test/test_extreme_estimator/test_estimator/test_max_stable_estimators.py
+++ b/test/test_extreme_estimator/test_estimator/test_max_stable_estimators.py
@@ -5,32 +5,33 @@ from extreme_estimator.extreme_models.max_stable_model.abstract_max_stable_model
 from extreme_estimator.estimator.max_stable_estimator import MaxStableEstimator
 from spatio_temporal_dataset.dataset.simulation_dataset import MaxStableDataset
 from spatio_temporal_dataset.coordinates.spatial_coordinates.generated_spatial_coordinates import CircleCoordinates
-from test.test_utils import load_test_max_stable_models
+from test.test_utils import load_test_max_stable_models, load_test_1D_and_2D_coordinates, \
+    load_test_max_stable_estimators
 
 
 class TestMaxStableEstimators(unittest.TestCase):
     DISPLAY = False
-
-    MAX_STABLE_ESTIMATORS = [MaxStableEstimator]
+    nb_points = 5
+    nb_obs = 10
 
     def setUp(self):
         super().setUp()
-        self.spatial_coord = CircleCoordinates.from_nb_points(nb_points=5, max_radius=1)
+        self.coordinates = load_test_1D_and_2D_coordinates(nb_points=self.nb_points)
         self.max_stable_models = load_test_max_stable_models()
 
     def test_max_stable_estimators(self):
-        for max_stable_model in self.max_stable_models:
-            dataset = MaxStableDataset.from_sampling(nb_obs=10,
-                                                     max_stable_model=max_stable_model,
-                                                     coordinates=self.spatial_coord)
-
-            for estimator_class in self.MAX_STABLE_ESTIMATORS:
-                estimator = estimator_class(dataset=dataset, max_stable_model=max_stable_model)
-                estimator.fit()
-                if self.DISPLAY:
-                    print(type(max_stable_model))
-                    print(dataset.df_dataset.head())
-                    print(estimator.additional_information)
+        for coordinates in self.coordinates:
+            for max_stable_model in self.max_stable_models:
+                dataset = MaxStableDataset.from_sampling(nb_obs=self.nb_obs,
+                                                         max_stable_model=max_stable_model,
+                                                         coordinates=coordinates)
+
+                for max_stable_estimator in load_test_max_stable_estimators(dataset, max_stable_model):
+                    max_stable_estimator.fit()
+                    if self.DISPLAY:
+                        print(type(max_stable_model))
+                        print(dataset.df_dataset.head())
+                        print(max_stable_estimator.additional_information)
         self.assertTrue(True)
 
 
diff --git a/test/test_spatio_temporal_dataset/test_dataset.py b/test/test_spatio_temporal_dataset/test_dataset.py
index c182eb0f..90282276 100644
--- a/test/test_spatio_temporal_dataset/test_dataset.py
+++ b/test/test_spatio_temporal_dataset/test_dataset.py
@@ -3,7 +3,8 @@ import unittest
 from itertools import product
 
 from spatio_temporal_dataset.dataset.simulation_dataset import MaxStableDataset
-from test.test_utils import load_test_max_stable_models, load_test_coordinates
+from test.test_utils import load_test_max_stable_models, load_test_coordinates, load_test_3D_coordinates, \
+    load_test_1D_and_2D_coordinates
 
 
 class TestDataset(unittest.TestCase):
@@ -12,7 +13,7 @@ class TestDataset(unittest.TestCase):
 
     def test_max_stable_dataset_R1_and_R2(self):
         max_stable_models = load_test_max_stable_models()[:]
-        coordinatess = load_test_coordinates(self.nb_points)[:-1]
+        coordinatess = load_test_1D_and_2D_coordinates(self.nb_points)
         for coordinates, max_stable_model in product(coordinatess, max_stable_models):
             MaxStableDataset.from_sampling(nb_obs=self.nb_obs,
                                            max_stable_model=max_stable_model,
@@ -23,7 +24,7 @@ class TestDataset(unittest.TestCase):
         """Test to warn me when spatialExtremes handles R3"""
         with self.assertRaises(RRuntimeError):
             smith_process = load_test_max_stable_models()[0]
-            coordinates = load_test_coordinates(self.nb_points)[-1]
+            coordinates = load_test_3D_coordinates(nb_points=self.nb_points)[0]
             MaxStableDataset.from_sampling(nb_obs=self.nb_obs,
                                            max_stable_model=smith_process,
                                            coordinates=coordinates)
diff --git a/test/test_utils.py b/test/test_utils.py
index 2170af02..84ec34ee 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -1,4 +1,7 @@
-from extreme_estimator.extreme_models.margin_model.smooth_margin_model import LinearAllParametersAxis0And1MarginModel, \
+from extreme_estimator.estimator.full_estimator import SmoothMarginalsThenUnitaryMsp, \
+    FullEstimatorInASingleStepWithSmoothMargin
+from extreme_estimator.estimator.max_stable_estimator import MaxStableEstimator
+from extreme_estimator.extreme_models.margin_model.smooth_margin_model import LinearAllParametersAllAxisMarginModel, \
     ConstantMarginModel
 from extreme_estimator.extreme_models.max_stable_model.abstract_max_stable_model import \
     AbstractMaxStableModelWithCovarianceFunction, CovarianceFunction
@@ -16,12 +19,24 @@ In this case, unit test (at least on the constructor) must be ensured in the tes
 """
 
 TEST_MAX_STABLE_MODEL = [Smith, BrownResnick, Schlather, Geometric, ExtremalT, ISchlather]
-TEST_COORDINATES = [UniformCoordinates, CircleCoordinates, AlpsStation3DCoordinatesWithAnisotropy]
-MARGIN_TYPES = [ConstantMarginModel, LinearAllParametersAxis0And1MarginModel][:]
+TEST_1D_AND_2D_COORDINATES = [UniformCoordinates, CircleCoordinates]
+TEST_3D_COORDINATES = [AlpsStation3DCoordinatesWithAnisotropy]
+TEST_MARGIN_TYPES = [ConstantMarginModel, LinearAllParametersAllAxisMarginModel][:]
+TEST_MAX_STABLE_ESTIMATOR = [MaxStableEstimator]
+TEST_FULL_ESTIMATORS = [SmoothMarginalsThenUnitaryMsp, FullEstimatorInASingleStepWithSmoothMargin][:]
+
+
+def load_test_full_estimators(dataset, margin_model, max_stable_model):
+    return [full_estimator(dataset=dataset, margin_model=margin_model, max_stable_model=max_stable_model) for
+            full_estimator in TEST_FULL_ESTIMATORS]
+
+
+def load_test_max_stable_estimators(dataset, max_stable_model):
+    return [max_stable_estimator(dataset, max_stable_model) for max_stable_estimator in TEST_MAX_STABLE_ESTIMATOR]
 
 
 def load_smooth_margin_models(coordinates):
-    return [margin_class(coordinates=coordinates) for margin_class in MARGIN_TYPES]
+    return [margin_class(coordinates=coordinates) for margin_class in TEST_MARGIN_TYPES]
 
 
 def load_test_max_stable_models():
@@ -36,5 +51,13 @@ def load_test_max_stable_models():
     return max_stable_models
 
 
-def load_test_coordinates(nb_points):
-    return [coordinate_class.from_nb_points(nb_points=nb_points) for coordinate_class in TEST_COORDINATES]
+def load_test_coordinates(nb_points, coordinate_types):
+    return [coordinate_class.from_nb_points(nb_points=nb_points) for coordinate_class in coordinate_types]
+
+
+def load_test_1D_and_2D_coordinates(nb_points):
+    return load_test_coordinates(nb_points, TEST_1D_AND_2D_COORDINATES)
+
+
+def load_test_3D_coordinates(nb_points):
+    return load_test_coordinates(nb_points, TEST_3D_COORDINATES)
-- 
GitLab