Commit d30d12bd authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[SPATIAL COORDINATES] refactor names

parent 30f2543f
No related merge requests found
Showing with 13 additions and 12 deletions
+13 -12
......@@ -9,9 +9,10 @@ from mpl_toolkits.mplot3d import Axes3D
class AbstractSpatialCoordinates(object):
# Columns
COORD_X = 'coord_x'
COORD_Y = 'coord_y'
COORD_Z = 'coord_z'
COORDINATE_X = 'coord_x'
COORDINATE_Y = 'coord_y'
COORDINATE_Z = 'coord_z'
COORDINATE_NAMES = [COORDINATE_X, COORDINATE_Y, COORDINATE_Z]
COORD_SPLIT = 'coord_split'
# Constants
TRAIN_SPLIT_STR = 'train_split'
......@@ -24,7 +25,7 @@ class AbstractSpatialCoordinates(object):
@classmethod
def from_df(cls, df: pd.DataFrame):
# X and Y coordinates must be defined
assert cls.COORD_X in df.columns and cls.COORD_Y in df.columns
assert cls.COORDINATE_X in df.columns and cls.COORDINATE_Y in df.columns
df_coordinates = df.loc[:, cls.coordinates_columns(df)]
# Potentially, a split column can be specified
s_split = df[cls.COORD_SPLIT] if cls.COORD_SPLIT in df.columns else None
......@@ -33,9 +34,9 @@ class AbstractSpatialCoordinates(object):
@classmethod
def coordinates_columns(cls, df_coord: pd.DataFrame) -> List[str]:
# If a Z coordinate is in the DataFrame, then
coord_columns = [cls.COORD_X, cls.COORD_Y]
if cls.COORD_Z in df_coord.columns:
coord_columns.append(cls.COORD_Z)
coord_columns = [cls.COORDINATE_X, cls.COORDINATE_Y]
if cls.COORDINATE_Z in df_coord.columns:
coord_columns.append(cls.COORDINATE_Z)
return coord_columns
@property
......@@ -80,11 +81,11 @@ class AbstractSpatialCoordinates(object):
@property
def x_coordinates(self) -> np.ndarray:
return self.df_coordinates.loc[:, self.COORD_X].values.copy()
return self.df_coordinates.loc[:, self.COORDINATE_X].values.copy()
@property
def y_coordinates(self) -> np.ndarray:
return self.df_coordinates.loc[:, self.COORD_Y].values.copy()
return self.df_coordinates.loc[:, self.COORDINATE_Y].values.copy()
@property
def coordinates_train(self) -> np.ndarray:
......
......@@ -10,7 +10,7 @@ class AlpsStation2DCoordinates(AlpsStation3DCoordinates):
def from_csv(cls, csv_file='coord-lambert2'):
# Remove the Z coordinates from df_coord
spatial_coordinates = super().from_csv(csv_file) # type: AlpsStation3DCoordinates
spatial_coordinates.df_coordinates.drop(cls.COORD_Z, axis=1, inplace=True)
spatial_coordinates.df_coordinates.drop(cls.COORDINATE_Z, axis=1, inplace=True)
return spatial_coordinates
......
......@@ -31,7 +31,7 @@ class AlpsStation3DCoordinates(AbstractSpatialCoordinates):
assert len(coordinates) == 3
station_to_coordinates[station_name] = coordinates
df = pd.DataFrame.from_dict(data=station_to_coordinates, orient='index',
columns=[cls.COORD_X, cls.COORD_Y, cls.COORD_Z])
columns=[cls.COORDINATE_X, cls.COORDINATE_Y, cls.COORDINATE_Z])
print(df.head())
filepath = op.join(cls.FULL_PATH, 'coord-lambert2.csv')
assert not op.exists(filepath)
......
......@@ -15,7 +15,7 @@ class CircleCoordinatesRadius1(AbstractSpatialCoordinates):
r = get_loaded_r()
angles = np.array(r.runif(nb_points, max=2 * math.pi))
radius = np.sqrt(np.array(r.runif(nb_points, max=max_radius)))
df = pd.DataFrame.from_dict({cls.COORD_X: radius * np.cos(angles), cls.COORD_Y: radius * np.sin(angles)})
df = pd.DataFrame.from_dict({cls.COORDINATE_X: radius * np.cos(angles), cls.COORDINATE_Y: radius * np.sin(angles)})
return cls.from_df(df)
def visualization_2D(self):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment