Commit e1995070 authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[EXTREME ESTIMATOR][MARGIN MODEL] handle transformed starting point. add test.

parent 1c9927df
No related merge requests found
Showing with 116 additions and 7 deletions
+116 -7
......@@ -19,15 +19,22 @@ class IndependentMarginFunction(AbstractMarginFunction):
super().__init__(coordinates)
self.gev_param_name_to_param_function = None # type: Union[None, Dict[str, AbstractParamFunction]]
def get_gev_params(self, coordinate: np.ndarray) -> GevParams:
def get_gev_params(self, coordinate: np.ndarray, already_transformed: bool = False) -> GevParams:
"""Each GEV parameter is computed independently through its corresponding param_function"""
assert self.gev_param_name_to_param_function is not None
assert len(self.gev_param_name_to_param_function) == 3
transformed_coordinate = self.coordinates.transform(coordinate)
if already_transformed:
transformed_coordinate = coordinate
else:
transformed_coordinate = self.transform(coordinate)
gev_params = {}
for gev_param_name in GevParams.PARAM_NAMES:
param_function = self.gev_param_name_to_param_function[gev_param_name]
gev_params[gev_param_name] = param_function.get_gev_param_value(transformed_coordinate)
return GevParams.from_dict(gev_params)
def transform(self, coordinate):
transformed_coordinate = self.coordinates.transform(coordinate)
return transformed_coordinate
......@@ -61,14 +61,15 @@ class ParametricMarginFunction(IndependentMarginFunction):
def load_specific_param_function(self, gev_param_name) -> AbstractParamFunction:
raise NotImplementedError
def get_gev_params(self, coordinate: np.ndarray) -> GevParams:
def get_gev_params(self, coordinate: np.ndarray, already_transformed: bool = False) -> GevParams:
transformed_coordinate = self.transform(coordinate)
if self.transformed_starting_point is not None:
# Shift temporal coordinate to enable to model temporal trend with starting point
assert self.coordinates.has_temporal_coordinates
assert 0 <= self.coordinates.idx_temporal_coordinates < len(coordinate)
if coordinate[self.coordinates.idx_temporal_coordinates] < self.transformed_starting_point:
coordinate[self.coordinates.idx_temporal_coordinates] = self.transformed_starting_point
return super().get_gev_params(coordinate)
if transformed_coordinate[self.coordinates.idx_temporal_coordinates] < self.transformed_starting_point:
transformed_coordinate[self.coordinates.idx_temporal_coordinates] = self.transformed_starting_point
return super().get_gev_params(transformed_coordinate, already_transformed=True)
@classmethod
def from_coef_dict(cls, coordinates: AbstractCoordinates, gev_param_name_to_dims: Dict[str, List[int]],
......
......@@ -31,7 +31,7 @@ class LinearOneAxisParamFunction(AbstractParamFunction):
def get_gev_param_value(self, coordinate: np.ndarray) -> float:
t = coordinate[self.dim]
if self.OUT_OF_BOUNDS_ASSERT:
assert self.t_min <= t <= self.t_max, 'Out of bounds'
assert self.t_min <= t <= self.t_max, '{} is out of bounds ({}, {})'.format(t, self.t_min, self.t_max)
return t * self.coef
......
import random
import unittest
import numpy as np
import pandas as pd
from extreme_estimator.estimator.margin_estimator.abstract_margin_estimator import LinearMarginEstimator
from extreme_estimator.extreme_models.margin_model.linear_margin_model import LinearNonStationaryLocationMarginModel, \
LinearStationaryMarginModel
from extreme_estimator.extreme_models.margin_model.temporal_linear_margin_model import StationaryStationModel, \
NonStationaryStationModel
from extreme_estimator.extreme_models.utils import r, set_seed_r, set_seed_for_test
from extreme_estimator.margin_fits.gev.gevmle_fit import GevMleFit
from extreme_estimator.margin_fits.gev.ismev_gev_fit import IsmevGevFit
from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
from spatio_temporal_dataset.coordinates.spatio_temporal_coordinates.abstract_spatio_temporal_coordinates import \
AbstractSpatioTemporalCoordinates
from spatio_temporal_dataset.coordinates.temporal_coordinates.abstract_temporal_coordinates import \
AbstractTemporalCoordinates
from spatio_temporal_dataset.coordinates.transformed_coordinates.transformation.uniform_normalization import \
BetweenZeroAndOneNormalization
from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
from spatio_temporal_dataset.dataset.simulation_dataset import MarginDataset
from spatio_temporal_dataset.spatio_temporal_observations.abstract_spatio_temporal_observations import \
AbstractSpatioTemporalObservations
from test.test_utils import load_test_spatiotemporal_coordinates, load_smooth_margin_models
class TestMarginTemporal(unittest.TestCase):
def setUp(self) -> None:
set_seed_for_test(seed=42)
self.nb_points = 2
self.nb_steps = 50
self.nb_obs = 1
# Load some 2D spatial coordinates
self.coordinates = load_test_spatiotemporal_coordinates(nb_steps=self.nb_steps, nb_points=self.nb_points,
transformation_class=BetweenZeroAndOneNormalization)[1] # type: AbstractSpatioTemporalCoordinates
self.smooth_margin_model = LinearNonStationaryLocationMarginModel(coordinates=self.coordinates,
starting_point=2)
self.dataset = MarginDataset.from_sampling(nb_obs=self.nb_obs,
margin_model=self.smooth_margin_model,
coordinates=self.coordinates)
def test_margin_fit_stationary(self):
# Create estimator
margin_model = LinearStationaryMarginModel(self.coordinates)
estimator = LinearMarginEstimator(self.dataset, margin_model)
estimator.fit()
ref = {'loc': 1.1650543404552496, 'scale': 1.1097775613768615, 'shape': 0.6737277802240037}
for year in range(1, 3):
coordinate = np.array([0.0, 0.0, year])
mle_params_estimated = estimator.margin_function_fitted.get_gev_params(coordinate).to_dict()
for key in ref.keys():
self.assertAlmostEqual(ref[key], mle_params_estimated[key], places=3)
def test_margin_fit_nonstationary(self):
# Create estimator
margin_model = LinearNonStationaryLocationMarginModel(self.coordinates)
estimator = LinearMarginEstimator(self.dataset, margin_model)
estimator.fit()
self.assertNotEqual(estimator.margin_function_fitted.mu1_temporal_trend, 0.0)
# Checks that parameters returned are indeed different
coordinate1 = np.array([0.0, 0.0, 1])
mle_params_estimated_year1 = estimator.margin_function_fitted.get_gev_params(coordinate1).to_dict()
coordinate3 = np.array([0.0, 0.0, 3])
mle_params_estimated_year3 = estimator.margin_function_fitted.get_gev_params(coordinate3).to_dict()
self.assertNotEqual(mle_params_estimated_year1, mle_params_estimated_year3)
def test_margin_fit_nonstationary_with_start_point(self):
# Create estimator
estimator = self.fit_non_stationary_estimator(starting_point=2)
# By default, estimator find the good margin
self.assertNotEqual(estimator.margin_function_fitted.mu1_temporal_trend, 0.0)
# Checks that parameters returned are indeed different
coordinate1 = np.array([0.0, 0.0, 1])
mle_params_estimated_year1 = estimator.margin_function_fitted.get_gev_params(coordinate1).to_dict()
coordinate2 = np.array([0.0, 0.0, 2])
mle_params_estimated_year2 = estimator.margin_function_fitted.get_gev_params(coordinate2).to_dict()
self.assertEqual(mle_params_estimated_year1, mle_params_estimated_year2)
coordinate5 = np.array([0.0, 0.0, 5])
mle_params_estimated_year5 = estimator.margin_function_fitted.get_gev_params(coordinate5).to_dict()
self.assertNotEqual(mle_params_estimated_year5, mle_params_estimated_year2)
def fit_non_stationary_estimator(self, starting_point):
margin_model = LinearNonStationaryLocationMarginModel(self.coordinates, starting_point=starting_point)
estimator = LinearMarginEstimator(self.dataset, margin_model)
estimator.fit()
return estimator
def test_two_different_starting_points(self):
# Create two different estimators
estimator1 = self.fit_non_stationary_estimator(starting_point=3)
estimator2 = self.fit_non_stationary_estimator(starting_point=20)
mu1_estimator1 = estimator1.margin_function_fitted.mu1_temporal_trend
mu1_estimator2 = estimator2.margin_function_fitted.mu1_temporal_trend
self.assertNotEqual(mu1_estimator1, mu1_estimator2)
if __name__ == '__main__':
unittest.main()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment