Commit ee91dd99 authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[contrasting] add spatio temporal model with cross terms

parent 294a5b1c
No related merge requests found
Showing with 97 additions and 37 deletions
+97 -37
...@@ -3,46 +3,95 @@ from extreme_fit.model.margin_model.polynomial_margin_model.spatio_temporal_poly ...@@ -3,46 +3,95 @@ from extreme_fit.model.margin_model.polynomial_margin_model.spatio_temporal_poly
AbstractSpatioTemporalPolynomialModel AbstractSpatioTemporalPolynomialModel
class StationaryAltitudinal(AbstractSpatioTemporalPolynomialModel): class AbstractAltitudinalModel(AbstractSpatioTemporalPolynomialModel):
def load_margin_function(self, param_name_to_dims=None): def load_margin_function(self, param_name_to_dims=None):
return super().load_margin_function({ return super().load_margin_function(self.param_name_to_list_dim_and_degree)
@property
def param_name_to_list_dim_and_degree(self):
raise NotImplementedError
class StationaryAltitudinal(AbstractAltitudinalModel):
@property
def param_name_to_list_dim_and_degree(self):
return {
GevParams.LOC: [(self.coordinates.idx_x_coordinates, 1)], GevParams.LOC: [(self.coordinates.idx_x_coordinates, 1)],
GevParams.SCALE: [(self.coordinates.idx_x_coordinates, 1)] GevParams.SCALE: [(self.coordinates.idx_x_coordinates, 1)]
}) }
class NonStationaryAltitudinalLocationLinear(AbstractSpatioTemporalPolynomialModel): class NonStationaryAltitudinalLocationLinear(AbstractAltitudinalModel):
def load_margin_function(self, param_name_to_dims=None): @property
return super().load_margin_function({ def param_name_to_list_dim_and_degree(self):
return {
GevParams.LOC: [(self.coordinates.idx_x_coordinates, 1), (self.coordinates.idx_temporal_coordinates, 1)], GevParams.LOC: [(self.coordinates.idx_x_coordinates, 1), (self.coordinates.idx_temporal_coordinates, 1)],
GevParams.SCALE: [(self.coordinates.idx_x_coordinates, 1)] GevParams.SCALE: [(self.coordinates.idx_x_coordinates, 1)]
}) }
class NonStationaryAltitudinalLocationQuadratic(AbstractSpatioTemporalPolynomialModel): class NonStationaryAltitudinalLocationQuadratic(AbstractAltitudinalModel):
def load_margin_function(self, param_name_to_dims=None): @property
return super().load_margin_function({ def param_name_to_list_dim_and_degree(self):
return {
GevParams.LOC: [(self.coordinates.idx_x_coordinates, 1), (self.coordinates.idx_temporal_coordinates, 2)], GevParams.LOC: [(self.coordinates.idx_x_coordinates, 1), (self.coordinates.idx_temporal_coordinates, 2)],
GevParams.SCALE: [(self.coordinates.idx_x_coordinates, 1)] GevParams.SCALE: [(self.coordinates.idx_x_coordinates, 1)]
}) }
class NonStationaryAltitudinalLocationLinearScaleLinear(AbstractSpatioTemporalPolynomialModel): class NonStationaryAltitudinalLocationLinearScaleLinear(AbstractAltitudinalModel):
def load_margin_function(self, param_name_to_dims=None): @property
return super().load_margin_function({ def param_name_to_list_dim_and_degree(self):
return {
GevParams.LOC: [(self.coordinates.idx_x_coordinates, 1), (self.coordinates.idx_temporal_coordinates, 1)], GevParams.LOC: [(self.coordinates.idx_x_coordinates, 1), (self.coordinates.idx_temporal_coordinates, 1)],
GevParams.SCALE: [(self.coordinates.idx_x_coordinates, 1), (self.coordinates.idx_temporal_coordinates, 1)], GevParams.SCALE: [(self.coordinates.idx_x_coordinates, 1), (self.coordinates.idx_temporal_coordinates, 1)],
}) }
class NonStationaryAltitudinalLocationQuadraticScaleLinear(AbstractSpatioTemporalPolynomialModel): class NonStationaryAltitudinalLocationQuadraticScaleLinear(AbstractAltitudinalModel):
def load_margin_function(self, param_name_to_dims=None): @property
return super().load_margin_function({ def param_name_to_list_dim_and_degree(self):
return {
GevParams.LOC: [(self.coordinates.idx_x_coordinates, 1), (self.coordinates.idx_temporal_coordinates, 2)], GevParams.LOC: [(self.coordinates.idx_x_coordinates, 1), (self.coordinates.idx_temporal_coordinates, 2)],
GevParams.SCALE: [(self.coordinates.idx_x_coordinates, 1), (self.coordinates.idx_temporal_coordinates, 1)], GevParams.SCALE: [(self.coordinates.idx_x_coordinates, 1), (self.coordinates.idx_temporal_coordinates, 1)],
}) }
# Add cross terms
class AbstractAddCrossTermForLocation(AbstractAltitudinalModel):
def load_margin_function(self, param_name_to_dims=None):
d = self.param_name_to_list_dim_and_degree
d[GevParams.LOC] += ((self.coordinates.idx_x_coordinates, self.coordinates.idx_temporal_coordinates), 1)
return super().load_margin_function(d)
class NonStationaryCrossTermForLocation(AbstractAddCrossTermForLocation, StationaryAltitudinal):
pass
class NonStationaryAltitudinalLocationLinearCrossTermForLocation(AbstractAddCrossTermForLocation,
NonStationaryAltitudinalLocationLinear):
pass
class NonStationaryAltitudinalLocationQuadraticCrossTermForLocation(AbstractAddCrossTermForLocation,
NonStationaryAltitudinalLocationQuadratic):
pass
class NonStationaryAltitudinalLocationLinearScaleLinearCrossTermForLocation(AbstractAddCrossTermForLocation,
NonStationaryAltitudinalLocationLinearScaleLinear):
pass
class NonStationaryAltitudinalLocationQuadraticScaleLinearCrossTermForLocation(AbstractAddCrossTermForLocation,
NonStationaryAltitudinalLocationQuadraticScaleLinear):
pass
...@@ -35,6 +35,7 @@ class PolynomialMarginModel(AbstractTemporalLinearMarginModel): ...@@ -35,6 +35,7 @@ class PolynomialMarginModel(AbstractTemporalLinearMarginModel):
# i.e. 1) spatial individual terms 2) combined terms 3) temporal individual terms # i.e. 1) spatial individual terms 2) combined terms 3) temporal individual terms
for param_name, list_dim_and_degree in param_name_to_list_dim_and_degree.items(): for param_name, list_dim_and_degree in param_name_to_list_dim_and_degree.items():
dims = [d for d, m in list_dim_and_degree] dims = [d for d, m in list_dim_and_degree]
assert all([isinstance(d, int) or isinstance(d, tuple) for d in dims])
if self.coordinates.has_spatial_coordinates and self.coordinates.idx_x_coordinates in dims: if self.coordinates.has_spatial_coordinates and self.coordinates.idx_x_coordinates in dims:
assert dims.index(self.coordinates.idx_x_coordinates) == 0 assert dims.index(self.coordinates.idx_x_coordinates) == 0
if self.coordinates.has_temporal_coordinates and self.coordinates.idx_temporal_coordinates in dims: if self.coordinates.has_temporal_coordinates and self.coordinates.idx_temporal_coordinates in dims:
......
from extreme_fit.model.margin_model.polynomial_margin_model.altitudinal_models import StationaryAltitudinal, \ from extreme_fit.model.margin_model.polynomial_margin_model.altitudinal_models import StationaryAltitudinal, \
NonStationaryAltitudinalLocationLinear, NonStationaryAltitudinalLocationQuadratic, \ NonStationaryAltitudinalLocationLinear, NonStationaryAltitudinalLocationQuadratic, \
NonStationaryAltitudinalLocationLinearScaleLinear, NonStationaryAltitudinalLocationQuadraticScaleLinear NonStationaryAltitudinalLocationLinearScaleLinear, NonStationaryAltitudinalLocationQuadraticScaleLinear, \
NonStationaryCrossTermForLocation, NonStationaryAltitudinalLocationLinearCrossTermForLocation, \
NonStationaryAltitudinalLocationQuadraticCrossTermForLocation, \
NonStationaryAltitudinalLocationLinearScaleLinearCrossTermForLocation, \
NonStationaryAltitudinalLocationQuadraticScaleLinearCrossTermForLocation
from extreme_fit.model.margin_model.polynomial_margin_model.spatio_temporal_polynomial_model import \ from extreme_fit.model.margin_model.polynomial_margin_model.spatio_temporal_polynomial_model import \
NonStationaryLocationSpatioTemporalLinearityModel1, NonStationaryLocationSpatioTemporalLinearityModel2, \ NonStationaryLocationSpatioTemporalLinearityModel1, NonStationaryLocationSpatioTemporalLinearityModel2, \
NonStationaryLocationSpatioTemporalLinearityModel3, NonStationaryLocationSpatioTemporalLinearityModel4, \ NonStationaryLocationSpatioTemporalLinearityModel3, NonStationaryLocationSpatioTemporalLinearityModel4, \
...@@ -9,22 +13,27 @@ from extreme_fit.model.margin_model.polynomial_margin_model.spatio_temporal_poly ...@@ -9,22 +13,27 @@ from extreme_fit.model.margin_model.polynomial_margin_model.spatio_temporal_poly
NonStationaryLocationSpatioTemporalLinearityModelAssertError3, NonStationaryLocationSpatioTemporalLinearityModel6 NonStationaryLocationSpatioTemporalLinearityModelAssertError3, NonStationaryLocationSpatioTemporalLinearityModel6
ALTITUDINAL_MODELS = [ ALTITUDINAL_MODELS = [
StationaryAltitudinal, StationaryAltitudinal,
NonStationaryAltitudinalLocationLinear, NonStationaryAltitudinalLocationLinear,
NonStationaryAltitudinalLocationQuadratic, NonStationaryAltitudinalLocationQuadratic,
NonStationaryAltitudinalLocationLinearScaleLinear, NonStationaryAltitudinalLocationLinearScaleLinear,
NonStationaryAltitudinalLocationQuadraticScaleLinear NonStationaryAltitudinalLocationQuadraticScaleLinear,
][:]
NonStationaryCrossTermForLocation,
NonStationaryAltitudinalLocationLinearCrossTermForLocation,
NonStationaryAltitudinalLocationQuadraticCrossTermForLocation,
NonStationaryAltitudinalLocationLinearScaleLinearCrossTermForLocation,
NonStationaryAltitudinalLocationQuadraticScaleLinearCrossTermForLocation,
][:]
MODELS_THAT_SHOULD_RAISE_AN_ASSERTION_ERROR = [NonStationaryLocationSpatioTemporalLinearityModelAssertError1, MODELS_THAT_SHOULD_RAISE_AN_ASSERTION_ERROR = [NonStationaryLocationSpatioTemporalLinearityModelAssertError1,
NonStationaryLocationSpatioTemporalLinearityModelAssertError2, NonStationaryLocationSpatioTemporalLinearityModelAssertError2,
NonStationaryLocationSpatioTemporalLinearityModelAssertError3] NonStationaryLocationSpatioTemporalLinearityModelAssertError3]
VARIOUS_SPATIO_TEMPORAL_MODELS = [NonStationaryLocationSpatioTemporalLinearityModel1, VARIOUS_SPATIO_TEMPORAL_MODELS = [NonStationaryLocationSpatioTemporalLinearityModel1,
NonStationaryLocationSpatioTemporalLinearityModel2, NonStationaryLocationSpatioTemporalLinearityModel2,
NonStationaryLocationSpatioTemporalLinearityModel3, NonStationaryLocationSpatioTemporalLinearityModel3,
NonStationaryLocationSpatioTemporalLinearityModel4, NonStationaryLocationSpatioTemporalLinearityModel4,
NonStationaryLocationSpatioTemporalLinearityModel5, NonStationaryLocationSpatioTemporalLinearityModel5,
NonStationaryLocationSpatioTemporalLinearityModel6, NonStationaryLocationSpatioTemporalLinearityModel6,
] ]
\ No newline at end of file
import unittest import unittest
from random import sample
from extreme_data.meteo_france_data.scm_models_data.safran.safran import SafranSnowfall1Day from extreme_data.meteo_france_data.scm_models_data.safran.safran import SafranSnowfall1Day
from extreme_fit.model.margin_model.polynomial_margin_model.utils import ALTITUDINAL_MODELS, \ from extreme_fit.model.margin_model.polynomial_margin_model.utils import ALTITUDINAL_MODELS, \
...@@ -36,16 +37,16 @@ class TestGevTemporalQuadraticExtremesMle(unittest.TestCase): ...@@ -36,16 +37,16 @@ class TestGevTemporalQuadraticExtremesMle(unittest.TestCase):
self.assertAlmostEqual(estimator.result_from_model_fit.bic, estimator.bic(split=estimator.train_split)) self.assertAlmostEqual(estimator.result_from_model_fit.bic, estimator.bic(split=estimator.train_split))
def test_assert_error(self): def test_assert_error(self):
for model_class in MODELS_THAT_SHOULD_RAISE_AN_ASSERTION_ERROR: for model_class in sample(MODELS_THAT_SHOULD_RAISE_AN_ASSERTION_ERROR, 1):
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
self.common_test(model_class) self.common_test(model_class)
def test_location_spatio_temporal_models(self): def test_location_spatio_temporal_models(self):
for model_class in VARIOUS_SPATIO_TEMPORAL_MODELS[:]: for model_class in sample(VARIOUS_SPATIO_TEMPORAL_MODELS, 3):
self.common_test(model_class) self.common_test(model_class)
def test_altitudinal_models(self): def test_altitudinal_models(self):
for model_class in ALTITUDINAL_MODELS: for model_class in sample(ALTITUDINAL_MODELS, 3):
self.common_test(model_class) self.common_test(model_class)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment