Commit 2b00370c authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[refactor] refactor main fitting function: by default, we fit without starting...

[refactor] refactor main fitting function: by default, we fit without starting value, then we try several time with the start value
parent 69f5b00f
No related merge requests found
Showing with 36 additions and 24 deletions
+36 -24
...@@ -8,7 +8,7 @@ class AbstractModel(object): ...@@ -8,7 +8,7 @@ class AbstractModel(object):
@property @property
def params_start_fit(self) -> dict: def params_start_fit(self) -> dict:
return self.merge_params(default_params=self.default_params, input_params=self.user_params_start_fit) return self.default_params.copy()
@property @property
def params_sample(self) -> dict: def params_sample(self) -> dict:
......
...@@ -57,7 +57,7 @@ class AbstractTemporalLinearMarginModel(LinearMarginModel): ...@@ -57,7 +57,7 @@ class AbstractTemporalLinearMarginModel(LinearMarginModel):
def ismev_gev_fit(self, x, df_coordinates_temp) -> ResultFromIsmev: def ismev_gev_fit(self, x, df_coordinates_temp) -> ResultFromIsmev:
y = df_coordinates_temp.values y = df_coordinates_temp.values
res = safe_run_r_estimator(function=r('gev.fit'), use_start=self.use_start_value, res = safe_run_r_estimator(function=r('gev.fit'),
xdat=x, y=y, mul=self.mul, xdat=x, y=y, mul=self.mul,
sigl=self.sigl, shl=self.shl) sigl=self.sigl, shl=self.shl)
return ResultFromIsmev(res, self.margin_function_start_fit.param_name_to_dims) return ResultFromIsmev(res, self.margin_function_start_fit.param_name_to_dims)
......
...@@ -6,11 +6,11 @@ import pandas as pd ...@@ -6,11 +6,11 @@ import pandas as pd
from extreme_fit.distribution.gev.gev_params import GevParams from extreme_fit.distribution.gev.gev_params import GevParams
from extreme_fit.function.margin_function.parametric_margin_function import \ from extreme_fit.function.margin_function.parametric_margin_function import \
ParametricMarginFunction ParametricMarginFunction
from extreme_fit.model.margin_model.abstract_margin_model import AbstractMarginModel
from extreme_fit.model.margin_model.utils import MarginFitMethod from extreme_fit.model.margin_model.utils import MarginFitMethod
from extreme_fit.model.result_from_model_fit.result_from_spatial_extreme import ResultFromSpatialExtreme from extreme_fit.model.result_from_model_fit.result_from_spatial_extreme import ResultFromSpatialExtreme
from extreme_fit.model.margin_model.abstract_margin_model import AbstractMarginModel from extreme_fit.model.utils import r, get_coord, \
from extreme_fit.model.utils import safe_run_r_estimator, r, get_coord, \ get_margin_formula_spatial_extreme, safe_run_r_estimator
get_margin_formula_spatial_extreme
from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
...@@ -45,7 +45,8 @@ class ParametricMarginModel(AbstractMarginModel, ABC): ...@@ -45,7 +45,8 @@ class ParametricMarginModel(AbstractMarginModel, ABC):
fit_params['temp.cov'] = get_coord(df_coordinates=df_coordinates_temp) fit_params['temp.cov'] = get_coord(df_coordinates=df_coordinates_temp)
# Start parameters # Start parameters
coef_dict = self.margin_function_start_fit.coef_dict coef_dict = self.margin_function_start_fit.coef_dict
fit_params['start'] = r.list(**coef_dict) # fit_params['start'] = r.list(**coef_dict)
res = safe_run_r_estimator(function=r.fitspatgev, use_start=self.use_start_value, data=data, res = safe_run_r_estimator(function=r.fitspatgev, data=data,
start_dict=coef_dict,
covariables=covariables, **fit_params) covariables=covariables, **fit_params)
return ResultFromSpatialExtreme(res) return ResultFromSpatialExtreme(res)
...@@ -7,8 +7,8 @@ from rpy2.rinterface._rinterface import RRuntimeError ...@@ -7,8 +7,8 @@ from rpy2.rinterface._rinterface import RRuntimeError
from extreme_fit.model.abstract_model import AbstractModel from extreme_fit.model.abstract_model import AbstractModel
from extreme_fit.model.result_from_model_fit.result_from_spatial_extreme import ResultFromSpatialExtreme from extreme_fit.model.result_from_model_fit.result_from_spatial_extreme import ResultFromSpatialExtreme
from extreme_fit.model.utils import r, safe_run_r_estimator, get_coord, \ from extreme_fit.model.utils import r, get_coord, \
get_margin_formula_spatial_extreme, SafeRunException get_margin_formula_spatial_extreme, SafeRunException, safe_run_r_estimator
from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
...@@ -58,7 +58,6 @@ class AbstractMaxStableModel(AbstractModel): ...@@ -58,7 +58,6 @@ class AbstractMaxStableModel(AbstractModel):
fit_params.update(margin_formulas) fit_params.update(margin_formulas)
if fitmaxstab_with_one_dimensional_data: if fitmaxstab_with_one_dimensional_data:
fit_params['iso'] = True fit_params['iso'] = True
fit_params['start'] = r.list(**start_dict)
fit_params['fit.marge'] = fit_marge fit_params['fit.marge'] = fit_marge
# Add some temporal covariates # Add some temporal covariates
...@@ -69,7 +68,8 @@ class AbstractMaxStableModel(AbstractModel): ...@@ -69,7 +68,8 @@ class AbstractMaxStableModel(AbstractModel):
fit_params['temp.cov'] = get_coord(df_coordinates_temp) fit_params['temp.cov'] = get_coord(df_coordinates_temp)
# Run the fitmaxstab in R # Run the fitmaxstab in R
res = safe_run_r_estimator(function=r.fitmaxstab, use_start=self.use_start_value, data=data, coord=coord, res = safe_run_r_estimator(function=r.fitmaxstab, data=data, coord=coord,
start_dict=start_dict,
**fit_params) **fit_params)
return ResultFromSpatialExtreme(res) return ResultFromSpatialExtreme(res)
......
...@@ -34,6 +34,7 @@ class BrownResnick(AbstractMaxStableModel): ...@@ -34,6 +34,7 @@ class BrownResnick(AbstractMaxStableModel):
'smooth': 0.5, 'smooth': 0.5,
} }
class Schlather(AbstractMaxStableModelWithCovarianceFunction): class Schlather(AbstractMaxStableModelWithCovarianceFunction):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
......
import copy
import io import io
import os.path as op import os.path as op
...@@ -78,7 +79,23 @@ class SafeRunException(Exception): ...@@ -78,7 +79,23 @@ class SafeRunException(Exception):
pass pass
def safe_run_r_estimator(function, data=None, use_start=False, max_ratio_between_two_extremes_values=10, maxit=1000000, def safe_run_r_estimator(function, data=None, start_dict=None, max_ratio_between_two_extremes_values=10, maxit=1000000,
**parameters) -> robjects.ListVector:
try:
return _safe_run_r_estimator(function, data, max_ratio_between_two_extremes_values, maxit, **parameters)
except SafeRunException as e:
if start_dict is not None:
for _ in range(5):
parameters['start'] = r.list(**start_dict)
try:
return _safe_run_r_estimator(function, data, max_ratio_between_two_extremes_values, maxit, **parameters)
except Exception:
continue
else:
raise e
def _safe_run_r_estimator(function, data=None, max_ratio_between_two_extremes_values=10, maxit=1000000,
**parameters) -> robjects.ListVector: **parameters) -> robjects.ListVector:
if OptimizationConstants.USE_MAXIT: if OptimizationConstants.USE_MAXIT:
# Add optimization parameters # Add optimization parameters
...@@ -107,25 +124,18 @@ def safe_run_r_estimator(function, data=None, use_start=False, max_ratio_between ...@@ -107,25 +124,18 @@ def safe_run_r_estimator(function, data=None, use_start=False, max_ratio_between
warnings.warn(msg, WarningTooMuchZeroValues) warnings.warn(msg, WarningTooMuchZeroValues)
# Add data to the parameters # Add data to the parameters
parameters['data'] = data parameters['data'] = data
# First run without using start value
# Then if it crashes, use start value
run_successful = False run_successful = False
res = None res = None
f = io.StringIO() f = io.StringIO()
# Warning print will not work in this part # Warning print will not work in this part
with redirect_stdout(f): with redirect_stdout(f):
while not run_successful: while not run_successful:
current_parameter = parameters.copy()
if not use_start and 'start' in current_parameter:
current_parameter.pop('start')
try: try:
res = function(**current_parameter) # type: res = function(**parameters) # type:
run_successful = True run_successful = True
except (RRuntimeError, RRuntimeWarning) as e: except (RRuntimeError, RRuntimeWarning) as e:
if not use_start: if isinstance(e, RRuntimeError):
use_start = True
continue
elif isinstance(e, RRuntimeError):
raise SafeRunException('Some R exception have been launched at RunTime: \n {}'.format(e.__repr__())) raise SafeRunException('Some R exception have been launched at RunTime: \n {}'.format(e.__repr__()))
if isinstance(e, RRuntimeWarning): if isinstance(e, RRuntimeWarning):
warnings.warn(e.__repr__(), WarningWhileRunningR) warnings.warn(e.__repr__(), WarningWhileRunningR)
......
...@@ -67,8 +67,8 @@ class TestMaxStableEstimatorGaussFor3DCoordinates(TestMaxStableEstimators): ...@@ -67,8 +67,8 @@ class TestMaxStableEstimatorGaussFor3DCoordinates(TestMaxStableEstimators):
self.max_stable_models = load_test_max_stable_models()[:1] self.max_stable_models = load_test_max_stable_models()[:1]
def test_max_stable_estimators(self): def test_max_stable_estimators(self):
with self.assertRaises(SafeRunException): self.fit_max_stable_estimator_for_all_coordinates()
self.fit_max_stable_estimator_for_all_coordinates() self.assertTrue(True)
if __name__ == '__main__': if __name__ == '__main__':
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment