Commit a3e040da authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[SPATIAL COORDINATES] add nb_points class constructor to any abstract_coordinates object.

parent 4a24783d
No related merge requests found
Showing with 36 additions and 26 deletions
+36 -26
...@@ -26,12 +26,29 @@ class AbstractSpatialCoordinates(object): ...@@ -26,12 +26,29 @@ class AbstractSpatialCoordinates(object):
s_split = df[cls.COORD_SPLIT] if cls.COORD_SPLIT in df.columns else None s_split = df[cls.COORD_SPLIT] if cls.COORD_SPLIT in df.columns else None
return cls(df_coord=df_coord, s_split=s_split) return cls(df_coord=df_coord, s_split=s_split)
@property
def df(self):
return self.df_coord if self.s_split is None else self.df_coord.join(self.s_split)
@classmethod @classmethod
def from_csv(cls, csv_path): def from_csv(cls, csv_path=None):
assert csv_path is not None
assert op.exists(csv_path) assert op.exists(csv_path)
df = pd.read_csv(csv_path) df = pd.read_csv(csv_path)
return cls.from_df(df) return cls.from_df(df)
@classmethod
def from_nb_points(cls, nb_points, **kwargs):
# Call the default class method from csv
coordinates = cls.from_csv() # type: AbstractSpatialCoordinates
# Sample randomly nb_points coordinates
nb_coordinates = len(coordinates)
if nb_points > nb_coordinates:
raise Exception('Nb coordinates in csv: {} < Nb points desired: {}'.format(nb_coordinates, nb_points))
else:
df_sample = pd.DataFrame.sample(coordinates.df, n=nb_points)
return cls.from_df(df=df_sample)
def coord_x_y_values(self, df_coord: pd.DataFrame) -> np.ndarray: def coord_x_y_values(self, df_coord: pd.DataFrame) -> np.ndarray:
return df_coord.loc[:, [self.COORD_X, self.COORD_Y]].values return df_coord.loc[:, [self.COORD_X, self.COORD_Y]].values
...@@ -56,8 +73,7 @@ class AbstractSpatialCoordinates(object): ...@@ -56,8 +73,7 @@ class AbstractSpatialCoordinates(object):
def index(self): def index(self):
return self.df_coord.index return self.df_coord.index
@property def __len__(self):
def nb_points(self):
return len(self.df_coord) return len(self.df_coord)
def visualization(self): def visualization(self):
......
...@@ -2,10 +2,12 @@ import pandas as pd ...@@ -2,10 +2,12 @@ import pandas as pd
import os.path as op import os.path as op
from spatio_temporal_dataset.spatial_coordinates.abstract_coordinates import AbstractSpatialCoordinates from spatio_temporal_dataset.spatial_coordinates.abstract_coordinates import AbstractSpatialCoordinates
from spatio_temporal_dataset.spatial_coordinates.normalized_coordinates import BetweenZeroAndOneNormalization, \
NormalizedCoordinates
from utils import get_full_path from utils import get_full_path
class AlpsStationCoordinate(AbstractSpatialCoordinates): class AlpsStationCoordinates(AbstractSpatialCoordinates):
RELATIVE_PATH = r'local/spatio_temporal_datasets/Gilles - precipitations' RELATIVE_PATH = r'local/spatio_temporal_datasets/Gilles - precipitations'
FULL_PATH = get_full_path(relative_path=RELATIVE_PATH) FULL_PATH = get_full_path(relative_path=RELATIVE_PATH)
...@@ -31,7 +33,19 @@ class AlpsStationCoordinate(AbstractSpatialCoordinates): ...@@ -31,7 +33,19 @@ class AlpsStationCoordinate(AbstractSpatialCoordinates):
print(df.index) print(df.index)
class AlpsStationCoordinatesBetweenZeroAndOne(AlpsStationCoordinates):
@classmethod
def from_csv(cls, csv_file='coord-lambert2'):
coord = super().from_csv(csv_file)
return NormalizedCoordinates.from_coordinates(spatial_coordinates=coord,
normalizing_function=BetweenZeroAndOneNormalization())
if __name__ == '__main__': if __name__ == '__main__':
# AlpsStationCoordinate.transform_txt_into_csv() # AlpsStationCoordinate.transform_txt_into_csv()
coord = AlpsStationCoordinate.from_csv() # coord = AlpsStationCoordinates.from_csv()
# coord = AlpsStationCoordinates.from_nb_points(nb_points=60)
# coord = AlpsStationCoordinatesBetweenZeroAndOne.from_csv()
coord = AlpsStationCoordinatesBetweenZeroAndOne.from_nb_points(nb_points=60)
coord.visualization() coord.visualization()
...@@ -7,20 +7,7 @@ from spatio_temporal_dataset.spatial_coordinates.abstract_coordinates import Abs ...@@ -7,20 +7,7 @@ from spatio_temporal_dataset.spatial_coordinates.abstract_coordinates import Abs
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
class SimulatedCoordinates(AbstractSpatialCoordinates): class CircleCoordinates(AbstractSpatialCoordinates):
"""
Common manipulation on generated coordinates
"""
def __init__(self, df_coord, s_split=None):
super().__init__(df_coord, s_split)
@classmethod
def from_nb_points(cls, nb_points, **kwargs):
pass
class CircleCoordinates(SimulatedCoordinates):
@classmethod @classmethod
def from_nb_points(cls, nb_points, **kwargs): def from_nb_points(cls, nb_points, **kwargs):
......
import pandas as pd import pandas as pd
from spatio_temporal_dataset.spatial_coordinates.abstract_coordinates import AbstractSpatialCoordinates from spatio_temporal_dataset.spatial_coordinates.abstract_coordinates import AbstractSpatialCoordinates
from spatio_temporal_dataset.spatial_coordinates.alps_station_coordinates import AlpsStationCoordinate
class AbstractNormalizingFunction(object): class AbstractNormalizingFunction(object):
...@@ -59,8 +57,3 @@ class BetweenZeroAndOneNormalization(UniformNormalization): ...@@ -59,8 +57,3 @@ class BetweenZeroAndOneNormalization(UniformNormalization):
return s_coord_scaled return s_coord_scaled
if __name__ == '__main__':
coord = AlpsStationCoordinate.from_csv()
normalized_coord = NormalizedCoordinates.from_coordinates(spatial_coordinates=coord,
normalizing_function=BetweenZeroAndOneNormalization())
normalized_coord.visualization()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment