Commit c0d698e3 authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[refactor] improve test_margin_function.py. refactor abstract_coef

parent 36ea1de2
No related merge requests found
Showing with 31 additions and 15 deletions
+31 -15
...@@ -5,7 +5,7 @@ from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoo ...@@ -5,7 +5,7 @@ from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoo
class AbstractCoef(object): class AbstractCoef(object):
def __init__(self, gev_param_name: str, default_value: float = 0.0, idx_to_coef=None): def __init__(self, gev_param_name: str = '', default_value: float = 0.0, idx_to_coef=None):
self.gev_param_name = gev_param_name self.gev_param_name = gev_param_name
self.default_value = default_value self.default_value = default_value
self.idx_to_coef = idx_to_coef self.idx_to_coef = idx_to_coef
......
...@@ -6,8 +6,12 @@ from extreme_fit.model.margin_model.linear_margin_model.linear_margin_model impo ...@@ -6,8 +6,12 @@ from extreme_fit.model.margin_model.linear_margin_model.linear_margin_model impo
from extreme_fit.function.margin_function.abstract_margin_function import \ from extreme_fit.function.margin_function.abstract_margin_function import \
AbstractMarginFunction AbstractMarginFunction
from extreme_fit.function.margin_function.linear_margin_function import LinearMarginFunction from extreme_fit.function.margin_function.linear_margin_function import LinearMarginFunction
from extreme_fit.model.utils import set_seed_for_test
from spatio_temporal_dataset.coordinates.spatial_coordinates.coordinates_1D import UniformSpatialCoordinates, \
LinSpaceSpatialCoordinates
from spatio_temporal_dataset.coordinates.spatial_coordinates.coordinates_2D import LinSpaceSpatial2DCoordinates from spatio_temporal_dataset.coordinates.spatial_coordinates.coordinates_2D import LinSpaceSpatial2DCoordinates
from test.test_utils import load_test_spatiotemporal_coordinates from test.test_utils import load_test_spatiotemporal_coordinates, load_test_temporal_coordinates, \
load_test_spatial_coordinates
class MarginFunction(unittest.TestCase): class MarginFunction(unittest.TestCase):
...@@ -15,23 +19,35 @@ class MarginFunction(unittest.TestCase): ...@@ -15,23 +19,35 @@ class MarginFunction(unittest.TestCase):
margin_function_class = LinearMarginFunction margin_function_class = LinearMarginFunction
margin_model_class = LinearAllParametersAllDimsMarginModel margin_model_class = LinearAllParametersAllDimsMarginModel
def test_grid_2D_orientation(self): def test_coef_dict_spatio_temporal_coordinates(self):
# Assert that the grid correspond to what we expect in a simple case set_seed_for_test(seed=41)
margin_model = self.margin_model_class(LinSpaceSpatial2DCoordinates.from_nb_points(nb_points=self.nb_points))
AbstractMarginFunction.VISUALIZATION_RESOLUTION = 2
grid = margin_model.margin_function_sample.grid_2D()['loc']
true_grid = np.array([[0.98, 1.0], [1.0, 1.02]])
self.assertTrue((grid == true_grid).all(), msg="\nexpected:\n{}, \nfound:\n{}".format(true_grid, grid))
def test_coef_dict(self):
coordinates = load_test_spatiotemporal_coordinates(self.nb_points, self.nb_points)[0] coordinates = load_test_spatiotemporal_coordinates(self.nb_points, self.nb_points)[0]
margin_model = self.margin_model_class(coordinates) margin_model = self.margin_model_class(coordinates)
# Test to check loading of margin function from coef dict # Test to check loading of margin function from coef dict
coef_dict = {'locCoeff1': 0, 'locCoeff2': 1, 'scaleCoeff1': 0, coef_dict = {'locCoeff1': 0, 'locCoeff2': 2, 'scaleCoeff1': 0,
'scaleCoeff2': 1, 'shapeCoeff1': 0, 'scaleCoeff2': 2, 'shapeCoeff1': 0,
'shapeCoeff2': 1, 'shapeCoeff2': 2,
'tempCoeffLoc1': 1, 'tempCoeffScale1': 1, 'tempCoeffLoc1': 1, 'tempCoeffScale1': 1,
'tempCoeffShape1': 1} 'tempCoeffShape1': 1}
self.margin_function_class.from_coef_dict(coordinates, margin_function = self.margin_function_class.from_coef_dict(coordinates,
margin_model.margin_function_sample.gev_param_name_to_dims, margin_model.margin_function_sample.gev_param_name_to_dims,
coef_dict) coef_dict)
gev_param = margin_function.get_gev_params(coordinate=np.array([0.5, 1.0]), is_transformed=False)
self.assertEqual({'loc': 2, 'scale': 2, 'shape': 2}, gev_param.to_dict())
def test_coef_dict_spatial_coordinates(self):
coordinates = LinSpaceSpatialCoordinates.from_nb_points(nb_points=self.nb_points+1, start=1, end=3)
margin_model = self.margin_model_class(coordinates)
# Test to check loading of margin function from coef dict
coef_dict = {
'locCoeff1': 2, 'locCoeff2': 1, 'scaleCoeff1': 0,
'scaleCoeff2': 1, 'shapeCoeff1': 0,
'shapeCoeff2': 1}
margin_function = self.margin_function_class.from_coef_dict(coordinates,
margin_model.margin_function_sample.gev_param_name_to_dims,
coef_dict)
gev_param = margin_function.get_gev_params(coordinate=np.array([1]), is_transformed=False)
self.assertEqual({'loc': 3, 'scale': 1, 'shape': 1}, gev_param.to_dict())
if __name__ == '__main__':
unittest.main()
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment