From 8cc64c5b82f36241fbac9f29cb7a11bdaaf1be10 Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Thu, 25 Feb 2021 10:43:43 +0100
Subject: [PATCH] [refactor] add nan_if_undefined_wrapper. update tests.

---
 extreme_fit/distribution/gev/gev_params.py    | 47 ++++++++-----------
 .../distribution/utils_extreme_params.py      | 12 +++++
 .../__init__.py                               |  0
 .../test_gev/test_gev_params.py               |  7 +++
 4 files changed, 38 insertions(+), 28 deletions(-)
 create mode 100644 extreme_fit/distribution/utils_extreme_params.py
 delete mode 100644 projects/contrasting_trends_in_snow_loads/__init__.py

diff --git a/extreme_fit/distribution/gev/gev_params.py b/extreme_fit/distribution/gev/gev_params.py
index a233f53c..e83e787f 100644
--- a/extreme_fit/distribution/gev/gev_params.py
+++ b/extreme_fit/distribution/gev/gev_params.py
@@ -7,6 +7,7 @@ from mpmath import euler
 
 from extreme_fit.distribution.abstract_extreme_params import AbstractExtremeParams
 from extreme_fit.distribution.abstract_params import AbstractParams
+from extreme_fit.distribution.utils_extreme_params import nan_if_undefined_wrapper
 from extreme_fit.model.utils import r
 import numpy as np
 from scipy.special import gamma
@@ -26,37 +27,24 @@ class GevParams(AbstractExtremeParams):
             self.has_undefined_parameters = False
         self.param_name_to_confidence_interval = None
 
+    @nan_if_undefined_wrapper
     def sample(self, n) -> np.ndarray:
-        if self.has_undefined_parameters:
-            return np.nan
-        else:
-            return np.array(r.rgev(n, self.location, self.scale, self.shape))
+        return np.array(r.rgev(n, self.location, self.scale, self.shape))
 
+    @nan_if_undefined_wrapper
     def quantile(self, p) -> float:
-        if self.has_undefined_parameters:
-            return np.nan
-        else:
-            return r.qgev(p, self.location, self.scale, self.shape)[0]
+        return r.qgev(p, self.location, self.scale, self.shape)[0]
 
     def return_level(self, return_period):
         return self.quantile(1 - 1 / return_period)
 
+    @nan_if_undefined_wrapper
     def density(self, x, log_scale=False):
-        if self.has_undefined_parameters:
-            return np.nan
-        else:
-            res = r.dgev(x, self.location, self.scale, self.shape, log_scale)
-            if isinstance(x, float):
-                return res[0]
-            else:
-                return np.array(res)
-
-    @property
-    def param_values(self):
-        if self.has_undefined_parameters:
-            return [np.nan for _ in range(3)]
+        res = r.dgev(x, self.location, self.scale, self.shape, log_scale)
+        if isinstance(x, float):
+            return res[0]
         else:
-            return [self.location, self.scale, self.shape]
+            return np.array(res)
 
     def time_derivative_of_return_level(self, p=0.99, mu1=0.0, sigma1=0.0):
         """
@@ -77,6 +65,11 @@ class GevParams(AbstractExtremeParams):
                 quantile_annual_variation -= (sigma1 / self.shape) * (1 - power)
         return quantile_annual_variation
 
+    @property
+    @nan_if_undefined_wrapper
+    def param_values(self):
+        return [self.location, self.scale, self.shape]
+
     # Compute some indicators (such as the mean and the variance)
 
     def g(self, k) -> float:
@@ -85,10 +78,9 @@ class GevParams(AbstractExtremeParams):
         return gamma(1 - k * self.shape)
 
     @property
+    @nan_if_undefined_wrapper
     def mean(self) -> float:
-        if self.has_undefined_parameters:
-            mean = np.nan
-        elif self.shape >= 1:
+        if self.shape >= 1:
             mean = np.inf
         elif self.shape == 0:
             mean = self.location + self.scale * float(euler)
@@ -98,10 +90,9 @@ class GevParams(AbstractExtremeParams):
         return mean
 
     @property
+    @nan_if_undefined_wrapper
     def variance(self) -> float:
-        if self.has_undefined_parameters:
-            return np.nan
-        elif self.shape >= 0.5:
+        if self.shape >= 0.5:
             return np.inf
         elif self.shape == 0.0:
             return (self.scale * np.pi) ** 2 / 6
diff --git a/extreme_fit/distribution/utils_extreme_params.py b/extreme_fit/distribution/utils_extreme_params.py
new file mode 100644
index 00000000..ce61e72e
--- /dev/null
+++ b/extreme_fit/distribution/utils_extreme_params.py
@@ -0,0 +1,12 @@
+import numpy as np
+
+from extreme_fit.distribution.abstract_extreme_params import AbstractExtremeParams
+
+
+def nan_if_undefined_wrapper(func):
+    def wrapper(obj: AbstractExtremeParams, *args, **kwargs):
+        if obj.has_undefined_parameters:
+            return np.nan
+        return func(obj, *args, **kwargs)
+
+    return wrapper
diff --git a/projects/contrasting_trends_in_snow_loads/__init__.py b/projects/contrasting_trends_in_snow_loads/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/test/test_extreme_fit/test_distribution/test_gev/test_gev_params.py b/test/test_extreme_fit/test_distribution/test_gev/test_gev_params.py
index 52a515f7..607e47a1 100644
--- a/test/test_extreme_fit/test_distribution/test_gev/test_gev_params.py
+++ b/test/test_extreme_fit/test_distribution/test_gev/test_gev_params.py
@@ -17,6 +17,13 @@ class TestGevParams(unittest.TestCase):
         for quantile_name, p in gev_params.quantile_name_to_p.items():
             self.assertAlmostEqual(- 1 / np.log(p), quantile_dict[quantile_name])
 
+    def test_wrapper(self):
+        gev_params = GevParams(loc=1.0, shape=1.0, scale=-1.0)
+        self.assertTrue(np.isnan(gev_params.quantile(p=0.5)))
+        self.assertTrue(np.isnan(gev_params.sample(n=10)))
+        self.assertTrue(np.isnan(gev_params.param_values))
+        self.assertTrue(np.isnan(gev_params.density(x=1.5)))
+
     def test_time_derivative_return_level(self):
         p = 0.99
         for mu1 in [-1, 0, 1]:
-- 
GitLab