from typing import List import pandas as pd from spatio_temporal_dataset.slicer.abstract_slicer import AbstractSlicer from spatio_temporal_dataset.slicer.split import Split class SpatioTemporalSlicer(AbstractSlicer): SPLITS = [Split.train_spatiotemporal, Split.test_spatiotemporal, Split.test_spatiotemporal_spatial, Split.test_spatiotemporal_temporal] @property def splits(self) -> List[Split]: return self.SPLITS @property def train_split(self) -> Split: return Split.train_spatiotemporal @property def test_split(self) -> Split: return Split.test_spatiotemporal @property def some_required_ind_are_not_defined(self): return self.index_train_ind is None or self.column_train_ind is None def specialized_loc_split(self, df: pd.DataFrame, split: Split): assert pd.Index.equals(df.columns, self.column_train_ind.index) assert pd.Index.equals(df.index, self.index_train_ind.index) if split is Split.train_spatiotemporal: return df.loc[self.index_train_ind, self.column_train_ind] elif split is Split.test_spatiotemporal: return df.loc[self.index_test_ind, self.column_test_ind] elif split is Split.test_spatiotemporal_spatial: return df.loc[self.index_test_ind, self.column_train_ind] elif split is Split.test_spatiotemporal_temporal: return df.loc[self.index_train_ind, self.column_test_ind]