Commit 0494d2bc authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[EXTREME ESTIMATOR][EXTREME MODEL] add warning when maximum absolute value of...

[EXTREME ESTIMATOR][EXTREME MODEL] add warning when maximum absolute value of data is too high in safe_run_r_estimator. add test.
parent 528c2840
No related merge requests found
Showing with 42 additions and 2 deletions
+42 -2
import os.path as op import os.path as op
import warnings
import numpy as np
import random import random
import sys import sys
from typing import Dict from types import TracebackType
from typing import Dict, Optional
import pandas as pd import pandas as pd
import rpy2.robjects as ro import rpy2.robjects as ro
...@@ -33,7 +37,20 @@ def get_associated_r_file(python_filepath: str) -> str: ...@@ -33,7 +37,20 @@ def get_associated_r_file(python_filepath: str) -> str:
return r_filepath return r_filepath
def safe_run_r_estimator(function, use_start=False, **parameters): class WarningMaximumAbsoluteValueTooHigh(Warning):
pass
def safe_run_r_estimator(function, data, use_start=False, threshold_max_abs_value=100, **parameters):
# Raise warning if the maximum absolute value is above a threshold
assert isinstance(data, np.ndarray)
maximum_absolute_value = np.max(np.abs(data))
if maximum_absolute_value > threshold_max_abs_value:
msg = "maxmimum absolute value in data {} is too high, i.e. above the defined threshold {}"\
.format(maximum_absolute_value, threshold_max_abs_value)
msg += '\nPotentially in that case, data should be re-normalized'
warnings.warn(msg, WarningMaximumAbsoluteValueTooHigh)
parameters['data'] = data
# First run without using start value # First run without using start value
# Then if it crashes, use start value # Then if it crashes, use start value
run_successful = False run_successful = False
......
import numpy as np
import unittest
from extreme_estimator.extreme_models.utils import safe_run_r_estimator, WarningMaximumAbsoluteValueTooHigh
def function(data):
pass
class TestSafeRunREstimator(unittest.TestCase):
def test_warning(self):
threshold = 10
value_above_threhsold = 2 * threshold
datas = [np.array([value_above_threhsold]), np.ones([2, 2]) * value_above_threhsold]
for data in datas:
with self.assertWarns(WarningMaximumAbsoluteValueTooHigh):
safe_run_r_estimator(function=function, data=data, threshold_max_abs_value=threshold)
if __name__ == '__main__':
unittest.main()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment