diff --git a/extreme_estimator/estimator/abstract_estimator.py b/extreme_estimator/estimator/abstract_estimator.py
index ac58f74a3678e9097c1ebcee9c4e33ba68d16c6d..7d5bef88046cb14538b19345682d502e6bf4acc5 100644
--- a/extreme_estimator/estimator/abstract_estimator.py
+++ b/extreme_estimator/estimator/abstract_estimator.py
@@ -1,7 +1,7 @@
 import time
 
 from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
-from spatio_temporal_dataset.spatio_temporal_split import SpatialTemporalSplit
+from spatio_temporal_dataset.slicer.split import Split
 
 
 class AbstractEstimator(object):
@@ -15,7 +15,10 @@ class AbstractEstimator(object):
     def __init__(self, dataset: AbstractDataset):
         self.dataset = dataset  # type: AbstractDataset
         self.additional_information = dict()
-        self.train_split = SpatialTemporalSplit.train
+
+    @property
+    def train_split(self):
+        return self.dataset.train_split
 
     def fit(self):
         ts = time.time()
diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
index 3a19e9cb959a6a43de97b2a02653ad0b64facd42..dd12340e387fae033de1b8ab02bfad4ab905c703 100644
--- a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
+++ b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
@@ -1,10 +1,11 @@
 import matplotlib.cm as cm
 import matplotlib.pyplot as plt
 import numpy as np
+import pandas as pd
 
 from extreme_estimator.gev_params import GevParams
 from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
-from spatio_temporal_dataset.spatio_temporal_split import SpatialTemporalSplit
+from spatio_temporal_dataset.slicer.split import Split
 
 
 class AbstractMarginFunction(object):
@@ -16,13 +17,24 @@ class AbstractMarginFunction(object):
         # Visualization parameters
         self.visualization_axes = None
         self.datapoint_display = False
-        self.spatio_temporal_split = SpatialTemporalSplit.all
+        self.spatio_temporal_split = Split.all
         self.datapoint_marker = 'o'
 
     def get_gev_params(self, coordinate: np.ndarray) -> GevParams:
         """Main method that maps each coordinate to its GEV parameters"""
         pass
 
+    # Extraction function
+
+    @property
+    def gev_params_for_coordinates(self):
+        gev_params = [self.get_gev_params(coordinate).to_dict() for coordinate in self.coordinates.coordinates_values()]
+        gev_param_name_to_serie = {}
+        for gev_param_name in GevParams.GEV_PARAM_NAMES:
+            s = pd.Series(data=[p[gev_param_name] for p in gev_params], index=self.coordinates.index)
+            gev_param_name_to_serie[gev_param_name] = s
+        return gev_param_name_to_serie
+
     # Visualization function
 
     def set_datapoint_display_parameters(self, spatio_temporal_split, datapoint_marker):
diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/utils.py b/extreme_estimator/extreme_models/margin_model/margin_function/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a653f012cb41996f5d80299aed3d3a0694ef605
--- /dev/null
+++ b/extreme_estimator/extreme_models/margin_model/margin_function/utils.py
@@ -0,0 +1,20 @@
+import numpy as np
+
+from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \
+    AbstractMarginFunction
+from extreme_estimator.gev_params import GevParams
+
+
+def abs_error(s1, s2):
+    return (s1 - s2).abs().pow(2)
+
+
+def error_dict_between_margin_functions(margin1: AbstractMarginFunction, margin2: AbstractMarginFunction):
+    assert margin1.coordinates == margin2.coordinates
+    margin1_gev_params, margin2_gev_params = margin1.gev_params_for_coordinates, margin2.gev_params_for_coordinates
+    gev_param_name_to_error_serie = {}
+    for gev_param_name in GevParams.GEV_PARAM_NAMES:
+        serie1, serie2 = margin1_gev_params[gev_param_name], margin2_gev_params[gev_param_name]
+        error = abs_error(serie1, serie2)
+        gev_param_name_to_error_serie[gev_param_name] = error
+    return gev_param_name_to_error_serie
diff --git a/extreme_estimator/extreme_models/max_stable_model/abstract_max_stable_model.py b/extreme_estimator/extreme_models/max_stable_model/abstract_max_stable_model.py
index 41303c9e5c9ed8d81e7c918887a73bb843c0f5ae..5751dba1c828b95800e4812ab914f3981d4facc8 100644
--- a/extreme_estimator/extreme_models/max_stable_model/abstract_max_stable_model.py
+++ b/extreme_estimator/extreme_models/max_stable_model/abstract_max_stable_model.py
@@ -34,7 +34,7 @@ class AbstractMaxStableModel(AbstractModel):
         maxima = maxima_gev if fit_marge else maxima_frech
         assert isinstance(maxima, np.ndarray)
         assert len(df_coordinates) == len(maxima), 'Coordinates and observations sizes should match,' \
-                                                   'check that the same split was used for both objects \n,' \
+                                                   'check that the same split was used for both objects, \n' \
                                                    'df_coordinates size: {}, data size {}'.format(len(df_coordinates),
                                                                                                   len(maxima))
         data = np.transpose(maxima)
diff --git a/spatio_temporal_dataset/coordinates/abstract_coordinates.py b/spatio_temporal_dataset/coordinates/abstract_coordinates.py
index f32b25559b1d86f41b7e786776d75b6998d6d8df..8d7e3511a4d1486c6a426c77fcdfb30767d8fded 100644
--- a/spatio_temporal_dataset/coordinates/abstract_coordinates.py
+++ b/spatio_temporal_dataset/coordinates/abstract_coordinates.py
@@ -6,8 +6,11 @@ import numpy as np
 import pandas as pd
 from mpl_toolkits.mplot3d import Axes3D
 
-from spatio_temporal_dataset.spatio_temporal_split import s_split_from_ratio, TEST_SPLIT_STR, \
-    TRAIN_SPLIT_STR, train_ind_from_s_split, SpatialTemporalSplit
+from spatio_temporal_dataset.slicer.spatial_slicer import SpatialSlicer
+from spatio_temporal_dataset.slicer.spatio_temporal_slicer import SpatioTemporalSlicer
+from spatio_temporal_dataset.slicer.split import s_split_from_ratio, TEST_SPLIT_STR, \
+    TRAIN_SPLIT_STR, train_ind_from_s_split, Split
+from spatio_temporal_dataset.slicer.temporal_slicer import TemporalSlicer
 
 
 class AbstractCoordinates(object):
@@ -31,7 +34,7 @@ class AbstractCoordinates(object):
         # Create a split based on the train_split_ratio
         if train_split_ratio is not None:
             assert cls.COORDINATE_SPLIT not in df.columns, "A split has already been defined"
-            s_split = s_split_from_ratio(length=len(df), train_split_ratio=train_split_ratio)
+            s_split = s_split_from_ratio(index=df.index, train_split_ratio=train_split_ratio)
             df[cls.COORDINATE_SPLIT] = s_split
         # Potentially, a split column can be specified directly in df
         if cls.COORDINATE_SPLIT not in df.columns:
@@ -91,15 +94,21 @@ class AbstractCoordinates(object):
         # Merged DataFrame of df_coord and s_split
         return self.df_coord if self.s_split is None else self.df_coord.join(self.s_split)
 
-    def df_coordinates(self, split: SpatialTemporalSplit = SpatialTemporalSplit.all) -> pd.DataFrame:
-        if split is SpatialTemporalSplit.all or self.s_split is None:
+    def df_coordinates(self, split: Split = Split.all) -> pd.DataFrame:
+        if self.train_ind is None:
             return self.df_coord
-        elif split in [SpatialTemporalSplit.train, SpatialTemporalSplit.test_temporal]:
+        if split is Split.all:
+            return self.df_coord
+        if split in [Split.train_temporal, Split.test_temporal]:
+            return self.df_coord
+        elif split in [Split.train_spatial, Split.train_spatiotemporal, Split.test_spatiotemporal_temporal]:
             return self.df_coord.loc[self.train_ind]
-        else:
+        elif split in [Split.test_spatial, Split.test_spatiotemporal, Split.test_spatiotemporal_spatial]:
             return self.df_coord.loc[~self.train_ind]
+        else:
+            raise NotImplementedError('Unknown split: {}'.format(split))
 
-    def coordinates_values(self, split: SpatialTemporalSplit = SpatialTemporalSplit.all) -> np.ndarray:
+    def coordinates_values(self, split: Split = Split.all) -> np.ndarray:
         return self.df_coordinates(split).values
 
     @property
diff --git a/spatio_temporal_dataset/dataset/abstract_dataset.py b/spatio_temporal_dataset/dataset/abstract_dataset.py
index 6cbf1ae0eb4bd9ea6260a415b59420ad76e59b89..4c7a78865855abd391520018a02c5b4a25be34e9 100644
--- a/spatio_temporal_dataset/dataset/abstract_dataset.py
+++ b/spatio_temporal_dataset/dataset/abstract_dataset.py
@@ -1,22 +1,29 @@
 import os
-import numpy as np
 import os.path as op
+from typing import List
+
+import numpy as np
 import pandas as pd
 
-from spatio_temporal_dataset.spatio_temporal_split import SpatialTemporalSplit, SpatioTemporalSlicer
-from spatio_temporal_dataset.spatio_temporal_observations.abstract_spatio_temporal_observations import AbstractSpatioTemporalObservations
 from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
+from spatio_temporal_dataset.slicer.abstract_slicer import AbstractSlicer
+from spatio_temporal_dataset.slicer.spatial_slicer import SpatialSlicer
+from spatio_temporal_dataset.slicer.split import Split
+from spatio_temporal_dataset.spatio_temporal_observations.abstract_spatio_temporal_observations import \
+    AbstractSpatioTemporalObservations
 
 
 class AbstractDataset(object):
 
-    def __init__(self, observations: AbstractSpatioTemporalObservations, coordinates: AbstractCoordinates):
-        # is_same_index = spatio_temporal_observations.index == coordinates.index  # type: pd.Series
-        # assert is_same_index.all()
+    def __init__(self, observations: AbstractSpatioTemporalObservations, coordinates: AbstractCoordinates,
+                 slicer_class: type = SpatialSlicer):
+        assert pd.Index.equals(observations.index, coordinates.index)
+        assert isinstance(slicer_class, type)
         self.observations = observations
         self.coordinates = coordinates
-        self.spatio_temporal_slicer = SpatioTemporalSlicer(coordinates_train_ind=self.coordinates.train_ind,
-                                                           observations_train_ind=self.observations.train_ind)
+        self.slicer = slicer_class(coordinates_train_ind=self.coordinates.train_ind,
+                                   observations_train_ind=self.observations.train_ind)  # type: AbstractSlicer
+        assert isinstance(self.slicer, AbstractSlicer)
 
     @classmethod
     def from_csv(cls, csv_path: str):
@@ -38,18 +45,36 @@ class AbstractDataset(object):
         # todo: maybe I should add the split from the temporal observations
         return self.observations.df_maxima_gev.join(self.coordinates.df_merged)
 
-    def df_coordinates(self, split: SpatialTemporalSplit = SpatialTemporalSplit.all):
+    def df_coordinates(self, split: Split = Split.all) -> pd.DataFrame:
         return self.coordinates.df_coordinates(split=split)
 
+    # Observation wrapper
+
+    def maxima_gev(self, split: Split = Split.all) -> np.ndarray:
+        return self.observations.maxima_gev(split, self.slicer)
+
+    def maxima_frech(self, split: Split = Split.all) -> np.ndarray:
+        return self.observations.maxima_frech(split, self.slicer)
+
+    def set_maxima_frech(self, maxima_frech_values: np.ndarray, split: Split = Split.all):
+        self.observations.set_maxima_frech(maxima_frech_values, split, self.slicer)
+
+    # Coordinates wrapper
+
     @property
-    def coordinates_values(self, split: SpatialTemporalSplit = SpatialTemporalSplit.all):
+    def coordinates_values(self, split: Split = Split.all) -> np.ndarray:
         return self.coordinates.coordinates_values(split=split)
 
-    def maxima_gev(self, split: SpatialTemporalSplit = SpatialTemporalSplit.all) -> np.ndarray:
-        return self.observations.maxima_gev(split, self.spatio_temporal_slicer)
+    # Slicer wrapper
+
+    @property
+    def train_split(self) -> Split:
+        return self.slicer.train_split
 
-    def maxima_frech(self, split: SpatialTemporalSplit = SpatialTemporalSplit.all) -> np.ndarray:
-        return self.observations.maxima_frech(split, self.spatio_temporal_slicer)
+    @property
+    def test_split(self) -> Split:
+        return self.slicer.test_split
 
-    def set_maxima_frech(self, maxima_frech_values: np.ndarray, split: SpatialTemporalSplit = SpatialTemporalSplit.all):
-        self.observations.set_maxima_frech(maxima_frech_values, split, self.spatio_temporal_slicer)
\ No newline at end of file
+    @property
+    def splits(self) -> List[Split]:
+        return self.slicer.splits
diff --git a/spatio_temporal_dataset/dataset/simulation_dataset.py b/spatio_temporal_dataset/dataset/simulation_dataset.py
index 780d42b1b85eb9b717a3669a6487eb462c6c6289..11e8ed5657997858015ebe78b2875ccdf2390fcb 100644
--- a/spatio_temporal_dataset/dataset/simulation_dataset.py
+++ b/spatio_temporal_dataset/dataset/simulation_dataset.py
@@ -1,10 +1,12 @@
 from extreme_estimator.extreme_models.margin_model.abstract_margin_model import AbstractMarginModel
 from extreme_estimator.extreme_models.max_stable_model.abstract_max_stable_model import AbstractMaxStableModel
-from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
 from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
-from spatio_temporal_dataset.spatio_temporal_observations.abstract_spatio_temporal_observations import AbstractSpatioTemporalObservations
+from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
+from spatio_temporal_dataset.slicer.spatial_slicer import SpatialSlicer
+from spatio_temporal_dataset.spatio_temporal_observations.abstract_spatio_temporal_observations import \
+    AbstractSpatioTemporalObservations
 from spatio_temporal_dataset.spatio_temporal_observations.annual_maxima_observations import \
-    MaxStableAnnualMaxima, AnnualMaxima, MarginAnnualMaxima, FullAnnualMaxima
+    MaxStableAnnualMaxima, MarginAnnualMaxima, FullAnnualMaxima
 
 
 class SimulatedDataset(AbstractDataset):
@@ -15,30 +17,33 @@ class SimulatedDataset(AbstractDataset):
 
     def __init__(self, observations: AbstractSpatioTemporalObservations,
                  coordinates: AbstractCoordinates,
+                 slicer_class: type = SpatialSlicer,
                  max_stable_model: AbstractMaxStableModel = None,
                  margin_model: AbstractMarginModel = None):
-        super().__init__(observations, coordinates)
+        super().__init__(observations, coordinates, slicer_class)
         assert margin_model is not None or max_stable_model is not None
         self.margin_model = margin_model  # type: AbstractMarginModel
-        self.max_stable_model = max_stable_model  #  type: AbstractMaxStableModel
+        self.max_stable_model = max_stable_model  # type: AbstractMaxStableModel
 
 
 class MaxStableDataset(SimulatedDataset):
 
     @classmethod
     def from_sampling(cls, nb_obs: int, max_stable_model: AbstractMaxStableModel, coordinates: AbstractCoordinates,
-                      train_split_ratio: float = None):
+                      train_split_ratio: float = None, slicer_class: type = SpatialSlicer):
         observations = MaxStableAnnualMaxima.from_sampling(nb_obs, max_stable_model, coordinates, train_split_ratio)
-        return cls(observations=observations, coordinates=coordinates, max_stable_model=max_stable_model)
+        return cls(observations=observations, coordinates=coordinates, slicer_class=slicer_class,
+                   max_stable_model=max_stable_model)
 
 
 class MarginDataset(SimulatedDataset):
 
     @classmethod
     def from_sampling(cls, nb_obs: int, margin_model: AbstractMarginModel, coordinates: AbstractCoordinates,
-                      train_split_ratio: float = None):
+                      train_split_ratio: float = None, slicer_class: type = SpatialSlicer):
         observations = MarginAnnualMaxima.from_sampling(nb_obs, coordinates, margin_model, train_split_ratio)
-        return cls(observations=observations, coordinates=coordinates, margin_model=margin_model)
+        return cls(observations=observations, coordinates=coordinates, slicer_class=slicer_class,
+                   margin_model=margin_model)
 
 
 class FullSimulatedDataset(SimulatedDataset):
@@ -47,8 +52,9 @@ class FullSimulatedDataset(SimulatedDataset):
     def from_double_sampling(cls, nb_obs: int, max_stable_model: AbstractMaxStableModel,
                              coordinates: AbstractCoordinates,
                              margin_model: AbstractMarginModel,
-                             train_split_ratio: float = None):
+                             train_split_ratio: float = None,
+                             slicer_class: type = SpatialSlicer):
         observations = FullAnnualMaxima.from_double_sampling(nb_obs, max_stable_model,
                                                              coordinates, margin_model, train_split_ratio)
-        return cls(observations=observations, coordinates=coordinates, max_stable_model=max_stable_model,
-                   margin_model=margin_model)
+        return cls(observations=observations, coordinates=coordinates, slicer_class=slicer_class,
+                   max_stable_model=max_stable_model, margin_model=margin_model)
diff --git a/spatio_temporal_dataset/slicer/__init__.py b/spatio_temporal_dataset/slicer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/spatio_temporal_dataset/slicer/abstract_slicer.py b/spatio_temporal_dataset/slicer/abstract_slicer.py
new file mode 100644
index 0000000000000000000000000000000000000000..b45df076927b1b6fa2d71958e24cee383fc0a03d
--- /dev/null
+++ b/spatio_temporal_dataset/slicer/abstract_slicer.py
@@ -0,0 +1,78 @@
+from typing import Union, List
+
+import pandas as pd
+
+from spatio_temporal_dataset.slicer.split import Split
+
+
+class AbstractSlicer(object):
+
+    def __init__(self, coordinates_train_ind: Union[None, pd.Series], observations_train_ind: Union[None, pd.Series]):
+        self.index_train_ind = coordinates_train_ind  # type: Union[None, pd.Series]
+        self.column_train_ind = observations_train_ind  # type: Union[None, pd.Series]
+
+    @property
+    def train_split(self) -> Split:
+        pass
+
+    @property
+    def test_split(self) -> Split:
+        pass
+
+    @property
+    def splits(self) -> List[Split]:
+        pass
+
+
+    @property
+    def index_test_ind(self) -> pd.Series:
+        return ~self.index_train_ind
+
+    # todo: test should be the same as train when we don't care about that in the split
+    @property
+    def column_test_ind(self) -> pd.Series:
+        return ~self.column_train_ind
+
+    @property
+    def some_required_ind_are_not_defined(self):
+        pass
+
+    def summary(self):
+        print('Slicer summary: \n')
+        for s, global_name in [(self.index_train_ind, "Spatial"), (self.column_train_ind, "Temporal")]:
+            print(global_name + ' split')
+            if s is None:
+                print('Not handled by this slicer')
+            else:
+                for f, name in [(len, 'Total'), (sum, 'train')]:
+                    print("{}: {}".format(name, f(s)))
+                print('\n')
+
+    def loc_split(self, df: pd.DataFrame, split: Split):
+        # split should belong to the list of split accepted by the slicer
+        assert isinstance(split, Split)
+
+        if split is Split.all:
+            return df
+
+        assert split in self.splits, "split:{}, slicer_type:{}".format(split, type(self))
+
+        # By default, some required splits are not defined
+        # instead of crashing, we return all the data for all the split
+        # This is the default behavior, when the required splits has been defined
+        if self.some_required_ind_are_not_defined:
+            return df
+        else:
+            return self.specialized_loc_split(df=df, split=split)
+
+    def specialized_loc_split(self, df: pd.DataFrame, split: Split):
+        # This method should be defined in the child class
+        return None
+
+
+def slice(df: pd.DataFrame, split: Split = Split.all, slicer: AbstractSlicer = None) -> pd.DataFrame:
+    if slicer is None:
+        assert split is Split.all
+        return df
+    else:
+        return slicer.loc_split(df, split)
diff --git a/spatio_temporal_dataset/slicer/spatial_slicer.py b/spatio_temporal_dataset/slicer/spatial_slicer.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6c22a9a674d5f02137150fd010fe588dc7a31d4
--- /dev/null
+++ b/spatio_temporal_dataset/slicer/spatial_slicer.py
@@ -0,0 +1,36 @@
+from typing import List, Union
+
+import pandas as pd
+
+from spatio_temporal_dataset.slicer.abstract_slicer import AbstractSlicer
+from spatio_temporal_dataset.slicer.split import Split
+
+
+class SpatialSlicer(AbstractSlicer):
+    SPLITS = [Split.train_spatial, Split.test_spatial]
+
+    def __init__(self, coordinates_train_ind: Union[None, pd.Series], observations_train_ind: Union[None, pd.Series]):
+        super().__init__(coordinates_train_ind, None)
+
+    @property
+    def splits(self) -> List[Split]:
+        return self.SPLITS
+
+    @property
+    def train_split(self) -> Split:
+        return Split.train_spatial
+
+    @property
+    def test_split(self) -> Split:
+        return Split.test_spatial
+
+    @property
+    def some_required_ind_are_not_defined(self):
+        return self.index_train_ind is None
+
+    def specialized_loc_split(self, df: pd.DataFrame, split: Split):
+        assert pd.Index.equals(df.index, self.index_train_ind.index)
+        if split is Split.train_spatial:
+            return df.loc[self.index_train_ind, :]
+        elif split is Split.test_spatial:
+            return df.loc[self.index_test_ind, :]
diff --git a/spatio_temporal_dataset/slicer/spatio_temporal_slicer.py b/spatio_temporal_dataset/slicer/spatio_temporal_slicer.py
new file mode 100644
index 0000000000000000000000000000000000000000..47a71610020d763d057b01a4780b4fa541eb2cfb
--- /dev/null
+++ b/spatio_temporal_dataset/slicer/spatio_temporal_slicer.py
@@ -0,0 +1,41 @@
+from typing import List
+
+import pandas as pd
+
+from spatio_temporal_dataset.slicer.abstract_slicer import AbstractSlicer
+from spatio_temporal_dataset.slicer.split import Split
+
+
+class SpatioTemporalSlicer(AbstractSlicer):
+    SPLITS = [Split.train_spatiotemporal,
+                Split.test_spatiotemporal,
+                Split.test_spatiotemporal_spatial,
+                Split.test_spatiotemporal_temporal]
+
+    @property
+    def splits(self) -> List[Split]:
+        return self.SPLITS
+
+    @property
+    def train_split(self) -> Split:
+        return Split.train_spatiotemporal
+
+    @property
+    def test_split(self) -> Split:
+        return Split.test_spatiotemporal
+
+    @property
+    def some_required_ind_are_not_defined(self):
+        return self.index_train_ind is None or self.column_train_ind is None
+
+    def specialized_loc_split(self, df: pd.DataFrame, split: Split):
+        assert pd.Index.equals(df.columns, self.column_train_ind.index)
+        assert pd.Index.equals(df.index, self.index_train_ind.index)
+        if split is Split.train_spatiotemporal:
+            return df.loc[self.index_train_ind, self.column_train_ind]
+        elif split is Split.test_spatiotemporal:
+            return df.loc[self.index_test_ind, self.column_test_ind]
+        elif split is Split.test_spatiotemporal_spatial:
+            return df.loc[self.index_test_ind, self.column_train_ind]
+        elif split is Split.test_spatiotemporal_temporal:
+            return df.loc[self.index_train_ind, self.column_test_ind]
diff --git a/spatio_temporal_dataset/slicer/split.py b/spatio_temporal_dataset/slicer/split.py
new file mode 100644
index 0000000000000000000000000000000000000000..9771bbcaf3b0254e4e54092dd8aee6ea08a1d826
--- /dev/null
+++ b/spatio_temporal_dataset/slicer/split.py
@@ -0,0 +1,44 @@
+from enum import Enum
+
+import pandas as pd
+
+
+class Split(Enum):
+    all = 0
+    # SpatioTemporal splits
+    train_spatiotemporal = 1
+    test_spatiotemporal = 2
+    test_spatiotemporal_spatial = 3
+    test_spatiotemporal_temporal = 4
+    # Spatial splits
+    train_spatial = 5
+    test_spatial = 6
+    # Temporal splits
+    train_temporal = 7
+    test_temporal = 8
+
+
+ALL_SPLITS_EXCEPT_ALL = [split for split in Split if split is not Split.all]
+
+SPLIT_NAME = 'split'
+TRAIN_SPLIT_STR = 'train_split'
+TEST_SPLIT_STR = 'test_split'
+
+
+def train_ind_from_s_split(s_split):
+    if s_split is None:
+        return None
+    else:
+        return s_split.isin([TRAIN_SPLIT_STR])
+
+
+def s_split_from_ratio(index, train_split_ratio):
+    length = len(index)
+    assert 0 < train_split_ratio < 1
+    s = pd.Series(TEST_SPLIT_STR, index=index)
+    nb_points_train = int(length * train_split_ratio)
+    assert 0 < nb_points_train < length
+    train_ind = pd.Series.sample(s, n=nb_points_train).index
+    assert 0 < len(train_ind) < length, "number of training points:{} length:{}".format(len(train_ind), length)
+    s.loc[train_ind] = TRAIN_SPLIT_STR
+    return s
diff --git a/spatio_temporal_dataset/slicer/temporal_slicer.py b/spatio_temporal_dataset/slicer/temporal_slicer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4bfd6d01f8e9e897863e1f038a25d412fdd2fb85
--- /dev/null
+++ b/spatio_temporal_dataset/slicer/temporal_slicer.py
@@ -0,0 +1,36 @@
+from typing import List, Union
+
+import pandas as pd
+
+from spatio_temporal_dataset.slicer.abstract_slicer import AbstractSlicer
+from spatio_temporal_dataset.slicer.split import Split
+
+
+class TemporalSlicer(AbstractSlicer):
+    SPLITS = [Split.train_temporal, Split.test_temporal]
+
+    def __init__(self, coordinates_train_ind: Union[None, pd.Series], observations_train_ind: Union[None, pd.Series]):
+        super().__init__(None, observations_train_ind)
+
+    @property
+    def splits(self) -> List[Split]:
+        return self.SPLITS
+
+    @property
+    def train_split(self) -> Split:
+        return Split.train_temporal
+
+    @property
+    def test_split(self) -> Split:
+        return Split.test_temporal
+
+    @property
+    def some_required_ind_are_not_defined(self):
+        return self.column_train_ind is None
+
+    def specialized_loc_split(self, df: pd.DataFrame, split: Split):
+        assert pd.Index.equals(df.columns, self.column_train_ind.index)
+        if split is Split.train_temporal:
+            return df.loc[:, self.column_train_ind]
+        elif split is Split.test_temporal:
+            return df.loc[:, self.column_test_ind]
diff --git a/spatio_temporal_dataset/slicer/utils.py b/spatio_temporal_dataset/slicer/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/spatio_temporal_dataset/spatio_temporal_observations/abstract_spatio_temporal_observations.py b/spatio_temporal_dataset/spatio_temporal_observations/abstract_spatio_temporal_observations.py
index 1393167781b1811869d5774126d1edbaf04c9e3b..40e0509926359e044078e75000f9e15ad25bf5ed 100644
--- a/spatio_temporal_dataset/spatio_temporal_observations/abstract_spatio_temporal_observations.py
+++ b/spatio_temporal_dataset/spatio_temporal_observations/abstract_spatio_temporal_observations.py
@@ -1,8 +1,9 @@
 import pandas as pd
 import numpy as np
 
-from spatio_temporal_dataset.spatio_temporal_split import SpatialTemporalSplit, SpatioTemporalSlicer, \
-    train_ind_from_s_split, TEST_SPLIT_STR, TRAIN_SPLIT_STR, s_split_from_ratio, spatio_temporal_slice
+from spatio_temporal_dataset.slicer.abstract_slicer import slice, AbstractSlicer
+from spatio_temporal_dataset.slicer.split import Split, \
+    train_ind_from_s_split, TEST_SPLIT_STR, TRAIN_SPLIT_STR, s_split_from_ratio
 
 
 class AbstractSpatioTemporalObservations(object):
@@ -22,30 +23,39 @@ class AbstractSpatioTemporalObservations(object):
             raise AttributeError('A split is already defined, there is no need to specify a ratio')
         elif s_split is not None or train_split_ratio is not None:
             if train_split_ratio:
-                s_split = s_split_from_ratio(length=self.nb_obs, train_split_ratio=train_split_ratio)
+                s_split = s_split_from_ratio(index=self._df_maxima.columns, train_split_ratio=train_split_ratio)
+            assert len(s_split) == len(self._df_maxima.columns)
             assert s_split.isin([TRAIN_SPLIT_STR, TEST_SPLIT_STR]).all()
         self.s_split = s_split
 
     @property
-    def nb_obs(self):
+    def _df_maxima(self) -> pd.DataFrame:
         if self.df_maxima_frech is not None:
-            return len(self.df_maxima_frech.columns)
+            return self.df_maxima_frech
         else:
-            return len(self.df_maxima_gev.columns)
+            return self.df_maxima_gev
+
+    @property
+    def index(self) -> pd.Index:
+        return self._df_maxima.index
+
+    @property
+    def nb_obs(self) -> int:
+        return len(self._df_maxima.columns)
 
     @classmethod
     def from_df(cls, df):
         pass
 
-    def maxima_gev(self, split: SpatialTemporalSplit = SpatialTemporalSplit.all, slicer: SpatioTemporalSlicer = None):
-        return spatio_temporal_slice(self.df_maxima_gev, split, slicer).values
+    def maxima_gev(self, split: Split = Split.all, slicer: AbstractSlicer = None) -> np.ndarray:
+        return slice(self.df_maxima_gev, split, slicer).values
 
-    def maxima_frech(self, split: SpatialTemporalSplit = SpatialTemporalSplit.all, slicer: SpatioTemporalSlicer = None):
-        return spatio_temporal_slice(self.df_maxima_frech, split, slicer).values
+    def maxima_frech(self, split: Split = Split.all, slicer: AbstractSlicer = None) -> np.ndarray:
+        return slice(self.df_maxima_frech, split, slicer).values
 
-    def set_maxima_frech(self, maxima_frech_values: np.ndarray, split: SpatialTemporalSplit = SpatialTemporalSplit.all,
-                         slicer: SpatioTemporalSlicer = None):
-        df = spatio_temporal_slice(self.df_maxima_frech, split, slicer)
+    def set_maxima_frech(self, maxima_frech_values: np.ndarray, split: Split = Split.all,
+                         slicer: AbstractSlicer = None):
+        df = slice(self.df_maxima_frech, split, slicer)
         df.loc[:] = maxima_frech_values
 
     @property
diff --git a/spatio_temporal_dataset/spatio_temporal_split.py b/spatio_temporal_dataset/spatio_temporal_split.py
deleted file mode 100644
index 3037b749cb72c6aefcb4004b853923ab7da75aae..0000000000000000000000000000000000000000
--- a/spatio_temporal_dataset/spatio_temporal_split.py
+++ /dev/null
@@ -1,88 +0,0 @@
-from enum import Enum
-
-import pandas as pd
-
-
-class SpatialTemporalSplit(Enum):
-    all = 0
-    train = 1
-    test = 2
-    test_temporal = 3
-    test_spatial = 4
-
-
-class SpatioTemporalSlicer(object):
-
-    def __init__(self, coordinates_train_ind: pd.Series, observations_train_ind: pd.Series):
-        self.index_train_ind = coordinates_train_ind  # type: pd.Series
-        self.column_train_ind = observations_train_ind  # type: pd.Series
-        if self.ind_are_not_defined:
-            msg = "One split was not defined \n \n" \
-                  "index: \n {}  \n, column:\n {} \n".format(self.index_train_ind, self.column_train_ind)
-            assert self.index_train_ind is None and self.column_train_ind is None, msg
-
-    def summary(self):
-        print('SpatioTemporal split summary: \n')
-        for s, global_name in [(self.index_train_ind, "Spatial"), (self.column_train_ind, "Temporal")]:
-            print(global_name + ' split')
-            for f, name in [(len, 'Total'), (sum, 'train')]:
-                print("{}: {}".format(name, f(s)))
-            print('\n')
-
-    @property
-    def index_test_ind(self) -> pd.Series:
-        return ~self.index_train_ind
-
-    @property
-    def column_test_ind(self) -> pd.Series:
-        return ~self.column_train_ind
-
-    @property
-    def ind_are_not_defined(self):
-        return self.index_train_ind is None or self.column_train_ind is None
-
-    def loc_split(self, df: pd.DataFrame, split: SpatialTemporalSplit):
-        assert isinstance(split, SpatialTemporalSplit)
-        # By default, if one of the two split is not defined we return all the data
-        if self.ind_are_not_defined or split is SpatialTemporalSplit.all:
-            return df
-        assert pd.RangeIndex.equals(df.columns, self.column_train_ind.index)
-        assert pd.RangeIndex.equals(df.index, self.index_train_ind.index)
-        if split is SpatialTemporalSplit.train:
-            return df.loc[self.index_train_ind, self.column_train_ind]
-        elif split is SpatialTemporalSplit.test:
-            return df.loc[self.index_test_ind, self.column_test_ind]
-        elif split is SpatialTemporalSplit.test_spatial:
-            return df.loc[self.index_test_ind, self.column_train_ind]
-        elif split is SpatialTemporalSplit.test_temporal:
-            return df.loc[self.index_train_ind, self.column_test_ind]
-
-
-SPLIT_NAME = 'split'
-TRAIN_SPLIT_STR = 'train_split'
-TEST_SPLIT_STR = 'test_split'
-
-
-def train_ind_from_s_split(s_split):
-    if s_split is None:
-        return None
-    else:
-        return s_split.isin([TRAIN_SPLIT_STR])
-
-
-def s_split_from_ratio(length, train_split_ratio):
-    assert 0 < train_split_ratio < 1
-    s = pd.Series([TEST_SPLIT_STR for _ in range(length)])
-    nb_points_train = int(length * train_split_ratio)
-    train_ind = pd.Series.sample(s, n=nb_points_train).index
-    s.loc[train_ind] = TRAIN_SPLIT_STR
-    return s
-
-
-def spatio_temporal_slice(df: pd.DataFrame, split: SpatialTemporalSplit = SpatialTemporalSplit.all,
-                          slicer: SpatioTemporalSlicer = None) -> pd.DataFrame:
-    if slicer is None:
-        assert split is SpatialTemporalSplit.all
-        return df
-    else:
-        return slicer.loc_split(df, split)
diff --git a/test/test_extreme_estimator/test_estimator/test_full_estimators.py b/test/test_extreme_estimator/test_estimator/test_full_estimators.py
index 0c8aacbf41dd9a8482aec5b945b8b14baefd54d1..91b8c39fb052bed6916d3236f49252719987ce87 100644
--- a/test/test_extreme_estimator/test_estimator/test_full_estimators.py
+++ b/test/test_extreme_estimator/test_estimator/test_full_estimators.py
@@ -8,8 +8,8 @@ from test.test_utils import load_test_max_stable_models, load_smooth_margin_mode
 
 class TestFullEstimators(unittest.TestCase):
     DISPLAY = False
-    nb_obs = 10
-    nb_points = 5
+    nb_obs = 3
+    nb_points = 2
 
     def setUp(self):
         super().setUp()
diff --git a/test/test_extreme_estimator/test_estimator/test_max_stable_estimators.py b/test/test_extreme_estimator/test_estimator/test_max_stable_estimators.py
index aaea5ab300e0715f931a6242bdea0003d97333fc..b464f80ba1c815ced3a273b3e8e5c69cab740d90 100644
--- a/test/test_extreme_estimator/test_estimator/test_max_stable_estimators.py
+++ b/test/test_extreme_estimator/test_estimator/test_max_stable_estimators.py
@@ -11,8 +11,8 @@ from test.test_utils import load_test_max_stable_models, load_test_1D_and_2D_coo
 
 class TestMaxStableEstimators(unittest.TestCase):
     DISPLAY = False
-    nb_points = 5
-    nb_obs = 10
+    nb_points = 2
+    nb_obs = 3
 
     def setUp(self):
         super().setUp()
diff --git a/test/test_spatio_temporal_dataset/test_slicer.py b/test/test_spatio_temporal_dataset/test_slicer.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f9be82d34643acf0b36e6155748c2fe1b9f8b7a
--- /dev/null
+++ b/test/test_spatio_temporal_dataset/test_slicer.py
@@ -0,0 +1,147 @@
+import pandas as pd
+import numpy as np
+from rpy2.rinterface import RRuntimeError
+import unittest
+from itertools import product
+
+from extreme_estimator.extreme_models.margin_model.smooth_margin_model import ConstantMarginModel
+from extreme_estimator.extreme_models.max_stable_model.max_stable_models import Smith
+from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
+from spatio_temporal_dataset.coordinates.unidimensional_coordinates.coordinates_1D import LinSpaceCoordinates
+from spatio_temporal_dataset.dataset.simulation_dataset import MaxStableDataset, FullSimulatedDataset
+from spatio_temporal_dataset.slicer.spatial_slicer import SpatialSlicer
+from spatio_temporal_dataset.slicer.spatio_temporal_slicer import SpatioTemporalSlicer
+from spatio_temporal_dataset.slicer.split import ALL_SPLITS_EXCEPT_ALL, Split
+from spatio_temporal_dataset.slicer.temporal_slicer import TemporalSlicer
+from spatio_temporal_dataset.spatio_temporal_observations.abstract_spatio_temporal_observations import \
+    AbstractSpatioTemporalObservations
+
+
+class TestSlicerForDataset(unittest.TestCase):
+
+    def __init__(self, methodName: str = ...) -> None:
+        super().__init__(methodName)
+        self.dataset = None
+
+    nb_spatial_points = 2
+    nb_temporal_obs = 2
+    complete_shape = (nb_spatial_points, nb_temporal_obs)
+
+    def load_dataset(self, slicer_class, split_ratio_spatial, split_ratio_temporal):
+        coordinates = LinSpaceCoordinates.from_nb_points(nb_points=self.nb_spatial_points,
+                                                         train_split_ratio=split_ratio_spatial)
+        return FullSimulatedDataset.from_double_sampling(nb_obs=self.nb_temporal_obs,
+                                                         train_split_ratio=split_ratio_temporal,
+                                                         margin_model=ConstantMarginModel(coordinates=coordinates),
+                                                         coordinates=coordinates, max_stable_model=Smith(),
+                                                         slicer_class=slicer_class)
+
+    def get_shape(self, dataset, split):
+        return dataset.maxima_frech(split).shape
+
+    def test_spatiotemporal_slicer_for_dataset(self):
+        ind_tuple_to_observation_shape = {
+            (None, None): self.complete_shape,
+            (None, 0.5): self.complete_shape,
+            (0.5, None): self.complete_shape,
+            (0.5, 0.5): (1, 1),
+        }
+        self.check_shapes(ind_tuple_to_observation_shape, SpatioTemporalSlicer)
+
+    def test_spatial_slicer_for_dataset(self):
+        ind_tuple_to_observation_shape = {
+            (None, None): self.complete_shape,
+            (None, 0.5): self.complete_shape,
+            (0.5, None): (1, 2),
+            (0.5, 0.5): (1, 2),
+        }
+        self.check_shapes(ind_tuple_to_observation_shape, SpatialSlicer)
+
+    def test_temporal_slicer_for_dataset(self):
+        ind_tuple_to_observation_shape = {
+            (None, None): self.complete_shape,
+            (None, 0.5): (2, 1),
+            (0.5, None): self.complete_shape,
+            (0.5, 0.5): (2, 1),
+        }
+        self.check_shapes(ind_tuple_to_observation_shape, TemporalSlicer)
+
+    def check_shapes(self, ind_tuple_to_observation_shape, slicer_type):
+        for split_ratio, data_shape in ind_tuple_to_observation_shape.items():
+            dataset = self.load_dataset(slicer_type, *split_ratio)
+            self.assertEqual(self.complete_shape, self.get_shape(dataset, Split.all))
+            for split in ALL_SPLITS_EXCEPT_ALL:
+                if split in dataset.slicer.splits:
+                    self.assertEqual(data_shape, self.get_shape(dataset, split))
+                else:
+                    with self.assertRaises(AssertionError):
+                        self.get_shape(dataset, split)
+
+
+class TestSlicerForCoordinates(unittest.TestCase):
+
+    def nb_coordinates(self, coordinates: AbstractCoordinates, split):
+        return len(coordinates.coordinates_values(split))
+
+    def test_slicer_for_coordinates(self):
+        for split in Split:
+            coordinates1 = LinSpaceCoordinates.from_nb_points(nb_points=2, train_split_ratio=0.5)
+            if split in SpatialSlicer.SPLITS:
+                self.assertEqual(self.nb_coordinates(coordinates1, split), 1)
+            elif split in SpatioTemporalSlicer.SPLITS:
+                self.assertEqual(self.nb_coordinates(coordinates1, split), 1)
+            elif split in TemporalSlicer.SPLITS:
+                self.assertEqual(self.nb_coordinates(coordinates1, split), 2)
+            else:
+                self.assertEqual(self.nb_coordinates(coordinates1, split), 2)
+            coordinates2 = LinSpaceCoordinates.from_nb_points(nb_points=2)
+            self.assertEqual(self.nb_coordinates(coordinates2, split), 2)
+
+
+class TestSlicerForObservations(unittest.TestCase):
+
+    def load_observations(self, split_ratio_temporal):
+        df = pd.DataFrame.from_dict(
+            {
+                'year1': [1 for _ in range(4)],
+                'year2': [2 for _ in range(4)],
+
+            })
+        return AbstractSpatioTemporalObservations(df_maxima_frech=df, train_split_ratio=split_ratio_temporal)
+
+    def nb_obs(self, observations, split, slicer):
+        return len(np.transpose(observations.maxima_frech(split, slicer)))
+
+    def test_slicer_for_observations(self):
+        observations = self.load_observations(0.5)
+        # For the None Slicer, a slice should be returned only for split=SpatialTemporalSplit.all
+        # self.assertEqual(len(observations.maxima_frech(SpatialTemporalSplit.all, None)), 2)
+        self.assertEqual(2, self.nb_obs(observations, Split.all, None))
+        for split in ALL_SPLITS_EXCEPT_ALL:
+            with self.assertRaises(AssertionError):
+                observations.maxima_frech(split, None)
+        # For other slicers we try out all the possible combinations
+        slicer_type_to_size = {
+            SpatialSlicer: 2,
+            TemporalSlicer: 1,
+            SpatioTemporalSlicer: 1,
+        }
+        for slicer_type, size in slicer_type_to_size.items():
+            for coordinates_train_ind in [None, pd.Series([True, True, True, False])][::-1]:
+                slicer = slicer_type(coordinates_train_ind=coordinates_train_ind,
+                                     observations_train_ind=observations.train_ind)
+                self.assertEqual(2, self.nb_obs(observations, Split.all, slicer))
+                for split in ALL_SPLITS_EXCEPT_ALL:
+                    if split in slicer.splits:
+                        # By default for SpatioTemporalSlicer should slice if both train_ind are available
+                        # Otherwise if coordinates_train_ind is None, then it should return all the data
+                        if slicer_type is SpatioTemporalSlicer and coordinates_train_ind is None:
+                            size = 2
+                        self.assertEqual(size, self.nb_obs(observations, split, slicer))
+                    else:
+                        with self.assertRaises(AssertionError):
+                            observations.maxima_frech(split, slicer)
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/test/test_spatio_temporal_dataset/test_temporal_observations.py b/test/test_spatio_temporal_dataset/test_spatio_temporal_observations.py
similarity index 92%
rename from test/test_spatio_temporal_dataset/test_temporal_observations.py
rename to test/test_spatio_temporal_dataset/test_spatio_temporal_observations.py
index f330476f06fa7c36e88d3d3aee8042fd4fa3cd83..ed6c4d4016b57798b9b3d5bd82660a2d01fa2b52 100644
--- a/test/test_spatio_temporal_dataset/test_temporal_observations.py
+++ b/test/test_spatio_temporal_dataset/test_spatio_temporal_observations.py
@@ -6,7 +6,7 @@ import pandas as pd
 from spatio_temporal_dataset.spatio_temporal_observations.abstract_spatio_temporal_observations import AbstractSpatioTemporalObservations
 
 
-class TestTemporalObservations(unittest.TestCase):
+class TestSpatioTemporalObservations(unittest.TestCase):
     DISPLAY = False
 
     def test_set_maxima_gev(self):