diff --git a/spatio_temporal_dataset/coordinates/abstract_coordinates.py b/spatio_temporal_dataset/coordinates/abstract_coordinates.py index ac16e24cbd3a322afacdd4a39896cf2e2bb534e8..940f097a143c9c33b108e219acdba01b5843360b 100644 --- a/spatio_temporal_dataset/coordinates/abstract_coordinates.py +++ b/spatio_temporal_dataset/coordinates/abstract_coordinates.py @@ -30,9 +30,18 @@ class AbstractCoordinates(object): # Coordinates columns COORDINATES_NAMES = COORDINATE_SPATIAL_NAMES + [COORDINATE_T] - def __init__(self, df_coord: pd.DataFrame, slicer_class: type, s_split_spatial: pd.Series = None, + def __init__(self, df: pd.DataFrame, slicer_class: type, s_split_spatial: pd.Series = None, s_split_temporal: pd.Series = None): - self.df_all_coordinates = df_coord # type: pd.DataFrame + # Extract df_all_coordinates from df + coordinate_columns = [c for c in df.columns if c in self.COORDINATES_NAMES] + assert len(coordinate_columns) > 0 + self.df_all_coordinates = df.loc[:, coordinate_columns].copy() # type: pd.DataFrame + # Check the data type of the coordinate columns + accepted_dtypes = ['float64', 'int64'] + assert len(self.df_all_coordinates.select_dtypes(include=accepted_dtypes).columns) == len(coordinate_columns), \ + 'coordinates columns dtypes should belong to {}'.format(accepted_dtypes) + + # Slicing attributes self.s_split_spatial = s_split_spatial # type: pd.Series self.s_split_temporal = s_split_temporal # type: pd.Series self.slicer = None # type: AbstractSlicer @@ -51,12 +60,7 @@ class AbstractCoordinates(object): @classmethod def from_df(cls, df: pd.DataFrame): - # Extract df_coordinate - coordinate_columns = [c for c in df.columns if c in cls.COORDINATES_NAMES] - df_coord = df.loc[:, coordinate_columns].copy() - - # Extract the split - split_columns = [c for c in df.columns if c in [cls.SPATIAL_SPLIT, cls.TEMPORAL_SPLIT]] + # Extract the split if they are specified s_split_spatial = df[cls.SPATIAL_SPLIT].copy() if cls.SPATIAL_SPLIT in df.columns else None s_split_temporal = df[cls.TEMPORAL_SPLIT].copy() if cls.TEMPORAL_SPLIT in df.columns else None @@ -70,11 +74,7 @@ class AbstractCoordinates(object): else: slicer_class = SpatioTemporalSlicer - # Remove all the columns used from df - columns_used = coordinate_columns + split_columns - df.drop(columns_used, axis=1, inplace=True) - return cls(df_coord=df_coord, slicer_class=slicer_class, - s_split_spatial=s_split_spatial, s_split_temporal=s_split_temporal) + return cls(df=df, slicer_class=slicer_class, s_split_spatial=s_split_spatial, s_split_temporal=s_split_temporal) @classmethod def from_df_and_slicer(cls, df: pd.DataFrame, slicer_class: type, train_split_ratio: float = None): @@ -86,8 +86,7 @@ class AbstractCoordinates(object): # Create a temporal split s_split_temporal = s_split_from_df(df, cls.COORDINATE_T, cls.TEMPORAL_SPLIT, train_split_ratio, False) - return cls(df_coord=df, slicer_class=slicer_class, - s_split_spatial=s_split_spatial, s_split_temporal=s_split_temporal) + return cls(df=df, slicer_class=slicer_class, s_split_spatial=s_split_spatial, s_split_temporal=s_split_temporal) @classmethod def from_csv(cls, csv_path: str = None): diff --git a/spatio_temporal_dataset/coordinates/transformed_coordinates/transformed_coordinates.py b/spatio_temporal_dataset/coordinates/transformed_coordinates/transformed_coordinates.py index 28887ad6c81ecc1a2b988d648ae2d4b012ddedf8..3c4e1264d1e112002ec58e8bcd703d954b73bf0e 100644 --- a/spatio_temporal_dataset/coordinates/transformed_coordinates/transformed_coordinates.py +++ b/spatio_temporal_dataset/coordinates/transformed_coordinates/transformed_coordinates.py @@ -10,7 +10,7 @@ class TransformedCoordinates(AbstractCoordinates): transformation_function: AbstractTransformation): df_coordinates_transformed = coordinates.df_all_coordinates.copy() df_coordinates_transformed = transformation_function.transform(df_coord=df_coordinates_transformed) - return cls(df_coord=df_coordinates_transformed, slicer_class=type(coordinates.slicer), + return cls(df=df_coordinates_transformed, slicer_class=type(coordinates.slicer), s_split_spatial=coordinates.s_split_spatial, s_split_temporal=coordinates.s_split_temporal)