From 011a1e15a46a8d0553320d7426228c95c047ca0a Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Fri, 22 Feb 2019 10:23:22 +0100
Subject: [PATCH] [COORDINATES] sort any coordinates columns according to the
 list COORDINATES_NAMES. add test.

---
 .../coordinates/abstract_coordinates.py       |  4 +-
 .../test_coordinates.py                       | 37 +++++++++++++++----
 2 files changed, 32 insertions(+), 9 deletions(-)

diff --git a/spatio_temporal_dataset/coordinates/abstract_coordinates.py b/spatio_temporal_dataset/coordinates/abstract_coordinates.py
index 940f097a..0c35f5c2 100644
--- a/spatio_temporal_dataset/coordinates/abstract_coordinates.py
+++ b/spatio_temporal_dataset/coordinates/abstract_coordinates.py
@@ -35,7 +35,9 @@ class AbstractCoordinates(object):
         # Extract df_all_coordinates from df
         coordinate_columns = [c for c in df.columns if c in self.COORDINATES_NAMES]
         assert len(coordinate_columns) > 0
-        self.df_all_coordinates = df.loc[:, coordinate_columns].copy()  # type: pd.DataFrame
+        # Sort coordinates according to a specified order
+        sorted_coordinates_columns = [c for c in self.COORDINATES_NAMES if c in coordinate_columns]
+        self.df_all_coordinates = df.loc[:, sorted_coordinates_columns].copy()  # type: pd.DataFrame
         # Check the data type of the coordinate columns
         accepted_dtypes = ['float64', 'int64']
         assert len(self.df_all_coordinates.select_dtypes(include=accepted_dtypes).columns) == len(coordinate_columns), \
diff --git a/test/test_spatio_temporal_dataset/test_coordinates.py b/test/test_spatio_temporal_dataset/test_coordinates.py
index a95abdad..baf51fc9 100644
--- a/test/test_spatio_temporal_dataset/test_coordinates.py
+++ b/test/test_spatio_temporal_dataset/test_coordinates.py
@@ -1,5 +1,6 @@
 import unittest
-from collections import Counter
+import pandas as pd
+from collections import Counter, OrderedDict
 
 from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
 from spatio_temporal_dataset.coordinates.spatio_temporal_coordinates.generated_spatio_temporal_coordinates import \
@@ -43,17 +44,37 @@ class SpatioTemporalCoordinates(unittest.TestCase):
     nb_points = 4
     nb_steps = 2
 
-    def tearDown(self):
-        c = Counter([len(self.coordinates.df_coordinates(split)) for split in SpatioTemporalSlicer.SPLITS])
-        good_count = c == Counter([2, 2, 2, 2]) or c == Counter([0, 0, 4, 4])
-        self.assertTrue(good_count)
-
     def test_temporal_circle(self):
         self.coordinates = UniformSpatioTemporalCoordinates.from_nb_points_and_nb_steps(nb_points=self.nb_points,
                                                                                         nb_steps=self.nb_steps,
                                                                                         train_split_ratio=0.5)
-    # def test_temporal_alps(self):
-    #     pass
+        c = Counter([len(self.coordinates.df_coordinates(split)) for split in SpatioTemporalSlicer.SPLITS])
+        good_count = c == Counter([2, 2, 2, 2]) or c == Counter([0, 0, 4, 4])
+        self.assertTrue(good_count)
+
+    def test_ordered_coordinates(self):
+        # Order coordinates, to ensure that the first dimension/the second dimension and so on..
+        # Always are in the same order to a given type (e.g. spatio_temporal= of coordinates
+        # Check space coordinates
+        d = OrderedDict()
+        d[AbstractCoordinates.COORDINATE_Z] = [1]
+        d[AbstractCoordinates.COORDINATE_X] = [1]
+        d[AbstractCoordinates.COORDINATE_Y] = [1]
+        df = pd.DataFrame.from_dict(d)
+        for df2 in [df, df.loc[:, ::-1]][-1:]:
+            coordinates = AbstractCoordinates(df=df2, slicer_class=SpatioTemporalSlicer)
+            self.assertEqual(list(coordinates.df_all_coordinates.columns),
+                             [AbstractCoordinates.COORDINATE_X, AbstractCoordinates.COORDINATE_Y,
+                              AbstractCoordinates.COORDINATE_Z])
+        # Check space/time ordering
+        d = OrderedDict()
+        d[AbstractCoordinates.COORDINATE_T] = [1]
+        d[AbstractCoordinates.COORDINATE_X] = [1]
+        df = pd.DataFrame.from_dict(d)
+        for df2 in [df, df.loc[:, ::-1]][-1:]:
+            coordinates = AbstractCoordinates(df=df2, slicer_class=SpatioTemporalSlicer)
+            self.assertEqual(list(coordinates.df_all_coordinates.columns),
+                             [AbstractCoordinates.COORDINATE_X, AbstractCoordinates.COORDINATE_T])
 
 
 if __name__ == '__main__':
-- 
GitLab