From d30d12bd682f60f45a0c11578906bc1fa797514e Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Wed, 21 Nov 2018 12:58:47 +0100
Subject: [PATCH] [SPATIAL COORDINATES] refactor names

---
 .../abstract_spatial_coordinates.py           | 19 ++++++++++---------
 .../alps_station_2D_coordinates.py            |  2 +-
 .../alps_station_3D_coordinates.py            |  2 +-
 .../generated_coordinates.py                  |  2 +-
 4 files changed, 13 insertions(+), 12 deletions(-)

diff --git a/spatio_temporal_dataset/spatial_coordinates/abstract_spatial_coordinates.py b/spatio_temporal_dataset/spatial_coordinates/abstract_spatial_coordinates.py
index fe54ce00..05e55f27 100644
--- a/spatio_temporal_dataset/spatial_coordinates/abstract_spatial_coordinates.py
+++ b/spatio_temporal_dataset/spatial_coordinates/abstract_spatial_coordinates.py
@@ -9,9 +9,10 @@ from mpl_toolkits.mplot3d import Axes3D
 
 class AbstractSpatialCoordinates(object):
     # Columns
-    COORD_X = 'coord_x'
-    COORD_Y = 'coord_y'
-    COORD_Z = 'coord_z'
+    COORDINATE_X = 'coord_x'
+    COORDINATE_Y = 'coord_y'
+    COORDINATE_Z = 'coord_z'
+    COORDINATE_NAMES = [COORDINATE_X, COORDINATE_Y, COORDINATE_Z]
     COORD_SPLIT = 'coord_split'
     # Constants
     TRAIN_SPLIT_STR = 'train_split'
@@ -24,7 +25,7 @@ class AbstractSpatialCoordinates(object):
     @classmethod
     def from_df(cls, df: pd.DataFrame):
         #  X and Y coordinates must be defined
-        assert cls.COORD_X in df.columns and cls.COORD_Y in df.columns
+        assert cls.COORDINATE_X in df.columns and cls.COORDINATE_Y in df.columns
         df_coordinates = df.loc[:, cls.coordinates_columns(df)]
         # Potentially, a split column can be specified
         s_split = df[cls.COORD_SPLIT] if cls.COORD_SPLIT in df.columns else None
@@ -33,9 +34,9 @@ class AbstractSpatialCoordinates(object):
     @classmethod
     def coordinates_columns(cls, df_coord: pd.DataFrame) -> List[str]:
         # If a Z coordinate is in the DataFrame, then
-        coord_columns = [cls.COORD_X, cls.COORD_Y]
-        if cls.COORD_Z in df_coord.columns:
-            coord_columns.append(cls.COORD_Z)
+        coord_columns = [cls.COORDINATE_X, cls.COORDINATE_Y]
+        if cls.COORDINATE_Z in df_coord.columns:
+            coord_columns.append(cls.COORDINATE_Z)
         return coord_columns
 
     @property
@@ -80,11 +81,11 @@ class AbstractSpatialCoordinates(object):
 
     @property
     def x_coordinates(self) -> np.ndarray:
-        return self.df_coordinates.loc[:, self.COORD_X].values.copy()
+        return self.df_coordinates.loc[:, self.COORDINATE_X].values.copy()
 
     @property
     def y_coordinates(self) -> np.ndarray:
-        return self.df_coordinates.loc[:, self.COORD_Y].values.copy()
+        return self.df_coordinates.loc[:, self.COORDINATE_Y].values.copy()
 
     @property
     def coordinates_train(self) -> np.ndarray:
diff --git a/spatio_temporal_dataset/spatial_coordinates/alps_station_2D_coordinates.py b/spatio_temporal_dataset/spatial_coordinates/alps_station_2D_coordinates.py
index a85efb22..ae2586ea 100644
--- a/spatio_temporal_dataset/spatial_coordinates/alps_station_2D_coordinates.py
+++ b/spatio_temporal_dataset/spatial_coordinates/alps_station_2D_coordinates.py
@@ -10,7 +10,7 @@ class AlpsStation2DCoordinates(AlpsStation3DCoordinates):
     def from_csv(cls, csv_file='coord-lambert2'):
         # Remove the Z coordinates from df_coord
         spatial_coordinates = super().from_csv(csv_file)  # type: AlpsStation3DCoordinates
-        spatial_coordinates.df_coordinates.drop(cls.COORD_Z, axis=1, inplace=True)
+        spatial_coordinates.df_coordinates.drop(cls.COORDINATE_Z, axis=1, inplace=True)
         return spatial_coordinates
 
 
diff --git a/spatio_temporal_dataset/spatial_coordinates/alps_station_3D_coordinates.py b/spatio_temporal_dataset/spatial_coordinates/alps_station_3D_coordinates.py
index f45c94b6..67b4b7ff 100644
--- a/spatio_temporal_dataset/spatial_coordinates/alps_station_3D_coordinates.py
+++ b/spatio_temporal_dataset/spatial_coordinates/alps_station_3D_coordinates.py
@@ -31,7 +31,7 @@ class AlpsStation3DCoordinates(AbstractSpatialCoordinates):
                 assert len(coordinates) == 3
                 station_to_coordinates[station_name] = coordinates
         df = pd.DataFrame.from_dict(data=station_to_coordinates, orient='index',
-                                    columns=[cls.COORD_X, cls.COORD_Y, cls.COORD_Z])
+                                    columns=[cls.COORDINATE_X, cls.COORDINATE_Y, cls.COORDINATE_Z])
         print(df.head())
         filepath = op.join(cls.FULL_PATH, 'coord-lambert2.csv')
         assert not op.exists(filepath)
diff --git a/spatio_temporal_dataset/spatial_coordinates/generated_coordinates.py b/spatio_temporal_dataset/spatial_coordinates/generated_coordinates.py
index 52f53595..8f884427 100644
--- a/spatio_temporal_dataset/spatial_coordinates/generated_coordinates.py
+++ b/spatio_temporal_dataset/spatial_coordinates/generated_coordinates.py
@@ -15,7 +15,7 @@ class CircleCoordinatesRadius1(AbstractSpatialCoordinates):
         r = get_loaded_r()
         angles = np.array(r.runif(nb_points, max=2 * math.pi))
         radius = np.sqrt(np.array(r.runif(nb_points, max=max_radius)))
-        df = pd.DataFrame.from_dict({cls.COORD_X: radius * np.cos(angles), cls.COORD_Y: radius * np.sin(angles)})
+        df = pd.DataFrame.from_dict({cls.COORDINATE_X: radius * np.cos(angles), cls.COORDINATE_Y: radius * np.sin(angles)})
         return cls.from_df(df)
 
     def visualization_2D(self):
-- 
GitLab