Commit 8cc64c5b authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[refactor] add nan_if_undefined_wrapper. update tests.

parent a968f851
No related merge requests found
Showing with 38 additions and 28 deletions
+38 -28
......@@ -7,6 +7,7 @@ from mpmath import euler
from extreme_fit.distribution.abstract_extreme_params import AbstractExtremeParams
from extreme_fit.distribution.abstract_params import AbstractParams
from extreme_fit.distribution.utils_extreme_params import nan_if_undefined_wrapper
from extreme_fit.model.utils import r
import numpy as np
from scipy.special import gamma
......@@ -26,37 +27,24 @@ class GevParams(AbstractExtremeParams):
self.has_undefined_parameters = False
self.param_name_to_confidence_interval = None
@nan_if_undefined_wrapper
def sample(self, n) -> np.ndarray:
if self.has_undefined_parameters:
return np.nan
else:
return np.array(r.rgev(n, self.location, self.scale, self.shape))
return np.array(r.rgev(n, self.location, self.scale, self.shape))
@nan_if_undefined_wrapper
def quantile(self, p) -> float:
if self.has_undefined_parameters:
return np.nan
else:
return r.qgev(p, self.location, self.scale, self.shape)[0]
return r.qgev(p, self.location, self.scale, self.shape)[0]
def return_level(self, return_period):
return self.quantile(1 - 1 / return_period)
@nan_if_undefined_wrapper
def density(self, x, log_scale=False):
if self.has_undefined_parameters:
return np.nan
else:
res = r.dgev(x, self.location, self.scale, self.shape, log_scale)
if isinstance(x, float):
return res[0]
else:
return np.array(res)
@property
def param_values(self):
if self.has_undefined_parameters:
return [np.nan for _ in range(3)]
res = r.dgev(x, self.location, self.scale, self.shape, log_scale)
if isinstance(x, float):
return res[0]
else:
return [self.location, self.scale, self.shape]
return np.array(res)
def time_derivative_of_return_level(self, p=0.99, mu1=0.0, sigma1=0.0):
"""
......@@ -77,6 +65,11 @@ class GevParams(AbstractExtremeParams):
quantile_annual_variation -= (sigma1 / self.shape) * (1 - power)
return quantile_annual_variation
@property
@nan_if_undefined_wrapper
def param_values(self):
return [self.location, self.scale, self.shape]
# Compute some indicators (such as the mean and the variance)
def g(self, k) -> float:
......@@ -85,10 +78,9 @@ class GevParams(AbstractExtremeParams):
return gamma(1 - k * self.shape)
@property
@nan_if_undefined_wrapper
def mean(self) -> float:
if self.has_undefined_parameters:
mean = np.nan
elif self.shape >= 1:
if self.shape >= 1:
mean = np.inf
elif self.shape == 0:
mean = self.location + self.scale * float(euler)
......@@ -98,10 +90,9 @@ class GevParams(AbstractExtremeParams):
return mean
@property
@nan_if_undefined_wrapper
def variance(self) -> float:
if self.has_undefined_parameters:
return np.nan
elif self.shape >= 0.5:
if self.shape >= 0.5:
return np.inf
elif self.shape == 0.0:
return (self.scale * np.pi) ** 2 / 6
......
import numpy as np
from extreme_fit.distribution.abstract_extreme_params import AbstractExtremeParams
def nan_if_undefined_wrapper(func):
def wrapper(obj: AbstractExtremeParams, *args, **kwargs):
if obj.has_undefined_parameters:
return np.nan
return func(obj, *args, **kwargs)
return wrapper
......@@ -17,6 +17,13 @@ class TestGevParams(unittest.TestCase):
for quantile_name, p in gev_params.quantile_name_to_p.items():
self.assertAlmostEqual(- 1 / np.log(p), quantile_dict[quantile_name])
def test_wrapper(self):
gev_params = GevParams(loc=1.0, shape=1.0, scale=-1.0)
self.assertTrue(np.isnan(gev_params.quantile(p=0.5)))
self.assertTrue(np.isnan(gev_params.sample(n=10)))
self.assertTrue(np.isnan(gev_params.param_values))
self.assertTrue(np.isnan(gev_params.density(x=1.5)))
def test_time_derivative_return_level(self):
p = 0.99
for mu1 in [-1, 0, 1]:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment