From c0d698e3456c87cc458316df0b8fb7ea5ab39ddb Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Thu, 19 Mar 2020 20:44:50 +0100
Subject: [PATCH] [refactor] improve test_margin_function.py. refactor
 abstract_coef

---
 .../function/param_function/abstract_coef.py  |  2 +-
 .../test_function/test_margin_function.py     | 44 +++++++++++++------
 2 files changed, 31 insertions(+), 15 deletions(-)

diff --git a/extreme_fit/function/param_function/abstract_coef.py b/extreme_fit/function/param_function/abstract_coef.py
index 3ba91ecc..88a10d74 100644
--- a/extreme_fit/function/param_function/abstract_coef.py
+++ b/extreme_fit/function/param_function/abstract_coef.py
@@ -5,7 +5,7 @@ from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoo
 
 class AbstractCoef(object):
 
-    def __init__(self, gev_param_name: str, default_value: float = 0.0, idx_to_coef=None):
+    def __init__(self, gev_param_name: str = '', default_value: float = 0.0, idx_to_coef=None):
         self.gev_param_name = gev_param_name
         self.default_value = default_value
         self.idx_to_coef = idx_to_coef
diff --git a/test/test_extreme_fit/test_function/test_margin_function.py b/test/test_extreme_fit/test_function/test_margin_function.py
index 622d3a41..7f474fb9 100644
--- a/test/test_extreme_fit/test_function/test_margin_function.py
+++ b/test/test_extreme_fit/test_function/test_margin_function.py
@@ -6,8 +6,12 @@ from extreme_fit.model.margin_model.linear_margin_model.linear_margin_model impo
 from extreme_fit.function.margin_function.abstract_margin_function import \
     AbstractMarginFunction
 from extreme_fit.function.margin_function.linear_margin_function import LinearMarginFunction
+from extreme_fit.model.utils import set_seed_for_test
+from spatio_temporal_dataset.coordinates.spatial_coordinates.coordinates_1D import UniformSpatialCoordinates, \
+    LinSpaceSpatialCoordinates
 from spatio_temporal_dataset.coordinates.spatial_coordinates.coordinates_2D import LinSpaceSpatial2DCoordinates
-from test.test_utils import load_test_spatiotemporal_coordinates
+from test.test_utils import load_test_spatiotemporal_coordinates, load_test_temporal_coordinates, \
+    load_test_spatial_coordinates
 
 
 class MarginFunction(unittest.TestCase):
@@ -15,23 +19,35 @@ class MarginFunction(unittest.TestCase):
     margin_function_class = LinearMarginFunction
     margin_model_class = LinearAllParametersAllDimsMarginModel
 
-    def test_grid_2D_orientation(self):
-        # Assert that the grid correspond to what we expect in a simple case
-        margin_model = self.margin_model_class(LinSpaceSpatial2DCoordinates.from_nb_points(nb_points=self.nb_points))
-        AbstractMarginFunction.VISUALIZATION_RESOLUTION = 2
-        grid = margin_model.margin_function_sample.grid_2D()['loc']
-        true_grid = np.array([[0.98, 1.0], [1.0, 1.02]])
-        self.assertTrue((grid == true_grid).all(), msg="\nexpected:\n{}, \nfound:\n{}".format(true_grid, grid))
-
-    def test_coef_dict(self):
+    def test_coef_dict_spatio_temporal_coordinates(self):
+        set_seed_for_test(seed=41)
         coordinates = load_test_spatiotemporal_coordinates(self.nb_points, self.nb_points)[0]
         margin_model = self.margin_model_class(coordinates)
         # Test to check loading of margin function from coef dict
-        coef_dict = {'locCoeff1': 0, 'locCoeff2': 1, 'scaleCoeff1': 0,
-                     'scaleCoeff2': 1, 'shapeCoeff1': 0,
-                     'shapeCoeff2': 1,
+        coef_dict = {'locCoeff1': 0, 'locCoeff2': 2, 'scaleCoeff1': 0,
+                     'scaleCoeff2': 2, 'shapeCoeff1': 0,
+                     'shapeCoeff2': 2,
                      'tempCoeffLoc1': 1, 'tempCoeffScale1': 1,
                      'tempCoeffShape1': 1}
-        self.margin_function_class.from_coef_dict(coordinates,
+        margin_function = self.margin_function_class.from_coef_dict(coordinates,
                                                   margin_model.margin_function_sample.gev_param_name_to_dims,
                                                   coef_dict)
+        gev_param = margin_function.get_gev_params(coordinate=np.array([0.5, 1.0]), is_transformed=False)
+        self.assertEqual({'loc': 2, 'scale': 2, 'shape': 2}, gev_param.to_dict())
+
+    def test_coef_dict_spatial_coordinates(self):
+        coordinates = LinSpaceSpatialCoordinates.from_nb_points(nb_points=self.nb_points+1, start=1, end=3)
+        margin_model = self.margin_model_class(coordinates)
+        # Test to check loading of margin function from coef dict
+        coef_dict = {
+            'locCoeff1': 2, 'locCoeff2': 1, 'scaleCoeff1': 0,
+            'scaleCoeff2': 1, 'shapeCoeff1': 0,
+            'shapeCoeff2': 1}
+        margin_function = self.margin_function_class.from_coef_dict(coordinates,
+                                                                    margin_model.margin_function_sample.gev_param_name_to_dims,
+                                                                    coef_dict)
+        gev_param = margin_function.get_gev_params(coordinate=np.array([1]), is_transformed=False)
+        self.assertEqual({'loc': 3, 'scale': 1, 'shape': 1}, gev_param.to_dict())
+
+if __name__ == '__main__':
+    unittest.main()
\ No newline at end of file
-- 
GitLab