From 0494d2bc38bd9feead7bfff9e8f6722039c6c036 Mon Sep 17 00:00:00 2001 From: Le Roux Erwan <erwan.le-roux@irstea.fr> Date: Mon, 4 Mar 2019 16:30:23 +0100 Subject: [PATCH] [EXTREME ESTIMATOR][EXTREME MODEL] add warning when maximum absolute value of data is too high in safe_run_r_estimator. add test. --- extreme_estimator/extreme_models/utils.py | 21 +++++++++++++++-- .../test_safe_run_r_estimator.py | 23 +++++++++++++++++++ 2 files changed, 42 insertions(+), 2 deletions(-) create mode 100644 test/test_extreme_estimator/test_extreme_models/test_safe_run_r_estimator.py diff --git a/extreme_estimator/extreme_models/utils.py b/extreme_estimator/extreme_models/utils.py index ec8defa7..d404a994 100644 --- a/extreme_estimator/extreme_models/utils.py +++ b/extreme_estimator/extreme_models/utils.py @@ -1,7 +1,11 @@ import os.path as op +import warnings + +import numpy as np import random import sys -from typing import Dict +from types import TracebackType +from typing import Dict, Optional import pandas as pd import rpy2.robjects as ro @@ -33,7 +37,20 @@ def get_associated_r_file(python_filepath: str) -> str: return r_filepath -def safe_run_r_estimator(function, use_start=False, **parameters): +class WarningMaximumAbsoluteValueTooHigh(Warning): + pass + + +def safe_run_r_estimator(function, data, use_start=False, threshold_max_abs_value=100, **parameters): + # Raise warning if the maximum absolute value is above a threshold + assert isinstance(data, np.ndarray) + maximum_absolute_value = np.max(np.abs(data)) + if maximum_absolute_value > threshold_max_abs_value: + msg = "maxmimum absolute value in data {} is too high, i.e. above the defined threshold {}"\ + .format(maximum_absolute_value, threshold_max_abs_value) + msg += '\nPotentially in that case, data should be re-normalized' + warnings.warn(msg, WarningMaximumAbsoluteValueTooHigh) + parameters['data'] = data # First run without using start value # Then if it crashes, use start value run_successful = False diff --git a/test/test_extreme_estimator/test_extreme_models/test_safe_run_r_estimator.py b/test/test_extreme_estimator/test_extreme_models/test_safe_run_r_estimator.py new file mode 100644 index 00000000..d69cdbce --- /dev/null +++ b/test/test_extreme_estimator/test_extreme_models/test_safe_run_r_estimator.py @@ -0,0 +1,23 @@ +import numpy as np +import unittest + +from extreme_estimator.extreme_models.utils import safe_run_r_estimator, WarningMaximumAbsoluteValueTooHigh + + +def function(data): + pass + + +class TestSafeRunREstimator(unittest.TestCase): + + def test_warning(self): + threshold = 10 + value_above_threhsold = 2 * threshold + datas = [np.array([value_above_threhsold]), np.ones([2, 2]) * value_above_threhsold] + for data in datas: + with self.assertWarns(WarningMaximumAbsoluteValueTooHigh): + safe_run_r_estimator(function=function, data=data, threshold_max_abs_value=threshold) + + +if __name__ == '__main__': + unittest.main() -- GitLab