From a1a75f0111acf7a3a9cb20521434be7c58d68821 Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Thu, 29 Nov 2018 17:53:29 +0100
Subject: [PATCH] [DATASET] rename observations field

---
 .../regression_margin/regression_margin.py    | 12 ++++----
 .../spatial_robustness/alps_msp_robustness.py |  2 +-
 .../unidimensional_robustness.py              |  2 +-
 .../estimator/abstract_estimator.py           |  2 +-
 extreme_estimator/estimator/full_estimator.py |  9 +++---
 .../estimator/margin_estimator.py             |  4 +--
 .../estimator/max_stable_estimator.py         |  7 ++---
 .../abstract_margin_function.py               | 26 +++++++++++------
 .../coordinates/abstract_coordinates.py       |  2 +-
 .../dataset/abstract_dataset.py               | 24 ++++++++--------
 .../dataset/simulation_dataset.py             | 20 ++++++-------
 .../__init__.py                               |  0
 .../abstract_spatio_temporal_observations.py} | 21 ++++----------
 .../alps_precipitation_observations.py        |  5 ++++
 .../annual_maxima_observations.py             |  5 ++--
 .../{dataset => }/spatio_temporal_split.py    | 28 +++++++++++++------
 .../alps_precipitation_observations.py        |  5 ----
 .../test_temporal_observations.py             |  4 +--
 .../test_rmaxstab_without_margin.py           |  2 +-
 19 files changed, 94 insertions(+), 86 deletions(-)
 rename spatio_temporal_dataset/{temporal_observations => spatio_temporal_observations}/__init__.py (100%)
 rename spatio_temporal_dataset/{temporal_observations/abstract_temporal_observations.py => spatio_temporal_observations/abstract_spatio_temporal_observations.py} (72%)
 create mode 100644 spatio_temporal_dataset/spatio_temporal_observations/alps_precipitation_observations.py
 rename spatio_temporal_dataset/{temporal_observations => spatio_temporal_observations}/annual_maxima_observations.py (90%)
 rename spatio_temporal_dataset/{dataset => }/spatio_temporal_split.py (71%)
 delete mode 100644 spatio_temporal_dataset/temporal_observations/alps_precipitation_observations.py

diff --git a/experiment/regression_margin/regression_margin.py b/experiment/regression_margin/regression_margin.py
index ee49fa93..48449e7c 100644
--- a/experiment/regression_margin/regression_margin.py
+++ b/experiment/regression_margin/regression_margin.py
@@ -13,9 +13,9 @@ import matplotlib.pyplot as plt
 
 from spatio_temporal_dataset.dataset.simulation_dataset import FullSimulatedDataset
 
-nb_points = 2
-nb_obs = 10000
-nb_estimator = 1
+nb_points = 50
+nb_obs = 60
+nb_estimator = 2
 show = False
 
 coordinates = LinSpaceCoordinates.from_nb_points(nb_points=nb_points)
@@ -30,7 +30,7 @@ params_sample = {
     (GevParams.GEV_SCALE, 0): 1.0,
 }
 margin_model = ConstantMarginModel(coordinates=coordinates, params_sample=params_sample)
-margin_model_for_estimator_class = [LinearAllParametersAllDimsMarginModel][-1]
+margin_model_for_estimator_class = [LinearAllParametersAllDimsMarginModel, ConstantMarginModel][-1]
 max_stable_model = Smith()
 
 
@@ -47,8 +47,8 @@ for i in range(nb_estimator):
 
     if show and i == 0:
         # Plot a realization from the maxima gev (i.e the maxima obtained just by simulating the marginal law)
-        for maxima in np.transpose(dataset.maxima_frech):
-            plt.plot(coordinates.coordinates_values, maxima, 'o')
+        for maxima in np.transpose(dataset.maxima_frech()):
+            plt.plot(coordinates.coordinates_values(), maxima, 'o')
         plt.show()
 
     margin_function_sample = dataset.margin_model.margin_function_sample # type: LinearMarginFunction
diff --git a/experiment/robustness_plot/estimation_robustness/spatial_robustness/alps_msp_robustness.py b/experiment/robustness_plot/estimation_robustness/spatial_robustness/alps_msp_robustness.py
index c95238e8..920f2580 100644
--- a/experiment/robustness_plot/estimation_robustness/spatial_robustness/alps_msp_robustness.py
+++ b/experiment/robustness_plot/estimation_robustness/spatial_robustness/alps_msp_robustness.py
@@ -35,7 +35,7 @@ def multiple_spatial_robustness_alps():
         plot_row_item=MaxStableProcessPlot.NbStationItem,
         plot_label_item=MaxStableProcessPlot.MaxStableModelItem,
         nb_samples=nb_sample,
