From d61e48a92d184fad21bd780748361d570a8a3794 Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Wed, 15 Apr 2020 18:11:48 +0200
Subject: [PATCH] [contrasting project] add test for the spatio temporal
 coordinates for several splits

---
 .../altitunal_fit/altitudes_studies.py        | 37 ++++++++++++-------
 .../coordinates/abstract_coordinates.py       | 15 ++++++--
 .../test_altitudes_studies.py                 | 33 +++++++++++++----
 3 files changed, 62 insertions(+), 23 deletions(-)

diff --git a/projects/contrasting_trends_in_snow_loads/altitunal_fit/altitudes_studies.py b/projects/contrasting_trends_in_snow_loads/altitunal_fit/altitudes_studies.py
index 30839464..aad06c85 100644
--- a/projects/contrasting_trends_in_snow_loads/altitunal_fit/altitudes_studies.py
+++ b/projects/contrasting_trends_in_snow_loads/altitunal_fit/altitudes_studies.py
@@ -4,6 +4,7 @@ from collections import OrderedDict
 from cached_property import cached_property
 
 from extreme_data.meteo_france_data.scm_models_data.abstract_study import AbstractStudy
+from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
 from spatio_temporal_dataset.coordinates.spatial_coordinates.abstract_spatial_coordinates import \
     AbstractSpatialCoordinates
 from spatio_temporal_dataset.coordinates.spatio_temporal_coordinates.abstract_spatio_temporal_coordinates import \
@@ -15,8 +16,11 @@ from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
 
 class AltitudesStudies(object):
 
-    def __init__(self, study_class, altitudes, transformation_class=None, **kwargs_study):
-        self.transformation_class = transformation_class
+    def __init__(self, study_class, altitudes,
+                 spatial_transformation_class=None, temporal_transformation_class=None,
+                 **kwargs_study):
+        self.spatial_transformation_class = spatial_transformation_class
+        self.temporal_transformation_class = temporal_transformation_class
         self.altitudes = altitudes
         self.altitude_to_study = OrderedDict()
         for altitude in self.altitudes:
@@ -36,23 +40,30 @@ class AltitudesStudies(object):
     def temporal_coordinates(self):
         return ConsecutiveTemporalCoordinates.from_nb_temporal_steps(nb_temporal_steps=self.study.nb_years,
                                                                      start=self.study.year_min,
-                                                                     transformation_class=self.transformation_class)
+                                                                     transformation_class=self.spatial_transformation_class)
 
     @cached_property
     def spatial_coordinates(self):
         return AbstractSpatialCoordinates.from_list_x_coordinates(x_coordinates=self.altitudes,
-                                                                  transformation_class=self.transformation_class)
+                                                                  transformation_class=self.temporal_transformation_class)
+
+    @cached_property
+    def _df_coordinates(self):
+        return AbstractSpatioTemporalCoordinates.get_df_from_spatial_and_temporal_coordinates(self.spatial_coordinates,
+                                                                                              self.temporal_coordinates)
+
+    def random_s_split_spatial(self, train_split_ratio):
+        return AbstractCoordinates.spatial_s_split_from_df(self._df_coordinates, train_split_ratio)
 
     def random_s_split_temporal(self, train_split_ratio):
-        return AbstractSpatioTemporalCoordinates.get_random_s_split_temporal(
-            spatial_coordinates=self.spatial_coordinates,
-            temporal_coordinates=self.temporal_coordinates,
-            train_split_ratio=train_split_ratio)
-
-    def spatio_temporal_coordinates(self, slicer_class: type, s_split_spatial: pd.Series = None,
-                                    s_split_temporal: pd.Series = None):
-        return AbstractSpatioTemporalCoordinates(slicer_class=slicer_class, s_split_spatial=s_split_spatial,
+        return AbstractCoordinates.temporal_s_split_from_df(self._df_coordinates, train_split_ratio)
+
+    def spatio_temporal_coordinates(self, s_split_spatial: pd.Series = None, s_split_temporal: pd.Series = None):
+        slicer_class = AbstractCoordinates.slicer_class_from_s_splits(s_split_spatial=s_split_spatial,
+                                                                      s_split_temporal=s_split_temporal)
+        return AbstractSpatioTemporalCoordinates(slicer_class=slicer_class,
+                                                 s_split_spatial=s_split_spatial,
                                                  s_split_temporal=s_split_temporal,
-                                                 transformation_class=self.transformation_class,
+                                                 transformation_class=self.spatial_transformation_class,
                                                  spatial_coordinates=self.spatial_coordinates,
                                                  temporal_coordinates=self.temporal_coordinates)
diff --git a/spatio_temporal_dataset/coordinates/abstract_coordinates.py b/spatio_temporal_dataset/coordinates/abstract_coordinates.py
index ae69b412..7fde3c4a 100644
--- a/spatio_temporal_dataset/coordinates/abstract_coordinates.py
+++ b/spatio_temporal_dataset/coordinates/abstract_coordinates.py
@@ -87,6 +87,12 @@ class AbstractCoordinates(object):
         s_split_spatial = df[cls.SPATIAL_SPLIT].copy() if cls.SPATIAL_SPLIT in df.columns else None
         s_split_temporal = df[cls.TEMPORAL_SPLIT].copy() if cls.TEMPORAL_SPLIT in df.columns else None
 
+        slicer_class = cls.slicer_class_from_s_splits(s_split_spatial, s_split_temporal)
+
+        return cls(df=df, slicer_class=slicer_class, s_split_spatial=s_split_spatial, s_split_temporal=s_split_temporal)
+
+    @classmethod
+    def slicer_class_from_s_splits(cls, s_split_spatial, s_split_temporal):
         # Infer the slicer class
         if s_split_temporal is None and s_split_spatial is None:
             raise ValueError('Both split are unspecified')
@@ -96,8 +102,7 @@ class AbstractCoordinates(object):
             slicer_class = SpatialSlicer
         else:
             slicer_class = SpatioTemporalSlicer
-
-        return cls(df=df, slicer_class=slicer_class, s_split_spatial=s_split_spatial, s_split_temporal=s_split_temporal)
+        return slicer_class
 
     @classmethod
     def from_df_and_slicer(cls, df: pd.DataFrame, slicer_class: type, train_split_ratio: float = None,
@@ -106,13 +111,17 @@ class AbstractCoordinates(object):
         assert len(set(df.index)) == len(df), 'df indices are not unique'
 
         # Create a spatial split
-        s_split_spatial = s_split_from_df(df, cls.COORDINATE_X, cls.SPATIAL_SPLIT, train_split_ratio, True)
+        s_split_spatial = cls.spatial_s_split_from_df(df, train_split_ratio)
         # Create a temporal split
         s_split_temporal = cls.temporal_s_split_from_df(df, train_split_ratio)
 
         return cls(df=df, slicer_class=slicer_class, s_split_spatial=s_split_spatial, s_split_temporal=s_split_temporal,
                    transformation_class=transformation_class)
 
+    @classmethod
+    def spatial_s_split_from_df(cls, df, train_split_ratio):
+        return s_split_from_df(df, cls.COORDINATE_X, cls.SPATIAL_SPLIT, train_split_ratio, True)
+
     @classmethod
     def temporal_s_split_from_df(cls, df, train_split_ratio):
         return s_split_from_df(df, cls.COORDINATE_T, cls.TEMPORAL_SPLIT, train_split_ratio, False)
diff --git a/test/test_projects/test_contrasting/test_altitudes_studies.py b/test/test_projects/test_contrasting/test_altitudes_studies.py
index be3fe72b..b943fa66 100644
--- a/test/test_projects/test_contrasting/test_altitudes_studies.py
+++ b/test/test_projects/test_contrasting/test_altitudes_studies.py
@@ -2,6 +2,7 @@ import unittest
 
 from extreme_data.meteo_france_data.scm_models_data.safran.safran import SafranSnowfall1Day
 from projects.contrasting_trends_in_snow_loads.altitunal_fit.altitudes_studies import AltitudesStudies
+from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
 from spatio_temporal_dataset.slicer.split import Split, small_s_split_from_ratio
 from spatio_temporal_dataset.slicer.temporal_slicer import TemporalSlicer
 import pandas as pd
@@ -15,15 +16,33 @@ class TestAltitudesStudies(unittest.TestCase):
         study_class = SafranSnowfall1Day
         self.studies = AltitudesStudies(study_class, altitudes, year_min=1959, year_max=1962)
 
-    def test_spatio_temporal_coordinates_with_temporal_split(self):
+
+class TestSpatioTemporalCoordinates(TestAltitudesStudies):
+
+    def test_temporal_split(self):
         s_split_temporal = self.studies.random_s_split_temporal(train_split_ratio=0.75)
-        coordinates = self.studies.spatio_temporal_coordinates(slicer_class=TemporalSlicer,
+        coordinates = self.studies.spatio_temporal_coordinates(s_split_temporal=s_split_temporal)
+        self.assertEqual(coordinates.coordinates_values(split=Split.train_temporal).shape, (6, 2))
+        self.assertEqual(coordinates.coordinates_values(split=Split.test_temporal).shape, (2, 2))
+
+    def test_spatial_split(self):
+        s_split_spatial = self.studies.random_s_split_spatial(train_split_ratio=0.5)
+        coordinates = self.studies.spatio_temporal_coordinates(s_split_spatial=s_split_spatial)
+        self.assertEqual(coordinates.coordinates_values(split=Split.train_spatial).shape, (4, 2))
+        self.assertEqual(coordinates.coordinates_values(split=Split.test_spatial).shape, (4, 2))
+
+    def test_spatio_temporal_split(self):
+        s_split_spatial = self.studies.random_s_split_spatial(train_split_ratio=0.5)
+        s_split_temporal = self.studies.random_s_split_temporal(train_split_ratio=0.75)
+        coordinates = self.studies.spatio_temporal_coordinates(s_split_spatial=s_split_spatial,
                                                                s_split_temporal=s_split_temporal)
-        train_values = coordinates.coordinates_values(split=Split.train_temporal)
-        self.assertEqual(train_values.shape, (6, 2))
-        test_values = coordinates.coordinates_values(split=Split.test_temporal)
-        self.assertEqual(test_values.shape, (2, 2))
+        self.assertEqual(coordinates.coordinates_values(split=Split.train_spatiotemporal).shape, (3, 2))
+        self.assertEqual(coordinates.coordinates_values(split=Split.test_spatiotemporal_spatial).shape, (3, 2))
+        self.assertEqual(coordinates.coordinates_values(split=Split.test_spatiotemporal_temporal).shape, (1, 2))
+        self.assertEqual(coordinates.coordinates_values(split=Split.test_spatiotemporal).shape, (1, 2))
+
+
 
 
 if __name__ == '__main__':
-    unittest.main()
\ No newline at end of file
+    unittest.main()
-- 
GitLab