Commit d30d12bd authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[SPATIAL COORDINATES] refactor names

parent 30f2543f
No related merge requests found
Showing with 13 additions and 12 deletions
+13 -12
...@@ -9,9 +9,10 @@ from mpl_toolkits.mplot3d import Axes3D ...@@ -9,9 +9,10 @@ from mpl_toolkits.mplot3d import Axes3D
class AbstractSpatialCoordinates(object): class AbstractSpatialCoordinates(object):
# Columns # Columns
COORD_X = 'coord_x' COORDINATE_X = 'coord_x'
COORD_Y = 'coord_y' COORDINATE_Y = 'coord_y'
COORD_Z = 'coord_z' COORDINATE_Z = 'coord_z'
COORDINATE_NAMES = [COORDINATE_X, COORDINATE_Y, COORDINATE_Z]
COORD_SPLIT = 'coord_split' COORD_SPLIT = 'coord_split'
# Constants # Constants
TRAIN_SPLIT_STR = 'train_split' TRAIN_SPLIT_STR = 'train_split'
...@@ -24,7 +25,7 @@ class AbstractSpatialCoordinates(object): ...@@ -24,7 +25,7 @@ class AbstractSpatialCoordinates(object):
@classmethod @classmethod
def from_df(cls, df: pd.DataFrame): def from_df(cls, df: pd.DataFrame):
# X and Y coordinates must be defined # X and Y coordinates must be defined
assert cls.COORD_X in df.columns and cls.COORD_Y in df.columns assert cls.COORDINATE_X in df.columns and cls.COORDINATE_Y in df.columns
df_coordinates = df.loc[:, cls.coordinates_columns(df)] df_coordinates = df.loc[:, cls.coordinates_columns(df)]
# Potentially, a split column can be specified # Potentially, a split column can be specified
s_split = df[cls.COORD_SPLIT] if cls.COORD_SPLIT in df.columns else None s_split = df[cls.COORD_SPLIT] if cls.COORD_SPLIT in df.columns else None
...@@ -33,9 +34,9 @@ class AbstractSpatialCoordinates(object): ...@@ -33,9 +34,9 @@ class AbstractSpatialCoordinates(object):
@classmethod @classmethod
def coordinates_columns(cls, df_coord: pd.DataFrame) -> List[str]: def coordinates_columns(cls, df_coord: pd.DataFrame) -> List[str]:
# If a Z coordinate is in the DataFrame, then # If a Z coordinate is in the DataFrame, then
coord_columns = [cls.COORD_X, cls.COORD_Y] coord_columns = [cls.COORDINATE_X, cls.COORDINATE_Y]
if cls.COORD_Z in df_coord.columns: if cls.COORDINATE_Z in df_coord.columns:
coord_columns.append(cls.COORD_Z) coord_columns.append(cls.COORDINATE_Z)
return coord_columns return coord_columns
@property @property
...@@ -80,11 +81,11 @@ class AbstractSpatialCoordinates(object): ...@@ -80,11 +81,11 @@ class AbstractSpatialCoordinates(object):
@property @property
def x_coordinates(self) -> np.ndarray: def x_coordinates(self) -> np.ndarray:
return self.df_coordinates.loc[:, self.COORD_X].values.copy() return self.df_coordinates.loc[:, self.COORDINATE_X].values.copy()
@property @property
def y_coordinates(self) -> np.ndarray: def y_coordinates(self) -> np.ndarray:
return self.df_coordinates.loc[:, self.COORD_Y].values.copy() return self.df_coordinates.loc[:, self.COORDINATE_Y].values.copy()
@property @property
def coordinates_train(self) -> np.ndarray: def coordinates_train(self) -> np.ndarray:
......
...@@ -10,7 +10,7 @@ class AlpsStation2DCoordinates(AlpsStation3DCoordinates): ...@@ -10,7 +10,7 @@ class AlpsStation2DCoordinates(AlpsStation3DCoordinates):
def from_csv(cls, csv_file='coord-lambert2'): def from_csv(cls, csv_file='coord-lambert2'):
# Remove the Z coordinates from df_coord # Remove the Z coordinates from df_coord
spatial_coordinates = super().from_csv(csv_file) # type: AlpsStation3DCoordinates spatial_coordinates = super().from_csv(csv_file) # type: AlpsStation3DCoordinates
spatial_coordinates.df_coordinates.drop(cls.COORD_Z, axis=1, inplace=True) spatial_coordinates.df_coordinates.drop(cls.COORDINATE_Z, axis=1, inplace=True)
return spatial_coordinates return spatial_coordinates
......
...@@ -31,7 +31,7 @@ class AlpsStation3DCoordinates(AbstractSpatialCoordinates): ...@@ -31,7 +31,7 @@ class AlpsStation3DCoordinates(AbstractSpatialCoordinates):
assert len(coordinates) == 3 assert len(coordinates) == 3
station_to_coordinates[station_name] = coordinates station_to_coordinates[station_name] = coordinates
df = pd.DataFrame.from_dict(data=station_to_coordinates, orient='index', df = pd.DataFrame.from_dict(data=station_to_coordinates, orient='index',
columns=[cls.COORD_X, cls.COORD_Y, cls.COORD_Z]) columns=[cls.COORDINATE_X, cls.COORDINATE_Y, cls.COORDINATE_Z])
print(df.head()) print(df.head())
filepath = op.join(cls.FULL_PATH, 'coord-lambert2.csv') filepath = op.join(cls.FULL_PATH, 'coord-lambert2.csv')
assert not op.exists(filepath) assert not op.exists(filepath)
......
...@@ -15,7 +15,7 @@ class CircleCoordinatesRadius1(AbstractSpatialCoordinates): ...@@ -15,7 +15,7 @@ class CircleCoordinatesRadius1(AbstractSpatialCoordinates):
r = get_loaded_r() r = get_loaded_r()
angles = np.array(r.runif(nb_points, max=2 * math.pi)) angles = np.array(r.runif(nb_points, max=2 * math.pi))
radius = np.sqrt(np.array(r.runif(nb_points, max=max_radius))) radius = np.sqrt(np.array(r.runif(nb_points, max=max_radius)))
df = pd.DataFrame.from_dict({cls.COORD_X: radius * np.cos(angles), cls.COORD_Y: radius * np.sin(angles)}) df = pd.DataFrame.from_dict({cls.COORDINATE_X: radius * np.cos(angles), cls.COORDINATE_Y: radius * np.sin(angles)})
return cls.from_df(df) return cls.from_df(df)
def visualization_2D(self): def visualization_2D(self):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment