diff --git a/projects/contrasting_trends_in_snow_loads/altitunal_fit/altitudes_studies.py b/projects/contrasting_trends_in_snow_loads/altitunal_fit/altitudes_studies.py index 308394641054f3cbc455cc3d0b6f3b4a89f9deb2..aad06c85267ce31a038a25cbba67cdcff4cb70b0 100644 --- a/projects/contrasting_trends_in_snow_loads/altitunal_fit/altitudes_studies.py +++ b/projects/contrasting_trends_in_snow_loads/altitunal_fit/altitudes_studies.py @@ -4,6 +4,7 @@ from collections import OrderedDict from cached_property import cached_property from extreme_data.meteo_france_data.scm_models_data.abstract_study import AbstractStudy +from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates from spatio_temporal_dataset.coordinates.spatial_coordinates.abstract_spatial_coordinates import \ AbstractSpatialCoordinates from spatio_temporal_dataset.coordinates.spatio_temporal_coordinates.abstract_spatio_temporal_coordinates import \ @@ -15,8 +16,11 @@ from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset class AltitudesStudies(object): - def __init__(self, study_class, altitudes, transformation_class=None, **kwargs_study): - self.transformation_class = transformation_class + def __init__(self, study_class, altitudes, + spatial_transformation_class=None, temporal_transformation_class=None, + **kwargs_study): + self.spatial_transformation_class = spatial_transformation_class + self.temporal_transformation_class = temporal_transformation_class self.altitudes = altitudes self.altitude_to_study = OrderedDict() for altitude in self.altitudes: @@ -36,23 +40,30 @@ class AltitudesStudies(object): def temporal_coordinates(self): return ConsecutiveTemporalCoordinates.from_nb_temporal_steps(nb_temporal_steps=self.study.nb_years, start=self.study.year_min, - transformation_class=self.transformation_class) + transformation_class=self.spatial_transformation_class) @cached_property def spatial_coordinates(self): return AbstractSpatialCoordinates.from_list_x_coordinates(x_coordinates=self.altitudes, - transformation_class=self.transformation_class) + transformation_class=self.temporal_transformation_class) + + @cached_property + def _df_coordinates(self): + return AbstractSpatioTemporalCoordinates.get_df_from_spatial_and_temporal_coordinates(self.spatial_coordinates, + self.temporal_coordinates) + + def random_s_split_spatial(self, train_split_ratio): + return AbstractCoordinates.spatial_s_split_from_df(self._df_coordinates, train_split_ratio) def random_s_split_temporal(self, train_split_ratio): - return AbstractSpatioTemporalCoordinates.get_random_s_split_temporal( - spatial_coordinates=self.spatial_coordinates, - temporal_coordinates=self.temporal_coordinates, - train_split_ratio=train_split_ratio) - - def spatio_temporal_coordinates(self, slicer_class: type, s_split_spatial: pd.Series = None, - s_split_temporal: pd.Series = None): - return AbstractSpatioTemporalCoordinates(slicer_class=slicer_class, s_split_spatial=s_split_spatial, + return AbstractCoordinates.temporal_s_split_from_df(self._df_coordinates, train_split_ratio) + + def spatio_temporal_coordinates(self, s_split_spatial: pd.Series = None, s_split_temporal: pd.Series = None): + slicer_class = AbstractCoordinates.slicer_class_from_s_splits(s_split_spatial=s_split_spatial, + s_split_temporal=s_split_temporal) + return AbstractSpatioTemporalCoordinates(slicer_class=slicer_class, + s_split_spatial=s_split_spatial, s_split_temporal=s_split_temporal, - transformation_class=self.transformation_class, + transformation_class=self.spatial_transformation_class, spatial_coordinates=self.spatial_coordinates, temporal_coordinates=self.temporal_coordinates) diff --git a/spatio_temporal_dataset/coordinates/abstract_coordinates.py b/spatio_temporal_dataset/coordinates/abstract_coordinates.py index ae69b412c953f4b82a3d7a9cfc6298a39907f7c4..7fde3c4ab682c25892b096bed94180167c89b9ed 100644 --- a/spatio_temporal_dataset/coordinates/abstract_coordinates.py +++ b/spatio_temporal_dataset/coordinates/abstract_coordinates.py @@ -87,6 +87,12 @@ class AbstractCoordinates(object): s_split_spatial = df[cls.SPATIAL_SPLIT].copy() if cls.SPATIAL_SPLIT in df.columns else None s_split_temporal = df[cls.TEMPORAL_SPLIT].copy() if cls.TEMPORAL_SPLIT in df.columns else None + slicer_class = cls.slicer_class_from_s_splits(s_split_spatial, s_split_temporal) + + return cls(df=df, slicer_class=slicer_class, s_split_spatial=s_split_spatial, s_split_temporal=s_split_temporal) + + @classmethod + def slicer_class_from_s_splits(cls, s_split_spatial, s_split_temporal): # Infer the slicer class if s_split_temporal is None and s_split_spatial is None: raise ValueError('Both split are unspecified') @@ -96,8 +102,7 @@ class AbstractCoordinates(object): slicer_class = SpatialSlicer else: slicer_class = SpatioTemporalSlicer - - return cls(df=df, slicer_class=slicer_class, s_split_spatial=s_split_spatial, s_split_temporal=s_split_temporal) + return slicer_class @classmethod def from_df_and_slicer(cls, df: pd.DataFrame, slicer_class: type, train_split_ratio: float = None, @@ -106,13 +111,17 @@ class AbstractCoordinates(object): assert len(set(df.index)) == len(df), 'df indices are not unique' # Create a spatial split - s_split_spatial = s_split_from_df(df, cls.COORDINATE_X, cls.SPATIAL_SPLIT, train_split_ratio, True) + s_split_spatial = cls.spatial_s_split_from_df(df, train_split_ratio) # Create a temporal split s_split_temporal = cls.temporal_s_split_from_df(df, train_split_ratio) return cls(df=df, slicer_class=slicer_class, s_split_spatial=s_split_spatial, s_split_temporal=s_split_temporal, transformation_class=transformation_class) + @classmethod + def spatial_s_split_from_df(cls, df, train_split_ratio): + return s_split_from_df(df, cls.COORDINATE_X, cls.SPATIAL_SPLIT, train_split_ratio, True) + @classmethod def temporal_s_split_from_df(cls, df, train_split_ratio): return s_split_from_df(df, cls.COORDINATE_T, cls.TEMPORAL_SPLIT, train_split_ratio, False) diff --git a/test/test_projects/test_contrasting/test_altitudes_studies.py b/test/test_projects/test_contrasting/test_altitudes_studies.py index be3fe72b6f9d4fbae7184fadac582f492a66af89..b943fa66ce9c7c9f1eaa46f3ca666d3a4e8722c7 100644 --- a/test/test_projects/test_contrasting/test_altitudes_studies.py +++ b/test/test_projects/test_contrasting/test_altitudes_studies.py @@ -2,6 +2,7 @@ import unittest from extreme_data.meteo_france_data.scm_models_data.safran.safran import SafranSnowfall1Day from projects.contrasting_trends_in_snow_loads.altitunal_fit.altitudes_studies import AltitudesStudies +from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates from spatio_temporal_dataset.slicer.split import Split, small_s_split_from_ratio from spatio_temporal_dataset.slicer.temporal_slicer import TemporalSlicer import pandas as pd @@ -15,15 +16,33 @@ class TestAltitudesStudies(unittest.TestCase): study_class = SafranSnowfall1Day self.studies = AltitudesStudies(study_class, altitudes, year_min=1959, year_max=1962) - def test_spatio_temporal_coordinates_with_temporal_split(self): + +class TestSpatioTemporalCoordinates(TestAltitudesStudies): + + def test_temporal_split(self): s_split_temporal = self.studies.random_s_split_temporal(train_split_ratio=0.75) - coordinates = self.studies.spatio_temporal_coordinates(slicer_class=TemporalSlicer, + coordinates = self.studies.spatio_temporal_coordinates(s_split_temporal=s_split_temporal) + self.assertEqual(coordinates.coordinates_values(split=Split.train_temporal).shape, (6, 2)) + self.assertEqual(coordinates.coordinates_values(split=Split.test_temporal).shape, (2, 2)) + + def test_spatial_split(self): + s_split_spatial = self.studies.random_s_split_spatial(train_split_ratio=0.5) + coordinates = self.studies.spatio_temporal_coordinates(s_split_spatial=s_split_spatial) + self.assertEqual(coordinates.coordinates_values(split=Split.train_spatial).shape, (4, 2)) + self.assertEqual(coordinates.coordinates_values(split=Split.test_spatial).shape, (4, 2)) + + def test_spatio_temporal_split(self): + s_split_spatial = self.studies.random_s_split_spatial(train_split_ratio=0.5) + s_split_temporal = self.studies.random_s_split_temporal(train_split_ratio=0.75) + coordinates = self.studies.spatio_temporal_coordinates(s_split_spatial=s_split_spatial, s_split_temporal=s_split_temporal) - train_values = coordinates.coordinates_values(split=Split.train_temporal) - self.assertEqual(train_values.shape, (6, 2)) - test_values = coordinates.coordinates_values(split=Split.test_temporal) - self.assertEqual(test_values.shape, (2, 2)) + self.assertEqual(coordinates.coordinates_values(split=Split.train_spatiotemporal).shape, (3, 2)) + self.assertEqual(coordinates.coordinates_values(split=Split.test_spatiotemporal_spatial).shape, (3, 2)) + self.assertEqual(coordinates.coordinates_values(split=Split.test_spatiotemporal_temporal).shape, (1, 2)) + self.assertEqual(coordinates.coordinates_values(split=Split.test_spatiotemporal).shape, (1, 2)) + + if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main()