from enum import Enum from typing import Union import pandas as pd class Split(Enum): all = 0 # SpatioTemporal splits train_spatiotemporal = 1 test_spatiotemporal = 2 test_spatiotemporal_spatial = 3 test_spatiotemporal_temporal = 4 # Spatial splits train_spatial = 5 test_spatial = 6 # Temporal splits train_temporal = 7 test_temporal = 8 def split_to_display_kwargs(split: Split): marker = None gridsize = 1000 if 'train' in split.name: linewidth = 0.5 else: linewidth = 2 if 'spatiotemporal' in split.name: gridsize = 20 if 'spatial' in split.name and 'temporal' in split.name: marker = '*' elif 'spatial' in split.name: marker = '^' else: marker = '>' return {'marker': marker, 'linewidth': linewidth, 'gridsize':gridsize} ALL_SPLITS_EXCEPT_ALL = [split for split in Split if split is not Split.all] SPLIT_NAME = 'split' TRAIN_SPLIT_STR = 'train_split' TEST_SPLIT_STR = 'test_split' def ind_train_from_s_split(s_split): if s_split is None: return None else: return s_split.isin([TRAIN_SPLIT_STR]) def small_s_split_from_ratio(index: pd.Index, train_split_ratio): length = len(index) assert 0 < train_split_ratio < 1 s = pd.Series(TEST_SPLIT_STR, index=index) nb_points_train = int(length * train_split_ratio) assert 0 < nb_points_train < length train_ind = pd.Series.sample(s, n=nb_points_train).index assert 0 < len(train_ind) < length, "number of training points:{} length:{}".format(len(train_ind), length) s.loc[train_ind] = TRAIN_SPLIT_STR return s def s_split_from_df(df: pd.DataFrame, column, split_column, train_split_ratio, spatial_split) -> Union[None, pd.Series]: df = df.copy() # type: pd.DataFrame # Extract the index if train_split_ratio is None: return None if column not in df: return None elif split_column in df: raise Exception('A split has already been defined') else: serie = df.drop_duplicates(subset=[column], keep='first')[column] assert len(df) % len(serie) == 0 multiplication_factor = len(df) // len(serie) small_s_split = small_s_split_from_ratio(serie.index, train_split_ratio) if spatial_split: # concatenation for spatial_split s_split = pd.concat([small_s_split for _ in range(multiplication_factor)], ignore_index=True).copy() else: # dilatjon for the temporal split s_split = pd.Series(None, index=df.index) for i in range(len(s_split)): s_split.iloc[i] = small_s_split.iloc[i // multiplication_factor] s_split.index = df.index return s_split