From 354097c7590796478f500755e77b70b7faa1b97c Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Wed, 15 Apr 2020 19:49:00 +0200
Subject: [PATCH] [contrasting project] add test_two_fold_estimation.py

---
 .../altitunal_fit/two_fold_estimation.py      | 32 ++++++++++++++++--
 .../slicer/abstract_slicer.py                 |  2 +-
 spatio_temporal_dataset/slicer/split.py       |  7 ++++
 .../test_two_fold_estimation.py               | 33 +++++++++++++++++++
 .../test_slicer.py                            | 13 +++++++-
 5 files changed, 83 insertions(+), 4 deletions(-)
 create mode 100644 test/test_projects/test_contrasting/test_two_fold_estimation.py

diff --git a/projects/contrasting_trends_in_snow_loads/altitunal_fit/two_fold_estimation.py b/projects/contrasting_trends_in_snow_loads/altitunal_fit/two_fold_estimation.py
index 8ac122df..05b1ddf9 100644
--- a/projects/contrasting_trends_in_snow_loads/altitunal_fit/two_fold_estimation.py
+++ b/projects/contrasting_trends_in_snow_loads/altitunal_fit/two_fold_estimation.py
@@ -1,6 +1,34 @@
+from typing import Tuple, Dict, List
+
+from projects.contrasting_trends_in_snow_loads.altitunal_fit.altitudes_studies import AltitudesStudies
+from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
+from spatio_temporal_dataset.slicer.split import invert_s_split
+
 
 class TwoFoldEstimation(object):
 
+    def __init__(self, studies: AltitudesStudies, nb_samples):
+        self.studies = studies
+        self.nb_samples = nb_samples
+
+    @property
+    def massif_name_to_list_two_fold_datasets(self) -> Dict[str, List[Tuple[AbstractDataset, AbstractDataset]]]:
+        d = {}
+        for massif_name in self.studies.study.all_massif_names():
+            l = []
+            for _ in range(self.nb_samples):
+                # Append to the list
+                l.append(self.two_fold_datasets(massif_name))
+            d[massif_name] = l
+        return d
 
-    def __init__(self):
-        pass
+    def two_fold_datasets(self, massif_name: str) -> Tuple[AbstractDataset, AbstractDataset]:
+        # Create split for the 1st fold
+        s_split_temporal = self.studies.random_s_split_temporal(train_split_ratio=0.5)
+        dataset_fold_1 = self.studies.spatio_temporal_dataset(massif_name=massif_name,
+                                                              s_split_temporal=s_split_temporal)
+        # Invert the s_split for the 2nd fold
+        s_split_temporal_inverted = invert_s_split(s_split_temporal)
+        dataset_fold_2 = self.studies.spatio_temporal_dataset(massif_name=massif_name,
+                                                              s_split_temporal=s_split_temporal_inverted)
+        return dataset_fold_1, dataset_fold_2
diff --git a/spatio_temporal_dataset/slicer/abstract_slicer.py b/spatio_temporal_dataset/slicer/abstract_slicer.py
index a674a205..cb6d98a5 100644
--- a/spatio_temporal_dataset/slicer/abstract_slicer.py
+++ b/spatio_temporal_dataset/slicer/abstract_slicer.py
@@ -26,7 +26,7 @@ class AbstractSlicer(object):
         if split is Split.all:
             return df
 
-        assert split in self.splits, "split:{}, slicer_type:{}".format(split, type(self))
+        assert split in self.splits, "Split and slicer_type do not correspond:\nsplit:{}, slicer_type:{}".format(split, type(self))
 
         # By default, some required splits are not defined
         # instead of crashing, we return all the data for all the split
diff --git a/spatio_temporal_dataset/slicer/split.py b/spatio_temporal_dataset/slicer/split.py
index 8b2eea4f..584f8c4a 100644
--- a/spatio_temporal_dataset/slicer/split.py
+++ b/spatio_temporal_dataset/slicer/split.py
@@ -44,6 +44,13 @@ TRAIN_SPLIT_STR = 'train_split'
 TEST_SPLIT_STR = 'test_split'
 
 
+def invert_s_split(s_split):
+    ind = ind_train_from_s_split(s_split)
+    s_split.loc[ind] = TEST_SPLIT_STR
+    s_split.loc[~ind] = TRAIN_SPLIT_STR
+    return s_split
+
+
 def ind_train_from_s_split(s_split):
     if s_split is None:
         return None
diff --git a/test/test_projects/test_contrasting/test_two_fold_estimation.py b/test/test_projects/test_contrasting/test_two_fold_estimation.py
new file mode 100644
index 00000000..8af825f7
--- /dev/null
+++ b/test/test_projects/test_contrasting/test_two_fold_estimation.py
@@ -0,0 +1,33 @@
+import unittest
+import numpy as np
+
+from extreme_data.meteo_france_data.scm_models_data.safran.safran import SafranSnowfall1Day
+from projects.contrasting_trends_in_snow_loads.altitunal_fit.altitudes_studies import AltitudesStudies
+from projects.contrasting_trends_in_snow_loads.altitunal_fit.two_fold_estimation import TwoFoldEstimation
+from spatio_temporal_dataset.slicer.split import Split
+
+
+class TestAltitudesStudies(unittest.TestCase):
+
+    def setUp(self) -> None:
+        super().setUp()
+        altitudes = [900, 1200]
+        study_class = SafranSnowfall1Day
+        studies = AltitudesStudies(study_class, altitudes, year_min=1959, year_max=1962)
+        self.two_fold_estimation = TwoFoldEstimation(studies, nb_samples=2)
+
+    def test_dataset_sizes(self):
+        dataset1, dataset2 = self.two_fold_estimation.two_fold_datasets('Vercors')
+        np.testing.assert_equal(dataset1.maxima_gev(Split.train_temporal), dataset2.maxima_gev(Split.test_temporal))
+        np.testing.assert_equal(dataset1.maxima_gev(Split.test_temporal), dataset2.maxima_gev(Split.train_temporal))
+
+    def test_crash(self):
+        dataset1, _ = self.two_fold_estimation.two_fold_datasets('Vercors')
+        with self.assertRaises(AssertionError):
+            dataset1.maxima_gev(split=Split.train_spatiotemporal)
+        with self.assertRaises(AssertionError):
+            dataset1.maxima_gev(split=Split.train_spatial)
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/test/test_spatio_temporal_dataset/test_slicer.py b/test/test_spatio_temporal_dataset/test_slicer.py
index 0af73b92..119f667c 100644
--- a/test/test_spatio_temporal_dataset/test_slicer.py
+++ b/test/test_spatio_temporal_dataset/test_slicer.py
@@ -1,3 +1,5 @@
+import numpy as np
+import pandas as pd
 from typing import List
 
 import unittest
@@ -6,11 +8,20 @@ from extreme_fit.model.margin_model.linear_margin_model.linear_margin_model impo
 from extreme_fit.model.max_stable_model.max_stable_models import Smith
 from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
 from spatio_temporal_dataset.dataset.simulation_dataset import FullSimulatedDataset
-from spatio_temporal_dataset.slicer.split import ALL_SPLITS_EXCEPT_ALL, Split
+from spatio_temporal_dataset.slicer.split import ALL_SPLITS_EXCEPT_ALL, Split, small_s_split_from_ratio, invert_s_split
 from test.test_utils import load_test_1D_and_2D_spatial_coordinates, load_test_spatiotemporal_coordinates, \
     load_test_temporal_coordinates
 
 
+class TestSplitFunctions(unittest.TestCase):
+
+    def test_inversion(self):
+        index = pd.Index([0, 1])
+        s_split = small_s_split_from_ratio(index=index, train_split_ratio=0.5)
+        inverted_s_split = invert_s_split(s_split.copy())
+        np.testing.assert_equal(inverted_s_split.iloc[::-1].values, s_split.values)
+
+
 class TestSlicerForDataset(unittest.TestCase):
 
     def __init__(self, methodName: str = ...) -> None:
-- 
GitLab