Commit 4df8a93d authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[contrasting] fix quadratic models with cross terms

parent f5c1e693
No related merge requests found
Showing with 19 additions and 48 deletions
+19 -48
......@@ -33,8 +33,12 @@ class LinearCoef(AbstractCoef):
or any([coordinate_name in coefficient_name for coordinate_name in AbstractCoordinates.COORDINATES_NAMES])
if coefficient_name == cls.INTERCEPT_NAME or coefficient_name in AbstractCoordinates.COORDINATE_SPATIAL_NAMES:
coef_template_str = param_name + cls.COEFF_STR + '{}'
else:
elif coefficient_name == AbstractCoordinates.COORDINATE_T:
coef_template_str = 'temp' + cls.COEFF_STR + param_name.title() + '{}'
elif len([c for c in AbstractCoordinates.COORDINATES_NAMES if c in coefficient_name]) >= 2:
coef_template_str = 'cross' + cls.COEFF_STR + param_name.title() + '{}'
else:
raise NotImplementedError
assert cls.COEFF_STR in coef_template_str
return coef_template_str
......@@ -47,6 +51,10 @@ class LinearCoef(AbstractCoef):
else:
raise NotImplementedError
@classmethod
def offset_from_coefficient_name(cls, coefficient_name):
return 1 if coefficient_name == AbstractCoordinates.COORDINATE_X else 0
@staticmethod
def has_dependence_in_spatial_coordinates(dim_to_coefficient_name):
return any([coefficient_name in AbstractCoordinates.COORDINATE_SPATIAL_NAMES
......
......@@ -60,12 +60,12 @@ class PolynomialAllCoef(LinearCoef):
dim_to_polynomial_coef = {}
for dim, max_degree in list_dim_and_max_degree:
coefficient_name = cls.coefficient_name(dim, coordinates.dim_to_coordinate)
j = 1 if coefficient_name == AbstractCoordinates.COORDINATE_T else 2
offset = cls.offset_from_coefficient_name(coefficient_name)
degree_to_coef = {0: degree0}
for degree in range(1, max_degree + 1):
coef_value = coef_dict[cls.coef_template_str(param_name, coefficient_name).format(j)]
coef_key = cls.coef_template_str(param_name, coefficient_name).format(offset + degree)
coef_value = coef_dict[coef_key]
degree_to_coef[degree] = coef_value
j += 1
dim_to_polynomial_coef[dim] = PolynomialCoef(param_name=param_name, degree_to_coef=degree_to_coef)
return cls(param_name=param_name, dim_to_polynomial_coef=dim_to_polynomial_coef, intercept=intercept)
......
......@@ -23,7 +23,7 @@ ALTITUDINAL_MODELS = [
NonStationaryAltitudinalLocationQuadraticCrossTermForLocation,
NonStationaryAltitudinalLocationLinearScaleLinearCrossTermForLocation,
NonStationaryAltitudinalLocationQuadraticScaleLinearCrossTermForLocation,
][:7]
][:]
......
......@@ -36,12 +36,12 @@ def get_margin_coef_ordered_dict(param_name_to_dims, mle_values, type_for_mle="G
# We found (thanks to the test) that time was the first parameter when len(param_name_to_dims) == 1
# otherwise time is the second parameter in the order of the mle parameters
# inverted_dims = dims[::-1] if len(param_name_to_dims) == 1 else dims
for dim, max_degree in dims:
for dim, max_degree in dims[:]:
coefficient_name = LinearCoef.coefficient_name(dim, dim_to_coordinate_name)
coef_template = LinearCoef.coef_template_str(param_name, coefficient_name)
for j in range(1, max_degree + 1):
k = j if coefficient_name == AbstractCoordinates.COORDINATE_T else j + 1
coef_name = coef_template.format(k)
offset = LinearCoef.offset_from_coefficient_name(coefficient_name)
for k in range(1, max_degree + 1):
coef_name = coef_template.format(k + offset)
coef_dict[coef_name] = mle_values[i]
i += 1
return coef_dict
......@@ -4,7 +4,8 @@ from random import sample
from extreme_data.meteo_france_data.scm_models_data.safran.safran import SafranSnowfall1Day, SafranPrecipitation1Day
from extreme_fit.model.margin_model.polynomial_margin_model.altitudinal_models import \
NonStationaryAltitudinalLocationQuadraticScaleLinearCrossTermForLocation, \
NonStationaryAltitudinalLocationQuadraticCrossTermForLocation
NonStationaryAltitudinalLocationQuadraticCrossTermForLocation, NonStationaryAltitudinalLocationLinear, \
NonStationaryAltitudinalLocationLinearCrossTermForLocation
from extreme_fit.model.margin_model.polynomial_margin_model.utils import ALTITUDINAL_MODELS, \
MODELS_THAT_SHOULD_RAISE_AN_ASSERTION_ERROR, VARIOUS_SPATIO_TEMPORAL_MODELS
from extreme_fit.model.margin_model.utils import \
......@@ -53,46 +54,8 @@ class TestGevTemporalQuadraticExtremesMle(unittest.TestCase):
def test_altitudinal_models(self):
for model_class in ALTITUDINAL_MODELS:
# print(model_class)
self.common_test(model_class)
# class MyTest(unittest.TestCase):
#
# def setUp(self) -> None:
# self.study_class = SafranPrecipitation1Day
# self.altitudes = [900, 1200, 1500, 1800, 2100, 2400, 2700, 3000]
# self.massif_name = 'Aravis'
#
# def get_estimator_fitted(self, model_class):
# studies = AltitudesStudies(self.study_class, self.altitudes, year_max=2019)
# two_fold_datasets_generator = TwoFoldDatasetsGenerator(studies, nb_samples=1, massif_names=[self.massif_name])
# model_family_name_to_model_class = {'Non stationary': [model_class]}
# two_fold_fit = TwoFoldFit(two_fold_datasets_generator=two_fold_datasets_generator,
# model_family_name_to_model_classes=model_family_name_to_model_class,
# fit_method=MarginFitMethod.extremes_fevd_mle)
# massif_fit = two_fold_fit.massif_name_to_massif_fit[self.massif_name]
# sample_fit = massif_fit.sample_id_to_sample_fit[0]
# model_fit = sample_fit.model_class_to_model_fit[model_class] # type: TwoFoldModelFit
# estimator = model_fit.estimator_fold_1
# return estimator
#
# def common_test(self, model_class):
# estimator = self.get_estimator_fitted(model_class)
# # Assert that indicators are correctly computed
# self.assertAlmostEqual(estimator.result_from_model_fit.nllh, estimator.nllh(split=estimator.train_split))
# self.assertAlmostEqual(estimator.result_from_model_fit.aic, estimator.aic(split=estimator.train_split))
# self.assertAlmostEqual(estimator.result_from_model_fit.bic, estimator.bic(split=estimator.train_split))
#
# #
# # # def test_altitudinal_models(self):
# # # for model_class in ALTITUDINAL_MODELS:
# # # self.common_test(model_class)
# #
# def test_wrong(self):
# self.common_test(NonStationaryAltitudinalLocationQuadraticCrossTermForLocation)
# # self.common_test(NonStationaryAltitudinalLocationQuadraticScaleLinearCrossTermForLocation)
if __name__ == '__main__':
unittest.main()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment