From f6d13838f94fa9af4ec94a9bca273ccf8b981960 Mon Sep 17 00:00:00 2001 From: Le Roux Erwan <erwan.le-roux@irstea.fr> Date: Thu, 19 Mar 2020 18:21:33 +0100 Subject: [PATCH] [refactor] add warning when too much zero values. improve warning test. --- extreme_fit/estimator/abstract_estimator.py | 6 ++--- .../abstract_temporal_linear_margin_model.py | 5 +++-- extreme_fit/model/utils.py | 22 ++++++++++++++----- ...dy_visualizer_for_non_stationary_trends.py | 1 - .../test_model/test_safe_run_r_estimator.py | 22 +++++++++++-------- 5 files changed, 36 insertions(+), 20 deletions(-) diff --git a/extreme_fit/estimator/abstract_estimator.py b/extreme_fit/estimator/abstract_estimator.py index 39a695af..a2c8114d 100644 --- a/extreme_fit/estimator/abstract_estimator.py +++ b/extreme_fit/estimator/abstract_estimator.py @@ -15,10 +15,10 @@ class AbstractEstimator(object): self.dataset = dataset # type: AbstractDataset self._result_from_fit = None # type: Union[None, AbstractResultFromModelFit] - # Class constructor + # Class constructor (shortcut to initialize some subclasses) @classmethod def from_dataset(cls, dataset: AbstractDataset): - raise NotImplementedError + return cls(dataset) # Fit estimator @@ -28,7 +28,7 @@ class AbstractEstimator(object): def _fit(self) -> AbstractResultFromModelFit: raise NotImplementedError - # Results from model fit + # Fit results @property def result_from_model_fit(self) -> AbstractResultFromModelFit: diff --git a/extreme_fit/model/margin_model/linear_margin_model/abstract_temporal_linear_margin_model.py b/extreme_fit/model/margin_model/linear_margin_model/abstract_temporal_linear_margin_model.py index be160d1e..e5e2ee45 100644 --- a/extreme_fit/model/margin_model/linear_margin_model/abstract_temporal_linear_margin_model.py +++ b/extreme_fit/model/margin_model/linear_margin_model/abstract_temporal_linear_margin_model.py @@ -39,8 +39,9 @@ class AbstractTemporalLinearMarginModel(LinearMarginModel): def fitmargin_from_maxima_gev(self, data: np.ndarray, df_coordinates_spat: pd.DataFrame, df_coordinates_temp: pd.DataFrame) -> AbstractResultFromModelFit: - assert data.shape[1] == len(df_coordinates_temp.values) - x = ro.FloatVector(data[0]) + data = data[0] + assert len(data) == len(df_coordinates_temp.values) + x = ro.FloatVector(data) if self.fit_method == TemporalMarginFitMethod.is_mev_gev_fit: return self.ismev_gev_fit(x, df_coordinates_temp) if self.fit_method == TemporalMarginFitMethod.extremes_fevd_bayesian: diff --git a/extreme_fit/model/utils.py b/extreme_fit/model/utils.py index 1582cf33..582544e7 100644 --- a/extreme_fit/model/utils.py +++ b/extreme_fit/model/utils.py @@ -62,6 +62,10 @@ class WarningWhileRunningR(Warning): pass +class WarningTooMuchZeroValues(Warning): + pass + + class WarningMaximumAbsoluteValueTooHigh(Warning): pass @@ -74,7 +78,7 @@ class SafeRunException(Exception): pass -def safe_run_r_estimator(function, data=None, use_start=False, threshold_max_abs_value=100, maxit=1000000, +def safe_run_r_estimator(function, data=None, use_start=False, max_ratio_between_two_extremes_values=10, maxit=1000000, **parameters) -> robjects.ListVector: if OptimizationConstants.USE_MAXIT: # Add optimization parameters @@ -83,14 +87,21 @@ def safe_run_r_estimator(function, data=None, use_start=False, threshold_max_abs # Some checks for Spatial Extremes if data is not None: - # Raise warning if the maximum absolute value is above a threshold if isinstance(data, np.ndarray): - maximum_absolute_value = np.max(np.abs(data)) - if maximum_absolute_value > threshold_max_abs_value: + # Raise warning if the gap is too important between the two biggest values of data + sorted_data = sorted(data.flatten()) + print(data) + if sorted_data[-2] * max_ratio_between_two_extremes_values < sorted_data[-1]: msg = "maxmimum absolute value in data {} is too high, i.e. above the defined threshold {}" \ - .format(maximum_absolute_value, threshold_max_abs_value) + .format(sorted_data[-1], max_ratio_between_two_extremes_values) msg += '\nPotentially in that case, data should be re-normalized' warnings.warn(msg, WarningMaximumAbsoluteValueTooHigh) + # Raise warning if ratio of zeros in data is above some percentage (90% so far) + limit_percentage = 90 + if 100 * np.count_nonzero(data) / len(data) < limit_percentage: + msg = 'data contains more than {}% of zero values'.format(100 - limit_percentage) + warnings.warn(msg, WarningTooMuchZeroValues) + # Add data to the parameters parameters['data'] = data # First run without using start value # Then if it crashes, use start value @@ -131,6 +142,7 @@ def get_coord_df(df_coordinates: pd.DataFrame): coord = r('data.frame')(coord, stringsAsFactors=True) return coord + def get_null(): as_null = r['as.null'] return as_null(1.0) diff --git a/extreme_trend/visualizers/study_visualizer_for_non_stationary_trends.py b/extreme_trend/visualizers/study_visualizer_for_non_stationary_trends.py index b4c14f5e..dc2561dc 100644 --- a/extreme_trend/visualizers/study_visualizer_for_non_stationary_trends.py +++ b/extreme_trend/visualizers/study_visualizer_for_non_stationary_trends.py @@ -137,7 +137,6 @@ class StudyVisualizerForNonStationaryTrends(StudyVisualizer): # In both cases, we remove any massif with psnow < 0.9 if self.fit_only_time_series_with_ninety_percent_of_non_null_values: d = {m: v for m, v in d.items() if self.massif_name_to_psnow[m] >= 0.9} - print(d.keys()) return d @property diff --git a/test/test_extreme_fit/test_model/test_safe_run_r_estimator.py b/test/test_extreme_fit/test_model/test_safe_run_r_estimator.py index 7bf1ef17..f65c6ab4 100644 --- a/test/test_extreme_fit/test_model/test_safe_run_r_estimator.py +++ b/test/test_extreme_fit/test_model/test_safe_run_r_estimator.py @@ -1,22 +1,26 @@ import numpy as np import unittest -from extreme_fit.model.utils import safe_run_r_estimator, WarningMaximumAbsoluteValueTooHigh +from extreme_fit.model.utils import safe_run_r_estimator, WarningMaximumAbsoluteValueTooHigh, WarningTooMuchZeroValues -def function(data=None, control=None): +def empty_function(data=None, control=None): pass class TestSafeRunREstimator(unittest.TestCase): - def test_warning(self): - threshold = 10 - value_above_threhsold = 2 * threshold - datas = [np.array([value_above_threhsold]), np.ones([2, 2]) * value_above_threhsold] - for data in datas: - with self.assertWarns(WarningMaximumAbsoluteValueTooHigh): - safe_run_r_estimator(function=function, data=data, threshold_max_abs_value=threshold) + def test_warning_maximum_value(self): + ratio = 10 + data = np.array([ratio+1, 1]) + with self.assertWarns(WarningMaximumAbsoluteValueTooHigh): + safe_run_r_estimator(function=empty_function, data=data, max_ratio_between_two_extremes_values=ratio) + + def test_warning_too_much_zero(self): + n = 5 + data = np.concatenate([np.zeros(n), np.ones(n)]) + with self.assertWarns(WarningTooMuchZeroValues): + safe_run_r_estimator(function=empty_function, data=data) if __name__ == '__main__': -- GitLab