From ad558b5805c4c6681265701b6532322f0a1748ae Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Fri, 23 Nov 2018 19:19:26 +0100
Subject: [PATCH] [REFACTOR] refactor test folder. add utils.py

---
 .../abstract_max_stable_model.py              |  3 ++
 .../extreme_models/max_stable_model/utils.py  | 19 ---------
 spatio_temporal_dataset/coordinates/utils.py  | 10 -----
 .../dataset/simulation_dataset.py             |  2 +-
 .../test_estimator/test_full_estimators.py    |  7 ++--
 .../test_estimator/test_margin_estimators.py  | 10 +----
 .../test_max_stable_estimators.py             |  4 +-
 .../test_dataset.py                           | 15 ++++---
 test/test_utils.py                            | 40 +++++++++++++++++++
 9 files changed, 58 insertions(+), 52 deletions(-)
 delete mode 100644 extreme_estimator/extreme_models/max_stable_model/utils.py
 delete mode 100644 spatio_temporal_dataset/coordinates/utils.py
 create mode 100644 test/test_utils.py

diff --git a/extreme_estimator/extreme_models/max_stable_model/abstract_max_stable_model.py b/extreme_estimator/extreme_models/max_stable_model/abstract_max_stable_model.py
index 44643e1b..367b9b25 100644
--- a/extreme_estimator/extreme_models/max_stable_model/abstract_max_stable_model.py
+++ b/extreme_estimator/extreme_models/max_stable_model/abstract_max_stable_model.py
@@ -34,6 +34,9 @@ class AbstractMaxStableModel(AbstractModel):
         #  Prepare the fit params
         fit_params = self.cov_mod_param.copy()
         start_dict = self.params_start_fit
+        # Remove the 'var' parameter from the start_dict in the 2D case, otherwise fitmaxstab crashes
+        if len(df_coordinates.columns) == 2 and 'var' in start_dict.keys():
+                start_dict.pop('var')
         if fit_marge:
             start_dict.update(margin_start_dict)
             fit_params.update({k: robjects.Formula(v) for k, v in fit_marge_form_dict.items()})
diff --git a/extreme_estimator/extreme_models/max_stable_model/utils.py b/extreme_estimator/extreme_models/max_stable_model/utils.py
deleted file mode 100644
index 41333165..00000000
--- a/extreme_estimator/extreme_models/max_stable_model/utils.py
+++ /dev/null
@@ -1,19 +0,0 @@
-from extreme_estimator.extreme_models.max_stable_model.abstract_max_stable_model import \
-    AbstractMaxStableModelWithCovarianceFunction, CovarianceFunction
-from extreme_estimator.extreme_models.max_stable_model.max_stable_models import Smith, BrownResnick, Schlather, \
-    Geometric, ExtremalT, ISchlather
-
-MAX_STABLE_TYPES = [Smith, BrownResnick, Schlather, Geometric, ExtremalT, ISchlather]
-
-
-def load_max_stable_models():
-    # Load all max stable model
-    max_stable_models = []
-    for max_stable_class in MAX_STABLE_TYPES:
-        if issubclass(max_stable_class, AbstractMaxStableModelWithCovarianceFunction):
-            max_stable_models.extend([max_stable_class(covariance_function=covariance_function)
-                                      for covariance_function in CovarianceFunction])
-        else:
-            max_stable_models.append(max_stable_class())
-    return max_stable_models
-
diff --git a/spatio_temporal_dataset/coordinates/utils.py b/spatio_temporal_dataset/coordinates/utils.py
deleted file mode 100644
index 377fd82d..00000000
--- a/spatio_temporal_dataset/coordinates/utils.py
+++ /dev/null
@@ -1,10 +0,0 @@
-from spatio_temporal_dataset.coordinates.spatial_coordinates.alps_station_3D_coordinates import \
-    AlpsStation3DCoordinatesWithAnisotropy
-from spatio_temporal_dataset.coordinates.spatial_coordinates.generated_spatial_coordinates import CircleCoordinates
-from spatio_temporal_dataset.coordinates.unidimensional_coordinates.coordinates_1D import UniformCoordinates
-
-COORDINATES = [UniformCoordinates, CircleCoordinates, AlpsStation3DCoordinatesWithAnisotropy]
-
-
-def load_coordinates(nb_points):
-    return [coordinate_class.from_nb_points(nb_points=nb_points) for coordinate_class in COORDINATES]
diff --git a/spatio_temporal_dataset/dataset/simulation_dataset.py b/spatio_temporal_dataset/dataset/simulation_dataset.py
index 0b489d3a..4738c8c8 100644
--- a/spatio_temporal_dataset/dataset/simulation_dataset.py
+++ b/spatio_temporal_dataset/dataset/simulation_dataset.py
@@ -34,7 +34,7 @@ class MaxStableDataset(SimulatedDataset):
 class MarginDataset(SimulatedDataset):
 
     @classmethod
-    def from_sampling(cls, nb_obs: int, margin_model: AbstractMarginModel,coordinates: AbstractCoordinates):
+    def from_sampling(cls, nb_obs: int, margin_model: AbstractMarginModel, coordinates: AbstractCoordinates):
         temporal_obs = MarginAnnualMaxima.from_sampling(nb_obs, coordinates, margin_model)
         return cls(temporal_observations=temporal_obs, coordinates=coordinates, margin_model=margin_model)
 
diff --git a/test/test_extreme_estimator/test_estimator/test_full_estimators.py b/test/test_extreme_estimator/test_estimator/test_full_estimators.py
index db48460b..8c520aca 100644
--- a/test/test_extreme_estimator/test_estimator/test_full_estimators.py
+++ b/test/test_extreme_estimator/test_estimator/test_full_estimators.py
@@ -3,11 +3,11 @@ from itertools import product
 
 from extreme_estimator.estimator.full_estimator import SmoothMarginalsThenUnitaryMsp, \
     FullEstimatorInASingleStepWithSmoothMargin
-from extreme_estimator.extreme_models.max_stable_model.utils import load_max_stable_models
 from spatio_temporal_dataset.dataset.simulation_dataset import FullSimulatedDataset
 from spatio_temporal_dataset.coordinates.spatial_coordinates.generated_spatial_coordinates import CircleCoordinates
 from test.test_extreme_estimator.test_estimator.test_margin_estimators import TestSmoothMarginEstimator
 from test.test_extreme_estimator.test_estimator.test_max_stable_estimators import TestMaxStableEstimators
+from test.test_utils import load_test_max_stable_models, load_smooth_margin_models
 
 
 class TestFullEstimators(unittest.TestCase):
@@ -17,9 +17,8 @@ class TestFullEstimators(unittest.TestCase):
     def setUp(self):
         super().setUp()
         self.spatial_coordinates = CircleCoordinates.from_nb_points(nb_points=5, max_radius=1)
-        self.max_stable_models = load_max_stable_models()
-        self.smooth_margin_models = TestSmoothMarginEstimator.load_smooth_margin_models(
-            coordinates=self.spatial_coordinates)
+        self.max_stable_models = load_test_max_stable_models()
+        self.smooth_margin_models = load_smooth_margin_models(coordinates=self.spatial_coordinates)
 
     def test_full_estimators(self):
         for margin_model, max_stable_model in product(self.smooth_margin_models, self.max_stable_models):
diff --git a/test/test_extreme_estimator/test_estimator/test_margin_estimators.py b/test/test_extreme_estimator/test_estimator/test_margin_estimators.py
index 3def42b0..2e8969bb 100644
--- a/test/test_extreme_estimator/test_estimator/test_margin_estimators.py
+++ b/test/test_extreme_estimator/test_estimator/test_margin_estimators.py
@@ -7,23 +7,17 @@ from extreme_estimator.estimator.margin_estimator import SmoothMarginEstimator
 from extreme_estimator.return_level_plot.spatial_2D_plot import Spatial2DPlot
 from spatio_temporal_dataset.dataset.simulation_dataset import MarginDataset
 from spatio_temporal_dataset.coordinates.spatial_coordinates.generated_spatial_coordinates import CircleCoordinates
+from test.test_utils import load_smooth_margin_models
 
 
 class TestSmoothMarginEstimator(unittest.TestCase):
     DISPLAY = False
-    MARGIN_TYPES = [ConstantMarginModel, LinearShapeAxis0MarginModel,
-                    LinearShapeAxis0and1MarginModel, LinearAllParametersAxis0MarginModel,
-                    LinearAllParametersAxis0And1MarginModel][:]
     SMOOTH_MARGIN_ESTIMATORS = [SmoothMarginEstimator]
 
     def setUp(self):
         super().setUp()
         self.spatial_coordinates = CircleCoordinates.from_nb_points(nb_points=5, max_radius=1)
-        self.smooth_margin_models = self.load_smooth_margin_models(coordinates=self.spatial_coordinates)
-
-    @classmethod
-    def load_smooth_margin_models(cls, coordinates):
-        return [margin_class(coordinates=coordinates) for margin_class in cls.MARGIN_TYPES]
+        self.smooth_margin_models = load_smooth_margin_models(coordinates=self.spatial_coordinates)
 
     def test_dependency_estimators(self):
         for margin_model in self.smooth_margin_models:
diff --git a/test/test_extreme_estimator/test_estimator/test_max_stable_estimators.py b/test/test_extreme_estimator/test_estimator/test_max_stable_estimators.py
index 01ae2c3a..a4f1b8c7 100644
--- a/test/test_extreme_estimator/test_estimator/test_max_stable_estimators.py
+++ b/test/test_extreme_estimator/test_estimator/test_max_stable_estimators.py
@@ -3,9 +3,9 @@ import unittest
 from extreme_estimator.extreme_models.max_stable_model.abstract_max_stable_model import \
     AbstractMaxStableModelWithCovarianceFunction, CovarianceFunction
 from extreme_estimator.estimator.max_stable_estimator import MaxStableEstimator
-from extreme_estimator.extreme_models.max_stable_model.utils import load_max_stable_models
 from spatio_temporal_dataset.dataset.simulation_dataset import MaxStableDataset
 from spatio_temporal_dataset.coordinates.spatial_coordinates.generated_spatial_coordinates import CircleCoordinates
+from test.test_utils import load_test_max_stable_models
 
 
 class TestMaxStableEstimators(unittest.TestCase):
@@ -16,7 +16,7 @@ class TestMaxStableEstimators(unittest.TestCase):
     def setUp(self):
         super().setUp()
         self.spatial_coord = CircleCoordinates.from_nb_points(nb_points=5, max_radius=1)
-        self.max_stable_models = load_max_stable_models()
+        self.max_stable_models = load_test_max_stable_models()
 
     def test_max_stable_estimators(self):
         for max_stable_model in self.max_stable_models:
diff --git a/test/test_spatio_temporal_dataset/test_dataset.py b/test/test_spatio_temporal_dataset/test_dataset.py
index f19f40b6..c182eb0f 100644
--- a/test/test_spatio_temporal_dataset/test_dataset.py
+++ b/test/test_spatio_temporal_dataset/test_dataset.py
@@ -2,9 +2,8 @@ from rpy2.rinterface import RRuntimeError
 import unittest
 from itertools import product
 
-from extreme_estimator.extreme_models.max_stable_model.utils import load_max_stable_models
-from spatio_temporal_dataset.coordinates.utils import load_coordinates
 from spatio_temporal_dataset.dataset.simulation_dataset import MaxStableDataset
+from test.test_utils import load_test_max_stable_models, load_test_coordinates
 
 
 class TestDataset(unittest.TestCase):
@@ -12,8 +11,8 @@ class TestDataset(unittest.TestCase):
     nb_points = 10
 
     def test_max_stable_dataset_R1_and_R2(self):
-        max_stable_models = load_max_stable_models()[:]
-        coordinatess = load_coordinates(self.nb_points)[:-1]
+        max_stable_models = load_test_max_stable_models()[:]
+        coordinatess = load_test_coordinates(self.nb_points)[:-1]
         for coordinates, max_stable_model in product(coordinatess, max_stable_models):
             MaxStableDataset.from_sampling(nb_obs=self.nb_obs,
                                            max_stable_model=max_stable_model,
@@ -21,13 +20,13 @@ class TestDataset(unittest.TestCase):
         self.assertTrue(True)
 
     def test_max_stable_dataset_crash_R3(self):
-        # test to warn me when spatialExtremes handles R3
+        """Test to warn me when spatialExtremes handles R3"""
         with self.assertRaises(RRuntimeError):
-            smith_process = load_max_stable_models()[0]
-            coordinates_R3 = load_coordinates(self.nb_points)[-1]
+            smith_process = load_test_max_stable_models()[0]
+            coordinates = load_test_coordinates(self.nb_points)[-1]
             MaxStableDataset.from_sampling(nb_obs=self.nb_obs,
                                            max_stable_model=smith_process,
-                                           coordinates=coordinates_R3)
+                                           coordinates=coordinates)
 
 
 if __name__ == '__main__':
diff --git a/test/test_utils.py b/test/test_utils.py
new file mode 100644
index 00000000..2170af02
--- /dev/null
+++ b/test/test_utils.py
@@ -0,0 +1,40 @@
+from extreme_estimator.extreme_models.margin_model.smooth_margin_model import LinearAllParametersAxis0And1MarginModel, \
+    ConstantMarginModel
+from extreme_estimator.extreme_models.max_stable_model.abstract_max_stable_model import \
+    AbstractMaxStableModelWithCovarianceFunction, CovarianceFunction
+from extreme_estimator.extreme_models.max_stable_model.max_stable_models import Smith, BrownResnick, Schlather, \
+    Geometric, ExtremalT, ISchlather
+from spatio_temporal_dataset.coordinates.spatial_coordinates.alps_station_3D_coordinates import \
+    AlpsStation3DCoordinatesWithAnisotropy
+from spatio_temporal_dataset.coordinates.spatial_coordinates.generated_spatial_coordinates import CircleCoordinates
+from spatio_temporal_dataset.coordinates.unidimensional_coordinates.coordinates_1D import UniformCoordinates
+
+"""
+Common objects to load for the test.
+Sometimes it doesn't cover all the class (e.g margin_model, coordinates...)
+In this case, unit test (at least on the constructor) must be ensured in the test relative to the class 
+"""
+
+TEST_MAX_STABLE_MODEL = [Smith, BrownResnick, Schlather, Geometric, ExtremalT, ISchlather]
+TEST_COORDINATES = [UniformCoordinates, CircleCoordinates, AlpsStation3DCoordinatesWithAnisotropy]
+MARGIN_TYPES = [ConstantMarginModel, LinearAllParametersAxis0And1MarginModel][:]
+
+
+def load_smooth_margin_models(coordinates):
+    return [margin_class(coordinates=coordinates) for margin_class in MARGIN_TYPES]
+
+
+def load_test_max_stable_models():
+    # Load all max stable model
+    max_stable_models = []
+    for max_stable_class in TEST_MAX_STABLE_MODEL:
+        if issubclass(max_stable_class, AbstractMaxStableModelWithCovarianceFunction):
+            max_stable_models.extend([max_stable_class(covariance_function=covariance_function)
+                                      for covariance_function in CovarianceFunction])
+        else:
+            max_stable_models.append(max_stable_class())
+    return max_stable_models
+
+
+def load_test_coordinates(nb_points):
+    return [coordinate_class.from_nb_points(nb_points=nb_points) for coordinate_class in TEST_COORDINATES]
-- 
GitLab