From f6d13838f94fa9af4ec94a9bca273ccf8b981960 Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Thu, 19 Mar 2020 18:21:33 +0100
Subject: [PATCH] [refactor] add warning when too much zero values. improve
 warning test.

---
 extreme_fit/estimator/abstract_estimator.py   |  6 ++---
 .../abstract_temporal_linear_margin_model.py  |  5 +++--
 extreme_fit/model/utils.py                    | 22 ++++++++++++++-----
 ...dy_visualizer_for_non_stationary_trends.py |  1 -
 .../test_model/test_safe_run_r_estimator.py   | 22 +++++++++++--------
 5 files changed, 36 insertions(+), 20 deletions(-)

diff --git a/extreme_fit/estimator/abstract_estimator.py b/extreme_fit/estimator/abstract_estimator.py
index 39a695af..a2c8114d 100644
--- a/extreme_fit/estimator/abstract_estimator.py
+++ b/extreme_fit/estimator/abstract_estimator.py
@@ -15,10 +15,10 @@ class AbstractEstimator(object):
         self.dataset = dataset  # type: AbstractDataset
         self._result_from_fit = None  # type: Union[None, AbstractResultFromModelFit]
 
-    # Class constructor
+    # Class constructor (shortcut to initialize some subclasses)
     @classmethod
     def from_dataset(cls, dataset: AbstractDataset):
-        raise NotImplementedError
+        return cls(dataset)
 
     # Fit estimator
 
@@ -28,7 +28,7 @@ class AbstractEstimator(object):
     def _fit(self) -> AbstractResultFromModelFit:
         raise NotImplementedError
 
-    # Results from model fit
+    # Fit results
 
     @property
     def result_from_model_fit(self) -> AbstractResultFromModelFit:
diff --git a/extreme_fit/model/margin_model/linear_margin_model/abstract_temporal_linear_margin_model.py b/extreme_fit/model/margin_model/linear_margin_model/abstract_temporal_linear_margin_model.py
index be160d1e..e5e2ee45 100644
--- a/extreme_fit/model/margin_model/linear_margin_model/abstract_temporal_linear_margin_model.py
+++ b/extreme_fit/model/margin_model/linear_margin_model/abstract_temporal_linear_margin_model.py
@@ -39,8 +39,9 @@ class AbstractTemporalLinearMarginModel(LinearMarginModel):
 
     def fitmargin_from_maxima_gev(self, data: np.ndarray, df_coordinates_spat: pd.DataFrame,
                                   df_coordinates_temp: pd.DataFrame) -> AbstractResultFromModelFit:
-        assert data.shape[1] == len(df_coordinates_temp.values)
-        x = ro.FloatVector(data[0])
+        data = data[0]
+        assert len(data) == len(df_coordinates_temp.values)
+        x = ro.FloatVector(data)
         if self.fit_method == TemporalMarginFitMethod.is_mev_gev_fit:
             return self.ismev_gev_fit(x, df_coordinates_temp)
         if self.fit_method == TemporalMarginFitMethod.extremes_fevd_bayesian:
diff --git a/extreme_fit/model/utils.py b/extreme_fit/model/utils.py
index 1582cf33..582544e7 100644
--- a/extreme_fit/model/utils.py
+++ b/extreme_fit/model/utils.py
@@ -62,6 +62,10 @@ class WarningWhileRunningR(Warning):
     pass
 
 
+class WarningTooMuchZeroValues(Warning):
+    pass
+
+
 class WarningMaximumAbsoluteValueTooHigh(Warning):
     pass
 
@@ -74,7 +78,7 @@ class SafeRunException(Exception):
     pass
 
 
-def safe_run_r_estimator(function, data=None, use_start=False, threshold_max_abs_value=100, maxit=1000000,
+def safe_run_r_estimator(function, data=None, use_start=False, max_ratio_between_two_extremes_values=10, maxit=1000000,
                          **parameters) -> robjects.ListVector:
     if OptimizationConstants.USE_MAXIT:
         # Add optimization parameters
@@ -83,14 +87,21 @@ def safe_run_r_estimator(function, data=None, use_start=False, threshold_max_abs
 
     # Some checks for Spatial Extremes
     if data is not None:
-        # Raise warning if the maximum absolute value is above a threshold
         if isinstance(data, np.ndarray):
-            maximum_absolute_value = np.max(np.abs(data))
-            if maximum_absolute_value > threshold_max_abs_value:
+            # Raise warning if the gap is too important between the two biggest values of data
+            sorted_data = sorted(data.flatten())
+            print(data)
+            if sorted_data[-2] * max_ratio_between_two_extremes_values < sorted_data[-1]:
                 msg = "maxmimum absolute value in data {} is too high, i.e. above the defined threshold {}" \
-                    .format(maximum_absolute_value, threshold_max_abs_value)
+                    .format(sorted_data[-1], max_ratio_between_two_extremes_values)
                 msg += '\nPotentially in that case, data should be re-normalized'
                 warnings.warn(msg, WarningMaximumAbsoluteValueTooHigh)
+            # Raise warning if ratio of zeros in data is above some percentage (90% so far)
+            limit_percentage = 90
+            if 100 * np.count_nonzero(data) / len(data) < limit_percentage:
+                msg = 'data contains more than {}% of zero values'.format(100 - limit_percentage)
+                warnings.warn(msg, WarningTooMuchZeroValues)
+        # Add data to the parameters
         parameters['data'] = data
     # First run without using start value
     # Then if it crashes, use start value
@@ -131,6 +142,7 @@ def get_coord_df(df_coordinates: pd.DataFrame):
     coord = r('data.frame')(coord, stringsAsFactors=True)
     return coord
 
+
 def get_null():
     as_null = r['as.null']
     return as_null(1.0)
diff --git a/extreme_trend/visualizers/study_visualizer_for_non_stationary_trends.py b/extreme_trend/visualizers/study_visualizer_for_non_stationary_trends.py
index b4c14f5e..dc2561dc 100644
--- a/extreme_trend/visualizers/study_visualizer_for_non_stationary_trends.py
+++ b/extreme_trend/visualizers/study_visualizer_for_non_stationary_trends.py
@@ -137,7 +137,6 @@ class StudyVisualizerForNonStationaryTrends(StudyVisualizer):
         # In both cases, we remove any massif with psnow < 0.9
         if self.fit_only_time_series_with_ninety_percent_of_non_null_values:
             d = {m: v for m, v in d.items() if self.massif_name_to_psnow[m] >= 0.9}
-        print(d.keys())
         return d
 
     @property
diff --git a/test/test_extreme_fit/test_model/test_safe_run_r_estimator.py b/test/test_extreme_fit/test_model/test_safe_run_r_estimator.py
index 7bf1ef17..f65c6ab4 100644
--- a/test/test_extreme_fit/test_model/test_safe_run_r_estimator.py
+++ b/test/test_extreme_fit/test_model/test_safe_run_r_estimator.py
@@ -1,22 +1,26 @@
 import numpy as np
 import unittest
 
-from extreme_fit.model.utils import safe_run_r_estimator, WarningMaximumAbsoluteValueTooHigh
+from extreme_fit.model.utils import safe_run_r_estimator, WarningMaximumAbsoluteValueTooHigh, WarningTooMuchZeroValues
 
 
-def function(data=None, control=None):
+def empty_function(data=None, control=None):
     pass
 
 
 class TestSafeRunREstimator(unittest.TestCase):
 
-    def test_warning(self):
-        threshold = 10
-        value_above_threhsold = 2 * threshold
-        datas = [np.array([value_above_threhsold]), np.ones([2, 2]) * value_above_threhsold]
-        for data in datas:
-            with self.assertWarns(WarningMaximumAbsoluteValueTooHigh):
-                safe_run_r_estimator(function=function, data=data, threshold_max_abs_value=threshold)
+    def test_warning_maximum_value(self):
+        ratio = 10
+        data = np.array([ratio+1, 1])
+        with self.assertWarns(WarningMaximumAbsoluteValueTooHigh):
+            safe_run_r_estimator(function=empty_function, data=data, max_ratio_between_two_extremes_values=ratio)
+
+    def test_warning_too_much_zero(self):
+        n = 5
+        data = np.concatenate([np.zeros(n), np.ones(n)])
+        with self.assertWarns(WarningTooMuchZeroValues):
+            safe_run_r_estimator(function=empty_function, data=data)
 
 
 if __name__ == '__main__':
-- 
GitLab