From 38b8c43330e8f3f9853197a8c297b8a0377fc5b8 Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Mon, 14 Jan 2019 19:08:02 +0100
Subject: [PATCH] [COORDINATES] refactoring. add tests.

---
 .../coordinates/abstract_coordinates.py       | 29 +++++++++----------
 .../transformed_coordinates.py                |  2 +-
 2 files changed, 15 insertions(+), 16 deletions(-)

diff --git a/spatio_temporal_dataset/coordinates/abstract_coordinates.py b/spatio_temporal_dataset/coordinates/abstract_coordinates.py
index ac16e24c..940f097a 100644
--- a/spatio_temporal_dataset/coordinates/abstract_coordinates.py
+++ b/spatio_temporal_dataset/coordinates/abstract_coordinates.py
@@ -30,9 +30,18 @@ class AbstractCoordinates(object):
     # Coordinates columns
     COORDINATES_NAMES = COORDINATE_SPATIAL_NAMES + [COORDINATE_T]
 
-    def __init__(self, df_coord: pd.DataFrame, slicer_class: type, s_split_spatial: pd.Series = None,
+    def __init__(self, df: pd.DataFrame, slicer_class: type, s_split_spatial: pd.Series = None,
                  s_split_temporal: pd.Series = None):
-        self.df_all_coordinates = df_coord  # type: pd.DataFrame
+        # Extract df_all_coordinates from df
+        coordinate_columns = [c for c in df.columns if c in self.COORDINATES_NAMES]
+        assert len(coordinate_columns) > 0
+        self.df_all_coordinates = df.loc[:, coordinate_columns].copy()  # type: pd.DataFrame
+        # Check the data type of the coordinate columns
+        accepted_dtypes = ['float64', 'int64']
+        assert len(self.df_all_coordinates.select_dtypes(include=accepted_dtypes).columns) == len(coordinate_columns), \
+            'coordinates columns dtypes should belong to {}'.format(accepted_dtypes)
+
+        # Slicing attributes
         self.s_split_spatial = s_split_spatial  # type: pd.Series
         self.s_split_temporal = s_split_temporal  # type: pd.Series
         self.slicer = None  # type: AbstractSlicer
@@ -51,12 +60,7 @@ class AbstractCoordinates(object):
 
     @classmethod
     def from_df(cls, df: pd.DataFrame):
-        # Extract df_coordinate
-        coordinate_columns = [c for c in df.columns if c in cls.COORDINATES_NAMES]
-        df_coord = df.loc[:, coordinate_columns].copy()
-
-        # Extract the split
-        split_columns = [c for c in df.columns if c in [cls.SPATIAL_SPLIT, cls.TEMPORAL_SPLIT]]
+        # Extract the split if they are specified
         s_split_spatial = df[cls.SPATIAL_SPLIT].copy() if cls.SPATIAL_SPLIT in df.columns else None
         s_split_temporal = df[cls.TEMPORAL_SPLIT].copy() if cls.TEMPORAL_SPLIT in df.columns else None
 
@@ -70,11 +74,7 @@ class AbstractCoordinates(object):
         else:
             slicer_class = SpatioTemporalSlicer
 
-        # Remove all the columns used from df
-        columns_used = coordinate_columns + split_columns
-        df.drop(columns_used, axis=1, inplace=True)
-        return cls(df_coord=df_coord, slicer_class=slicer_class,
-                   s_split_spatial=s_split_spatial, s_split_temporal=s_split_temporal)
+        return cls(df=df, slicer_class=slicer_class, s_split_spatial=s_split_spatial, s_split_temporal=s_split_temporal)
 
     @classmethod
     def from_df_and_slicer(cls, df: pd.DataFrame, slicer_class: type, train_split_ratio: float = None):
@@ -86,8 +86,7 @@ class AbstractCoordinates(object):
         # Create a temporal split
         s_split_temporal = s_split_from_df(df, cls.COORDINATE_T, cls.TEMPORAL_SPLIT, train_split_ratio, False)
 
-        return cls(df_coord=df, slicer_class=slicer_class,
-                   s_split_spatial=s_split_spatial, s_split_temporal=s_split_temporal)
+        return cls(df=df, slicer_class=slicer_class, s_split_spatial=s_split_spatial, s_split_temporal=s_split_temporal)
 
     @classmethod
     def from_csv(cls, csv_path: str = None):
diff --git a/spatio_temporal_dataset/coordinates/transformed_coordinates/transformed_coordinates.py b/spatio_temporal_dataset/coordinates/transformed_coordinates/transformed_coordinates.py
index 28887ad6..3c4e1264 100644
--- a/spatio_temporal_dataset/coordinates/transformed_coordinates/transformed_coordinates.py
+++ b/spatio_temporal_dataset/coordinates/transformed_coordinates/transformed_coordinates.py
@@ -10,7 +10,7 @@ class TransformedCoordinates(AbstractCoordinates):
                          transformation_function: AbstractTransformation):
         df_coordinates_transformed = coordinates.df_all_coordinates.copy()
         df_coordinates_transformed = transformation_function.transform(df_coord=df_coordinates_transformed)
-        return cls(df_coord=df_coordinates_transformed, slicer_class=type(coordinates.slicer),
+        return cls(df=df_coordinates_transformed, slicer_class=type(coordinates.slicer),
                    s_split_spatial=coordinates.s_split_spatial, s_split_temporal=coordinates.s_split_temporal)
 
 
-- 
GitLab