from typing import List, Union

import pandas as pd

from spatio_temporal_dataset.slicer.abstract_slicer import AbstractSlicer
from spatio_temporal_dataset.slicer.split import Split


class SpatialSlicer(AbstractSlicer):
    SPLITS = [Split.train_spatial, Split.test_spatial]

    def __init__(self, coordinates_train_ind: Union[None, pd.Series], observations_train_ind: Union[None, pd.Series]):
        super().__init__(coordinates_train_ind, None)

    @property
    def splits(self) -> List[Split]:
        return self.SPLITS

    @property
    def train_split(self) -> Split:
        return Split.train_spatial

    @property
    def test_split(self) -> Split:
        return Split.test_spatial

    @property
    def some_required_ind_are_not_defined(self):
        return self.index_train_ind is None

    def specialized_loc_split(self, df: pd.DataFrame, split: Split):
        assert pd.Index.equals(df.index, self.index_train_ind.index)
        if split is Split.train_spatial:
            return df.loc[self.index_train_ind, :]
        elif split is Split.test_spatial:
            return df.loc[self.index_test_ind, :]