diff --git a/experiment/fit_diagnosis/main_split.py b/experiment/fit_diagnosis/main_split.py
deleted file mode 100644
index c6c326198163b33665738e62b0b5704f5ff45bef..0000000000000000000000000000000000000000
--- a/experiment/fit_diagnosis/main_split.py
+++ /dev/null
@@ -1,51 +0,0 @@
-import random
-
-from experiment.fit_diagnosis.split_curve import SplitCurve, LocFunction
-from extreme_estimator.estimator.full_estimator import FullEstimatorInASingleStepWithSmoothMargin
-from extreme_estimator.estimator.margin_estimator import SmoothMarginEstimator
-from extreme_estimator.extreme_models.margin_model.smooth_margin_model import ConstantMarginModel, \
-    LinearAllParametersAllDimsMarginModel
-from extreme_estimator.extreme_models.max_stable_model.max_stable_models import Smith
-from extreme_estimator.gev_params import GevParams
-from spatio_temporal_dataset.coordinates.unidimensional_coordinates.coordinates_1D import LinSpaceCoordinates
-
-from spatio_temporal_dataset.dataset.simulation_dataset import FullSimulatedDataset
-from spatio_temporal_dataset.slicer.spatial_slicer import SpatialSlicer
-from spatio_temporal_dataset.slicer.spatio_temporal_slicer import SpatioTemporalSlicer
-
-random.seed(42)
-
-
-def load_dataset():
-    nb_points = 50
-    nb_obs = 60
-    coordinates = LinSpaceCoordinates.from_nb_points(nb_points=nb_points, train_split_ratio=0.8)
-
-    # MarginModel Linear with respect to the shape (from 0.01 to 0.02)
-    params_sample = {
-        (GevParams.GEV_LOC, 0): 10,
-        (GevParams.GEV_SHAPE, 0): 1.0,
-        (GevParams.GEV_SCALE, 0): 1.0,
-    }
-    margin_model = ConstantMarginModel(coordinates=coordinates, params_sample=params_sample)
-    max_stable_model = Smith()
-
-    return FullSimulatedDataset.from_double_sampling(nb_obs=nb_obs, margin_model=margin_model,
-                                                     coordinates=coordinates,
-                                                     max_stable_model=max_stable_model)
-
-
-def full_estimator(dataset):
-    max_stable_model = Smith()
-    margin_model_for_estimator = LinearAllParametersAllDimsMarginModel(dataset.coordinates)
-    # full_estimator = FullEstimatorInASingleStepWithSmoothMargin(dataset, margin_model_for_estimator, max_stable_model)
-    fast_estimator = SmoothMarginEstimator(dataset, margin_model_for_estimator)
-    return fast_estimator
-
-
-if __name__ == '__main__':
-    dataset = load_dataset()
-    dataset.slicer.summary()
-    full_estimator = full_estimator(dataset)
-    curve = SplitCurve(dataset, full_estimator, margin_functions=[LocFunction()])
-    curve.visualize()
diff --git a/experiment/fit_diagnosis/split_curve.py b/experiment/fit_diagnosis/split_curve.py
index bc408c5ea02254aae864730ca6faf807833e10c2..a7511791734f5ea4503025f602a2608b0342e377 100644
--- a/experiment/fit_diagnosis/split_curve.py
+++ b/experiment/fit_diagnosis/split_curve.py
@@ -1,4 +1,5 @@
 import numpy as np
+import matplotlib.cm as cm
 
 import matplotlib.pyplot as plt
 import seaborn as sns
@@ -7,42 +8,53 @@ from typing import Union, List
 
 from extreme_estimator.estimator.full_estimator import AbstractFullEstimator
 from extreme_estimator.estimator.margin_estimator import AbstractMarginEstimator
+from extreme_estimator.extreme_models.margin_model.margin_function.combined_margin_function import \
+    CombinedMarginFunction
 from extreme_estimator.extreme_models.margin_model.margin_function.utils import error_dict_between_margin_functions
 from extreme_estimator.gev_params import GevParams
 from spatio_temporal_dataset.dataset.simulation_dataset import FullSimulatedDataset
 from spatio_temporal_dataset.slicer.split import Split, ALL_SPLITS_EXCEPT_ALL
 
 
-class MarginFunction(object):
+class SplitCurve(object):
 
-    def margin_function(self, gev_param: GevParams) -> float:
-        pass
+    def __init__(self, nb_fit: int = 1):
+        self.nb_fit = nb_fit
+        self.margin_function_fitted_all = None
 
+    def fit(self, show=True):
+        self.margin_function_fitted_all = []
 
-class LocFunction(MarginFunction):
+        for i in range(self.nb_fit):
+            # A new dataset with the same margin, but just the observations are resampled
+            self.dataset = self.load_dataset()
+            # Both split must be defined
+            assert not self.dataset.slicer.some_required_ind_are_not_defined
+            self.margin_function_sample = self.dataset.margin_model.margin_function_sample
 
-    def margin_function(self, gev_param: GevParams) -> float:
-        return gev_param.location
+            print('Fitting {}/{}...'.format(i + 1, self.nb_fit))
+            self.estimator = self.load_estimator(self.dataset)
+            # Fit the estimator and get the margin_function
+            self.estimator.fit()
+            self.margin_function_fitted_all.append(self.estimator.margin_function_fitted)
 
+        # Individual error dict
+        self.error_dict_all = [error_dict_between_margin_functions(self.margin_function_sample, m)
+                               for m in self.margin_function_fitted_all]
 
-class SplitCurve(object):
+        # Mean margin
+        self.mean_margin_function_fitted = CombinedMarginFunction.from_margin_functions(self.margin_function_fitted_all)
+        self.mean_error_dict = error_dict_between_margin_functions(self.margin_function_sample,
+                                                                   self.mean_margin_function_fitted)
 
-    def __init__(self, dataset: FullSimulatedDataset, estimator: Union[AbstractFullEstimator, AbstractMarginEstimator],
-                 margin_functions: List[MarginFunction]):
-        # Dataset is already loaded and will not be modified
-        self.dataset = dataset
-        # Both split must be defined
-        assert not self.dataset.slicer.some_required_ind_are_not_defined
-        self.margin_function_sample = self.dataset.margin_model.margin_function_sample
+        if show:
+            self.visualize()
 
-        self.estimator = estimator
-        # Fit the estimator and get the margin_function
-        self.estimator.fit()
-        # todo: potentially I will do the fit several times, and retrieve the mean error
-        # there is a big variablility so it would be really interesting to average, to make real
-        self.margin_function_fitted = estimator.margin_function_fitted
+    def load_dataset(self):
+        pass
 
-        self.error_dict = error_dict_between_margin_functions(self.margin_function_sample, self.margin_function_fitted)
+    def load_estimator(self, dataset):
+        pass
 
     @property
     def main_title(self):
@@ -58,17 +70,39 @@ class SplitCurve(object):
         plt.show()
 
     def margin_graph(self, ax, gev_value_name):
-        # Display the fitted curve
-        self.margin_function_fitted.visualize_single_param(gev_value_name, ax, show=False)
+        # Create bins of data, each with an associated color corresponding to its error
+
+        data = self.mean_error_dict[gev_value_name].values
+        nb_bins = 10
+        limits = np.linspace(data.min(), data.max(), num=nb_bins + 1)
+        limits[-1] += 0.01
+        colors = cm.binary(limits)
+
         # Display train/test points
         for split, marker in [(self.dataset.train_split, 'o'), (self.dataset.test_split, 'x')]:
-            self.margin_function_sample.set_datapoint_display_parameters(split, datapoint_marker=marker)
-            self.margin_function_sample.visualize_single_param(gev_value_name, ax, show=False)
+            for left_limit, right_limit, color in zip(limits[:-1], limits[1:], colors):
+                # Find for the split the index
+                data_ind = self.mean_error_dict[gev_value_name].loc[
+                    self.dataset.coordinates.coordinates_index(split)].values
+                data_filter = np.logical_and(left_limit <= data_ind, data_ind < right_limit)
+
+                self.margin_function_sample.set_datapoint_display_parameters(split, datapoint_marker=marker,
+                                                                             filter=data_filter, color=color)
+                self.margin_function_sample.visualize_single_param(gev_value_name, ax, show=False)
+
+        # Display the individual fitted curve
+        self.mean_margin_function_fitted.color = 'lightskyblue'
+        for m in self.margin_function_fitted_all:
+            m.visualize_single_param(gev_value_name, ax, show=False)
+        # Display the mean fitted curve
+        self.mean_margin_function_fitted.color = 'blue'
+        self.mean_margin_function_fitted.visualize_single_param(gev_value_name, ax, show=False)
 
     def score_graph(self, ax, gev_value_name):
         # todo: for the moment only the train/test are interresting (the spatio temporal isn"t working yet)
+
         sns.set_style('whitegrid')
-        s = self.error_dict[gev_value_name]
+        s = self.mean_error_dict[gev_value_name]
         for split in self.dataset.splits:
             ind = self.dataset.coordinates_index(split)
             data = s.loc[ind].values
diff --git a/experiment/fit_diagnosis/split_curve_example.py b/experiment/fit_diagnosis/split_curve_example.py
new file mode 100644
index 0000000000000000000000000000000000000000..790ca10e730cc88b85d9fc4f18c6881f2159234a
--- /dev/null
+++ b/experiment/fit_diagnosis/split_curve_example.py
@@ -0,0 +1,55 @@
+from typing import Union
+
+from experiment.fit_diagnosis.split_curve import SplitCurve
+from extreme_estimator.estimator.full_estimator import AbstractFullEstimator
+from extreme_estimator.estimator.margin_estimator import AbstractMarginEstimator, ConstantMarginEstimator
+from spatio_temporal_dataset.dataset.simulation_dataset import FullSimulatedDataset
+
+import random
+
+from experiment.fit_diagnosis.split_curve import SplitCurve
+from extreme_estimator.estimator.full_estimator import FullEstimatorInASingleStepWithSmoothMargin
+from extreme_estimator.estimator.margin_estimator import SmoothMarginEstimator
+from extreme_estimator.extreme_models.margin_model.smooth_margin_model import ConstantMarginModel, \
+    LinearAllParametersAllDimsMarginModel
+from extreme_estimator.extreme_models.max_stable_model.max_stable_models import Smith
+from extreme_estimator.gev_params import GevParams
+from spatio_temporal_dataset.coordinates.unidimensional_coordinates.coordinates_1D import LinSpaceCoordinates
+from spatio_temporal_dataset.dataset.simulation_dataset import FullSimulatedDataset
+
+
+class SplitCurveExample(SplitCurve):
+
+    def __init__(self, nb_fit: int = 1):
+        super().__init__(nb_fit)
+        self.nb_points = 50
+        self.nb_obs = 60
+        self.coordinates = LinSpaceCoordinates.from_nb_points(nb_points=self.nb_points, train_split_ratio=0.8)
+        # MarginModel Linear with respect to the shape (from 0.01 to 0.02)
+        params_sample = {
+            (GevParams.GEV_LOC, 0): 10,
+            (GevParams.GEV_SHAPE, 0): 1.0,
+            (GevParams.GEV_SCALE, 0): 1.0,
+        }
+        self.margin_model = ConstantMarginModel(coordinates=self.coordinates, params_sample=params_sample)
+        self.max_stable_model = Smith()
+
+    def load_dataset(self):
+        return FullSimulatedDataset.from_double_sampling(nb_obs=self.nb_obs, margin_model=self.margin_model,
+                                                         coordinates=self.coordinates,
+                                                         max_stable_model=self.max_stable_model)
+
+    def load_estimator(self, dataset):
+        max_stable_model = Smith()
+        margin_model_for_estimator = LinearAllParametersAllDimsMarginModel(dataset.coordinates)
+        estimator = FullEstimatorInASingleStepWithSmoothMargin(dataset, margin_model_for_estimator, max_stable_model)
+        # estimator = SmoothMarginEstimator(dataset, margin_model_for_estimator)
+        return estimator
+
+
+
+
+
+if __name__ == '__main__':
+    curve = SplitCurveExample(nb_fit=2)
+    curve.fit()
diff --git a/experiment/return_level_plot/__init__.py b/experiment/return_level_plot/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/experiment/return_level_plot/spatial_2D_plot.py b/experiment/return_level_plot/spatial_2D_plot.py
deleted file mode 100644
index 8347b0e90cf475097612577e6dc9296a8d698542..0000000000000000000000000000000000000000
--- a/experiment/return_level_plot/spatial_2D_plot.py
+++ /dev/null
@@ -1,30 +0,0 @@
-from itertools import product
-from typing import List, Dict
-
-import matplotlib.pyplot as plt
-
-from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \
-    AbstractMarginFunction
-from extreme_estimator.gev_params import GevParams
-
-plt.style.use('seaborn-white')
-
-
-class Spatial2DPlot(object):
-
-    def __init__(self, name_to_margin_function: Dict[str, AbstractMarginFunction]):
-        self.name_to_margin_function = name_to_margin_function # type: Dict[str, AbstractMarginFunction]
-        self.grid_columns = GevParams.GEV_PARAM_NAMES
-
-    def plot(self):
-        nb_grid_rows, nb_grid_columns = len(self.name_to_margin_function), len(self.grid_columns)
-        fig, axes = plt.subplots(nb_grid_rows, nb_grid_columns, sharex='col', sharey='row')
-        fig.subplots_adjust(hspace=0.4, wspace=0.4)
-        margin_function: AbstractMarginFunction
-        for i, (name, margin_function) in enumerate(self.name_to_margin_function.items()):
-            for j, param_name in enumerate(self.grid_columns):
-                ax = axes[i, j] if nb_grid_rows > 1 else axes[j]
-                margin_function.visualize_2D(gev_param_name=param_name, ax=ax)
-                ax.set_title("{} for {}".format(param_name, name))
-        fig.suptitle('Spatial2DPlot')
-        plt.show()
diff --git a/extreme_estimator/estimator/margin_estimator.py b/extreme_estimator/estimator/margin_estimator.py
index e4268ddede6f64a3fd22dbcb3bad14dc50ab4e0f..829cd65c7305c51cfa8421ba8392a98eb3b89250 100644
--- a/extreme_estimator/estimator/margin_estimator.py
+++ b/extreme_estimator/estimator/margin_estimator.py
@@ -22,6 +22,17 @@ class PointWiseMarginEstimator(AbstractMarginEstimator):
     pass
 
 
+class ConstantMarginEstimator(AbstractMarginEstimator):
+
+    def __init__(self, dataset: AbstractDataset, margin_model: LinearMarginModel):
+        super().__init__(dataset)
+        assert isinstance(margin_model, LinearMarginModel)
+        self.margin_model = margin_model
+
+    def _fit(self):
+        self._margin_function_fitted = self.margin_model.margin_function_start_fit
+
+
 class SmoothMarginEstimator(AbstractMarginEstimator):
     """# with different type of marginals: cosntant, linear...."""
 
diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
index 0d65ba0e0436e619da477d35665a40d676fdf937..1072d4a5d4749f32efffa3b6639fce674221ad13 100644
--- a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
+++ b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
@@ -21,6 +21,8 @@ class AbstractMarginFunction(object):
         self.datapoint_display = False
         self.spatio_temporal_split = Split.all
         self.datapoint_marker = 'o'
+        self.color = 'skyblue'
+        self.filter = None
 
     def get_gev_params(self, coordinate: np.ndarray) -> GevParams:
         """Main method that maps each coordinate to its GEV parameters"""
@@ -40,21 +42,23 @@ class AbstractMarginFunction(object):
 
     # Visualization function
 
-    def set_datapoint_display_parameters(self, spatio_temporal_split, datapoint_marker):
+    def set_datapoint_display_parameters(self, spatio_temporal_split, datapoint_marker, filter=None, color=None):
         self.datapoint_display = True
         self.spatio_temporal_split = spatio_temporal_split
         self.datapoint_marker = datapoint_marker
+        self.filter = filter
+        self.color = color
 
     def visualize(self, axes=None, show=True, dot_display=False):
         self.datapoint_display = dot_display
         if axes is None:
-            fig, axes = plt.subplots(3, 1, sharex='col', sharey='row')
-            fig.subplots_adjust(hspace=0.4, wspace=0.4, )
+            fig, axes = plt.subplots(1, len(GevParams.GEV_VALUE_NAMES))
+            fig.subplots_adjust(hspace=1.0, wspace=1.0)
         self.visualization_axes = axes
-        for i, gev_param_name in enumerate(GevParams.GEV_PARAM_NAMES):
+        for i, gev_value_name in enumerate(GevParams.GEV_VALUE_NAMES):
             ax = axes[i]
-            self.visualize_single_param(gev_param_name, ax, show=False)
-            title_str = gev_param_name
+            self.visualize_single_param(gev_value_name, ax, show=False)
+            title_str = gev_value_name
             ax.set_title(title_str)
         if show:
             plt.show()
@@ -68,45 +72,32 @@ class AbstractMarginFunction(object):
         else:
             raise NotImplementedError('3D Margin visualization not yet implemented')
 
+    # Visualization 1D
+
     def visualize_1D(self, gev_value_name=GevParams.GEV_LOC, ax=None, show=True):
         x = self.coordinates.x_coordinates
         grid, linspace = self.get_grid_values_1D(x)
         if ax is None:
             ax = plt.gca()
         if self.datapoint_display:
-            ax.plot(linspace, grid[gev_value_name], self.datapoint_marker)
+            ax.plot(linspace, grid[gev_value_name], self.datapoint_marker, color=self.color)
         else:
-            ax.plot(linspace, grid[gev_value_name])
+            ax.plot(linspace, grid[gev_value_name], color=self.color)
         # X axis
-        ax.set_xlabel('coordinate')
+        ax.set_xlabel('coordinate X')
         plt.setp(ax.get_xticklabels(), visible=True)
         ax.xaxis.set_tick_params(labelbottom=True)
-        # Y axis
-        ax.set_ylabel(gev_value_name)
-        plt.setp(ax.get_yticklabels(), visible=True)
-        ax.yaxis.set_tick_params(labelbottom=True)
 
         if show:
             plt.show()
 
-    def visualize_2D(self, gev_param_name=GevParams.GEV_LOC, ax=None, show=True):
-        x = self.coordinates.x_coordinates
-        y = self.coordinates.y_coordinates
-        grid = self.get_grid_2D(x, y)
-        gev_param_idx = GevParams.GEV_PARAM_NAMES.index(gev_param_name)
-        if ax is None:
-            ax = plt.gca()
-        imshow_method = ax.imshow
-        imshow_method(grid[..., gev_param_idx], extent=(x.min(), x.max(), y.min(), y.max()),
-                      interpolation='nearest', cmap=cm.gist_rainbow)
-        # todo: add dot display in 2D
-        if show:
-            plt.show()
-
     def get_grid_values_1D(self, x):
         # TODO: to avoid getting the value several times, I could cache the results
         if self.datapoint_display:
+            # todo: keep only the index of interest here
             linspace = self.coordinates.coordinates_values(self.spatio_temporal_split)[:, 0]
+            if self.filter is not None:
+                linspace = linspace[self.filter]
             resolution = len(linspace)
         else:
             resolution = 100
@@ -119,10 +110,35 @@ class AbstractMarginFunction(object):
         grid = {gev_param: [g[gev_param] for g in grid] for gev_param in GevParams.GEV_VALUE_NAMES}
         return grid, linspace
 
+    # Visualization 2D
+
+    def visualize_2D(self, gev_value_name=GevParams.GEV_LOC, ax=None, show=True):
+        x = self.coordinates.x_coordinates
+        y = self.coordinates.y_coordinates
+        grid = self.get_grid_2D(x, y)
+        if ax is None:
+            ax = plt.gca()
+        imshow_method = ax.imshow
+        imshow_method(grid[gev_value_name], extent=(x.min(), x.max(), y.min(), y.max()),
+                      interpolation='nearest', cmap=cm.gist_rainbow)
+        # X axis
+        ax.set_xlabel('coordinate X')
+        plt.setp(ax.get_xticklabels(), visible=True)
+        ax.xaxis.set_tick_params(labelbottom=True)
+        # Y axis
+        ax.set_ylabel('coordinate Y')
+        plt.setp(ax.get_yticklabels(), visible=True)
+        ax.yaxis.set_tick_params(labelbottom=True)
+        # todo: add dot display in 2D
+        if show:
+            plt.show()
+
     def get_grid_2D(self, x, y):
         resolution = 100
-        grid = np.zeros([resolution, resolution, 3])
+        grid = []
         for i, xi in enumerate(np.linspace(x.min(), x.max(), resolution)):
             for j, yj in enumerate(np.linspace(y.min(), y.max(), resolution)):
-                grid[i, j] = self.get_gev_params(np.array([xi, yj])).to_array()
+                grid.append(self.get_gev_params(np.array([xi, yj])).value_dict)
+        grid = {value_name: np.array([g[value_name] for g in grid]).reshape([resolution, resolution])
+                for value_name in GevParams.GEV_VALUE_NAMES}
         return grid
diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/combined_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/combined_margin_function.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d149cb78fea9822f2c0826a8a75f06409dc180d
--- /dev/null
+++ b/extreme_estimator/extreme_models/margin_model/margin_function/combined_margin_function.py
@@ -0,0 +1,30 @@
+from typing import List
+
+import numpy as np
+
+from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \
+    AbstractMarginFunction
+from extreme_estimator.gev_params import GevParams
+from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
+
+
+class CombinedMarginFunction(AbstractMarginFunction):
+
+    def __init__(self, coordinates: AbstractCoordinates, margin_functions: List[AbstractMarginFunction]):
+        super().__init__(coordinates)
+        self.margin_functions = margin_functions  # type: List[AbstractMarginFunction]
+
+    def get_gev_params(self, coordinate: np.ndarray) -> GevParams:
+        gev_params_list = [margin_function.get_gev_params(coordinate) for margin_function in self.margin_functions]
+        mean_gev_params = np.mean(np.array([gev_param.to_array() for gev_param in gev_params_list]), axis=0)
+        gev_param = GevParams(*mean_gev_params)
+        return gev_param
+
+    @classmethod
+    def from_margin_functions(cls, margin_functions: List[AbstractMarginFunction]):
+        assert len(margin_functions) > 0
+        assert all([isinstance(margin_function, AbstractMarginFunction) for margin_function in margin_functions])
+        common_coordinates = set([margin_function.coordinates for margin_function in margin_functions])
+        assert len(common_coordinates) == 1
+        coordinates = common_coordinates.pop()
+        return cls(coordinates, margin_functions)
diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/utils.py b/extreme_estimator/extreme_models/margin_model/margin_function/utils.py
index 2ea29cc1f5368de65e1deb493ed6ba7162a378d0..4f044b987f4cb0778dfaa87d86e05f95949332c4 100644
--- a/extreme_estimator/extreme_models/margin_model/margin_function/utils.py
+++ b/extreme_estimator/extreme_models/margin_model/margin_function/utils.py
@@ -20,10 +20,8 @@ def error_dict_between_margin_functions(reference: AbstractMarginFunction, fitte
     assert reference.coordinates == fitted.coordinates
     reference_values = reference.gev_value_name_to_serie
     fitted_values = fitted.gev_value_name_to_serie
-
     gev_param_name_to_error_serie = {}
-    for value_name in GevParams.GEV_VALUE_NAMES:
-        print(value_name)
-        error = relative_abs_error(reference_values[value_name], fitted_values[value_name])
-        gev_param_name_to_error_serie[value_name] = error
+    for gev_value_name in GevParams.GEV_VALUE_NAMES:
+        error = relative_abs_error(reference_values[gev_value_name], fitted_values[gev_value_name])
+        gev_param_name_to_error_serie[gev_value_name] = error
     return gev_param_name_to_error_serie
diff --git a/test/test_experiment/test_split_curve.py b/test/test_experiment/test_split_curve.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b9b1d69093ec2ff3269f51008e50bb868ccd8ac
--- /dev/null
+++ b/test/test_experiment/test_split_curve.py
@@ -0,0 +1,56 @@
+import unittest
+from typing import Union
+
+from experiment.fit_diagnosis.split_curve import SplitCurve
+from extreme_estimator.estimator.full_estimator import AbstractFullEstimator
+from extreme_estimator.estimator.margin_estimator import AbstractMarginEstimator, ConstantMarginEstimator
+from spatio_temporal_dataset.dataset.simulation_dataset import FullSimulatedDataset
+
+import random
+
+from experiment.fit_diagnosis.split_curve import SplitCurve
+from extreme_estimator.estimator.full_estimator import FullEstimatorInASingleStepWithSmoothMargin
+from extreme_estimator.estimator.margin_estimator import SmoothMarginEstimator
+from extreme_estimator.extreme_models.margin_model.smooth_margin_model import ConstantMarginModel, \
+    LinearAllParametersAllDimsMarginModel
+from extreme_estimator.extreme_models.max_stable_model.max_stable_models import Smith
+from extreme_estimator.gev_params import GevParams
+from spatio_temporal_dataset.coordinates.unidimensional_coordinates.coordinates_1D import LinSpaceCoordinates
+from spatio_temporal_dataset.dataset.simulation_dataset import FullSimulatedDataset
+
+
+class TestSplitCurve(unittest.TestCase):
+    DISPLAY = False
+
+    class SplitCurveFastForTest(SplitCurve):
+
+        def __init__(self, nb_fit: int = 1):
+            super().__init__(nb_fit)
+            self.nb_points = 50
+            self.nb_obs = 60
+            self.coordinates = LinSpaceCoordinates.from_nb_points(nb_points=self.nb_points, train_split_ratio=0.8)
+            # MarginModel Linear with respect to the shape (from 0.01 to 0.02)
+            params_sample = {
+                (GevParams.GEV_LOC, 0): 10,
+                (GevParams.GEV_SHAPE, 0): 1.0,
+                (GevParams.GEV_SCALE, 0): 1.0,
+            }
+            self.margin_model = ConstantMarginModel(coordinates=self.coordinates, params_sample=params_sample)
+            self.max_stable_model = Smith()
+
+        def load_dataset(self):
+            return FullSimulatedDataset.from_double_sampling(nb_obs=self.nb_obs, margin_model=self.margin_model,
+                                                             coordinates=self.coordinates,
+                                                             max_stable_model=self.max_stable_model)
+
+        def load_estimator(self, dataset):
+            # todo: create a test from that example
+            return ConstantMarginEstimator(dataset, LinearAllParametersAllDimsMarginModel(dataset.coordinates))
+
+    def test_split_curve(self):
+        s = self.SplitCurveFastForTest(nb_fit=2)
+        s.fit(show=self.DISPLAY)
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/test/test_extreme_estimator/test_estimator/test_margin_estimators.py b/test/test_extreme_estimator/test_estimator/test_margin_estimators.py
index cd7b0a9e1d3c91da72e25e4f54d20abd28d6a60d..f9a51a737613b4fe6bd4855cf9ba475b9c3b9b2c 100644
--- a/test/test_extreme_estimator/test_estimator/test_margin_estimators.py
+++ b/test/test_extreme_estimator/test_estimator/test_margin_estimators.py
@@ -1,7 +1,6 @@
 import unittest
 
 from extreme_estimator.estimator.margin_estimator import SmoothMarginEstimator
-from experiment.return_level_plot.spatial_2D_plot import Spatial2DPlot
 from spatio_temporal_dataset.dataset.simulation_dataset import MarginDataset
 from test.test_utils import load_smooth_margin_models, load_test_1D_and_2D_coordinates
 
@@ -17,21 +16,15 @@ class TestSmoothMarginEstimator(unittest.TestCase):
     def test_dependency_estimators(self):
         for coordinates in self.coordinates:
             smooth_margin_models = load_smooth_margin_models(coordinates=coordinates)
-            for margin_model in smooth_margin_models:
+            for margin_model in smooth_margin_models[1:]:
                 dataset = MarginDataset.from_sampling(nb_obs=10,
                                                       margin_model=margin_model,
                                                       coordinates=coordinates)
                 # Fit estimator
                 estimator = SmoothMarginEstimator(dataset=dataset, margin_model=margin_model)
                 estimator.fit()
-                # Map name to their margin functions
-                name_to_margin_function = {
-                    'Ground truth margin function': dataset.margin_model.margin_function_sample,
-                    'Estimated margin function': estimator.margin_function_fitted,
-                }
-                # Spatial Plot
-                if self.DISPLAY:
-                    Spatial2DPlot(name_to_margin_function=name_to_margin_function).plot()
+                # Plot
+                margin_model.margin_function_sample.visualize(show=self.DISPLAY)
         self.assertTrue(True)