Commit ee91dd99 authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[contrasting] add spatio temporal model with cross terms

parent 294a5b1c
No related merge requests found
Showing with 97 additions and 37 deletions
+97 -37
......@@ -3,46 +3,95 @@ from extreme_fit.model.margin_model.polynomial_margin_model.spatio_temporal_poly
AbstractSpatioTemporalPolynomialModel
class StationaryAltitudinal(AbstractSpatioTemporalPolynomialModel):
class AbstractAltitudinalModel(AbstractSpatioTemporalPolynomialModel):
def load_margin_function(self, param_name_to_dims=None):
return super().load_margin_function({
return super().load_margin_function(self.param_name_to_list_dim_and_degree)
@property
def param_name_to_list_dim_and_degree(self):
raise NotImplementedError
class StationaryAltitudinal(AbstractAltitudinalModel):
@property
def param_name_to_list_dim_and_degree(self):
return {
GevParams.LOC: [(self.coordinates.idx_x_coordinates, 1)],
GevParams.SCALE: [(self.coordinates.idx_x_coordinates, 1)]
})
}
class NonStationaryAltitudinalLocationLinear(AbstractSpatioTemporalPolynomialModel):
class NonStationaryAltitudinalLocationLinear(AbstractAltitudinalModel):
def load_margin_function(self, param_name_to_dims=None):
return super().load_margin_function({
@property
def param_name_to_list_dim_and_degree(self):
return {
GevParams.LOC: [(self.coordinates.idx_x_coordinates, 1), (self.coordinates.idx_temporal_coordinates, 1)],
GevParams.SCALE: [(self.coordinates.idx_x_coordinates, 1)]
})
}
class NonStationaryAltitudinalLocationQuadratic(AbstractSpatioTemporalPolynomialModel):
class NonStationaryAltitudinalLocationQuadratic(AbstractAltitudinalModel):
def load_margin_function(self, param_name_to_dims=None):
return super().load_margin_function({
@property
def param_name_to_list_dim_and_degree(self):
return {
GevParams.LOC: [(self.coordinates.idx_x_coordinates, 1), (self.coordinates.idx_temporal_coordinates, 2)],
GevParams.SCALE: [(self.coordinates.idx_x_coordinates, 1)]
})
}
class NonStationaryAltitudinalLocationLinearScaleLinear(AbstractSpatioTemporalPolynomialModel):
class NonStationaryAltitudinalLocationLinearScaleLinear(AbstractAltitudinalModel):
def load_margin_function(self, param_name_to_dims=None):
return super().load_margin_function({
@property
def param_name_to_list_dim_and_degree(self):
return {
GevParams.LOC: [(self.coordinates.idx_x_coordinates, 1), (self.coordinates.idx_temporal_coordinates, 1)],
GevParams.SCALE: [(self.coordinates.idx_x_coordinates, 1), (self.coordinates.idx_temporal_coordinates, 1)],
})
}
class NonStationaryAltitudinalLocationQuadraticScaleLinear(AbstractSpatioTemporalPolynomialModel):
class NonStationaryAltitudinalLocationQuadraticScaleLinear(AbstractAltitudinalModel):
def load_margin_function(self, param_name_to_dims=None):
return super().load_margin_function({
@property
def param_name_to_list_dim_and_degree(self):
return {
GevParams.LOC: [(self.coordinates.idx_x_coordinates, 1), (self.coordinates.idx_temporal_coordinates, 2)],
GevParams.SCALE: [(self.coordinates.idx_x_coordinates, 1), (self.coordinates.idx_temporal_coordinates, 1)],
})
}
# Add cross terms
class AbstractAddCrossTermForLocation(AbstractAltitudinalModel):
def load_margin_function(self, param_name_to_dims=None):
d = self.param_name_to_list_dim_and_degree
d[GevParams.LOC] += ((self.coordinates.idx_x_coordinates, self.coordinates.idx_temporal_coordinates), 1)
return super().load_margin_function(d)
class NonStationaryCrossTermForLocation(AbstractAddCrossTermForLocation, StationaryAltitudinal):
pass
class NonStationaryAltitudinalLocationLinearCrossTermForLocation(AbstractAddCrossTermForLocation,
NonStationaryAltitudinalLocationLinear):
pass
class NonStationaryAltitudinalLocationQuadraticCrossTermForLocation(AbstractAddCrossTermForLocation,
NonStationaryAltitudinalLocationQuadratic):
pass
class NonStationaryAltitudinalLocationLinearScaleLinearCrossTermForLocation(AbstractAddCrossTermForLocation,
NonStationaryAltitudinalLocationLinearScaleLinear):
pass
class NonStationaryAltitudinalLocationQuadraticScaleLinearCrossTermForLocation(AbstractAddCrossTermForLocation,
NonStationaryAltitudinalLocationQuadraticScaleLinear):
pass
......@@ -35,6 +35,7 @@ class PolynomialMarginModel(AbstractTemporalLinearMarginModel):
# i.e. 1) spatial individual terms 2) combined terms 3) temporal individual terms
for param_name, list_dim_and_degree in param_name_to_list_dim_and_degree.items():
dims = [d for d, m in list_dim_and_degree]
assert all([isinstance(d, int) or isinstance(d, tuple) for d in dims])
if self.coordinates.has_spatial_coordinates and self.coordinates.idx_x_coordinates in dims:
assert dims.index(self.coordinates.idx_x_coordinates) == 0
if self.coordinates.has_temporal_coordinates and self.coordinates.idx_temporal_coordinates in dims:
......
from extreme_fit.model.margin_model.polynomial_margin_model.altitudinal_models import StationaryAltitudinal, \
NonStationaryAltitudinalLocationLinear, NonStationaryAltitudinalLocationQuadratic, \
NonStationaryAltitudinalLocationLinearScaleLinear, NonStationaryAltitudinalLocationQuadraticScaleLinear
NonStationaryAltitudinalLocationLinearScaleLinear, NonStationaryAltitudinalLocationQuadraticScaleLinear, \
NonStationaryCrossTermForLocation, NonStationaryAltitudinalLocationLinearCrossTermForLocation, \
NonStationaryAltitudinalLocationQuadraticCrossTermForLocation, \
NonStationaryAltitudinalLocationLinearScaleLinearCrossTermForLocation, \
NonStationaryAltitudinalLocationQuadraticScaleLinearCrossTermForLocation
from extreme_fit.model.margin_model.polynomial_margin_model.spatio_temporal_polynomial_model import \
NonStationaryLocationSpatioTemporalLinearityModel1, NonStationaryLocationSpatioTemporalLinearityModel2, \
NonStationaryLocationSpatioTemporalLinearityModel3, NonStationaryLocationSpatioTemporalLinearityModel4, \
......@@ -9,22 +13,27 @@ from extreme_fit.model.margin_model.polynomial_margin_model.spatio_temporal_poly
NonStationaryLocationSpatioTemporalLinearityModelAssertError3, NonStationaryLocationSpatioTemporalLinearityModel6
ALTITUDINAL_MODELS = [
StationaryAltitudinal,
NonStationaryAltitudinalLocationLinear,
NonStationaryAltitudinalLocationQuadratic,
NonStationaryAltitudinalLocationLinearScaleLinear,
NonStationaryAltitudinalLocationQuadraticScaleLinear
][:]
StationaryAltitudinal,
NonStationaryAltitudinalLocationLinear,
NonStationaryAltitudinalLocationQuadratic,
NonStationaryAltitudinalLocationLinearScaleLinear,
NonStationaryAltitudinalLocationQuadraticScaleLinear,
NonStationaryCrossTermForLocation,
NonStationaryAltitudinalLocationLinearCrossTermForLocation,
NonStationaryAltitudinalLocationQuadraticCrossTermForLocation,
NonStationaryAltitudinalLocationLinearScaleLinearCrossTermForLocation,
NonStationaryAltitudinalLocationQuadraticScaleLinearCrossTermForLocation,
][:]
MODELS_THAT_SHOULD_RAISE_AN_ASSERTION_ERROR = [NonStationaryLocationSpatioTemporalLinearityModelAssertError1,
NonStationaryLocationSpatioTemporalLinearityModelAssertError2,
NonStationaryLocationSpatioTemporalLinearityModelAssertError3]
NonStationaryLocationSpatioTemporalLinearityModelAssertError2,
NonStationaryLocationSpatioTemporalLinearityModelAssertError3]
VARIOUS_SPATIO_TEMPORAL_MODELS = [NonStationaryLocationSpatioTemporalLinearityModel1,
NonStationaryLocationSpatioTemporalLinearityModel2,
NonStationaryLocationSpatioTemporalLinearityModel3,
NonStationaryLocationSpatioTemporalLinearityModel4,
NonStationaryLocationSpatioTemporalLinearityModel5,
NonStationaryLocationSpatioTemporalLinearityModel6,
]
\ No newline at end of file
NonStationaryLocationSpatioTemporalLinearityModel2,
NonStationaryLocationSpatioTemporalLinearityModel3,
NonStationaryLocationSpatioTemporalLinearityModel4,
NonStationaryLocationSpatioTemporalLinearityModel5,
NonStationaryLocationSpatioTemporalLinearityModel6,
]
import unittest
from random import sample
from extreme_data.meteo_france_data.scm_models_data.safran.safran import SafranSnowfall1Day
from extreme_fit.model.margin_model.polynomial_margin_model.utils import ALTITUDINAL_MODELS, \
......@@ -36,16 +37,16 @@ class TestGevTemporalQuadraticExtremesMle(unittest.TestCase):
self.assertAlmostEqual(estimator.result_from_model_fit.bic, estimator.bic(split=estimator.train_split))
def test_assert_error(self):
for model_class in MODELS_THAT_SHOULD_RAISE_AN_ASSERTION_ERROR:
for model_class in sample(MODELS_THAT_SHOULD_RAISE_AN_ASSERTION_ERROR, 1):
with self.assertRaises(AssertionError):
self.common_test(model_class)
def test_location_spatio_temporal_models(self):
for model_class in VARIOUS_SPATIO_TEMPORAL_MODELS[:]:
for model_class in sample(VARIOUS_SPATIO_TEMPORAL_MODELS, 3):
self.common_test(model_class)
def test_altitudinal_models(self):
for model_class in ALTITUDINAL_MODELS:
for model_class in sample(ALTITUDINAL_MODELS, 3):
self.common_test(model_class)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment