From 8cc64c5b82f36241fbac9f29cb7a11bdaaf1be10 Mon Sep 17 00:00:00 2001 From: Le Roux Erwan <erwan.le-roux@irstea.fr> Date: Thu, 25 Feb 2021 10:43:43 +0100 Subject: [PATCH] [refactor] add nan_if_undefined_wrapper. update tests. --- extreme_fit/distribution/gev/gev_params.py | 47 ++++++++----------- .../distribution/utils_extreme_params.py | 12 +++++ .../__init__.py | 0 .../test_gev/test_gev_params.py | 7 +++ 4 files changed, 38 insertions(+), 28 deletions(-) create mode 100644 extreme_fit/distribution/utils_extreme_params.py delete mode 100644 projects/contrasting_trends_in_snow_loads/__init__.py diff --git a/extreme_fit/distribution/gev/gev_params.py b/extreme_fit/distribution/gev/gev_params.py index a233f53c..e83e787f 100644 --- a/extreme_fit/distribution/gev/gev_params.py +++ b/extreme_fit/distribution/gev/gev_params.py @@ -7,6 +7,7 @@ from mpmath import euler from extreme_fit.distribution.abstract_extreme_params import AbstractExtremeParams from extreme_fit.distribution.abstract_params import AbstractParams +from extreme_fit.distribution.utils_extreme_params import nan_if_undefined_wrapper from extreme_fit.model.utils import r import numpy as np from scipy.special import gamma @@ -26,37 +27,24 @@ class GevParams(AbstractExtremeParams): self.has_undefined_parameters = False self.param_name_to_confidence_interval = None + @nan_if_undefined_wrapper def sample(self, n) -> np.ndarray: - if self.has_undefined_parameters: - return np.nan - else: - return np.array(r.rgev(n, self.location, self.scale, self.shape)) + return np.array(r.rgev(n, self.location, self.scale, self.shape)) + @nan_if_undefined_wrapper def quantile(self, p) -> float: - if self.has_undefined_parameters: - return np.nan - else: - return r.qgev(p, self.location, self.scale, self.shape)[0] + return r.qgev(p, self.location, self.scale, self.shape)[0] def return_level(self, return_period): return self.quantile(1 - 1 / return_period) + @nan_if_undefined_wrapper def density(self, x, log_scale=False): - if self.has_undefined_parameters: - return np.nan - else: - res = r.dgev(x, self.location, self.scale, self.shape, log_scale) - if isinstance(x, float): - return res[0] - else: - return np.array(res) - - @property - def param_values(self): - if self.has_undefined_parameters: - return [np.nan for _ in range(3)] + res = r.dgev(x, self.location, self.scale, self.shape, log_scale) + if isinstance(x, float): + return res[0] else: - return [self.location, self.scale, self.shape] + return np.array(res) def time_derivative_of_return_level(self, p=0.99, mu1=0.0, sigma1=0.0): """ @@ -77,6 +65,11 @@ class GevParams(AbstractExtremeParams): quantile_annual_variation -= (sigma1 / self.shape) * (1 - power) return quantile_annual_variation + @property + @nan_if_undefined_wrapper + def param_values(self): + return [self.location, self.scale, self.shape] + # Compute some indicators (such as the mean and the variance) def g(self, k) -> float: @@ -85,10 +78,9 @@ class GevParams(AbstractExtremeParams): return gamma(1 - k * self.shape) @property + @nan_if_undefined_wrapper def mean(self) -> float: - if self.has_undefined_parameters: - mean = np.nan - elif self.shape >= 1: + if self.shape >= 1: mean = np.inf elif self.shape == 0: mean = self.location + self.scale * float(euler) @@ -98,10 +90,9 @@ class GevParams(AbstractExtremeParams): return mean @property + @nan_if_undefined_wrapper def variance(self) -> float: - if self.has_undefined_parameters: - return np.nan - elif self.shape >= 0.5: + if self.shape >= 0.5: return np.inf elif self.shape == 0.0: return (self.scale * np.pi) ** 2 / 6 diff --git a/extreme_fit/distribution/utils_extreme_params.py b/extreme_fit/distribution/utils_extreme_params.py new file mode 100644 index 00000000..ce61e72e --- /dev/null +++ b/extreme_fit/distribution/utils_extreme_params.py @@ -0,0 +1,12 @@ +import numpy as np + +from extreme_fit.distribution.abstract_extreme_params import AbstractExtremeParams + + +def nan_if_undefined_wrapper(func): + def wrapper(obj: AbstractExtremeParams, *args, **kwargs): + if obj.has_undefined_parameters: + return np.nan + return func(obj, *args, **kwargs) + + return wrapper diff --git a/projects/contrasting_trends_in_snow_loads/__init__.py b/projects/contrasting_trends_in_snow_loads/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/test/test_extreme_fit/test_distribution/test_gev/test_gev_params.py b/test/test_extreme_fit/test_distribution/test_gev/test_gev_params.py index 52a515f7..607e47a1 100644 --- a/test/test_extreme_fit/test_distribution/test_gev/test_gev_params.py +++ b/test/test_extreme_fit/test_distribution/test_gev/test_gev_params.py @@ -17,6 +17,13 @@ class TestGevParams(unittest.TestCase): for quantile_name, p in gev_params.quantile_name_to_p.items(): self.assertAlmostEqual(- 1 / np.log(p), quantile_dict[quantile_name]) + def test_wrapper(self): + gev_params = GevParams(loc=1.0, shape=1.0, scale=-1.0) + self.assertTrue(np.isnan(gev_params.quantile(p=0.5))) + self.assertTrue(np.isnan(gev_params.sample(n=10))) + self.assertTrue(np.isnan(gev_params.param_values)) + self.assertTrue(np.isnan(gev_params.density(x=1.5))) + def test_time_derivative_return_level(self): p = 0.99 for mu1 in [-1, 0, 1]: -- GitLab