From 5ebb6469b93831984ad7890912e2757a35c29d3c Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Thu, 29 Nov 2018 13:53:29 +0100
Subject: [PATCH] [DATASET] add spatio_temporal_split. add test. refactor code
 accordingly.

---
 extreme_estimator/estimator/full_estimator.py |  8 ++-
 .../estimator/margin_estimator.py             |  9 ++-
 .../estimator/max_stable_estimator.py         |  6 +-
 .../coordinates/abstract_coordinates.py       | 67 ++++++++++---------
 .../dataset/abstract_dataset.py               | 20 +++---
 .../dataset/spatio_temporal_split.py          | 48 +++++++++++++
 .../abstract_temporal_observations.py         | 55 +++++++++------
 .../annual_maxima_observations.py             |  2 +-
 .../test_temporal_observations.py             | 13 ++++
 .../test_rmaxstab_with_margin.py              |  4 +-
 .../test_rmaxstab_without_margin.py           |  2 +-
 11 files changed, 158 insertions(+), 76 deletions(-)
 create mode 100644 spatio_temporal_dataset/dataset/spatio_temporal_split.py

diff --git a/extreme_estimator/estimator/full_estimator.py b/extreme_estimator/estimator/full_estimator.py
index 20eda73f..aac5229c 100644
--- a/extreme_estimator/estimator/full_estimator.py
+++ b/extreme_estimator/estimator/full_estimator.py
@@ -8,6 +8,7 @@ from extreme_estimator.estimator.abstract_estimator import AbstractEstimator
 from extreme_estimator.estimator.margin_estimator import SmoothMarginEstimator
 from extreme_estimator.estimator.max_stable_estimator import MaxStableEstimator
 from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
+from spatio_temporal_dataset.dataset.spatio_temporal_split import SpatialTemporalSplit
 
 
 class AbstractFullEstimator(AbstractEstimator):
@@ -41,11 +42,12 @@ class SmoothMarginalsThenUnitaryMsp(AbstractFullEstimator):
         # Estimate the margin parameters
         self.margin_estimator.fit()
         # Compute the maxima_frech
-        maxima_frech = AbstractMarginModel.gev2frech(maxima_gev=self.dataset.maxima_gev,
+        maxima_gev_train = self.dataset.maxima_gev(split=SpatialTemporalSplit.train)
+        maxima_frech = AbstractMarginModel.gev2frech(maxima_gev=maxima_gev_train,
                                                      coordinates_values=self.dataset.coordinates_values,
                                                      margin_function=self.margin_estimator.margin_function_fitted)
         # Update maxima frech field through the dataset object
-        self.dataset.maxima_frech = maxima_frech
+        self.dataset.set_maxima_frech(maxima_frech, split=SpatialTemporalSplit.train)
         # Estimate the max stable parameters
         self.max_stable_estimator.fit()
 
@@ -68,7 +70,7 @@ class FullEstimatorInASingleStepWithSmoothMargin(AbstractFullEstimator):
     def _fit(self):
         # Estimate both the margin and the max-stable structure
         self.full_params_fitted = self.max_stable_model.fitmaxstab(
-            maxima_gev=self.dataset.maxima_gev,
+            maxima_gev=self.dataset.maxima_gev(split=SpatialTemporalSplit.train),
             df_coordinates=self.dataset.df_coordinates,
             fit_marge=True,
             fit_marge_form_dict=self.linear_margin_function_to_fit.form_dict,
diff --git a/extreme_estimator/estimator/margin_estimator.py b/extreme_estimator/estimator/margin_estimator.py
index 4453d868..6941fecf 100644
--- a/extreme_estimator/estimator/margin_estimator.py
+++ b/extreme_estimator/estimator/margin_estimator.py
@@ -4,13 +4,14 @@ from extreme_estimator.extreme_models.margin_model.margin_function.abstract_marg
     AbstractMarginFunction
 from extreme_estimator.extreme_models.margin_model.smooth_margin_model import LinearMarginModel
 from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
+from spatio_temporal_dataset.dataset.spatio_temporal_split import SpatialTemporalSplit
 
 
 class AbstractMarginEstimator(AbstractEstimator):
 
     def __init__(self, dataset: AbstractDataset):
         super().__init__(dataset)
-        assert self.dataset.maxima_gev is not None
+        assert self.dataset.maxima_gev() is not None
         self._margin_function_fitted = None
 
     @property
@@ -32,5 +33,7 @@ class SmoothMarginEstimator(AbstractMarginEstimator):
         self.margin_model = margin_model
 
     def _fit(self):
-        self._margin_function_fitted = self.margin_model.fitmargin_from_maxima_gev(maxima_gev=self.dataset.maxima_gev,
-                                                                                   coordinates_values=self.dataset.coordinates_values)
+        maxima_gev = self.dataset.maxima_gev(split=SpatialTemporalSplit.train)
+        corodinate_values = self.dataset.coordinates_values
+        self._margin_function_fitted = self.margin_model.fitmargin_from_maxima_gev(maxima_gev=maxima_gev,
+                                                                                   coordinates_values=corodinate_values)
diff --git a/extreme_estimator/estimator/max_stable_estimator.py b/extreme_estimator/estimator/max_stable_estimator.py
index d83809a3..dff88648 100644
--- a/extreme_estimator/estimator/max_stable_estimator.py
+++ b/extreme_estimator/estimator/max_stable_estimator.py
@@ -3,6 +3,8 @@ from extreme_estimator.extreme_models.max_stable_model.abstract_max_stable_model
 from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
 import numpy as np
 
+from spatio_temporal_dataset.dataset.spatio_temporal_split import SpatialTemporalSplit
+
 
 class AbstractMaxStableEstimator(AbstractEstimator):
 
@@ -16,9 +18,9 @@ class AbstractMaxStableEstimator(AbstractEstimator):
 class MaxStableEstimator(AbstractMaxStableEstimator):
 
     def _fit(self):
-        assert self.dataset.maxima_frech is not None
+        assert self.dataset.maxima_frech(split=SpatialTemporalSplit.train) is not None
         self.max_stable_params_fitted = self.max_stable_model.fitmaxstab(
-            maxima_frech=self.dataset.maxima_frech,
+            maxima_frech=self.dataset.maxima_frech(split=SpatialTemporalSplit.train),
             df_coordinates=self.dataset.coordinates.df_coordinates)
 
     def _error(self, true_max_stable_params: dict):
diff --git a/spatio_temporal_dataset/coordinates/abstract_coordinates.py b/spatio_temporal_dataset/coordinates/abstract_coordinates.py
index 68f997d6..dc18803b 100644
--- a/spatio_temporal_dataset/coordinates/abstract_coordinates.py
+++ b/spatio_temporal_dataset/coordinates/abstract_coordinates.py
@@ -13,14 +13,16 @@ class AbstractCoordinates(object):
     COORDINATE_Y = 'coord_y'
     COORDINATE_Z = 'coord_z'
     COORDINATE_NAMES = [COORDINATE_X, COORDINATE_Y, COORDINATE_Z]
-    COORD_SPLIT = 'coord_split'
-    # Constants
+    COORDINATE_SPLIT = 'coord_split'
+    # Constants for the split column
     TRAIN_SPLIT_STR = 'train_split'
     TEST_SPLIT_STR = 'test_split'
 
     def __init__(self, df_coordinates: pd.DataFrame, s_split: pd.Series = None):
-        self.df_coordinates = df_coordinates
-        self.s_split = s_split
+        self.df_coordinates = df_coordinates  # type: pd.DataFrame
+        self.s_split = s_split  # type: pd.Series
+
+    # ClassMethod constructor
 
     @classmethod
     def from_df(cls, df: pd.DataFrame):
@@ -28,9 +30,30 @@ class AbstractCoordinates(object):
         assert cls.COORDINATE_X in df.columns
         df_coordinates = df.loc[:, cls.coordinates_columns(df)]
         # Potentially, a split column can be specified
-        s_split = df[cls.COORD_SPLIT] if cls.COORD_SPLIT in df.columns else None
+        s_split = df[cls.COORDINATE_SPLIT] if cls.COORDINATE_SPLIT in df.columns else None
+        if s_split is not None:
+            assert s_split.isin([cls.TRAIN_SPLIT_STR, cls.TEST_SPLIT_STR])
         return cls(df_coordinates=df_coordinates, s_split=s_split)
 
+    @classmethod
+    def from_csv(cls, csv_path: str = None):
+        assert csv_path is not None
+        assert op.exists(csv_path)
+        df = pd.read_csv(csv_path)
+        return cls.from_df(df)
+
+    @classmethod
+    def from_nb_points(cls, nb_points: int, **kwargs):
+        # Call the default class method from csv
+        coordinates = cls.from_csv()  # type: AbstractCoordinates
+        # Sample randomly nb_points coordinates
+        nb_coordinates = len(coordinates)
+        if nb_points > nb_coordinates:
+            raise Exception('Nb coordinates in csv: {} < Nb points desired: {}'.format(nb_coordinates, nb_points))
+        else:
+            df_sample = pd.DataFrame.sample(coordinates.df, n=nb_points)
+            return cls.from_df(df=df_sample)
+
     @classmethod
     def coordinates_columns(cls, df_coord: pd.DataFrame) -> List[str]:
         coord_columns = [cls.COORDINATE_X]
@@ -52,25 +75,6 @@ class AbstractCoordinates(object):
         # Merged DataFrame of df_coord and s_split
         return self.df_coordinates if self.s_split is None else self.df_coordinates.join(self.s_split)
 
-    @classmethod
-    def from_csv(cls, csv_path: str = None):
-        assert csv_path is not None
-        assert op.exists(csv_path)
-        df = pd.read_csv(csv_path)
-        return cls.from_df(df)
-
-    @classmethod
-    def from_nb_points(cls, nb_points: int, **kwargs):
-        # Call the default class method from csv
-        coordinates = cls.from_csv()  # type: AbstractCoordinates
-        # Sample randomly nb_points coordinates
-        nb_coordinates = len(coordinates)
-        if nb_points > nb_coordinates:
-            raise Exception('Nb coordinates in csv: {} < Nb points desired: {}'.format(nb_coordinates, nb_points))
-        else:
-            df_sample = pd.DataFrame.sample(coordinates.df, n=nb_points)
-            return cls.from_df(df=df_sample)
-
     def df_coordinates_split(self, split_str: str) -> pd.DataFrame:
         assert self.s_split is not None
         ind = self.s_split == split_str
@@ -92,16 +96,15 @@ class AbstractCoordinates(object):
         return self.df_coordinates.loc[:, self.COORDINATE_Y].values.copy()
 
     @property
-    def coordinates_train(self) -> np.ndarray:
-        return self._coordinates_values(df_coordinates=self.df_coordinates_split(self.TRAIN_SPLIT_STR))
-
-    @property
-    def coordinates_test(self) -> np.ndarray:
-        return self._coordinates_values(df_coordinates=self.df_coordinates_split(self.TEST_SPLIT_STR))
+    def index(self) -> pd.Series:
+        return self.df_coordinates.index
 
     @property
-    def index(self):
-        return self.df_coordinates.index
+    def train_ind(self) -> pd.Series:
+        if self.s_split is None:
+            return None
+        else:
+            return self.s_split.isin([self.TRAIN_SPLIT_STR])
 
     #  Visualization
 
diff --git a/spatio_temporal_dataset/dataset/abstract_dataset.py b/spatio_temporal_dataset/dataset/abstract_dataset.py
index f42453ef..f17d9610 100644
--- a/spatio_temporal_dataset/dataset/abstract_dataset.py
+++ b/spatio_temporal_dataset/dataset/abstract_dataset.py
@@ -2,6 +2,8 @@ import os
 import numpy as np
 import os.path as op
 import pandas as pd
+
+from spatio_temporal_dataset.dataset.spatio_temporal_split import SpatialTemporalSplit, SpatioTemporalSlicer
 from spatio_temporal_dataset.temporal_observations.abstract_temporal_observations import AbstractTemporalObservations
 from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
 
@@ -13,6 +15,8 @@ class AbstractDataset(object):
         # assert is_same_index.all()
         self.temporal_observations = temporal_observations
         self.coordinates = coordinates
+        self.spatio_temporal_slicer = SpatioTemporalSlicer(coordinate_train_ind=self.coordinates.train_ind,
+                                                           observation_train_ind=self.temporal_observations.train_ind)
 
     @classmethod
     def from_csv(cls, csv_path: str):
@@ -41,15 +45,11 @@ class AbstractDataset(object):
     def coordinates_values(self):
         return self.coordinates.coordinates_values
 
-    @property
-    def maxima_gev(self) -> np.ndarray:
-        return self.temporal_observations.maxima_gev
-
-    @property
-    def maxima_frech(self):
-        return self.temporal_observations.maxima_frech
+    def maxima_gev(self, split: SpatialTemporalSplit = SpatialTemporalSplit.all) -> np.ndarray:
+        return self.temporal_observations.maxima_gev(split, self.spatio_temporal_slicer)
 
-    @maxima_frech.setter
-    def maxima_frech(self, maxima_frech_to_set):
-        self.temporal_observations.maxima_frech = maxima_frech_to_set
+    def maxima_frech(self, split: SpatialTemporalSplit = SpatialTemporalSplit.all) -> np.ndarray:
+        return self.temporal_observations.maxima_frech(split, self.spatio_temporal_slicer)
 
+    def set_maxima_frech(self, maxima_frech_values: np.ndarray, split: SpatialTemporalSplit = SpatialTemporalSplit.all):
+        self.temporal_observations.set_maxima_frech(maxima_frech_values, split, self.spatio_temporal_slicer)
\ No newline at end of file
diff --git a/spatio_temporal_dataset/dataset/spatio_temporal_split.py b/spatio_temporal_dataset/dataset/spatio_temporal_split.py
new file mode 100644
index 00000000..599e656d
--- /dev/null
+++ b/spatio_temporal_dataset/dataset/spatio_temporal_split.py
@@ -0,0 +1,48 @@
+from enum import Enum
+
+import pandas as pd
+
+
+class SpatialTemporalSplit(Enum):
+    all = 0
+    train = 1
+    test = 2
+    test_temporal = 3
+    test_spatial = 4
+
+
+class SpatioTemporalSlicer(object):
+
+    def __init__(self, coordinate_train_ind: pd.Series, observation_train_ind: pd.Series):
+        self.index_train_ind = coordinate_train_ind  # type: pd.Series
+        self.column_train_ind = observation_train_ind  # type: pd.Series
+        if self.ind_are_not_defined:
+            assert self.index_train_ind is None and self.column_train_ind is None, "One split was not defined"
+
+    @property
+    def index_test_ind(self) -> pd.Series:
+        return ~self.index_train_ind
+
+    @property
+    def column_test_ind(self) -> pd.Series:
+        return ~self.column_train_ind
+
+    @property
+    def ind_are_not_defined(self):
+        return self.index_train_ind is None or self.column_train_ind is None
+
+    def loc_split(self, df: pd.DataFrame, split: SpatialTemporalSplit):
+        assert isinstance(split, SpatialTemporalSplit)
+        # By default, if one of the two split is not defined we return all the data
+        if self.ind_are_not_defined or split is SpatialTemporalSplit.all:
+            return df
+        assert df.columns == self.column_train_ind.index
+        assert df.index == self.index_train_ind.index
+        if split is SpatialTemporalSplit.train:
+            return df.loc[self.index_train_ind, self.column_train_ind]
+        elif split is SpatialTemporalSplit.test:
+            return df.loc[self.index_test_ind, self.column_test_ind]
+        elif split is SpatialTemporalSplit.test_spatial:
+            return df.loc[self.index_test_ind, self.column_train_ind]
+        elif split is SpatialTemporalSplit.test_temporal:
+            return df.loc[self.index_train_ind, self.column_test_ind]
diff --git a/spatio_temporal_dataset/temporal_observations/abstract_temporal_observations.py b/spatio_temporal_dataset/temporal_observations/abstract_temporal_observations.py
index 34e1796a..8a19a123 100644
--- a/spatio_temporal_dataset/temporal_observations/abstract_temporal_observations.py
+++ b/spatio_temporal_dataset/temporal_observations/abstract_temporal_observations.py
@@ -1,14 +1,25 @@
 import pandas as pd
+import numpy as np
+
+from spatio_temporal_dataset.dataset.spatio_temporal_split import SpatialTemporalSplit, SpatioTemporalSlicer
 
 
 class AbstractTemporalObservations(object):
 
-    def __init__(self, df_maxima_frech: pd.DataFrame = None, df_maxima_gev: pd.DataFrame = None):
+    # Constants for the split column
+    TRAIN_SPLIT_STR = 'train_split'
+    TEST_SPLIT_STR = 'test_split'
+
+    def __init__(self, df_maxima_frech: pd.DataFrame = None, df_maxima_gev: pd.DataFrame = None,
+                 s_split: pd.Series = None):
         """
         Main attribute of the class is the DataFrame df_maxima
         Index are stations index
         Columns are the temporal moment of the maxima
         """
+        if s_split is not None:
+            assert s_split.isin([self.TRAIN_SPLIT_STR, self.TEST_SPLIT_STR])
+        self.s_split = s_split
         self.df_maxima_frech = df_maxima_frech
         self.df_maxima_gev = df_maxima_gev
 
@@ -16,29 +27,29 @@ class AbstractTemporalObservations(object):
     def from_df(cls, df):
         pass
 
-    @property
-    def maxima_gev(self):
-        return self.df_maxima_gev.values
+    @staticmethod
+    def df_maxima(df: pd.DataFrame, split: SpatialTemporalSplit = SpatialTemporalSplit.all,
+                  slicer: SpatioTemporalSlicer = None):
+        if slicer is None:
+            assert split is SpatialTemporalSplit.all
+            return df
+        else:
+            return slicer.loc_split(df, split)
 
-    @property
-    def maxima_frech(self):
-        return self.df_maxima_frech.values
+    def maxima_gev(self, split: SpatialTemporalSplit = SpatialTemporalSplit.all, slicer: SpatioTemporalSlicer = None):
+        return self.df_maxima(self.df_maxima_gev, split, slicer).values
 
-    @maxima_frech.setter
-    def maxima_frech(self, maxima_frech_to_set):
-        assert maxima_frech_to_set is not None
-        assert maxima_frech_to_set.shape == self.maxima_gev.shape
-        self.df_maxima_frech = pd.DataFrame(data=maxima_frech_to_set,
-                                            index=self.df_maxima_gev.index,
-                                            columns=self.df_maxima_gev.columns)
+    def maxima_frech(self, split: SpatialTemporalSplit = SpatialTemporalSplit.all, slicer: SpatioTemporalSlicer = None):
+        return self.df_maxima(self.df_maxima_frech, split, slicer).values
 
-    @property
-    def column_to_time_index(self):
-        pass
+    def set_maxima_frech(self, maxima_frech_values: np.ndarray, split: SpatialTemporalSplit = SpatialTemporalSplit.all,
+                         slicer: SpatioTemporalSlicer = None):
+        df = self.df_maxima(self.df_maxima_frech, split, slicer)
+        df.loc[:] = maxima_frech_values
 
     @property
-    def index(self):
-        return self.df_maxima_gev.index
-
-
-
+    def train_ind(self) -> pd.Series:
+        if self.s_split is None:
+            return None
+        else:
+            return self.s_split.isin([self.TRAIN_SPLIT_STR])
diff --git a/spatio_temporal_dataset/temporal_observations/annual_maxima_observations.py b/spatio_temporal_dataset/temporal_observations/annual_maxima_observations.py
index b159bbba..a197687c 100644
--- a/spatio_temporal_dataset/temporal_observations/annual_maxima_observations.py
+++ b/spatio_temporal_dataset/temporal_observations/annual_maxima_observations.py
@@ -40,7 +40,7 @@ class FullAnnualMaxima(MaxStableAnnualMaxima):
                              coordinates: AbstractCoordinates, margin_model: AbstractMarginModel):
         max_stable_annual_maxima = super().from_sampling(nb_obs, max_stable_model, coordinates)
         #  Compute df_maxima_gev from df_maxima_frech
-        maxima_gev = margin_model.rmargin_from_maxima_frech(maxima_frech=max_stable_annual_maxima.maxima_frech,
+        maxima_gev = margin_model.rmargin_from_maxima_frech(maxima_frech=max_stable_annual_maxima.maxima_frech(),
                                                             coordinates_values=coordinates.coordinates_values)
         max_stable_annual_maxima.df_maxima_gev = pd.DataFrame(data=maxima_gev, index=coordinates.index)
         return max_stable_annual_maxima
diff --git a/test/test_spatio_temporal_dataset/test_temporal_observations.py b/test/test_spatio_temporal_dataset/test_temporal_observations.py
index 3e0657da..204b9606 100644
--- a/test/test_spatio_temporal_dataset/test_temporal_observations.py
+++ b/test/test_spatio_temporal_dataset/test_temporal_observations.py
@@ -1,9 +1,22 @@
 import unittest
+import numpy as np
+
+import pandas as pd
+
+from spatio_temporal_dataset.temporal_observations.abstract_temporal_observations import AbstractTemporalObservations
 
 
 class TestTemporalObservations(unittest.TestCase):
     DISPLAY = False
 
+    def test_set_maxima_gev(self):
+        df = pd.DataFrame.from_dict({'ok': [2, 5]})
+        temporal_observation = AbstractTemporalObservations(df_maxima_frech=df)
+        example = np.array([[3], [6]])
+        temporal_observation.set_maxima_frech(maxima_frech_values=example)
+        maxima_frech = temporal_observation.maxima_frech()
+        self.assertTrue(np.equal(example, maxima_frech).all(), msg="{} {}".format(example, maxima_frech))
+
 
 if __name__ == '__main__':
     unittest.main()
diff --git a/test/test_unitary/test_rmaxstab/test_rmaxstab_with_margin.py b/test/test_unitary/test_rmaxstab/test_rmaxstab_with_margin.py
index fac7b393..1db26310 100644
--- a/test/test_unitary/test_rmaxstab/test_rmaxstab_with_margin.py
+++ b/test/test_unitary/test_rmaxstab/test_rmaxstab_with_margin.py
@@ -51,7 +51,7 @@ class TestRMaxStabWithMarginConstant(TestUnitaryAbstract):
     @property
     def python_output(self):
         dataset = self.python_code()
-        return np.sum(dataset.maxima_gev)
+        return np.sum(dataset.maxima_gev())
 
     def test_rmaxstab_with_constant_margin(self):
         self.compare()
@@ -96,7 +96,7 @@ class TestRMaxStabWithLinearMargin(TestUnitaryAbstract):
     @property
     def python_output(self):
         dataset = self.python_code()
-        return np.sum(dataset.maxima_gev)
+        return np.sum(dataset.maxima_gev())
 
     def test_rmaxstab_with_linear_margin(self):
         self.compare()
diff --git a/test/test_unitary/test_rmaxstab/test_rmaxstab_without_margin.py b/test/test_unitary/test_rmaxstab/test_rmaxstab_without_margin.py
index c97a0726..3d522b81 100644
--- a/test/test_unitary/test_rmaxstab/test_rmaxstab_without_margin.py
+++ b/test/test_unitary/test_rmaxstab/test_rmaxstab_without_margin.py
@@ -37,7 +37,7 @@ class TestRMaxStab(TestUnitaryAbstract):
         coordinates, max_stable_model = self.python_code()
         m = MaxStableAnnualMaxima.from_sampling(nb_obs=40, max_stable_model=max_stable_model, coordinates=coordinates)
         # TODO: understand why the array are not in the same order
-        return np.sum(m.maxima_frech)
+        return np.sum(m.maxima_frech())
 
     def test_rmaxstab(self):
         self.compare()
-- 
GitLab