Commit 011a1e15 authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[COORDINATES] sort any coordinates columns according to the list COORDINATES_NAMES. add test.

parent 91aea614
No related merge requests found
Showing with 32 additions and 9 deletions
+32 -9
...@@ -35,7 +35,9 @@ class AbstractCoordinates(object): ...@@ -35,7 +35,9 @@ class AbstractCoordinates(object):
# Extract df_all_coordinates from df # Extract df_all_coordinates from df
coordinate_columns = [c for c in df.columns if c in self.COORDINATES_NAMES] coordinate_columns = [c for c in df.columns if c in self.COORDINATES_NAMES]
assert len(coordinate_columns) > 0 assert len(coordinate_columns) > 0
self.df_all_coordinates = df.loc[:, coordinate_columns].copy() # type: pd.DataFrame # Sort coordinates according to a specified order
sorted_coordinates_columns = [c for c in self.COORDINATES_NAMES if c in coordinate_columns]
self.df_all_coordinates = df.loc[:, sorted_coordinates_columns].copy() # type: pd.DataFrame
# Check the data type of the coordinate columns # Check the data type of the coordinate columns
accepted_dtypes = ['float64', 'int64'] accepted_dtypes = ['float64', 'int64']
assert len(self.df_all_coordinates.select_dtypes(include=accepted_dtypes).columns) == len(coordinate_columns), \ assert len(self.df_all_coordinates.select_dtypes(include=accepted_dtypes).columns) == len(coordinate_columns), \
......
import unittest import unittest
from collections import Counter import pandas as pd
from collections import Counter, OrderedDict
from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
from spatio_temporal_dataset.coordinates.spatio_temporal_coordinates.generated_spatio_temporal_coordinates import \ from spatio_temporal_dataset.coordinates.spatio_temporal_coordinates.generated_spatio_temporal_coordinates import \
...@@ -43,17 +44,37 @@ class SpatioTemporalCoordinates(unittest.TestCase): ...@@ -43,17 +44,37 @@ class SpatioTemporalCoordinates(unittest.TestCase):
nb_points = 4 nb_points = 4
nb_steps = 2 nb_steps = 2
def tearDown(self):
c = Counter([len(self.coordinates.df_coordinates(split)) for split in SpatioTemporalSlicer.SPLITS])
good_count = c == Counter([2, 2, 2, 2]) or c == Counter([0, 0, 4, 4])
self.assertTrue(good_count)
def test_temporal_circle(self): def test_temporal_circle(self):
self.coordinates = UniformSpatioTemporalCoordinates.from_nb_points_and_nb_steps(nb_points=self.nb_points, self.coordinates = UniformSpatioTemporalCoordinates.from_nb_points_and_nb_steps(nb_points=self.nb_points,
nb_steps=self.nb_steps, nb_steps=self.nb_steps,
train_split_ratio=0.5) train_split_ratio=0.5)
# def test_temporal_alps(self): c = Counter([len(self.coordinates.df_coordinates(split)) for split in SpatioTemporalSlicer.SPLITS])
# pass good_count = c == Counter([2, 2, 2, 2]) or c == Counter([0, 0, 4, 4])
self.assertTrue(good_count)
def test_ordered_coordinates(self):
# Order coordinates, to ensure that the first dimension/the second dimension and so on..
# Always are in the same order to a given type (e.g. spatio_temporal= of coordinates
# Check space coordinates
d = OrderedDict()
d[AbstractCoordinates.COORDINATE_Z] = [1]
d[AbstractCoordinates.COORDINATE_X] = [1]
d[AbstractCoordinates.COORDINATE_Y] = [1]
df = pd.DataFrame.from_dict(d)
for df2 in [df, df.loc[:, ::-1]][-1:]:
coordinates = AbstractCoordinates(df=df2, slicer_class=SpatioTemporalSlicer)
self.assertEqual(list(coordinates.df_all_coordinates.columns),
[AbstractCoordinates.COORDINATE_X, AbstractCoordinates.COORDINATE_Y,
AbstractCoordinates.COORDINATE_Z])
# Check space/time ordering
d = OrderedDict()
d[AbstractCoordinates.COORDINATE_T] = [1]
d[AbstractCoordinates.COORDINATE_X] = [1]
df = pd.DataFrame.from_dict(d)
for df2 in [df, df.loc[:, ::-1]][-1:]:
coordinates = AbstractCoordinates(df=df2, slicer_class=SpatioTemporalSlicer)
self.assertEqual(list(coordinates.df_all_coordinates.columns),
[AbstractCoordinates.COORDINATE_X, AbstractCoordinates.COORDINATE_T])
if __name__ == '__main__': if __name__ == '__main__':
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment