-
Le Roux Erwan authorede5bdf83b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from typing import Union, List
import pandas as pd
from spatio_temporal_dataset.slicer.split import Split
class AbstractSlicer(object):
def __init__(self, ind_train_spatial: Union[None, pd.Series], ind_train_temporal: Union[None, pd.Series]):
self.ind_train_spatial = ind_train_spatial # type: Union[None, pd.Series]
self.ind_train_temporal = ind_train_temporal # type: Union[None, pd.Series]
@property
def ind_test_spatial(self) -> pd.Series:
return ~self.ind_train_spatial
@property
def ind_test_temporal(self) -> pd.Series:
return ~self.ind_train_temporal
def loc_split(self, df: pd.DataFrame, split: Split) -> pd.DataFrame:
# split should belong to the list of split accepted by the slicer
assert isinstance(split, Split)
if split is Split.all:
return df
assert split in self.splits, "Split and slicer_type do not correspond:\nsplit:{}, slicer_type:{}".format(split, type(self))
# By default, some required splits are not defined
# instead of crashing, we return all the data for all the split
# This is the default behavior, when the required splits has been defined
if self.some_required_ind_are_not_defined:
return df
else:
return self.specialized_loc_split(df=df, split=split)
def summary(self, show=True):
msg = ''
for s, global_name in [(self.ind_train_spatial, "Spatial"), (self.ind_train_temporal, "Temporal")]:
msg += global_name + ': '
if s is None:
msg += 'Not handled by this slicer'
else:
for f, name in [(len, 'Total'), (sum, 'train')]:
msg += "{}: {} ".format(name, f(s))
msg += ' / '
if show:
print(msg)
return msg
# Methods that need to be defined in the child class
def specialized_loc_split(self, df: pd.DataFrame, split: Split) -> pd.DataFrame:
raise NotImplementedError
@property
def some_required_ind_are_not_defined(self) -> bool:
raise NotImplementedError
@property
def train_split(self) -> Split:
raise NotImplementedError
@property
def test_split(self) -> Split:
raise NotImplementedError
@property
def splits(self) -> List[Split]:
raise NotImplementedError
def df_sliced(df: pd.DataFrame, split: Split = Split.all, slicer: AbstractSlicer = None) -> pd.DataFrame:
if slicer is None:
assert split is Split.all
return df
else:
return slicer.loc_split(df, split)