From 6c96e20c709fc3061a8ac87be603516d0b6779af Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Thu, 21 Mar 2019 14:44:43 +0100
Subject: [PATCH] [COORDINATES] add a 2D spatio temporal coordinates

---
 .../spatial_coordinates/coordinates_2D.py     |  7 +++++-
 .../abstract_spatio_temporal_coordinates.py   | 13 ++++++++++-
 .../generated_spatio_temporal_coordinates.py  | 22 +++++++++++--------
 .../dataset/simulation_dataset.py             |  1 +
 .../test_slicer.py                            |  1 +
 test/test_utils.py                            |  4 ++--
 6 files changed, 35 insertions(+), 13 deletions(-)

diff --git a/spatio_temporal_dataset/coordinates/spatial_coordinates/coordinates_2D.py b/spatio_temporal_dataset/coordinates/spatial_coordinates/coordinates_2D.py
index 3c845c9e..4498f655 100644
--- a/spatio_temporal_dataset/coordinates/spatial_coordinates/coordinates_2D.py
+++ b/spatio_temporal_dataset/coordinates/spatial_coordinates/coordinates_2D.py
@@ -15,7 +15,12 @@ class LinSpaceSpatial2DCoordinates(AbstractBiDimensionalSpatialCoordinates):
 
     @classmethod
     def from_nb_points(cls, nb_points, train_split_ratio: float = None, start=-1.0, end=1.0):
+        df = cls.df_spatial(nb_points, start, end)
+        return cls.from_df(df, train_split_ratio)
+
+    @classmethod
+    def df_spatial(cls, nb_points, start=-1.0, end=1.0):
         axis_coordinates = np.linspace(start, end, nb_points)
         df = pd.DataFrame.from_dict({cls.COORDINATE_X: axis_coordinates,
                                      cls.COORDINATE_Y: axis_coordinates})
-        return cls.from_df(df, train_split_ratio)
+        return df
diff --git a/spatio_temporal_dataset/coordinates/spatio_temporal_coordinates/abstract_spatio_temporal_coordinates.py b/spatio_temporal_dataset/coordinates/spatio_temporal_coordinates/abstract_spatio_temporal_coordinates.py
index 651eb349..47883682 100644
--- a/spatio_temporal_dataset/coordinates/spatio_temporal_coordinates/abstract_spatio_temporal_coordinates.py
+++ b/spatio_temporal_dataset/coordinates/spatio_temporal_coordinates/abstract_spatio_temporal_coordinates.py
@@ -14,4 +14,15 @@ class AbstractSpatioTemporalCoordinates(AbstractCoordinates):
         nb_points = len(set(df[cls.COORDINATE_X]))
         first_time_step_for_all_points = df.iloc[:nb_points][cls.COORDINATE_T]
         assert len(set(first_time_step_for_all_points)) == 1
-        return super().from_df_and_slicer(df, SpatioTemporalSlicer, train_split_ratio)
\ No newline at end of file
+        return super().from_df_and_slicer(df, SpatioTemporalSlicer, train_split_ratio)
+
+    @classmethod
+    def generate_df_spatio_temporal(cls, df_spatial, nb_steps):
+        # df_temporal = ConsecutiveTemporalCoordinates.df_temporal(nb_temporal_steps=nb_temporal_steps)
+        df_time_steps = []
+        for t in range(nb_steps):
+            df_time_step = df_spatial.copy()
+            df_time_step[cls.COORDINATE_T] = t
+            df_time_steps.append(df_time_step)
+        df_time_steps = pd.concat(df_time_steps, ignore_index=True)
+        return df_time_steps
\ No newline at end of file
diff --git a/spatio_temporal_dataset/coordinates/spatio_temporal_coordinates/generated_spatio_temporal_coordinates.py b/spatio_temporal_dataset/coordinates/spatio_temporal_coordinates/generated_spatio_temporal_coordinates.py
index deaccc8d..9d6ba771 100644
--- a/spatio_temporal_dataset/coordinates/spatio_temporal_coordinates/generated_spatio_temporal_coordinates.py
+++ b/spatio_temporal_dataset/coordinates/spatio_temporal_coordinates/generated_spatio_temporal_coordinates.py
@@ -1,23 +1,27 @@
 import pandas as pd
 
 from spatio_temporal_dataset.coordinates.spatial_coordinates.coordinates_1D import UniformSpatialCoordinates
+from spatio_temporal_dataset.coordinates.spatial_coordinates.coordinates_2D import LinSpaceSpatial2DCoordinates
 from spatio_temporal_dataset.coordinates.spatio_temporal_coordinates.abstract_spatio_temporal_coordinates import \
     AbstractSpatioTemporalCoordinates
 
 
-class UniformSpatioTemporalCoordinates(AbstractSpatioTemporalCoordinates):
+class GeneratedSpatioTemporalCoordinates(AbstractSpatioTemporalCoordinates):
+    SPATIAL_COORDINATES_CLASS = None
 
     @classmethod
     def from_nb_points_and_nb_steps(cls, nb_points, nb_steps, train_split_ratio: float = None):
         assert isinstance(nb_steps, int) and nb_steps >= 1
-        df_spatial = UniformSpatialCoordinates.df_spatial(nb_points=nb_points)
-        # df_temporal = ConsecutiveTemporalCoordinates.df_temporal(nb_temporal_steps=nb_temporal_steps)
-        df_time_steps = []
-        for t in range(nb_steps):
-            df_time_step = df_spatial.copy()
-            df_time_step[cls.COORDINATE_T] = t
-            df_time_steps.append(df_time_step)
-        df_time_steps = pd.concat(df_time_steps, ignore_index=True)
+        assert cls.SPATIAL_COORDINATES_CLASS is not None
+        assert hasattr(cls.SPATIAL_COORDINATES_CLASS, 'df_spatial')
+        df_spatial = cls.SPATIAL_COORDINATES_CLASS.df_spatial(nb_points=nb_points)
+        df_time_steps = cls.generate_df_spatio_temporal(df_spatial, nb_steps)
         return cls.from_df(df=df_time_steps, train_split_ratio=train_split_ratio)
 
 
+class UniformSpatioTemporalCoordinates(GeneratedSpatioTemporalCoordinates):
+    SPATIAL_COORDINATES_CLASS = UniformSpatialCoordinates
+
+
+class LinSpaceSpatial2DSpatioTemporalCoordinates(GeneratedSpatioTemporalCoordinates):
+    SPATIAL_COORDINATES_CLASS = LinSpaceSpatial2DCoordinates
diff --git a/spatio_temporal_dataset/dataset/simulation_dataset.py b/spatio_temporal_dataset/dataset/simulation_dataset.py
index d475bfe2..a2cbea15 100644
--- a/spatio_temporal_dataset/dataset/simulation_dataset.py
+++ b/spatio_temporal_dataset/dataset/simulation_dataset.py
@@ -46,6 +46,7 @@ class FullSimulatedDataset(SimulatedDataset):
     def from_double_sampling(cls, nb_obs: int, max_stable_model: AbstractMaxStableModel,
                              coordinates: AbstractCoordinates,
                              margin_model: AbstractMarginModel):
+        assert coordinates.nb_coordinates <= 2, 'rmaxstable available only for 2D coordinates'
         observations = FullAnnualMaxima.from_double_sampling(nb_obs, max_stable_model,
                                                              coordinates, margin_model)
         return cls(observations=observations, coordinates=coordinates,
diff --git a/test/test_spatio_temporal_dataset/test_slicer.py b/test/test_spatio_temporal_dataset/test_slicer.py
index 5a99be0a..ed13dda6 100644
--- a/test/test_spatio_temporal_dataset/test_slicer.py
+++ b/test/test_spatio_temporal_dataset/test_slicer.py
@@ -102,6 +102,7 @@ class TestSlicerForSpatioTemporalDataset(TestSlicerForDataset):
         coordinates_list = load_test_spatiotemporal_coordinates(nb_points=self.nb_points,
                                                                 nb_steps=self.nb_steps,
                                                                 train_split_ratio=train_split_ratio)
+        coordinates_list = [coordinates for coordinates in coordinates_list if coordinates.nb_coordinates <= 2]
         dataset_list = [FullSimulatedDataset.from_double_sampling(nb_obs=self.nb_obs,
                                                                   margin_model=ConstantMarginModel(
                                                                       coordinates=coordinates),
diff --git a/test/test_utils.py b/test/test_utils.py
index ca1f800b..170988f3 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -20,7 +20,7 @@ from spatio_temporal_dataset.coordinates.spatial_coordinates.alps_station_3D_coo
 from spatio_temporal_dataset.coordinates.spatial_coordinates.generated_spatial_coordinates import \
     CircleSpatialCoordinates
 from spatio_temporal_dataset.coordinates.spatio_temporal_coordinates.generated_spatio_temporal_coordinates import \
-    UniformSpatioTemporalCoordinates
+    UniformSpatioTemporalCoordinates, LinSpaceSpatial2DSpatioTemporalCoordinates
 from spatio_temporal_dataset.coordinates.spatial_coordinates.coordinates_1D import UniformSpatialCoordinates
 from spatio_temporal_dataset.coordinates.temporal_coordinates.generated_temporal_coordinates import \
     ConsecutiveTemporalCoordinates
@@ -35,7 +35,7 @@ TEST_MAX_STABLE_MODEL = [Smith, BrownResnick, Schlather, Geometric, ExtremalT, I
 TEST_1D_AND_2D_SPATIAL_COORDINATES = [UniformSpatialCoordinates, CircleSpatialCoordinates]
 TEST_3D_SPATIAL_COORDINATES = [AlpsStation3DCoordinatesWithAnisotropy]
 TEST_TEMPORAL_COORDINATES = [ConsecutiveTemporalCoordinates]
-TEST_SPATIO_TEMPORAL_COORDINATES = [UniformSpatioTemporalCoordinates]
+TEST_SPATIO_TEMPORAL_COORDINATES = [UniformSpatioTemporalCoordinates, LinSpaceSpatial2DSpatioTemporalCoordinates]
 TEST_MARGIN_TYPES = [ConstantMarginModel, LinearAllParametersAllDimsMarginModel][:]
 TEST_MAX_STABLE_ESTIMATOR = [MaxStableEstimator]
 TEST_FULL_ESTIMATORS = [SmoothMarginalsThenUnitaryMsp, FullEstimatorInASingleStepWithSmoothMargin][:]
-- 
GitLab