-        main_title="Max stable analysis with {} years of temporal_observations".format(nb_observation),
+        main_title="Max stable analysis with {} years of spatio_temporal_observations".format(nb_observation),
         plot_png_filename=plot_name
     )
     # Load all the models
diff --git a/experiment/robustness_plot/estimation_robustness/unidimensional_robustness/unidimensional_robustness.py b/experiment/robustness_plot/estimation_robustness/unidimensional_robustness/unidimensional_robustness.py
index 34703d67..05ce69c6 100644
--- a/experiment/robustness_plot/estimation_robustness/unidimensional_robustness/unidimensional_robustness.py
+++ b/experiment/robustness_plot/estimation_robustness/unidimensional_robustness/unidimensional_robustness.py
@@ -35,7 +35,7 @@ def multiple_unidimensional_robustness():
         plot_row_item=MaxStableProcessPlot.NbStationItem,
         plot_label_item=MaxStableProcessPlot.MaxStableModelItem,
         nb_samples=nb_sample,
-        main_title="Max stable analysis with {} years of temporal_observations".format(nb_observation),
+        main_title="Max stable analysis with {} years of spatio_temporal_observations".format(nb_observation),
         plot_png_filename=plot_name
     )
     # Load all the models
diff --git a/extreme_estimator/estimator/abstract_estimator.py b/extreme_estimator/estimator/abstract_estimator.py
index d5d36f14..ac58f74a 100644
--- a/extreme_estimator/estimator/abstract_estimator.py
+++ b/extreme_estimator/estimator/abstract_estimator.py
@@ -1,7 +1,7 @@
 import time
 
 from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
-from spatio_temporal_dataset.dataset.spatio_temporal_split import SpatialTemporalSplit
+from spatio_temporal_dataset.spatio_temporal_split import SpatialTemporalSplit
 
 
 class AbstractEstimator(object):
diff --git a/extreme_estimator/estimator/full_estimator.py b/extreme_estimator/estimator/full_estimator.py
index d60d1776..42d870ac 100644
--- a/extreme_estimator/estimator/full_estimator.py
+++ b/extreme_estimator/estimator/full_estimator.py
@@ -1,14 +1,13 @@
+from extreme_estimator.estimator.abstract_estimator import AbstractEstimator
+from extreme_estimator.estimator.margin_estimator import SmoothMarginEstimator
+from extreme_estimator.estimator.max_stable_estimator import MaxStableEstimator
 from extreme_estimator.extreme_models.margin_model.abstract_margin_model import AbstractMarginModel
 from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \
     AbstractMarginFunction
 from extreme_estimator.extreme_models.margin_model.margin_function.linear_margin_function import LinearMarginFunction
 from extreme_estimator.extreme_models.margin_model.smooth_margin_model import LinearMarginModel
 from extreme_estimator.extreme_models.max_stable_model.abstract_max_stable_model import AbstractMaxStableModel
-from extreme_estimator.estimator.abstract_estimator import AbstractEstimator
-from extreme_estimator.estimator.margin_estimator import SmoothMarginEstimator
-from extreme_estimator.estimator.max_stable_estimator import MaxStableEstimator
 from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
-from spatio_temporal_dataset.dataset.spatio_temporal_split import SpatialTemporalSplit
 
 
 class AbstractFullEstimator(AbstractEstimator):
@@ -47,7 +46,7 @@ class SmoothMarginalsThenUnitaryMsp(AbstractFullEstimator):
                                                      coordinates_values=self.dataset.coordinates_values,
                                                      margin_function=self.margin_estimator.margin_function_fitted)
         # Update maxima frech field through the dataset object
-        self.dataset.set_maxima_frech(maxima_frech, split=SpatialTemporalSplit.train)
+        self.dataset.set_maxima_frech(maxima_frech, split=self.train_split)
         # Estimate the max stable parameters
         self.max_stable_estimator.fit()
 
diff --git a/extreme_estimator/estimator/margin_estimator.py b/extreme_estimator/estimator/margin_estimator.py
index 6941fecf..86e8e42c 100644
--- a/extreme_estimator/estimator/margin_estimator.py
+++ b/extreme_estimator/estimator/margin_estimator.py
@@ -1,10 +1,8 @@
-from extreme_estimator.extreme_models.margin_model.abstract_margin_model import AbstractMarginModel
 from extreme_estimator.estimator.abstract_estimator import AbstractEstimator
 from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \
     AbstractMarginFunction
 from extreme_estimator.extreme_models.margin_model.smooth_margin_model import LinearMarginModel
 from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
-from spatio_temporal_dataset.dataset.spatio_temporal_split import SpatialTemporalSplit
 
 
 class AbstractMarginEstimator(AbstractEstimator):
@@ -33,7 +31,7 @@ class SmoothMarginEstimator(AbstractMarginEstimator):
         self.margin_model = margin_model
 
     def _fit(self):
-        maxima_gev = self.dataset.maxima_gev(split=SpatialTemporalSplit.train)
+        maxima_gev = self.dataset.maxima_gev(split=self.train_split)
         corodinate_values = self.dataset.coordinates_values
         self._margin_function_fitted = self.margin_model.fitmargin_from_maxima_gev(maxima_gev=maxima_gev,
                                                                                    coordinates_values=corodinate_values)
diff --git a/extreme_estimator/estimator/max_stable_estimator.py b/extreme_estimator/estimator/max_stable_estimator.py
index c1dac517..9fe8d2e8 100644
--- a/extreme_estimator/estimator/max_stable_estimator.py
+++ b/extreme_estimator/estimator/max_stable_estimator.py
@@ -1,9 +1,8 @@
+import numpy as np
+
 from extreme_estimator.estimator.abstract_estimator import AbstractEstimator
 from extreme_estimator.extreme_models.max_stable_model.abstract_max_stable_model import AbstractMaxStableModel
 from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
-import numpy as np
-
-from spatio_temporal_dataset.dataset.spatio_temporal_split import SpatialTemporalSplit
 
 
 class AbstractMaxStableEstimator(AbstractEstimator):
@@ -18,7 +17,7 @@ class AbstractMaxStableEstimator(AbstractEstimator):
 class MaxStableEstimator(AbstractMaxStableEstimator):
 
     def _fit(self):
-        assert self.dataset.maxima_frech(split=SpatialTemporalSplit.train) is not None
+        assert self.dataset.maxima_frech(split=self.train_split) is not None
         self.max_stable_params_fitted = self.max_stable_model.fitmaxstab(
             maxima_frech=self.dataset.maxima_frech(split=self.train_split),
             df_coordinates=self.dataset.df_coordinates(split=self.train_split))
diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
index 1ded02a9..3a19e9cb 100644
--- a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
+++ b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
@@ -4,6 +4,7 @@ import numpy as np
 
 from extreme_estimator.gev_params import GevParams
 from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
+from spatio_temporal_dataset.spatio_temporal_split import SpatialTemporalSplit
 
 
 class AbstractMarginFunction(object):
@@ -11,8 +12,12 @@ class AbstractMarginFunction(object):
 
     def __init__(self, coordinates: AbstractCoordinates):
         self.coordinates = coordinates
+
+        # Visualization parameters
         self.visualization_axes = None
-        self.dot_display = False
+        self.datapoint_display = False
+        self.spatio_temporal_split = SpatialTemporalSplit.all
+        self.datapoint_marker = 'o'
 
     def get_gev_params(self, coordinate: np.ndarray) -> GevParams:
         """Main method that maps each coordinate to its GEV parameters"""
@@ -20,8 +25,13 @@ class AbstractMarginFunction(object):
 
     # Visualization function
 
+    def set_datapoint_display_parameters(self, spatio_temporal_split, datapoint_marker):
+        self.datapoint_display = True
+        self.spatio_temporal_split = spatio_temporal_split
+        self.datapoint_marker = datapoint_marker
+
     def visualize(self, axes=None, show=True, dot_display=False):
-        self.dot_display = dot_display
+        self.datapoint_display = dot_display
         if axes is None:
             fig, axes = plt.subplots(3, 1, sharex='col', sharey='row')
             fig.subplots_adjust(hspace=0.4, wspace=0.4, )
@@ -48,8 +58,8 @@ class AbstractMarginFunction(object):
         gev_param_idx = GevParams.GEV_PARAM_NAMES.index(gev_param_name)
         if ax is None:
             ax = plt.gca()
-        if self.dot_display:
-            ax.plot(linspace, grid[:, gev_param_idx], 'o')
+        if self.datapoint_display:
+            ax.plot(linspace, grid[:, gev_param_idx], self.datapoint_marker)
         else:
             ax.plot(linspace, grid[:, gev_param_idx])
 
@@ -66,15 +76,15 @@ class AbstractMarginFunction(object):
         imshow_method = ax.imshow
         imshow_method(grid[..., gev_param_idx], extent=(x.min(), x.max(), y.min(), y.max()),
                       interpolation='nearest', cmap=cm.gist_rainbow)
+        # todo: add dot display in 2D
         if show:
             plt.show()
 
     def get_grid_1D(self, x):
         # TODO: to avoid getting the value several times, I could cache the results
-        if self.dot_display:
-            resolution = len(self.coordinates)
-            linspace = self.coordinates.coordinates_values()[:, 0]
-            print('dot display')
+        if self.datapoint_display:
+            linspace = self.coordinates.coordinates_values(self.spatio_temporal_split)[:, 0]
+            resolution = len(linspace)
         else:
             resolution = 100
             linspace = np.linspace(x.min(), x.max(), resolution)
diff --git a/spatio_temporal_dataset/coordinates/abstract_coordinates.py b/spatio_temporal_dataset/coordinates/abstract_coordinates.py
index 4cf3fd6f..f32b2555 100644
--- a/spatio_temporal_dataset/coordinates/abstract_coordinates.py
+++ b/spatio_temporal_dataset/coordinates/abstract_coordinates.py
@@ -6,7 +6,7 @@ import numpy as np
 import pandas as pd
 from mpl_toolkits.mplot3d import Axes3D
 
-from spatio_temporal_dataset.dataset.spatio_temporal_split import s_split_from_ratio, TEST_SPLIT_STR, \
+from spatio_temporal_dataset.spatio_temporal_split import s_split_from_ratio, TEST_SPLIT_STR, \
     TRAIN_SPLIT_STR, train_ind_from_s_split, SpatialTemporalSplit
 
 
diff --git a/spatio_temporal_dataset/dataset/abstract_dataset.py b/spatio_temporal_dataset/dataset/abstract_dataset.py
index ca1feef9..6cbf1ae0 100644
--- a/spatio_temporal_dataset/dataset/abstract_dataset.py
+++ b/spatio_temporal_dataset/dataset/abstract_dataset.py
@@ -3,26 +3,26 @@ import numpy as np
 import os.path as op
 import pandas as pd
 
-from spatio_temporal_dataset.dataset.spatio_temporal_split import SpatialTemporalSplit, SpatioTemporalSlicer
-from spatio_temporal_dataset.temporal_observations.abstract_temporal_observations import AbstractTemporalObservations
+from spatio_temporal_dataset.spatio_temporal_split import SpatialTemporalSplit, SpatioTemporalSlicer
+from spatio_temporal_dataset.spatio_temporal_observations.abstract_spatio_temporal_observations import AbstractSpatioTemporalObservations
 from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
 
 
 class AbstractDataset(object):
 
-    def __init__(self, temporal_observations: AbstractTemporalObservations, coordinates: AbstractCoordinates):
-        # is_same_index = temporal_observations.index == coordinates.index  # type: pd.Series
+    def __init__(self, observations: AbstractSpatioTemporalObservations, coordinates: AbstractCoordinates):
+        # is_same_index = spatio_temporal_observations.index == coordinates.index  # type: pd.Series
         # assert is_same_index.all()
-        self.temporal_observations = temporal_observations
+        self.observations = observations
         self.coordinates = coordinates
-        self.spatio_temporal_slicer = SpatioTemporalSlicer(coordinate_train_ind=self.coordinates.train_ind,
-                                                           observation_train_ind=self.temporal_observations.train_ind)
+        self.spatio_temporal_slicer = SpatioTemporalSlicer(coordinates_train_ind=self.coordinates.train_ind,
+                                                           observations_train_ind=self.observations.train_ind)
 
     @classmethod
     def from_csv(cls, csv_path: str):
         assert op.exists(csv_path)
         df = pd.read_csv(csv_path)
-        temporal_maxima = AbstractTemporalObservations.from_df(df)
+        temporal_maxima = AbstractSpatioTemporalObservations.from_df(df)
         coordinates = AbstractCoordinates.from_df(df)
         return cls(temporal_maxima, coordinates)
 
@@ -36,7 +36,7 @@ class AbstractDataset(object):
     def df_dataset(self) -> pd.DataFrame:
         # Merge dataframes with the maxima and with the coordinates
         # todo: maybe I should add the split from the temporal observations
-        return self.temporal_observations.df_maxima_gev.join(self.coordinates.df_merged)
+        return self.observations.df_maxima_gev.join(self.coordinates.df_merged)
 
     def df_coordinates(self, split: SpatialTemporalSplit = SpatialTemporalSplit.all):
         return self.coordinates.df_coordinates(split=split)
@@ -46,10 +46,10 @@ class AbstractDataset(object):
         return self.coordinates.coordinates_values(split=split)
 
     def maxima_gev(self, split: SpatialTemporalSplit = SpatialTemporalSplit.all) -> np.ndarray:
-        return self.temporal_observations.maxima_gev(split, self.spatio_temporal_slicer)
+        return self.observations.maxima_gev(split, self.spatio_temporal_slicer)
 
     def maxima_frech(self, split: SpatialTemporalSplit = SpatialTemporalSplit.all) -> np.ndarray:
-        return self.temporal_observations.maxima_frech(split, self.spatio_temporal_slicer)
+        return self.observations.maxima_frech(split, self.spatio_temporal_slicer)
 
     def set_maxima_frech(self, maxima_frech_values: np.ndarray, split: SpatialTemporalSplit = SpatialTemporalSplit.all):
-        self.temporal_observations.set_maxima_frech(maxima_frech_values, split, self.spatio_temporal_slicer)
\ No newline at end of file
+        self.observations.set_maxima_frech(maxima_frech_values, split, self.spatio_temporal_slicer)
\ No newline at end of file
diff --git a/spatio_temporal_dataset/dataset/simulation_dataset.py b/spatio_temporal_dataset/dataset/simulation_dataset.py
index c7140a7e..780d42b1 100644
--- a/spatio_temporal_dataset/dataset/simulation_dataset.py
+++ b/spatio_temporal_dataset/dataset/simulation_dataset.py
@@ -2,8 +2,8 @@ from extreme_estimator.extreme_models.margin_model.abstract_margin_model import
 from extreme_estimator.extreme_models.max_stable_model.abstract_max_stable_model import AbstractMaxStableModel
 from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
 from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
-from spatio_temporal_dataset.temporal_observations.abstract_temporal_observations import AbstractTemporalObservations
-from spatio_temporal_dataset.temporal_observations.annual_maxima_observations import \
+from spatio_temporal_dataset.spatio_temporal_observations.abstract_spatio_temporal_observations import AbstractSpatioTemporalObservations
+from spatio_temporal_dataset.spatio_temporal_observations.annual_maxima_observations import \
     MaxStableAnnualMaxima, AnnualMaxima, MarginAnnualMaxima, FullAnnualMaxima
 
 
@@ -13,11 +13,11 @@ class SimulatedDataset(AbstractDataset):
         -the max_stable_model AND/OR marginal_model that was used for sampling
     """
 
-    def __init__(self, temporal_observations: AbstractTemporalObservations,
+    def __init__(self, observations: AbstractSpatioTemporalObservations,
                  coordinates: AbstractCoordinates,
                  max_stable_model: AbstractMaxStableModel = None,
                  margin_model: AbstractMarginModel = None):
-        super().__init__(temporal_observations, coordinates)
+        super().__init__(observations, coordinates)
         assert margin_model is not None or max_stable_model is not None
         self.margin_model = margin_model  # type: AbstractMarginModel
         self.max_stable_model = max_stable_model  #  type: AbstractMaxStableModel
@@ -28,8 +28,8 @@ class MaxStableDataset(SimulatedDataset):
     @classmethod
     def from_sampling(cls, nb_obs: int, max_stable_model: AbstractMaxStableModel, coordinates: AbstractCoordinates,
                       train_split_ratio: float = None):
-        temporal_obs = MaxStableAnnualMaxima.from_sampling(nb_obs, max_stable_model, coordinates, train_split_ratio)
-        return cls(temporal_observations=temporal_obs, coordinates=coordinates, max_stable_model=max_stable_model)
+        observations = MaxStableAnnualMaxima.from_sampling(nb_obs, max_stable_model, coordinates, train_split_ratio)
+        return cls(observations=observations, coordinates=coordinates, max_stable_model=max_stable_model)
 
 
 class MarginDataset(SimulatedDataset):
@@ -37,8 +37,8 @@ class MarginDataset(SimulatedDataset):
     @classmethod
     def from_sampling(cls, nb_obs: int, margin_model: AbstractMarginModel, coordinates: AbstractCoordinates,
                       train_split_ratio: float = None):
-        temporal_obs = MarginAnnualMaxima.from_sampling(nb_obs, coordinates, margin_model, train_split_ratio)
-        return cls(temporal_observations=temporal_obs, coordinates=coordinates, margin_model=margin_model)
+        observations = MarginAnnualMaxima.from_sampling(nb_obs, coordinates, margin_model, train_split_ratio)
+        return cls(observations=observations, coordinates=coordinates, margin_model=margin_model)
 
 
 class FullSimulatedDataset(SimulatedDataset):
@@ -48,7 +48,7 @@ class FullSimulatedDataset(SimulatedDataset):
                              coordinates: AbstractCoordinates,
                              margin_model: AbstractMarginModel,
                              train_split_ratio: float = None):
-        temporal_obs = FullAnnualMaxima.from_double_sampling(nb_obs, max_stable_model,
+        observations = FullAnnualMaxima.from_double_sampling(nb_obs, max_stable_model,
                                                              coordinates, margin_model, train_split_ratio)
-        return cls(temporal_observations=temporal_obs, coordinates=coordinates, max_stable_model=max_stable_model,
+        return cls(observations=observations, coordinates=coordinates, max_stable_model=max_stable_model,
                    margin_model=margin_model)
diff --git a/spatio_temporal_dataset/temporal_observations/__init__.py b/spatio_temporal_dataset/spatio_temporal_observations/__init__.py
similarity index 100%
rename from spatio_temporal_dataset/temporal_observations/__init__.py
rename to spatio_temporal_dataset/spatio_temporal_observations/__init__.py
diff --git a/spatio_temporal_dataset/temporal_observations/abstract_temporal_observations.py b/spatio_temporal_dataset/spatio_temporal_observations/abstract_spatio_temporal_observations.py
similarity index 72%
rename from spatio_temporal_dataset/temporal_observations/abstract_temporal_observations.py
rename to spatio_temporal_dataset/spatio_temporal_observations/abstract_spatio_temporal_observations.py
index 86926ce1..13931677 100644
--- a/spatio_temporal_dataset/temporal_observations/abstract_temporal_observations.py
+++ b/spatio_temporal_dataset/spatio_temporal_observations/abstract_spatio_temporal_observations.py
@@ -1,11 +1,11 @@
 import pandas as pd
 import numpy as np
 
-from spatio_temporal_dataset.dataset.spatio_temporal_split import SpatialTemporalSplit, SpatioTemporalSlicer, \
-    train_ind_from_s_split, TEST_SPLIT_STR, TRAIN_SPLIT_STR, s_split_from_ratio
+from spatio_temporal_dataset.spatio_temporal_split import SpatialTemporalSplit, SpatioTemporalSlicer, \
+    train_ind_from_s_split, TEST_SPLIT_STR, TRAIN_SPLIT_STR, s_split_from_ratio, spatio_temporal_slice
 
 
-class AbstractTemporalObservations(object):
+class AbstractSpatioTemporalObservations(object):
 
     def __init__(self, df_maxima_frech: pd.DataFrame = None, df_maxima_gev: pd.DataFrame = None,
                  s_split: pd.Series = None, train_split_ratio: float = None):
@@ -37,24 +37,15 @@ class AbstractTemporalObservations(object):
     def from_df(cls, df):
         pass
 
-    @staticmethod
-    def df_maxima(df: pd.DataFrame, split: SpatialTemporalSplit = SpatialTemporalSplit.all,
-                  slicer: SpatioTemporalSlicer = None) -> pd.DataFrame:
-        if slicer is None:
-            assert split is SpatialTemporalSplit.all
-            return df
-        else:
-            return slicer.loc_split(df, split)
-
     def maxima_gev(self, split: SpatialTemporalSplit = SpatialTemporalSplit.all, slicer: SpatioTemporalSlicer = None):
-        return self.df_maxima(self.df_maxima_gev, split, slicer).values
+        return spatio_temporal_slice(self.df_maxima_gev, split, slicer).values
 
     def maxima_frech(self, split: SpatialTemporalSplit = SpatialTemporalSplit.all, slicer: SpatioTemporalSlicer = None):
-        return self.df_maxima(self.df_maxima_frech, split, slicer).values
+        return spatio_temporal_slice(self.df_maxima_frech, split, slicer).values
 
     def set_maxima_frech(self, maxima_frech_values: np.ndarray, split: SpatialTemporalSplit = SpatialTemporalSplit.all,
                          slicer: SpatioTemporalSlicer = None):
-        df = self.df_maxima(self.df_maxima_frech, split, slicer)
+        df = spatio_temporal_slice(self.df_maxima_frech, split, slicer)
         df.loc[:] = maxima_frech_values
 
     @property
diff --git a/spatio_temporal_dataset/spatio_temporal_observations/alps_precipitation_observations.py b/spatio_temporal_dataset/spatio_temporal_observations/alps_precipitation_observations.py
new file mode 100644
index 00000000..4311bc5f
--- /dev/null
+++ b/spatio_temporal_dataset/spatio_temporal_observations/alps_precipitation_observations.py
@@ -0,0 +1,5 @@
+from spatio_temporal_dataset.spatio_temporal_observations.abstract_spatio_temporal_observations import AbstractSpatioTemporalObservations
+
+
+class AlpsPrecipitationObservations(AbstractSpatioTemporalObservations):
+    pass
\ No newline at end of file
diff --git a/spatio_temporal_dataset/temporal_observations/annual_maxima_observations.py b/spatio_temporal_dataset/spatio_temporal_observations/annual_maxima_observations.py
similarity index 90%
rename from spatio_temporal_dataset/temporal_observations/annual_maxima_observations.py
rename to spatio_temporal_dataset/spatio_temporal_observations/annual_maxima_observations.py
index ecb86156..bd751dba 100644
--- a/spatio_temporal_dataset/temporal_observations/annual_maxima_observations.py
+++ b/spatio_temporal_dataset/spatio_temporal_observations/annual_maxima_observations.py
@@ -3,11 +3,10 @@ import pandas as pd
 from extreme_estimator.extreme_models.margin_model.abstract_margin_model import AbstractMarginModel
 from extreme_estimator.extreme_models.max_stable_model.abstract_max_stable_model import AbstractMaxStableModel
 from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
-from spatio_temporal_dataset.dataset.spatio_temporal_split import s_split_from_ratio
-from spatio_temporal_dataset.temporal_observations.abstract_temporal_observations import AbstractTemporalObservations
+from spatio_temporal_dataset.spatio_temporal_observations.abstract_spatio_temporal_observations import AbstractSpatioTemporalObservations
 
 
-class AnnualMaxima(AbstractTemporalObservations):
+class AnnualMaxima(AbstractSpatioTemporalObservations):
     """
     Index are stations index
     Columns are the annual of the maxima
diff --git a/spatio_temporal_dataset/dataset/spatio_temporal_split.py b/spatio_temporal_dataset/spatio_temporal_split.py
similarity index 71%
rename from spatio_temporal_dataset/dataset/spatio_temporal_split.py
rename to spatio_temporal_dataset/spatio_temporal_split.py
index 3d7d210a..3037b749 100644
--- a/spatio_temporal_dataset/dataset/spatio_temporal_split.py
+++ b/spatio_temporal_dataset/spatio_temporal_split.py
@@ -13,14 +13,22 @@ class SpatialTemporalSplit(Enum):
 
 class SpatioTemporalSlicer(object):
 
-    def __init__(self, coordinate_train_ind: pd.Series, observation_train_ind: pd.Series):
-        self.index_train_ind = coordinate_train_ind  # type: pd.Series
-        self.column_train_ind = observation_train_ind  # type: pd.Series
+    def __init__(self, coordinates_train_ind: pd.Series, observations_train_ind: pd.Series):
+        self.index_train_ind = coordinates_train_ind  # type: pd.Series
+        self.column_train_ind = observations_train_ind  # type: pd.Series
         if self.ind_are_not_defined:
             msg = "One split was not defined \n \n" \
                   "index: \n {}  \n, column:\n {} \n".format(self.index_train_ind, self.column_train_ind)
             assert self.index_train_ind is None and self.column_train_ind is None, msg
 
+    def summary(self):
+        print('SpatioTemporal split summary: \n')
+        for s, global_name in [(self.index_train_ind, "Spatial"), (self.column_train_ind, "Temporal")]:
+            print(global_name + ' split')
+            for f, name in [(len, 'Total'), (sum, 'train')]:
+                print("{}: {}".format(name, f(s)))
+            print('\n')
+
     @property
     def index_test_ind(self) -> pd.Series:
         return ~self.index_train_ind
@@ -56,11 +64,6 @@ TEST_SPLIT_STR = 'test_split'
 
 
 def train_ind_from_s_split(s_split):
-    """
-
-    :param s_split:
-    :return:
-    """
     if s_split is None:
         return None
     else:
@@ -74,3 +77,12 @@ def s_split_from_ratio(length, train_split_ratio):
     train_ind = pd.Series.sample(s, n=nb_points_train).index
     s.loc[train_ind] = TRAIN_SPLIT_STR
     return s
+
+
+def spatio_temporal_slice(df: pd.DataFrame, split: SpatialTemporalSplit = SpatialTemporalSplit.all,
+                          slicer: SpatioTemporalSlicer = None) -> pd.DataFrame:
+    if slicer is None:
+        assert split is SpatialTemporalSplit.all
+        return df
+    else:
+        return slicer.loc_split(df, split)
diff --git a/spatio_temporal_dataset/temporal_observations/alps_precipitation_observations.py b/spatio_temporal_dataset/temporal_observations/alps_precipitation_observations.py
deleted file mode 100644
index 80d6a8e0..00000000
--- a/spatio_temporal_dataset/temporal_observations/alps_precipitation_observations.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from spatio_temporal_dataset.temporal_observations.abstract_temporal_observations import AbstractTemporalObservations
-
-
-class AlpsPrecipitationObservations(AbstractTemporalObservations):
-    pass
\ No newline at end of file
diff --git a/test/test_spatio_temporal_dataset/test_temporal_observations.py b/test/test_spatio_temporal_dataset/test_temporal_observations.py
index 204b9606..f330476f 100644
--- a/test/test_spatio_temporal_dataset/test_temporal_observations.py
+++ b/test/test_spatio_temporal_dataset/test_temporal_observations.py
@@ -3,7 +3,7 @@ import numpy as np
 
 import pandas as pd
 
-from spatio_temporal_dataset.temporal_observations.abstract_temporal_observations import AbstractTemporalObservations
+from spatio_temporal_dataset.spatio_temporal_observations.abstract_spatio_temporal_observations import AbstractSpatioTemporalObservations
 
 
 class TestTemporalObservations(unittest.TestCase):
@@ -11,7 +11,7 @@ class TestTemporalObservations(unittest.TestCase):
 
     def test_set_maxima_gev(self):
         df = pd.DataFrame.from_dict({'ok': [2, 5]})
-        temporal_observation = AbstractTemporalObservations(df_maxima_frech=df)
+        temporal_observation = AbstractSpatioTemporalObservations(df_maxima_frech=df)
         example = np.array([[3], [6]])
         temporal_observation.set_maxima_frech(maxima_frech_values=example)
         maxima_frech = temporal_observation.maxima_frech()
diff --git a/test/test_unitary/test_rmaxstab/test_rmaxstab_without_margin.py b/test/test_unitary/test_rmaxstab/test_rmaxstab_without_margin.py
index 3d522b81..dd803e54 100644
--- a/test/test_unitary/test_rmaxstab/test_rmaxstab_without_margin.py
+++ b/test/test_unitary/test_rmaxstab/test_rmaxstab_without_margin.py
@@ -7,7 +7,7 @@ from extreme_estimator.extreme_models.max_stable_model.abstract_max_stable_model
 from extreme_estimator.extreme_models.max_stable_model.max_stable_models import Schlather
 from extreme_estimator.extreme_models.utils import r
 from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
-from spatio_temporal_dataset.temporal_observations.annual_maxima_observations import MaxStableAnnualMaxima
+from spatio_temporal_dataset.spatio_temporal_observations.annual_maxima_observations import MaxStableAnnualMaxima
 from test.test_unitary.test_unitary_abstract import TestUnitaryAbstract
 
 
-- 
GitLab