Commit 4a295776 authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[refactor] remove use_start_value

parent 2b00370c
No related merge requests found
Showing with 21 additions and 24 deletions
+21 -24
class AbstractModel(object):
def __init__(self, use_start_value=False, params_start_fit=None, params_sample=None):
def __init__(self, params_start_fit=None, params_sample=None):
self.default_params = None
self.use_start_value = use_start_value
self.user_params_start_fit = params_start_fit
self.user_params_sample = params_sample
......
......@@ -19,10 +19,10 @@ class AbstractMarginModel(AbstractModel, ABC):
-margin_function_start_fit for starting to fit
"""
def __init__(self, coordinates: AbstractCoordinates, use_start_value=False,
def __init__(self, coordinates: AbstractCoordinates,
params_start_fit=None, params_sample=None,
params_class=GevParams):
super().__init__(use_start_value, params_start_fit, params_sample)
super().__init__(params_start_fit, params_sample)
assert isinstance(coordinates, AbstractCoordinates), type(coordinates)
self.coordinates = coordinates
self.margin_function_sample = None # type: AbstractMarginFunction
......
......@@ -19,14 +19,15 @@ from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoo
class AbstractTemporalLinearMarginModel(LinearMarginModel):
"""Linearity only with respect to the temporal coordinates"""
def __init__(self, coordinates: AbstractCoordinates, use_start_value=False, params_start_fit=None,
def __init__(self, coordinates: AbstractCoordinates,
params_start_fit=None,
params_sample=None, starting_point=None,
fit_method=MarginFitMethod.is_mev_gev_fit,
nb_iterations_for_bayesian_fit=5000,
params_start_fit_bayesian=None,
type_for_MLE="GEV",
params_class=GevParams):
super().__init__(coordinates, use_start_value, params_start_fit, params_sample, starting_point,
super().__init__(coordinates, params_start_fit, params_sample, starting_point,
params_class)
self.type_for_mle = type_for_MLE
self.params_start_fit_bayesian = params_start_fit_bayesian
......
......@@ -70,10 +70,10 @@ class NonStationaryLocationAndScaleTemporalModel(AbstractTemporalLinearMarginMod
class GumbelTemporalModel(StationaryTemporalModel):
def __init__(self, coordinates: AbstractCoordinates, use_start_value=False, params_start_fit=None,
def __init__(self, coordinates: AbstractCoordinates, params_start_fit=None,
params_sample=None, starting_point=None, fit_method=MarginFitMethod.is_mev_gev_fit,
nb_iterations_for_bayesian_fit=5000, params_start_fit_bayesian=None):
super().__init__(coordinates, use_start_value, params_start_fit, params_sample, starting_point, fit_method,
super().__init__(coordinates, params_start_fit, params_sample, starting_point, fit_method,
nb_iterations_for_bayesian_fit, params_start_fit_bayesian, type_for_MLE="Gumbel")
......
......@@ -16,7 +16,7 @@ from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoo
class ParametricMarginModel(AbstractMarginModel, ABC):
def __init__(self, coordinates: AbstractCoordinates, use_start_value=False, params_start_fit=None,
def __init__(self, coordinates: AbstractCoordinates, params_start_fit=None,
params_sample=None, starting_point=None, params_class=GevParams,
fit_method=MarginFitMethod.spatial_extremes_mle):
"""
......@@ -27,7 +27,7 @@ class ParametricMarginModel(AbstractMarginModel, ABC):
self.margin_function_sample = None # type: ParametricMarginFunction
self.margin_function_start_fit = None # type: ParametricMarginFunction
self.drop_duplicates = True
super().__init__(coordinates, use_start_value, params_start_fit, params_sample, params_class)
super().__init__(coordinates, params_start_fit, params_sample, params_class)
def fitmargin_from_maxima_gev(self, data: np.ndarray, df_coordinates_spat: pd.DataFrame,
df_coordinates_temp: pd.DataFrame) -> ResultFromSpatialExtreme:
......
......@@ -11,9 +11,9 @@ from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoo
class SplineMarginModel(ParametricMarginModel):
def __init__(self, coordinates: AbstractCoordinates, use_start_value=False, params_start_fit=None,
def __init__(self, coordinates: AbstractCoordinates, params_start_fit=None,
params_sample=None):
super().__init__(coordinates, use_start_value, params_start_fit, params_sample)
super().__init__(coordinates, params_start_fit, params_sample)
def load_margin_functions(self, param_name_to_dims: Dict[str, List[int]] = None,
param_name_to_coef: Dict[str, AbstractCoef] = None,
......
......@@ -14,8 +14,8 @@ from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoo
class AbstractMaxStableModel(AbstractModel):
def __init__(self, use_start_value=False, params_start_fit=None, params_sample=None):
super().__init__(use_start_value, params_start_fit, params_sample)
def __init__(self, params_start_fit=None, params_sample=None):
super().__init__(params_start_fit, params_sample)
self.cov_mod = None
@property
......@@ -105,9 +105,9 @@ class CovarianceFunction(Enum):
class AbstractMaxStableModelWithCovarianceFunction(AbstractMaxStableModel):
def __init__(self, use_start_value=False, params_start_fit=None, params_sample=None,
def __init__(self, params_start_fit=None, params_sample=None,
covariance_function: CovarianceFunction = None):
super().__init__(use_start_value, params_start_fit, params_sample)
super().__init__(params_start_fit, params_sample)
assert covariance_function is not None
self.covariance_function = covariance_function
self.default_params = {
......
......@@ -80,12 +80,12 @@ class SafeRunException(Exception):
def safe_run_r_estimator(function, data=None, start_dict=None, max_ratio_between_two_extremes_values=10, maxit=1000000,
**parameters) -> robjects.ListVector:
nb_tries_for_start_value=5, **parameters) -> robjects.ListVector:
try:
return _safe_run_r_estimator(function, data, max_ratio_between_two_extremes_values, maxit, **parameters)
except SafeRunException as e:
if start_dict is not None:
for _ in range(5):
for _ in range(nb_tries_for_start_value):
parameters['start'] = r.list(**start_dict)
try:
return _safe_run_r_estimator(function, data, max_ratio_between_two_extremes_values, maxit, **parameters)
......
......@@ -25,9 +25,6 @@ class TestFullEstimators(unittest.TestCase):
coordinates=coordinates,
max_stable_model=max_stable_model)
margin_model.use_start_value = True
# todo: understand why it is crashing without specifying that (when not using start value was passed by default this test started crashing)
for full_estimator in load_test_full_estimators(dataset, margin_model, max_stable_model):
full_estimator.fit()
if self.DISPLAY:
......
......@@ -27,7 +27,7 @@ class TestMaxStableFitWithConstantMargin(TestUnitaryAbstract):
@property
def python_output(self):
dataset = TestRMaxStabWithMarginConstant.python_code()
max_stable_model = Schlather(covariance_function=CovarianceFunction.whitmat, use_start_value=False)
max_stable_model = Schlather(covariance_function=CovarianceFunction.whitmat)
margin_model = ConstantMarginModel(dataset.coordinates)
full_estimator = FullEstimatorInASingleStepWithSmoothMargin(dataset, margin_model,
max_stable_model)
......@@ -54,7 +54,7 @@ class TestMaxStableFitWithLinearMargin(TestUnitaryAbstract):
@property
def python_output(self):
dataset = TestRMaxStabWithMarginConstant.python_code()
max_stable_model = Schlather(covariance_function=CovarianceFunction.whitmat, use_start_value=False)
max_stable_model = Schlather(covariance_function=CovarianceFunction.whitmat)
margin_model = LinearMarginModelExample(dataset.coordinates)
full_estimator = FullEstimatorInASingleStepWithSmoothMargin(dataset, margin_model,
max_stable_model)
......
......@@ -21,7 +21,7 @@ class TestMaxStableFitWithoutMargin(TestUnitaryAbstract):
def python_output(self):
coordinates, max_stable_model = TestRMaxStab.python_code()
dataset = MaxStableDataset.from_sampling(nb_obs=40, max_stable_model=max_stable_model, coordinates=coordinates)
max_stable_model = Schlather(covariance_function=CovarianceFunction.whitmat, use_start_value=False)
max_stable_model = Schlather(covariance_function=CovarianceFunction.whitmat)
max_stable_estimator = MaxStableEstimator(dataset, max_stable_model)
max_stable_estimator.fit()
return max_stable_estimator.max_stable_params_fitted
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment