From 72c192071c62df59d69ef9b528a56574013d5583 Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Thu, 29 Nov 2018 16:50:00 +0100
Subject: [PATCH] [DATASET] add train_split_ratio argument for temporal
 observations. some fix and refactoring.

---
 .../estimator/abstract_estimator.py           |   2 +
 extreme_estimator/estimator/full_estimator.py |   6 +-
 .../estimator/max_stable_estimator.py         |   4 +-
 .../abstract_margin_function.py               |   2 +-
 .../margin_function/linear_margin_function.py |   2 +-
 .../abstract_max_stable_model.py              |  11 +-
 .../coordinates/abstract_coordinates.py       | 110 ++++++++++--------
 .../alps_station_2D_coordinates.py            |   2 +-
 .../alps_station_3D_coordinates.py            |   1 +
 .../generated_spatial_coordinates.py          |   8 +-
 .../transformed_coordinates.py                |   4 +-
 .../coordinates_1D.py                         |   8 +-
 .../dataset/abstract_dataset.py               |  12 +-
 .../dataset/simulation_dataset.py             |  15 ++-
 .../dataset/spatio_temporal_split.py          |  34 +++++-
 .../abstract_temporal_observations.py         |  35 +++---
 .../annual_maxima_observations.py             |  23 ++--
 17 files changed, 169 insertions(+), 110 deletions(-)

diff --git a/extreme_estimator/estimator/abstract_estimator.py b/extreme_estimator/estimator/abstract_estimator.py
index 3e22c0e2..d5d36f14 100644
--- a/extreme_estimator/estimator/abstract_estimator.py
+++ b/extreme_estimator/estimator/abstract_estimator.py
@@ -1,6 +1,7 @@
 import time
 
 from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
+from spatio_temporal_dataset.dataset.spatio_temporal_split import SpatialTemporalSplit
 
 
 class AbstractEstimator(object):
@@ -14,6 +15,7 @@ class AbstractEstimator(object):
     def __init__(self, dataset: AbstractDataset):
         self.dataset = dataset  # type: AbstractDataset
         self.additional_information = dict()
+        self.train_split = SpatialTemporalSplit.train
 
     def fit(self):
         ts = time.time()
diff --git a/extreme_estimator/estimator/full_estimator.py b/extreme_estimator/estimator/full_estimator.py
index aac5229c..d60d1776 100644
--- a/extreme_estimator/estimator/full_estimator.py
+++ b/extreme_estimator/estimator/full_estimator.py
@@ -42,7 +42,7 @@ class SmoothMarginalsThenUnitaryMsp(AbstractFullEstimator):
         # Estimate the margin parameters
         self.margin_estimator.fit()
         # Compute the maxima_frech
-        maxima_gev_train = self.dataset.maxima_gev(split=SpatialTemporalSplit.train)
+        maxima_gev_train = self.dataset.maxima_gev(split=self.train_split)
         maxima_frech = AbstractMarginModel.gev2frech(maxima_gev=maxima_gev_train,
                                                      coordinates_values=self.dataset.coordinates_values,
                                                      margin_function=self.margin_estimator.margin_function_fitted)
@@ -70,8 +70,8 @@ class FullEstimatorInASingleStepWithSmoothMargin(AbstractFullEstimator):
     def _fit(self):
         # Estimate both the margin and the max-stable structure
         self.full_params_fitted = self.max_stable_model.fitmaxstab(
-            maxima_gev=self.dataset.maxima_gev(split=SpatialTemporalSplit.train),
-            df_coordinates=self.dataset.df_coordinates,
+            maxima_gev=self.dataset.maxima_gev(split=self.train_split),
+            df_coordinates=self.dataset.df_coordinates(split=self.train_split),
             fit_marge=True,
             fit_marge_form_dict=self.linear_margin_function_to_fit.form_dict,
             margin_start_dict=self.linear_margin_function_to_fit.coef_dict
diff --git a/extreme_estimator/estimator/max_stable_estimator.py b/extreme_estimator/estimator/max_stable_estimator.py
index dff88648..c1dac517 100644
--- a/extreme_estimator/estimator/max_stable_estimator.py
+++ b/extreme_estimator/estimator/max_stable_estimator.py
@@ -20,8 +20,8 @@ class MaxStableEstimator(AbstractMaxStableEstimator):
     def _fit(self):
         assert self.dataset.maxima_frech(split=SpatialTemporalSplit.train) is not None
         self.max_stable_params_fitted = self.max_stable_model.fitmaxstab(
-            maxima_frech=self.dataset.maxima_frech(split=SpatialTemporalSplit.train),
-            df_coordinates=self.dataset.coordinates.df_coordinates)
+            maxima_frech=self.dataset.maxima_frech(split=self.train_split),
+            df_coordinates=self.dataset.df_coordinates(split=self.train_split))
 
     def _error(self, true_max_stable_params: dict):
         absolute_errors = {param_name: np.abs(param_true_value - self.max_stable_params_fitted[param_name])
diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
index 293ae7a1..1ded02a9 100644
--- a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
+++ b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
@@ -73,7 +73,7 @@ class AbstractMarginFunction(object):
         # TODO: to avoid getting the value several times, I could cache the results
         if self.dot_display:
             resolution = len(self.coordinates)
-            linspace = self.coordinates.coordinates_values[:, 0]
+            linspace = self.coordinates.coordinates_values()[:, 0]
             print('dot display')
         else:
             resolution = 100
diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/linear_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/linear_margin_function.py
index 2d50202e..c3b5b96b 100644
--- a/extreme_estimator/extreme_models/margin_model/margin_function/linear_margin_function.py
+++ b/extreme_estimator/extreme_models/margin_model/margin_function/linear_margin_function.py
@@ -48,7 +48,7 @@ class LinearMarginFunction(IndependentMarginFunction):
             # Otherwise, we fit a LinearParamFunction
             else:
                 param_function = LinearParamFunction(linear_dims=self.gev_param_name_to_linear_dims[gev_param_name],
-                                                     coordinates=self.coordinates.coordinates_values,
+                                                     coordinates=self.coordinates.coordinates_values(),
                                                      linear_coef=linear_coef)
             # Add the param_function to the dictionary
             self.gev_param_name_to_param_function[gev_param_name] = param_function
diff --git a/extreme_estimator/extreme_models/max_stable_model/abstract_max_stable_model.py b/extreme_estimator/extreme_models/max_stable_model/abstract_max_stable_model.py
index f9e1235c..41303c9e 100644
--- a/extreme_estimator/extreme_models/max_stable_model/abstract_max_stable_model.py
+++ b/extreme_estimator/extreme_models/max_stable_model/abstract_max_stable_model.py
@@ -22,7 +22,8 @@ class AbstractMaxStableModel(AbstractModel):
     def cov_mod_param(self):
         return {'cov.mod': self.cov_mod}
 
-    def fitmaxstab(self, df_coordinates: pd.DataFrame, maxima_frech: np.ndarray=None, maxima_gev: np.ndarray=None, fit_marge=False,
+    def fitmaxstab(self, df_coordinates: pd.DataFrame, maxima_frech: np.ndarray = None, maxima_gev: np.ndarray = None,
+                   fit_marge=False,
                    fit_marge_form_dict=None, margin_start_dict=None) -> dict:
         assert isinstance(df_coordinates, pd.DataFrame)
         if fit_marge:
@@ -32,6 +33,10 @@ class AbstractMaxStableModel(AbstractModel):
         # Prepare the data
         maxima = maxima_gev if fit_marge else maxima_frech
         assert isinstance(maxima, np.ndarray)
+        assert len(df_coordinates) == len(maxima), 'Coordinates and observations sizes should match,' \
+                                                   'check that the same split was used for both objects \n,' \
+                                                   'df_coordinates size: {}, data size {}'.format(len(df_coordinates),
+                                                                                                  len(maxima))
         data = np.transpose(maxima)
 
         # Prepare the coord
@@ -75,12 +80,12 @@ class AbstractMaxStableModel(AbstractModel):
         fitted_values = {key: fitted_values.rx2(key)[0] for key in fitted_values.names}
         return fitted_values
 
-    def rmaxstab(self, nb_obs: int, coordinates: np.ndarray) -> np.ndarray:
+    def rmaxstab(self, nb_obs: int, coordinates_values: np.ndarray) -> np.ndarray:
         """
         Return an numpy of maxima. With rows being the stations and columns being the years of maxima
         """
         maxima_frech = np.array(
-            r.rmaxstab(nb_obs, coordinates, *list(self.cov_mod_param.values()), **self.params_sample))
+            r.rmaxstab(nb_obs, coordinates_values, *list(self.cov_mod_param.values()), **self.params_sample))
         return np.transpose(maxima_frech)
 
     def remove_unused_parameters(self, start_dict, coordinate_dim):
diff --git a/spatio_temporal_dataset/coordinates/abstract_coordinates.py b/spatio_temporal_dataset/coordinates/abstract_coordinates.py
index dc18803b..4cf3fd6f 100644
--- a/spatio_temporal_dataset/coordinates/abstract_coordinates.py
+++ b/spatio_temporal_dataset/coordinates/abstract_coordinates.py
@@ -1,11 +1,14 @@
 import os.path as op
 from typing import List
 
+import matplotlib.pyplot as plt
 import numpy as np
 import pandas as pd
-import matplotlib.pyplot as plt
 from mpl_toolkits.mplot3d import Axes3D
 
+from spatio_temporal_dataset.dataset.spatio_temporal_split import s_split_from_ratio, TEST_SPLIT_STR, \
+    TRAIN_SPLIT_STR, train_ind_from_s_split, SpatialTemporalSplit
+
 
 class AbstractCoordinates(object):
     # Columns
@@ -14,45 +17,54 @@ class AbstractCoordinates(object):
     COORDINATE_Z = 'coord_z'
     COORDINATE_NAMES = [COORDINATE_X, COORDINATE_Y, COORDINATE_Z]
     COORDINATE_SPLIT = 'coord_split'
-    # Constants for the split column
-    TRAIN_SPLIT_STR = 'train_split'
-    TEST_SPLIT_STR = 'test_split'
 
-    def __init__(self, df_coordinates: pd.DataFrame, s_split: pd.Series = None):
-        self.df_coordinates = df_coordinates  # type: pd.DataFrame
+    def __init__(self, df_coord: pd.DataFrame, s_split: pd.Series = None):
+        self.df_coord = df_coord  # type: pd.DataFrame
         self.s_split = s_split  # type: pd.Series
 
     # ClassMethod constructor
 
     @classmethod
-    def from_df(cls, df: pd.DataFrame):
+    def from_df(cls, df: pd.DataFrame, train_split_ratio: float = None):
         #  X and coordinates must be defined
         assert cls.COORDINATE_X in df.columns
-        df_coordinates = df.loc[:, cls.coordinates_columns(df)]
-        # Potentially, a split column can be specified
-        s_split = df[cls.COORDINATE_SPLIT] if cls.COORDINATE_SPLIT in df.columns else None
-        if s_split is not None:
-            assert s_split.isin([cls.TRAIN_SPLIT_STR, cls.TEST_SPLIT_STR])
-        return cls(df_coordinates=df_coordinates, s_split=s_split)
+        # Create a split based on the train_split_ratio
+        if train_split_ratio is not None:
+            assert cls.COORDINATE_SPLIT not in df.columns, "A split has already been defined"
+            s_split = s_split_from_ratio(length=len(df), train_split_ratio=train_split_ratio)
+            df[cls.COORDINATE_SPLIT] = s_split
+        # Potentially, a split column can be specified directly in df
+        if cls.COORDINATE_SPLIT not in df.columns:
+            df_coord = df
+            s_split = None
+        else:
+            df_coord = df.loc[:, cls.coordinates_columns(df)]
+            s_split = df[cls.COORDINATE_SPLIT]
+            assert s_split.isin([TRAIN_SPLIT_STR, TEST_SPLIT_STR]).all()
+        return cls(df_coord=df_coord, s_split=s_split)
 
     @classmethod
     def from_csv(cls, csv_path: str = None):
         assert csv_path is not None
         assert op.exists(csv_path)
         df = pd.read_csv(csv_path)
+        # Index correspond to the first column
+        index_column_name = df.columns[0]
+        assert index_column_name not in cls.coordinates_columns(df)
+        df.set_index(index_column_name, inplace=True)
         return cls.from_df(df)
 
     @classmethod
-    def from_nb_points(cls, nb_points: int, **kwargs):
+    def from_nb_points(cls, nb_points: int, train_split_ratio: float = None, **kwargs):
         # Call the default class method from csv
         coordinates = cls.from_csv()  # type: AbstractCoordinates
-        # Sample randomly nb_points coordinates
+        # Check that nb_points asked is not superior to the number of coordinates
         nb_coordinates = len(coordinates)
         if nb_points > nb_coordinates:
             raise Exception('Nb coordinates in csv: {} < Nb points desired: {}'.format(nb_coordinates, nb_points))
-        else:
-            df_sample = pd.DataFrame.sample(coordinates.df, n=nb_points)
-            return cls.from_df(df=df_sample)
+        # Sample randomly nb_points coordinates
+        df_sample = pd.DataFrame.sample(coordinates.df_merged, n=nb_points)
+        return cls.from_df(df=df_sample, train_split_ratio=train_split_ratio)
 
     @classmethod
     def coordinates_columns(cls, df_coord: pd.DataFrame) -> List[str]:
@@ -64,52 +76,48 @@ class AbstractCoordinates(object):
 
     @property
     def columns(self):
-        return self.coordinates_columns(df_coord=self.df_coordinates)
+        return self.coordinates_columns(df_coord=self.df_coord)
 
     @property
     def nb_columns(self):
         return len(self.columns)
 
     @property
-    def df(self) -> pd.DataFrame:
-        # Merged DataFrame of df_coord and s_split
-        return self.df_coordinates if self.s_split is None else self.df_coordinates.join(self.s_split)
+    def index(self):
+        return self.df_coord.index
 
-    def df_coordinates_split(self, split_str: str) -> pd.DataFrame:
-        assert self.s_split is not None
-        ind = self.s_split == split_str
-        return self.df_coordinates.loc[ind]
+    @property
+    def df_merged(self) -> pd.DataFrame:
+        # Merged DataFrame of df_coord and s_split
+        return self.df_coord if self.s_split is None else self.df_coord.join(self.s_split)
 
-    def _coordinates_values(self, df_coordinates: pd.DataFrame) -> np.ndarray:
-        return df_coordinates.loc[:, self.coordinates_columns(df_coordinates)].values
+    def df_coordinates(self, split: SpatialTemporalSplit = SpatialTemporalSplit.all) -> pd.DataFrame:
+        if split is SpatialTemporalSplit.all or self.s_split is None:
+            return self.df_coord
+        elif split in [SpatialTemporalSplit.train, SpatialTemporalSplit.test_temporal]:
+            return self.df_coord.loc[self.train_ind]
+        else:
+            return self.df_coord.loc[~self.train_ind]
 
-    @property
-    def coordinates_values(self) -> np.ndarray:
-        return self._coordinates_values(df_coordinates=self.df_coordinates)
+    def coordinates_values(self, split: SpatialTemporalSplit = SpatialTemporalSplit.all) -> np.ndarray:
+        return self.df_coordinates(split).values
 
     @property
     def x_coordinates(self) -> np.ndarray:
-        return self.df_coordinates.loc[:, self.COORDINATE_X].values.copy()
+        return self.df_coord[self.COORDINATE_X].values.copy()
 
     @property
     def y_coordinates(self) -> np.ndarray:
-        return self.df_coordinates.loc[:, self.COORDINATE_Y].values.copy()
-
-    @property
-    def index(self) -> pd.Series:
-        return self.df_coordinates.index
+        return self.df_coord[self.COORDINATE_Y].values.copy()
 
     @property
     def train_ind(self) -> pd.Series:
-        if self.s_split is None:
-            return None
-        else:
-            return self.s_split.isin([self.TRAIN_SPLIT_STR])
+        return train_ind_from_s_split(s_split=self.s_split)
 
     #  Visualization
 
     def visualize(self):
-        nb_coordinates_columns = len(self.coordinates_columns(self.df_coordinates))
+        nb_coordinates_columns = len(self.coordinates_columns(self.df_coord))
         if nb_coordinates_columns == 1:
             self.visualization_1D()
         elif nb_coordinates_columns == 2:
@@ -118,21 +126,23 @@ class AbstractCoordinates(object):
             self.visualization_3D()
 
     def visualization_1D(self):
-        assert len(self.coordinates_columns(self.df_coordinates)) >= 1
-        x = self.coordinates_values[:]
+        assert len(self.coordinates_columns(self.df_coord)) >= 1
+        x = self.coordinates_values()[:]
         y = np.zeros(len(x))
         plt.scatter(x, y)
         plt.show()
 
     def visualization_2D(self):
-        assert len(self.coordinates_columns(self.df_coordinates)) >= 2
-        x, y = self.coordinates_values[:, 0], self.coordinates_values[:, 1]
+        assert len(self.coordinates_columns(self.df_coord)) >= 2
+        coordinates_values = self.coordinates_values()
+        x, y = coordinates_values[:, 0], coordinates_values[:, 1]
         plt.scatter(x, y)
         plt.show()
 
     def visualization_3D(self):
-        assert len(self.coordinates_columns(self.df_coordinates)) == 3
-        x, y, z = self.coordinates_values[:, 0], self.coordinates_values[:, 1], self.coordinates_values[:, 2]
+        assert len(self.coordinates_columns(self.df_coord)) == 3
+        coordinates_values = self.coordinates_values()
+        x, y, z = coordinates_values[:, 0], coordinates_values[:, 1], coordinates_values[:, 2]
         fig = plt.figure()
         ax = fig.add_subplot(111, projection='3d')  # type: Axes3D
         ax.scatter(x, y, z, marker='^')
@@ -141,10 +151,10 @@ class AbstractCoordinates(object):
     #  Magic Methods
 
     def __len__(self):
-        return len(self.df_coordinates)
+        return len(self.df_coord)
 
     def __mul__(self, other: float):
-        self.df_coordinates *= other
+        self.df_coord *= other
         return self
 
     def __rmul__(self, other):
diff --git a/spatio_temporal_dataset/coordinates/spatial_coordinates/alps_station_2D_coordinates.py b/spatio_temporal_dataset/coordinates/spatial_coordinates/alps_station_2D_coordinates.py
index 629e24f9..a69b43f9 100644
--- a/spatio_temporal_dataset/coordinates/spatial_coordinates/alps_station_2D_coordinates.py
+++ b/spatio_temporal_dataset/coordinates/spatial_coordinates/alps_station_2D_coordinates.py
@@ -10,7 +10,7 @@ class AlpsStation2DCoordinates(AlpsStation3DCoordinates):
     def from_csv(cls, csv_file='coord-lambert2'):
         # Remove the Z coordinates from df_coord
         spatial_coordinates = super().from_csv(csv_file)  # type: AlpsStation3DCoordinates
-        spatial_coordinates.df_coordinates.drop(cls.COORDINATE_Z, axis=1, inplace=True)
+        spatial_coordinates.df_coord.drop(cls.COORDINATE_Z, axis=1, inplace=True)
         return spatial_coordinates
 
 
diff --git a/spatio_temporal_dataset/coordinates/spatial_coordinates/alps_station_3D_coordinates.py b/spatio_temporal_dataset/coordinates/spatial_coordinates/alps_station_3D_coordinates.py
index d5ef2ad1..6d00efb4 100644
--- a/spatio_temporal_dataset/coordinates/spatial_coordinates/alps_station_3D_coordinates.py
+++ b/spatio_temporal_dataset/coordinates/spatial_coordinates/alps_station_3D_coordinates.py
@@ -43,5 +43,6 @@ class AlpsStation3DCoordinatesWithAnisotropy(AlpsStation3DCoordinates):
     @classmethod
     def from_csv(cls, csv_file='coord-lambert2'):
         coord = super().from_csv(csv_file)
+        print(coord)
         return TransformedCoordinates.from_coordinates(coordinates=coord,
                                                        transformation_function=AnisotropyTransformation())
diff --git a/spatio_temporal_dataset/coordinates/spatial_coordinates/generated_spatial_coordinates.py b/spatio_temporal_dataset/coordinates/spatial_coordinates/generated_spatial_coordinates.py
index 267d7fee..fc600e26 100644
--- a/spatio_temporal_dataset/coordinates/spatial_coordinates/generated_spatial_coordinates.py
+++ b/spatio_temporal_dataset/coordinates/spatial_coordinates/generated_spatial_coordinates.py
@@ -10,13 +10,13 @@ import matplotlib.pyplot as plt
 class CircleCoordinates(AbstractCoordinates):
 
     @classmethod
-    def from_nb_points(cls, nb_points, max_radius=1.0):
+    def from_nb_points(cls, nb_points, train_split_ratio: float = None, max_radius=1.0):
         # Sample uniformly inside the circle
         angles = np.array(r.runif(nb_points, max=2 * math.pi))
         radius = np.sqrt(np.array(r.runif(nb_points, max=max_radius)))
         df = pd.DataFrame.from_dict({cls.COORDINATE_X: radius * np.cos(angles),
                                      cls.COORDINATE_Y: radius * np.sin(angles)})
-        return cls.from_df(df)
+        return cls.from_df(df, train_split_ratio)
 
     def visualization_2D(self):
         r = 1.0
@@ -30,6 +30,6 @@ class CircleCoordinates(AbstractCoordinates):
 class CircleCoordinatesRadius2(CircleCoordinates):
 
     @classmethod
-    def from_nb_points(cls, nb_points, max_radius=1.0):
-        return 2 * super().from_nb_points(nb_points, max_radius)
+    def from_nb_points(cls, nb_points, train_split_ratio: float = None, max_radius=1.0):
+        return 2 * super().from_nb_points(nb_points, train_split_ratio, max_radius)
 
diff --git a/spatio_temporal_dataset/coordinates/transformed_coordinates/transformed_coordinates.py b/spatio_temporal_dataset/coordinates/transformed_coordinates/transformed_coordinates.py
index 4cc302d9..5022b244 100644
--- a/spatio_temporal_dataset/coordinates/transformed_coordinates/transformed_coordinates.py
+++ b/spatio_temporal_dataset/coordinates/transformed_coordinates/transformed_coordinates.py
@@ -7,8 +7,8 @@ class TransformedCoordinates(AbstractCoordinates):
     @classmethod
     def from_coordinates(cls, coordinates: AbstractCoordinates,
                          transformation_function: AbstractTransformation):
-        df_coordinates_transformed = coordinates.df_coordinates.copy()
+        df_coordinates_transformed = coordinates.df_coord.copy()
         df_coordinates_transformed = transformation_function.transform(df_coord=df_coordinates_transformed)
-        return cls(df_coordinates=df_coordinates_transformed, s_split=coordinates.s_split)
+        return cls(df_coord=df_coordinates_transformed, s_split=coordinates.s_split)
 
 
diff --git a/spatio_temporal_dataset/coordinates/unidimensional_coordinates/coordinates_1D.py b/spatio_temporal_dataset/coordinates/unidimensional_coordinates/coordinates_1D.py
index 34dbacc3..64c56d52 100644
--- a/spatio_temporal_dataset/coordinates/unidimensional_coordinates/coordinates_1D.py
+++ b/spatio_temporal_dataset/coordinates/unidimensional_coordinates/coordinates_1D.py
@@ -13,17 +13,17 @@ class AbstractUniDimensionalCoordinates(AbstractCoordinates):
 class LinSpaceCoordinates(AbstractUniDimensionalCoordinates):
 
     @classmethod
-    def from_nb_points(cls, nb_points, start=-1.0, end=1.0):
+    def from_nb_points(cls, nb_points, train_split_ratio: float = None, start=-1.0, end=1.0):
         axis_coordinates = np.linspace(start, end, nb_points)
         df = pd.DataFrame.from_dict({cls.COORDINATE_X: axis_coordinates})
-        return cls.from_df(df)
+        return cls.from_df(df, train_split_ratio)
 
 
 class UniformCoordinates(AbstractUniDimensionalCoordinates):
 
     @classmethod
-    def from_nb_points(cls, nb_points, start=-1.0, end=1.0):
+    def from_nb_points(cls, nb_points, train_split_ratio: float = None, start=-1.0, end=1.0):
         # Sample uniformly inside the circle
         axis_coordinates = np.array(r.runif(nb_points, min=start, max=end))
         df = pd.DataFrame.from_dict({cls.COORDINATE_X: axis_coordinates})
-        return cls.from_df(df)
+        return cls.from_df(df, train_split_ratio)
diff --git a/spatio_temporal_dataset/dataset/abstract_dataset.py b/spatio_temporal_dataset/dataset/abstract_dataset.py
index f17d9610..ca1feef9 100644
--- a/spatio_temporal_dataset/dataset/abstract_dataset.py
+++ b/spatio_temporal_dataset/dataset/abstract_dataset.py
@@ -35,15 +35,15 @@ class AbstractDataset(object):
     @property
     def df_dataset(self) -> pd.DataFrame:
         # Merge dataframes with the maxima and with the coordinates
-        return self.temporal_observations.df_maxima_gev.join(self.coordinates.df_coordinates)
+        # todo: maybe I should add the split from the temporal observations
+        return self.temporal_observations.df_maxima_gev.join(self.coordinates.df_merged)
 
-    @property
-    def df_coordinates(self):
-        return self.coordinates.df_coordinates
+    def df_coordinates(self, split: SpatialTemporalSplit = SpatialTemporalSplit.all):
+        return self.coordinates.df_coordinates(split=split)
 
     @property
-    def coordinates_values(self):
-        return self.coordinates.coordinates_values
+    def coordinates_values(self, split: SpatialTemporalSplit = SpatialTemporalSplit.all):
+        return self.coordinates.coordinates_values(split=split)
 
     def maxima_gev(self, split: SpatialTemporalSplit = SpatialTemporalSplit.all) -> np.ndarray:
         return self.temporal_observations.maxima_gev(split, self.spatio_temporal_slicer)
diff --git a/spatio_temporal_dataset/dataset/simulation_dataset.py b/spatio_temporal_dataset/dataset/simulation_dataset.py
index a635baa7..c7140a7e 100644
--- a/spatio_temporal_dataset/dataset/simulation_dataset.py
+++ b/spatio_temporal_dataset/dataset/simulation_dataset.py
@@ -26,16 +26,18 @@ class SimulatedDataset(AbstractDataset):
 class MaxStableDataset(SimulatedDataset):
 
     @classmethod
-    def from_sampling(cls, nb_obs: int, max_stable_model: AbstractMaxStableModel, coordinates: AbstractCoordinates):
-        temporal_obs = MaxStableAnnualMaxima.from_sampling(nb_obs, max_stable_model, coordinates)
+    def from_sampling(cls, nb_obs: int, max_stable_model: AbstractMaxStableModel, coordinates: AbstractCoordinates,
+                      train_split_ratio: float = None):
+        temporal_obs = MaxStableAnnualMaxima.from_sampling(nb_obs, max_stable_model, coordinates, train_split_ratio)
         return cls(temporal_observations=temporal_obs, coordinates=coordinates, max_stable_model=max_stable_model)
 
 
 class MarginDataset(SimulatedDataset):
 
     @classmethod
-    def from_sampling(cls, nb_obs: int, margin_model: AbstractMarginModel, coordinates: AbstractCoordinates):
-        temporal_obs = MarginAnnualMaxima.from_sampling(nb_obs, coordinates, margin_model)
+    def from_sampling(cls, nb_obs: int, margin_model: AbstractMarginModel, coordinates: AbstractCoordinates,
+                      train_split_ratio: float = None):
+        temporal_obs = MarginAnnualMaxima.from_sampling(nb_obs, coordinates, margin_model, train_split_ratio)
         return cls(temporal_observations=temporal_obs, coordinates=coordinates, margin_model=margin_model)
 
 
@@ -44,8 +46,9 @@ class FullSimulatedDataset(SimulatedDataset):
     @classmethod
     def from_double_sampling(cls, nb_obs: int, max_stable_model: AbstractMaxStableModel,
                              coordinates: AbstractCoordinates,
-                             margin_model: AbstractMarginModel):
+                             margin_model: AbstractMarginModel,
+                             train_split_ratio: float = None):
         temporal_obs = FullAnnualMaxima.from_double_sampling(nb_obs, max_stable_model,
-                                                             coordinates, margin_model)
+                                                             coordinates, margin_model, train_split_ratio)
         return cls(temporal_observations=temporal_obs, coordinates=coordinates, max_stable_model=max_stable_model,
                    margin_model=margin_model)
diff --git a/spatio_temporal_dataset/dataset/spatio_temporal_split.py b/spatio_temporal_dataset/dataset/spatio_temporal_split.py
index 599e656d..3d7d210a 100644
--- a/spatio_temporal_dataset/dataset/spatio_temporal_split.py
+++ b/spatio_temporal_dataset/dataset/spatio_temporal_split.py
@@ -17,7 +17,9 @@ class SpatioTemporalSlicer(object):
         self.index_train_ind = coordinate_train_ind  # type: pd.Series
         self.column_train_ind = observation_train_ind  # type: pd.Series
         if self.ind_are_not_defined:
-            assert self.index_train_ind is None and self.column_train_ind is None, "One split was not defined"
+            msg = "One split was not defined \n \n" \
+                  "index: \n {}  \n, column:\n {} \n".format(self.index_train_ind, self.column_train_ind)
+            assert self.index_train_ind is None and self.column_train_ind is None, msg
 
     @property
     def index_test_ind(self) -> pd.Series:
@@ -36,8 +38,8 @@ class SpatioTemporalSlicer(object):
         # By default, if one of the two split is not defined we return all the data
         if self.ind_are_not_defined or split is SpatialTemporalSplit.all:
             return df
-        assert df.columns == self.column_train_ind.index
-        assert df.index == self.index_train_ind.index
+        assert pd.RangeIndex.equals(df.columns, self.column_train_ind.index)
+        assert pd.RangeIndex.equals(df.index, self.index_train_ind.index)
         if split is SpatialTemporalSplit.train:
             return df.loc[self.index_train_ind, self.column_train_ind]
         elif split is SpatialTemporalSplit.test:
@@ -46,3 +48,29 @@ class SpatioTemporalSlicer(object):
             return df.loc[self.index_test_ind, self.column_train_ind]
         elif split is SpatialTemporalSplit.test_temporal:
             return df.loc[self.index_train_ind, self.column_test_ind]
+
+
+SPLIT_NAME = 'split'
+TRAIN_SPLIT_STR = 'train_split'
+TEST_SPLIT_STR = 'test_split'
+
+
+def train_ind_from_s_split(s_split):
+    """
+
+    :param s_split:
+    :return:
+    """
+    if s_split is None:
+        return None
+    else:
+        return s_split.isin([TRAIN_SPLIT_STR])
+
+
+def s_split_from_ratio(length, train_split_ratio):
+    assert 0 < train_split_ratio < 1
+    s = pd.Series([TEST_SPLIT_STR for _ in range(length)])
+    nb_points_train = int(length * train_split_ratio)
+    train_ind = pd.Series.sample(s, n=nb_points_train).index
+    s.loc[train_ind] = TRAIN_SPLIT_STR
+    return s
diff --git a/spatio_temporal_dataset/temporal_observations/abstract_temporal_observations.py b/spatio_temporal_dataset/temporal_observations/abstract_temporal_observations.py
index 8a19a123..86926ce1 100644
--- a/spatio_temporal_dataset/temporal_observations/abstract_temporal_observations.py
+++ b/spatio_temporal_dataset/temporal_observations/abstract_temporal_observations.py
@@ -1,35 +1,45 @@
 import pandas as pd
 import numpy as np
 
-from spatio_temporal_dataset.dataset.spatio_temporal_split import SpatialTemporalSplit, SpatioTemporalSlicer
+from spatio_temporal_dataset.dataset.spatio_temporal_split import SpatialTemporalSplit, SpatioTemporalSlicer, \
+    train_ind_from_s_split, TEST_SPLIT_STR, TRAIN_SPLIT_STR, s_split_from_ratio
 
 
 class AbstractTemporalObservations(object):
 
-    # Constants for the split column
-    TRAIN_SPLIT_STR = 'train_split'
-    TEST_SPLIT_STR = 'test_split'
-
     def __init__(self, df_maxima_frech: pd.DataFrame = None, df_maxima_gev: pd.DataFrame = None,
-                 s_split: pd.Series = None):
+                 s_split: pd.Series = None, train_split_ratio: float = None):
         """
         Main attribute of the class is the DataFrame df_maxima
         Index are stations index
         Columns are the temporal moment of the maxima
         """
-        if s_split is not None:
-            assert s_split.isin([self.TRAIN_SPLIT_STR, self.TEST_SPLIT_STR])
-        self.s_split = s_split
+        assert df_maxima_frech is not None or df_maxima_gev is not None
         self.df_maxima_frech = df_maxima_frech
         self.df_maxima_gev = df_maxima_gev
 
+        if s_split is not None and train_split_ratio is not None:
+            raise AttributeError('A split is already defined, there is no need to specify a ratio')
+        elif s_split is not None or train_split_ratio is not None:
+            if train_split_ratio:
+                s_split = s_split_from_ratio(length=self.nb_obs, train_split_ratio=train_split_ratio)
+            assert s_split.isin([TRAIN_SPLIT_STR, TEST_SPLIT_STR]).all()
+        self.s_split = s_split
+
+    @property
+    def nb_obs(self):
+        if self.df_maxima_frech is not None:
+            return len(self.df_maxima_frech.columns)
+        else:
+            return len(self.df_maxima_gev.columns)
+
     @classmethod
     def from_df(cls, df):
         pass
 
     @staticmethod
     def df_maxima(df: pd.DataFrame, split: SpatialTemporalSplit = SpatialTemporalSplit.all,
-                  slicer: SpatioTemporalSlicer = None):
+                  slicer: SpatioTemporalSlicer = None) -> pd.DataFrame:
         if slicer is None:
             assert split is SpatialTemporalSplit.all
             return df
@@ -49,7 +59,4 @@ class AbstractTemporalObservations(object):
 
     @property
     def train_ind(self) -> pd.Series:
-        if self.s_split is None:
-            return None
-        else:
-            return self.s_split.isin([self.TRAIN_SPLIT_STR])
+        return train_ind_from_s_split(s_split=self.s_split)
diff --git a/spatio_temporal_dataset/temporal_observations/annual_maxima_observations.py b/spatio_temporal_dataset/temporal_observations/annual_maxima_observations.py
index a197687c..ecb86156 100644
--- a/spatio_temporal_dataset/temporal_observations/annual_maxima_observations.py
+++ b/spatio_temporal_dataset/temporal_observations/annual_maxima_observations.py
@@ -3,6 +3,7 @@ import pandas as pd
 from extreme_estimator.extreme_models.margin_model.abstract_margin_model import AbstractMarginModel
 from extreme_estimator.extreme_models.max_stable_model.abstract_max_stable_model import AbstractMaxStableModel
 from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
+from spatio_temporal_dataset.dataset.spatio_temporal_split import s_split_from_ratio
 from spatio_temporal_dataset.temporal_observations.abstract_temporal_observations import AbstractTemporalObservations
 
 
@@ -18,29 +19,31 @@ class MarginAnnualMaxima(AnnualMaxima):
 
     @classmethod
     def from_sampling(cls, nb_obs: int, coordinates: AbstractCoordinates,
-                      margin_model: AbstractMarginModel):
-        maxima_gev = margin_model.rmargin_from_nb_obs(nb_obs=nb_obs, coordinates_values=coordinates.coordinates_values)
+                      margin_model: AbstractMarginModel, train_split_ratio: float = None):
+        maxima_gev = margin_model.rmargin_from_nb_obs(nb_obs=nb_obs, coordinates_values=coordinates.coordinates_values())
         df_maxima_gev = pd.DataFrame(data=maxima_gev, index=coordinates.index)
-        return cls(df_maxima_gev=df_maxima_gev)
+        return cls(df_maxima_gev=df_maxima_gev, train_split_ratio=train_split_ratio)
 
 
-class MaxStableAnnualMaxima(AbstractTemporalObservations):
+class MaxStableAnnualMaxima(AnnualMaxima):
 
     @classmethod
-    def from_sampling(cls, nb_obs: int, max_stable_model: AbstractMaxStableModel, coordinates: AbstractCoordinates):
-        maxima_frech = max_stable_model.rmaxstab(nb_obs=nb_obs, coordinates=coordinates.coordinates_values)
+    def from_sampling(cls, nb_obs: int, max_stable_model: AbstractMaxStableModel, coordinates: AbstractCoordinates,
+                      train_split_ratio: float = None):
+        maxima_frech = max_stable_model.rmaxstab(nb_obs=nb_obs, coordinates_values=coordinates.coordinates_values())
         df_maxima_frech = pd.DataFrame(data=maxima_frech, index=coordinates.index)
-        return cls(df_maxima_frech=df_maxima_frech)
+        return cls(df_maxima_frech=df_maxima_frech, train_split_ratio=train_split_ratio)
 
 
 class FullAnnualMaxima(MaxStableAnnualMaxima):
 
     @classmethod
     def from_double_sampling(cls, nb_obs: int, max_stable_model: AbstractMaxStableModel,
-                             coordinates: AbstractCoordinates, margin_model: AbstractMarginModel):
-        max_stable_annual_maxima = super().from_sampling(nb_obs, max_stable_model, coordinates)
+                             coordinates: AbstractCoordinates, margin_model: AbstractMarginModel,
+                             train_split_ratio: float = None):
+        max_stable_annual_maxima = super().from_sampling(nb_obs, max_stable_model, coordinates, train_split_ratio)
         #  Compute df_maxima_gev from df_maxima_frech
         maxima_gev = margin_model.rmargin_from_maxima_frech(maxima_frech=max_stable_annual_maxima.maxima_frech(),
-                                                            coordinates_values=coordinates.coordinates_values)
+                                                            coordinates_values=coordinates.coordinates_values())
         max_stable_annual_maxima.df_maxima_gev = pd.DataFrame(data=maxima_gev, index=coordinates.index)
         return max_stable_annual_maxima
-- 
GitLab