From 294a5b1c0a4f606d2b57f9c877b09105539fc3b6 Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Thu, 25 Jun 2020 11:35:28 +0200
Subject: [PATCH] [contrasting] add spatio temporal model with cross terms

---
 .../gev/main_fevd_mle_two_covariates.R        |  3 +-
 .../margin_function/linear_margin_function.py | 16 ++++-
 .../parametric_margin_function.py             | 11 ++-
 .../function/param_function/linear_coef.py    | 12 +++-
 .../function/param_function/param_function.py |  8 ++-
 .../param_function/polynomial_coef.py         |  5 +-
 .../polynomial_margin_model.py                | 19 +++--
 .../spatio_temporal_polynomial_model.py       | 70 ++++++++++++++++++-
 .../polynomial_margin_model/utils.py          | 19 +++++
 .../model/result_from_model_fit/utils.py      | 10 +--
 .../test_gev_spatio_temporal_extremes_mle.py  | 16 +++--
 11 files changed, 162 insertions(+), 27 deletions(-)

diff --git a/extreme_fit/distribution/gev/main_fevd_mle_two_covariates.R b/extreme_fit/distribution/gev/main_fevd_mle_two_covariates.R
index 92317a48..fb450f26 100644
--- a/extreme_fit/distribution/gev/main_fevd_mle_two_covariates.R
+++ b/extreme_fit/distribution/gev/main_fevd_mle_two_covariates.R
@@ -27,7 +27,8 @@ coord = data.frame(coord, stringsAsFactors = TRUE)
 # res = fevd_fixed(x_gev, data=coord, method='MLE', verbose=TRUE, use.phi=FALSE)
 # res = fevd_fixed(x_gev, data=coord, location.fun= ~T, scale.fun= ~T, method='MLE', type="GEV", verbose=FALSE, use.phi=FALSE)
 # res = fevd_fixed(x_gev, data=coord, location.fun= ~sin(X) + cos(T), method='MLE', type="GEV", verbose=FALSE, use.phi=FALSE)
-res = fevd_fixed(x_gev, data=coord, location.fun= ~poly(X, 1, raw = TRUE) + poly(T, 2, raw = TRUE) , method='MLE', type="GEV", verbose=FALSE, use.phi=FALSE)
+res = fevd_fixed(x_gev, data=coord, location.fun= ~poly(X * T, 1, raw = TRUE),  method='MLE', type="GEV", verbose=FALSE, use.phi=FALSE)
+# res = fevd_fixed(x_gev, data=coord, location.fun= ~poly(X, 1, raw = TRUE) + poly(T, 2, raw = TRUE) , method='MLE', type="GEV", verbose=FALSE, use.phi=FALSE)
 print(res)
 
 # Some display for the results
diff --git a/extreme_fit/function/margin_function/linear_margin_function.py b/extreme_fit/function/margin_function/linear_margin_function.py
index d2882b1b..b989c1c6 100644
--- a/extreme_fit/function/margin_function/linear_margin_function.py
+++ b/extreme_fit/function/margin_function/linear_margin_function.py
@@ -8,6 +8,7 @@ from extreme_fit.function.param_function.linear_coef import LinearCoef
 from extreme_fit.function.param_function.param_function import AbstractParamFunction, \
     LinearParamFunction
 from extreme_fit.distribution.gev.gev_params import GevParams
+from extreme_fit.function.param_function.polynomial_coef import PolynomialAllCoef
 from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
 
 
@@ -78,9 +79,21 @@ class LinearMarginFunction(ParametricMarginFunction):
                                  if self.coordinate_name_to_dim[name] in linear_dims]
                 spatial_dims = [self.coordinate_name_to_dim[name] for name in spatial_names]
                 spatial_form = self.param_name_to_coef[param_name].spatial_form_dict(spatial_names, spatial_dims)
+                # Load cross term combining several coordinates (necessarily including spatial coordinates)
+                tuple_dims = [e for e in linear_dims if isinstance(e, tuple)]
+                if len(tuple_dims) > 0:
+                    key, value = spatial_form.popitem()
+                    for tuple_dim in tuple_dims:
+                        coef = self.param_name_to_coef[param_name]
+                        assert isinstance(coef, PolynomialAllCoef)
+                        name = ' * '.join([self.coordinates.dim_to_coordinate[dim] for dim in tuple_dim])
+                        form = coef.form_dict([name], [tuple_dim])
+                        _, additional_value = form.popitem()
+                        additional_value = additional_value.split('~')[-1]
+                        value += ' + ' + additional_value
+                    spatial_form[key] = value
                 form_dict.update(spatial_form)
             # Load temporal form dict (only if we have some temporal coordinates)
-
             if self.coordinates.has_temporal_coordinates:
                 temporal_names = [name for name in self.coordinates.temporal_coordinates_names
                                   if self.coordinate_name_to_dim[name] in linear_dims]
@@ -89,6 +102,7 @@ class LinearMarginFunction(ParametricMarginFunction):
                 # Specifying a formula '~ 1' creates a bug in fitspatgev of SpatialExtreme R package
                 assert not any(['~ 1' in formula for formula in temporal_form.values()])
                 form_dict.update(temporal_form)
+
         return form_dict
 
     # Properties for the location parameter
diff --git a/extreme_fit/function/margin_function/parametric_margin_function.py b/extreme_fit/function/margin_function/parametric_margin_function.py
index 17e5dcfd..d08609d8 100644
--- a/extreme_fit/function/margin_function/parametric_margin_function.py
+++ b/extreme_fit/function/margin_function/parametric_margin_function.py
@@ -41,8 +41,15 @@ class ParametricMarginFunction(IndependentMarginFunction):
         # Check the dimension are well-defined with respect to the coordinates
         for dims in self.param_name_to_dims.values():
             for dim in dims:
-                assert 0 <= dim < coordinates.nb_coordinates, \
-                    "dim={}, nb_columns={}".format(dim, coordinates.nb_coordinates)
+                if isinstance(dim, int):
+                    assert 0 <= dim < coordinates.nb_coordinates, \
+                        "dim={}, nb_columns={}".format(dim, coordinates.nb_coordinates)
+                elif isinstance(dim, tuple):
+                    for d in dim:
+                        assert 0 <= d < coordinates.nb_coordinates, \
+                            "dim={}, nb_columns={}".format(d, coordinates.nb_coordinates)
+                else:
+                    raise TypeError(type(dim))
 
         self.param_name_to_coef = param_name_to_coef  # type: Dict[str, AbstractCoef]
 
diff --git a/extreme_fit/function/param_function/linear_coef.py b/extreme_fit/function/param_function/linear_coef.py
index f139dfe4..0d5e9f22 100644
--- a/extreme_fit/function/param_function/linear_coef.py
+++ b/extreme_fit/function/param_function/linear_coef.py
@@ -29,7 +29,8 @@ class LinearCoef(AbstractCoef):
         :param coefficient_name:
         :return:
         """
-        assert coefficient_name == cls.INTERCEPT_NAME or coefficient_name in AbstractCoordinates.COORDINATES_NAMES
+        assert coefficient_name == cls.INTERCEPT_NAME or coefficient_name in AbstractCoordinates.COORDINATES_NAMES \
+               or any([coordinate_name in coefficient_name for coordinate_name in AbstractCoordinates.COORDINATES_NAMES])
         if coefficient_name == cls.INTERCEPT_NAME or coefficient_name in AbstractCoordinates.COORDINATE_SPATIAL_NAMES:
             coef_template_str = param_name + cls.COEFF_STR + '{}'
         else:
@@ -37,6 +38,15 @@ class LinearCoef(AbstractCoef):
         assert cls.COEFF_STR in coef_template_str
         return coef_template_str
 
+    @classmethod
+    def coefficient_name(cls, dim, dim_to_coordinate_name):
+        if isinstance(dim, int):
+            return dim_to_coordinate_name[dim]
+        elif isinstance(dim, tuple):
+            return ' * '.join([dim_to_coordinate_name[d] for d in dim])
+        else:
+            raise NotImplementedError
+
     @staticmethod
     def has_dependence_in_spatial_coordinates(dim_to_coefficient_name):
         return any([coefficient_name in AbstractCoordinates.COORDINATE_SPATIAL_NAMES
diff --git a/extreme_fit/function/param_function/param_function.py b/extreme_fit/function/param_function/param_function.py
index f59c925d..e6c65b06 100644
--- a/extreme_fit/function/param_function/param_function.py
+++ b/extreme_fit/function/param_function/param_function.py
@@ -1,3 +1,5 @@
+import operator
+from functools import reduce
 from typing import List
 import numpy as np
 from extreme_fit.function.param_function.linear_coef import LinearCoef
@@ -60,13 +62,17 @@ class PolynomialParamFunction(AbstractParamFunction):
     def get_param_value(self, coordinate: np.ndarray) -> float:
         gev_param_value = 0
         for i, (dim, max_degree) in enumerate(self.dim_and_degree):
+            if isinstance(dim, int):
+                coordinate_value_for_dim = coordinate[dim]
+            else:
+                coordinate_value_for_dim = reduce(operator.mul, [coordinate[d] for d in dim])
             # Add intercept only once
             add_intercept = i == 0
             first_degree = 0 if add_intercept else 1
             for degree in range(first_degree, max_degree+1):
                 polynomial_coef = self.coef.dim_to_polynomial_coef[dim]  # type: PolynomialCoef
                 polynomial_coef_value = polynomial_coef.idx_to_coef[degree]
-                gev_param_value += polynomial_coef_value * np.power(coordinate[dim], degree)
+                gev_param_value += polynomial_coef_value * np.power(coordinate_value_for_dim, degree)
         return gev_param_value
 
 
diff --git a/extreme_fit/function/param_function/polynomial_coef.py b/extreme_fit/function/param_function/polynomial_coef.py
index 5a1da889..9460556d 100644
--- a/extreme_fit/function/param_function/polynomial_coef.py
+++ b/extreme_fit/function/param_function/polynomial_coef.py
@@ -59,7 +59,7 @@ class PolynomialAllCoef(LinearCoef):
             intercept = None
             dim_to_polynomial_coef = {}
             for dim, max_degree in list_dim_and_max_degree:
-                coefficient_name = coordinates.coordinates_names[dim]
+                coefficient_name = cls.coefficient_name(dim, coordinates.dim_to_coordinate)
                 j = 1 if coefficient_name == AbstractCoordinates.COORDINATE_T else 2
                 degree_to_coef = {0: degree0}
                 for degree in range(1, max_degree + 1):
@@ -70,9 +70,6 @@ class PolynomialAllCoef(LinearCoef):
         return cls(param_name=param_name, dim_to_polynomial_coef=dim_to_polynomial_coef, intercept=intercept)
 
     def form_dict(self, coordinates_names: List[str], dims) -> Dict[str, str]:
-        if len(coordinates_names) >= 2:
-            raise NotImplementedError(
-                'Check how do we sum two polynomails without having two times an intercept parameter')
         formula_list = []
         if len(coordinates_names) == 0:
             formula_str = '1'
diff --git a/extreme_fit/model/margin_model/polynomial_margin_model/polynomial_margin_model.py b/extreme_fit/model/margin_model/polynomial_margin_model/polynomial_margin_model.py
index 838043c0..4fde1188 100644
--- a/extreme_fit/model/margin_model/polynomial_margin_model/polynomial_margin_model.py
+++ b/extreme_fit/model/margin_model/polynomial_margin_model/polynomial_margin_model.py
@@ -1,3 +1,5 @@
+import itertools
+
 from cached_property import cached_property
 
 from extreme_fit.distribution.gev.gev_params import GevParams
@@ -29,6 +31,15 @@ class PolynomialMarginModel(AbstractTemporalLinearMarginModel):
         return super().margin_function
 
     def load_margin_function(self, param_name_to_list_dim_and_degree=None):
+        # Assert the order of list of dim and degree, to match the order of the form dict,
+        # i.e. 1) spatial individual terms 2) combined terms 3) temporal individual terms
+        for param_name, list_dim_and_degree in param_name_to_list_dim_and_degree.items():
+            dims = [d for d, m in list_dim_and_degree]
+            if self.coordinates.has_spatial_coordinates and self.coordinates.idx_x_coordinates in dims:
+                assert dims.index(self.coordinates.idx_x_coordinates) == 0
+            if self.coordinates.has_temporal_coordinates and self.coordinates.idx_temporal_coordinates in dims:
+                assert dims.index(self.coordinates.idx_temporal_coordinates) == len(dims) - 1
+        # Load param_name_to_polynomial_all_coef
         param_name_to_polynomial_all_coef = self.param_name_to_polynomial_all_coef(
             param_name_to_list_dim_and_degree=param_name_to_list_dim_and_degree,
             param_name_and_dim_and_degree_to_default_coef=self.default_params)
@@ -43,7 +54,10 @@ class PolynomialMarginModel(AbstractTemporalLinearMarginModel):
         default_slope = 0.01
         param_name_and_dim_and_degree_to_coef = {}
         for param_name in self.params_class.PARAM_NAMES:
-            for dim in self.coordinates.coordinates_dims:
+            all_individual_dims = self.coordinates.coordinates_dims
+            combinations_of_two_dims = list(itertools.combinations(all_individual_dims, 2))
+            dims = all_individual_dims + combinations_of_two_dims
+            for dim in dims:
                 for degree in range(self.max_degree + 1):
                     param_name_and_dim_and_degree_to_coef[(param_name, dim, degree)] = default_slope
         return param_name_and_dim_and_degree_to_coef
@@ -59,9 +73,6 @@ class PolynomialMarginModel(AbstractTemporalLinearMarginModel):
                 for (param_name_loop, dim_loop, degree), coef in param_name_and_dim_and_degree_to_default_coef.items():
                     if param_name == param_name_loop and dim == dim_loop and degree <= max_degree:
                         degree_to_coef[degree] = coef
-                # print(degree_to_coef, param_name)
-                # if len(degree_to_coef) == 0:
-                #     degree_to_coef = {0: param_name_and_dim_and_degree_to_default_coef[(param_name, dim, 0)]}
                 polynomial_coef = PolynomialCoef(param_name, degree_to_coef=degree_to_coef)
                 dim_to_polynomial_coef[dim] = polynomial_coef
             if len(dim_to_polynomial_coef) == 0:
diff --git a/extreme_fit/model/margin_model/polynomial_margin_model/spatio_temporal_polynomial_model.py b/extreme_fit/model/margin_model/polynomial_margin_model/spatio_temporal_polynomial_model.py
index b862e2ab..fefeb13b 100644
--- a/extreme_fit/model/margin_model/polynomial_margin_model/spatio_temporal_polynomial_model.py
+++ b/extreme_fit/model/margin_model/polynomial_margin_model/spatio_temporal_polynomial_model.py
@@ -14,12 +14,12 @@ class AbstractSpatioTemporalPolynomialModel(PolynomialMarginModel):
         self.drop_duplicates = False
 
 
-class NonStationaryLocationSpatioTemporalLinearityModel(AbstractSpatioTemporalPolynomialModel):
+class NonStationaryLocationSpatioTemporalLinearityModel1(AbstractSpatioTemporalPolynomialModel):
 
     def load_margin_function(self, param_name_to_dims=None):
         return super().load_margin_function({GevParams.LOC: [
-            (self.coordinates.idx_temporal_coordinates, 1),
             (self.coordinates.idx_x_coordinates, 1),
+            (self.coordinates.idx_temporal_coordinates, 1),
         ]})
 
 
@@ -27,6 +27,72 @@ class NonStationaryLocationSpatioTemporalLinearityModel2(AbstractSpatioTemporalP
 
     def load_margin_function(self, param_name_to_dims=None):
         return super().load_margin_function({GevParams.LOC: [
+            (self.coordinates.idx_x_coordinates, 1),
             (self.coordinates.idx_temporal_coordinates, 2),
+        ]})
+
+
+class NonStationaryLocationSpatioTemporalLinearityModel3(AbstractSpatioTemporalPolynomialModel):
+
+    def load_margin_function(self, param_name_to_dims=None):
+        return super().load_margin_function({GevParams.LOC: [
+            ((self.coordinates.idx_x_coordinates, self.coordinates.idx_temporal_coordinates), 1),
+        ]})
+
+
+class NonStationaryLocationSpatioTemporalLinearityModel4(AbstractSpatioTemporalPolynomialModel):
+
+    def load_margin_function(self, param_name_to_dims=None):
+        return super().load_margin_function({GevParams.LOC: [
+            (self.coordinates.idx_x_coordinates, 1),
+            ((self.coordinates.idx_x_coordinates, self.coordinates.idx_temporal_coordinates), 1),
+        ]})
+
+
+class NonStationaryLocationSpatioTemporalLinearityModel5(AbstractSpatioTemporalPolynomialModel):
+
+    def load_margin_function(self, param_name_to_dims=None):
+        return super().load_margin_function({GevParams.LOC: [
+            ((self.coordinates.idx_x_coordinates, self.coordinates.idx_temporal_coordinates), 1),
+            (self.coordinates.idx_temporal_coordinates, 1),
+        ]})
+
+
+class NonStationaryLocationSpatioTemporalLinearityModel6(AbstractSpatioTemporalPolynomialModel):
+
+    def load_margin_function(self, param_name_to_dims=None):
+        return super().load_margin_function({GevParams.LOC: [
+            (self.coordinates.idx_x_coordinates, 1),
+            ((self.coordinates.idx_x_coordinates, self.coordinates.idx_temporal_coordinates), 1),
+            (self.coordinates.idx_temporal_coordinates, 1),
+        ]})
+
+
+# Models that are supposed to raise errors
+
+class NonStationaryLocationSpatioTemporalLinearityModelAssertError1(AbstractSpatioTemporalPolynomialModel):
+
+    def load_margin_function(self, param_name_to_dims=None):
+        return super().load_margin_function({GevParams.LOC: [
+            (self.coordinates.idx_temporal_coordinates, 1),
             (self.coordinates.idx_x_coordinates, 1),
         ]})
+
+
+class NonStationaryLocationSpatioTemporalLinearityModelAssertError2(AbstractSpatioTemporalPolynomialModel):
+
+    def load_margin_function(self, param_name_to_dims=None):
+        return super().load_margin_function({GevParams.LOC: [
+            ((self.coordinates.idx_x_coordinates, self.coordinates.idx_temporal_coordinates), 1),
+            (self.coordinates.idx_x_coordinates, 1),
+        ]})
+
+
+class NonStationaryLocationSpatioTemporalLinearityModelAssertError3(AbstractSpatioTemporalPolynomialModel):
+
+    def load_margin_function(self, param_name_to_dims=None):
+        return super().load_margin_function({GevParams.LOC: [
+            (self.coordinates.idx_temporal_coordinates, 1),
+            ((self.coordinates.idx_x_coordinates, self.coordinates.idx_temporal_coordinates), 1),
+
+        ]})
diff --git a/extreme_fit/model/margin_model/polynomial_margin_model/utils.py b/extreme_fit/model/margin_model/polynomial_margin_model/utils.py
index acd15190..db52b4fb 100644
--- a/extreme_fit/model/margin_model/polynomial_margin_model/utils.py
+++ b/extreme_fit/model/margin_model/polynomial_margin_model/utils.py
@@ -1,6 +1,12 @@
 from extreme_fit.model.margin_model.polynomial_margin_model.altitudinal_models import StationaryAltitudinal, \
     NonStationaryAltitudinalLocationLinear, NonStationaryAltitudinalLocationQuadratic, \
     NonStationaryAltitudinalLocationLinearScaleLinear, NonStationaryAltitudinalLocationQuadraticScaleLinear
+from extreme_fit.model.margin_model.polynomial_margin_model.spatio_temporal_polynomial_model import \
+    NonStationaryLocationSpatioTemporalLinearityModel1, NonStationaryLocationSpatioTemporalLinearityModel2, \
+    NonStationaryLocationSpatioTemporalLinearityModel3, NonStationaryLocationSpatioTemporalLinearityModel4, \
+    NonStationaryLocationSpatioTemporalLinearityModel5, NonStationaryLocationSpatioTemporalLinearityModelAssertError1, \
+    NonStationaryLocationSpatioTemporalLinearityModelAssertError2, \
+    NonStationaryLocationSpatioTemporalLinearityModelAssertError3, NonStationaryLocationSpatioTemporalLinearityModel6
 
 ALTITUDINAL_MODELS = [
     StationaryAltitudinal,
@@ -9,3 +15,16 @@ ALTITUDINAL_MODELS = [
     NonStationaryAltitudinalLocationLinearScaleLinear,
     NonStationaryAltitudinalLocationQuadraticScaleLinear
 ][:]
+
+
+MODELS_THAT_SHOULD_RAISE_AN_ASSERTION_ERROR = [NonStationaryLocationSpatioTemporalLinearityModelAssertError1,
+                       NonStationaryLocationSpatioTemporalLinearityModelAssertError2,
+                       NonStationaryLocationSpatioTemporalLinearityModelAssertError3]
+
+VARIOUS_SPATIO_TEMPORAL_MODELS = [NonStationaryLocationSpatioTemporalLinearityModel1,
+                    NonStationaryLocationSpatioTemporalLinearityModel2,
+                    NonStationaryLocationSpatioTemporalLinearityModel3,
+                    NonStationaryLocationSpatioTemporalLinearityModel4,
+                    NonStationaryLocationSpatioTemporalLinearityModel5,
+                    NonStationaryLocationSpatioTemporalLinearityModel6,
+                    ]
\ No newline at end of file
diff --git a/extreme_fit/model/result_from_model_fit/utils.py b/extreme_fit/model/result_from_model_fit/utils.py
index 35c0dada..fd1d9eb2 100644
--- a/extreme_fit/model/result_from_model_fit/utils.py
+++ b/extreme_fit/model/result_from_model_fit/utils.py
@@ -35,12 +35,12 @@ def get_margin_coef_ordered_dict(param_name_to_dims, mle_values, type_for_mle="G
             else:
                 # We found (thanks to the test) that time was the first parameter when len(param_name_to_dims) == 1
                 # otherwise time is the second parameter in the order of the mle parameters
-                inverted_dims = dims[::-1] if len(param_name_to_dims) == 1 else dims
-                for dim, max_degree in inverted_dims:
-                    coordinate_name = dim_to_coordinate_name[dim]
-                    coef_template = LinearCoef.coef_template_str(param_name, coordinate_name)
+                # inverted_dims = dims[::-1] if len(param_name_to_dims) == 1 else dims
+                for dim, max_degree in dims:
+                    coefficient_name = LinearCoef.coefficient_name(dim, dim_to_coordinate_name)
+                    coef_template = LinearCoef.coef_template_str(param_name, coefficient_name)
                     for j in range(1, max_degree + 1):
-                        k = j if coordinate_name == AbstractCoordinates.COORDINATE_T else j + 1
+                        k = j if coefficient_name == AbstractCoordinates.COORDINATE_T else j + 1
                         coef_name = coef_template.format(k)
                         coef_dict[coef_name] = mle_values[i]
                         i += 1
diff --git a/test/test_extreme_fit/test_estimator/test_gev_spatio_temporal_extremes_mle.py b/test/test_extreme_fit/test_estimator/test_gev_spatio_temporal_extremes_mle.py
index 9af269f3..68eec4e2 100644
--- a/test/test_extreme_fit/test_estimator/test_gev_spatio_temporal_extremes_mle.py
+++ b/test/test_extreme_fit/test_estimator/test_gev_spatio_temporal_extremes_mle.py
@@ -1,13 +1,13 @@
 import unittest
 
 from extreme_data.meteo_france_data.scm_models_data.safran.safran import SafranSnowfall1Day
-from extreme_fit.model.margin_model.polynomial_margin_model.spatio_temporal_polynomial_model import \
-    NonStationaryLocationSpatioTemporalLinearityModel, NonStationaryLocationSpatioTemporalLinearityModel2
-from extreme_fit.model.margin_model.polynomial_margin_model.utils import ALTITUDINAL_MODELS
+from extreme_fit.model.margin_model.polynomial_margin_model.utils import ALTITUDINAL_MODELS, \
+    MODELS_THAT_SHOULD_RAISE_AN_ASSERTION_ERROR, VARIOUS_SPATIO_TEMPORAL_MODELS
 from extreme_fit.model.margin_model.utils import \
     MarginFitMethod
 from projects.altitude_spatial_model.altitudes_fit.altitudes_studies import AltitudesStudies
-from projects.altitude_spatial_model.altitudes_fit.two_fold_analysis.two_fold_datasets_generator import TwoFoldDatasetsGenerator
+from projects.altitude_spatial_model.altitudes_fit.two_fold_analysis.two_fold_datasets_generator import \
+    TwoFoldDatasetsGenerator
 from projects.altitude_spatial_model.altitudes_fit.two_fold_analysis.two_fold_fit import TwoFoldFit
 
 
@@ -35,9 +35,13 @@ class TestGevTemporalQuadraticExtremesMle(unittest.TestCase):
         self.assertAlmostEqual(estimator.result_from_model_fit.aic, estimator.aic(split=estimator.train_split))
         self.assertAlmostEqual(estimator.result_from_model_fit.bic, estimator.bic(split=estimator.train_split))
 
+    def test_assert_error(self):
+        for model_class in MODELS_THAT_SHOULD_RAISE_AN_ASSERTION_ERROR:
+            with self.assertRaises(AssertionError):
+                self.common_test(model_class)
+
     def test_location_spatio_temporal_models(self):
-        for model_class in [NonStationaryLocationSpatioTemporalLinearityModel,
-                            NonStationaryLocationSpatioTemporalLinearityModel2]:
+        for model_class in VARIOUS_SPATIO_TEMPORAL_MODELS[:]:
             self.common_test(model_class)
 
     def test_altitudinal_models(self):
-- 
GitLab