from typing import Union, List

import pandas as pd

from spatio_temporal_dataset.slicer.split import Split


class AbstractSlicer(object):

    def __init__(self, coordinates_train_ind: Union[None, pd.Series], observations_train_ind: Union[None, pd.Series]):
        self.index_train_ind = coordinates_train_ind  # type: Union[None, pd.Series]
        self.column_train_ind = observations_train_ind  # type: Union[None, pd.Series]

    @property
    def train_split(self) -> Split:
        pass

    @property
    def test_split(self) -> Split:
        pass

    @property
    def splits(self) -> List[Split]:
        pass


    @property
    def index_test_ind(self) -> pd.Series:
        return ~self.index_train_ind

    # todo: test should be the same as train when we don't care about that in the split
    @property
    def column_test_ind(self) -> pd.Series:
        return ~self.column_train_ind

    @property
    def some_required_ind_are_not_defined(self):
        pass

    def summary(self):
        print('Slicer summary: \n')
        for s, global_name in [(self.index_train_ind, "Spatial"), (self.column_train_ind, "Temporal")]:
            print(global_name + ' split')
            if s is None:
                print('Not handled by this slicer')
            else:
                for f, name in [(len, 'Total'), (sum, 'train')]:
                    print("{}: {}".format(name, f(s)))
                print('\n')

    def loc_split(self, df: pd.DataFrame, split: Split):
        # split should belong to the list of split accepted by the slicer
        assert isinstance(split, Split)

        if split is Split.all:
            return df

        assert split in self.splits, "split:{}, slicer_type:{}".format(split, type(self))

        # By default, some required splits are not defined
        # instead of crashing, we return all the data for all the split
        # This is the default behavior, when the required splits has been defined
        if self.some_required_ind_are_not_defined:
            return df
        else:
            return self.specialized_loc_split(df=df, split=split)

    def specialized_loc_split(self, df: pd.DataFrame, split: Split):
        # This method should be defined in the child class
        return None


def slice(df: pd.DataFrame, split: Split = Split.all, slicer: AbstractSlicer = None) -> pd.DataFrame:
    if slicer is None:
        assert split is Split.all
        return df
    else:
        return slicer.loc_split(df, split)