Commit 14f3cd3f authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[contrasting] refactor tests

parent 49deb42f
No related merge requests found
Showing with 20 additions and 22 deletions
+20 -22
from extreme_fit.model.margin_model.polynomial_margin_model import NonStationaryQuadraticLocationModel, \ from extreme_fit.model.margin_model.polynomial_margin_model.polynomial_margin_model import NonStationaryQuadraticLocationModel, \
NonStationaryQuadraticScaleModel NonStationaryQuadraticScaleModel
from extreme_fit.model.margin_model.utils import \ from extreme_fit.model.margin_model.utils import \
MarginFitMethod MarginFitMethod
......
from extreme_data.eurocode_data.utils import EUROCODE_QUANTILE from extreme_data.eurocode_data.utils import EUROCODE_QUANTILE
from extreme_fit.model.margin_model.polynomial_margin_model import NonStationaryQuadraticLocationModel, \ from extreme_fit.model.margin_model.polynomial_margin_model.polynomial_margin_model import NonStationaryQuadraticLocationModel, \
NonStationaryQuadraticScaleModel NonStationaryQuadraticScaleModel
from extreme_trend.abstract_gev_trend_test import AbstractGevTrendTest from extreme_trend.abstract_gev_trend_test import AbstractGevTrendTest
from extreme_trend.trend_test_one_parameter.gev_trend_test_one_parameter import \ from extreme_trend.trend_test_one_parameter.gev_trend_test_one_parameter import \
......
from extreme_data.eurocode_data.utils import EUROCODE_QUANTILE from extreme_data.eurocode_data.utils import EUROCODE_QUANTILE
from extreme_fit.model.margin_model.polynomial_margin_model import NonStationaryQuadraticLocationModel, \ from extreme_fit.model.margin_model.polynomial_margin_model.polynomial_margin_model import \
NonStationaryQuadraticLocationGumbelModel, NonStationaryQuadraticScaleGumbelModel NonStationaryQuadraticLocationGumbelModel, NonStationaryQuadraticScaleGumbelModel
from extreme_trend.trend_test_two_parameters.gev_trend_test_two_parameters import \ from extreme_trend.trend_test_two_parameters.gev_trend_test_two_parameters import \
GevTrendTestTwoParameters GevTrendTestTwoParameters
......
...@@ -13,8 +13,7 @@ class TwoFoldFit(object): ...@@ -13,8 +13,7 @@ class TwoFoldFit(object):
def __init__(self, two_fold_datasets_generator: TwoFoldDatasetsGenerator, def __init__(self, two_fold_datasets_generator: TwoFoldDatasetsGenerator,
model_family_name_to_model_classes: Dict[str, List[type]], model_family_name_to_model_classes: Dict[str, List[type]],
fit_method=MarginFitMethod.extremes_fevd_mle, fit_method=MarginFitMethod.extremes_fevd_mle):
):
self.two_fold_datasets_generator = two_fold_datasets_generator self.two_fold_datasets_generator = two_fold_datasets_generator
self.fit_method = fit_method self.fit_method = fit_method
self.model_family_name_to_model_classes = model_family_name_to_model_classes self.model_family_name_to_model_classes = model_family_name_to_model_classes
......
...@@ -4,13 +4,11 @@ import numpy as np ...@@ -4,13 +4,11 @@ import numpy as np
import pandas as pd import pandas as pd
from extreme_fit.distribution.gev.gev_params import GevParams from extreme_fit.distribution.gev.gev_params import GevParams
from extreme_fit.model.margin_model.polynomial_margin_model import NonStationaryQuadraticLocationModel, \ from extreme_fit.model.margin_model.polynomial_margin_model.polynomial_margin_model import NonStationaryQuadraticLocationModel, \
NonStationaryQuadraticScaleModel, NonStationaryQuadraticLocationGumbelModel, NonStationaryQuadraticScaleGumbelModel NonStationaryQuadraticScaleModel, NonStationaryQuadraticLocationGumbelModel, NonStationaryQuadraticScaleGumbelModel
from extreme_trend.abstract_gev_trend_test import fitted_linear_margin_estimator from extreme_trend.abstract_gev_trend_test import fitted_linear_margin_estimator
from extreme_fit.model.margin_model.utils import \ from extreme_fit.model.margin_model.utils import \
MarginFitMethod MarginFitMethod
from extreme_fit.model.margin_model.linear_margin_model.temporal_linear_margin_models import StationaryTemporalModel, \
NonStationaryLocationTemporalModel, NonStationaryLocationAndScaleTemporalModel
from extreme_fit.model.utils import r, set_seed_r from extreme_fit.model.utils import r, set_seed_r
from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
from spatio_temporal_dataset.coordinates.temporal_coordinates.abstract_temporal_coordinates import \ from spatio_temporal_dataset.coordinates.temporal_coordinates.abstract_temporal_coordinates import \
......
...@@ -14,34 +14,35 @@ from projects.altitude_spatial_model.altitudes_fit.utils import Score ...@@ -14,34 +14,35 @@ from projects.altitude_spatial_model.altitudes_fit.utils import Score
from spatio_temporal_dataset.slicer.split import Split from spatio_temporal_dataset.slicer.split import Split
def load_two_fold_fit(fit_method, year_max):
altitudes = [900, 1200]
study_class = SafranSnowfall1Day
studies = AltitudesStudies(study_class, altitudes, year_max=year_max)
two_fold_datasets_generator = TwoFoldDatasetsGenerator(studies, nb_samples=1, massif_names=['Vercors'])
model_family_name_to_model_class = {'Stationary': [ConstantMarginModel]}
return TwoFoldFit(two_fold_datasets_generator=two_fold_datasets_generator,
model_family_name_to_model_classes=model_family_name_to_model_class,
fit_method=fit_method)
class TestTwoFoldFit(unittest.TestCase): class TestTwoFoldFit(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
super().setUp() super().setUp()
set_seed_for_test() set_seed_for_test()
def load_two_fold_fit(self, fit_method, year_max):
self.altitudes = [900, 1200]
self.study_class = SafranSnowfall1Day
studies = AltitudesStudies(self.study_class, self.altitudes, year_max=year_max)
self.two_fold_datasets_generator = TwoFoldDatasetsGenerator(studies, nb_samples=1, massif_names=['Vercors'])
self.model_family_name_to_model_class = {'Stationary': [ConstantMarginModel]}
return TwoFoldFit(two_fold_datasets_generator=self.two_fold_datasets_generator,
model_family_name_to_model_classes=self.model_family_name_to_model_class,
fit_method=fit_method)
def test_determinism_dataset_generation(self): def test_determinism_dataset_generation(self):
two_fold_fit = self.load_two_fold_fit(fit_method=MarginFitMethod.spatial_extremes_mle, year_max=1963) two_fold_fit = load_two_fold_fit(fit_method=MarginFitMethod.spatial_extremes_mle, year_max=1963)
massif_fit = two_fold_fit.massif_name_to_massif_fit['Vercors'] massif_fit = two_fold_fit.massif_name_to_massif_fit['Vercors']
model_fit = massif_fit.sample_id_to_sample_fit[0].model_class_to_model_fit[ sample_fit = massif_fit.sample_id_to_sample_fit[0]
ConstantMarginModel] # type: TwoFoldModelFit model_fit = sample_fit.model_class_to_model_fit[ConstantMarginModel] # type: TwoFoldModelFit
dataset_fold1 = model_fit.estimator_fold_1.dataset dataset_fold1 = model_fit.estimator_fold_1.dataset
index_train = list(dataset_fold1.coordinates.coordinates_index(split=Split.train_temporal)) index_train = list(dataset_fold1.coordinates.coordinates_index(split=Split.train_temporal))
self.assertEqual([2, 3, 8, 9], index_train) self.assertEqual([2, 3, 8, 9], index_train)
self.assertEqual(110.52073192596436, np.sum(dataset_fold1.maxima_gev(split=Split.train_temporal))) self.assertEqual(110.52073192596436, np.sum(dataset_fold1.maxima_gev(split=Split.train_temporal)))
def test_determinism_fit_spatial_extreme(self): def test_determinism_fit_spatial_extreme(self):
two_fold_fit = self.load_two_fold_fit(fit_method=MarginFitMethod.spatial_extremes_mle, year_max=2019) two_fold_fit = load_two_fold_fit(fit_method=MarginFitMethod.spatial_extremes_mle, year_max=2019)
massif_fit = two_fold_fit.massif_name_to_massif_fit['Vercors'] massif_fit = two_fold_fit.massif_name_to_massif_fit['Vercors']
model_fit = massif_fit.sample_id_to_sample_fit[0].model_class_to_model_fit[ model_fit = massif_fit.sample_id_to_sample_fit[0].model_class_to_model_fit[
ConstantMarginModel] # type: TwoFoldModelFit ConstantMarginModel] # type: TwoFoldModelFit
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment