Commit 69f5b00f authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[refactor] remove default_param_start_fit and default_param_sample by a single default_param

parent 3f80990e
No related merge requests found
Showing with 12 additions and 24 deletions
+12 -24
class AbstractModel(object): class AbstractModel(object):
def __init__(self, use_start_value=False, params_start_fit=None, params_sample=None): def __init__(self, use_start_value=False, params_start_fit=None, params_sample=None):
self.default_params_start_fit = None self.default_params = None
self.default_params_sample = None
self.use_start_value = use_start_value self.use_start_value = use_start_value
self.user_params_start_fit = params_start_fit self.user_params_start_fit = params_start_fit
self.user_params_sample = params_sample self.user_params_sample = params_sample
@property @property
def params_start_fit(self) -> dict: def params_start_fit(self) -> dict:
return self.merge_params(default_params=self.default_params_start_fit, input_params=self.user_params_start_fit) return self.merge_params(default_params=self.default_params, input_params=self.user_params_start_fit)
@property @property
def params_sample(self) -> dict: def params_sample(self) -> dict:
return self.merge_params(default_params=self.default_params_sample, input_params=self.user_params_sample) return self.merge_params(default_params=self.default_params, input_params=self.user_params_sample)
@staticmethod @staticmethod
def merge_params(default_params, input_params): def merge_params(default_params, input_params):
......
...@@ -19,8 +19,7 @@ class LinearMarginModel(ParametricMarginModel): ...@@ -19,8 +19,7 @@ class LinearMarginModel(ParametricMarginModel):
'load_margin_functions needs to be implemented in child class' 'load_margin_functions needs to be implemented in child class'
# Load default params (with a dictionary format to enable quick replacement) # Load default params (with a dictionary format to enable quick replacement)
# IMPORTANT: Using a dictionary format enable using the default/user params methodology # IMPORTANT: Using a dictionary format enable using the default/user params methodology
self.default_params_sample = self.default_param_name_and_dim_to_coef self.default_params = self.default_param_name_and_dim_to_coef
self.default_params_start_fit = self.default_param_name_and_dim_to_coef
# Load sample coef # Load sample coef
coef_sample = self.param_name_to_linear_coef(param_name_and_dim_to_coef=self.params_sample) coef_sample = self.param_name_to_linear_coef(param_name_and_dim_to_coef=self.params_sample)
......
...@@ -110,7 +110,7 @@ class AbstractMaxStableModelWithCovarianceFunction(AbstractMaxStableModel): ...@@ -110,7 +110,7 @@ class AbstractMaxStableModelWithCovarianceFunction(AbstractMaxStableModel):
super().__init__(use_start_value, params_start_fit, params_sample) super().__init__(use_start_value, params_start_fit, params_sample)
assert covariance_function is not None assert covariance_function is not None
self.covariance_function = covariance_function self.covariance_function = covariance_function
self.default_params_sample = { self.default_params = {
'range': 3, 'range': 3,
'smooth': 0.5, 'smooth': 0.5,
'nugget': 0.5 'nugget': 0.5
......
...@@ -9,13 +9,12 @@ class Smith(AbstractMaxStableModel): ...@@ -9,13 +9,12 @@ class Smith(AbstractMaxStableModel):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.cov_mod = 'gauss' self.cov_mod = 'gauss'
self.default_params_start_fit = { self.default_params = {
'var': 1, 'var': 1,
'cov11': 1, 'cov11': 1,
'cov12': 0, 'cov12': 0,
'cov22': 1 'cov22': 1
} }
self.default_params_sample = self.default_params_start_fit.copy()
def remove_unused_parameters(self, start_dict, fitmaxstab_with_one_dimensional_data): def remove_unused_parameters(self, start_dict, fitmaxstab_with_one_dimensional_data):
if fitmaxstab_with_one_dimensional_data: if fitmaxstab_with_one_dimensional_data:
...@@ -30,23 +29,17 @@ class BrownResnick(AbstractMaxStableModel): ...@@ -30,23 +29,17 @@ class BrownResnick(AbstractMaxStableModel):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.cov_mod = 'brown' self.cov_mod = 'brown'
self.default_params_start_fit = { self.default_params = {
'range': 3, 'range': 3,
'smooth': 0.5, 'smooth': 0.5,
} }
self.default_params_sample = {
'range': 3,
'smooth': 0.5,
}
class Schlather(AbstractMaxStableModelWithCovarianceFunction): class Schlather(AbstractMaxStableModelWithCovarianceFunction):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.cov_mod = self.covariance_function.name self.cov_mod = self.covariance_function.name
self.default_params_sample.update({}) self.default_params.update({})
self.default_params_start_fit = self.default_params_sample.copy()
class Geometric(AbstractMaxStableModelWithCovarianceFunction): class Geometric(AbstractMaxStableModelWithCovarianceFunction):
...@@ -54,8 +47,7 @@ class Geometric(AbstractMaxStableModelWithCovarianceFunction): ...@@ -54,8 +47,7 @@ class Geometric(AbstractMaxStableModelWithCovarianceFunction):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.cov_mod = 'g' + self.covariance_function.name self.cov_mod = 'g' + self.covariance_function.name
self.default_params_sample.update({'sigma2': 0.5}) self.default_params.update({'sigma2': 0.5})
self.default_params_start_fit = self.default_params_sample.copy()
class ExtremalT(AbstractMaxStableModelWithCovarianceFunction): class ExtremalT(AbstractMaxStableModelWithCovarianceFunction):
...@@ -63,8 +55,7 @@ class ExtremalT(AbstractMaxStableModelWithCovarianceFunction): ...@@ -63,8 +55,7 @@ class ExtremalT(AbstractMaxStableModelWithCovarianceFunction):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.cov_mod = 't' + self.covariance_function.name self.cov_mod = 't' + self.covariance_function.name
self.default_params_sample.update({'DoF': 2}) self.default_params.update({'DoF': 2})
self.default_params_start_fit = self.default_params_sample.copy()
class ISchlather(AbstractMaxStableModelWithCovarianceFunction): class ISchlather(AbstractMaxStableModelWithCovarianceFunction):
...@@ -72,5 +63,4 @@ class ISchlather(AbstractMaxStableModelWithCovarianceFunction): ...@@ -72,5 +63,4 @@ class ISchlather(AbstractMaxStableModelWithCovarianceFunction):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.cov_mod = 'i' + self.covariance_function.name self.cov_mod = 'i' + self.covariance_function.name
self.default_params_sample.update({'alpha': 0.5}) self.default_params.update({'alpha': 0.5})
self.default_params_start_fit = self.default_params_sample.copy()
...@@ -87,7 +87,7 @@ class TestMaxStableFitWithLinearMargin(TestUnitaryAbstract): ...@@ -87,7 +87,7 @@ class TestMaxStableFitWithLinearMargin(TestUnitaryAbstract):
# @property # @property
# def python_output(self): # def python_output(self):
# dataset = TestRMaxStabWithMarginConstant.python_code() # dataset = TestRMaxStabWithMarginConstant.python_code()
# max_stable_model = Schlather(covariance_function=CovarianceFunction.whitmat, use_start_value=False) # max_stable_model = Schlather(covariance_function=CovarianceFunction.whitmat)
# margin_model = LinearMarginModelExample(dataset.coordinates) # margin_model = LinearMarginModelExample(dataset.coordinates)
# full_estimator = FullEstimatorInASingleStepWithSmoothMargin(dataset, margin_model, # full_estimator = FullEstimatorInASingleStepWithSmoothMargin(dataset, margin_model,
# max_stable_model) # max_stable_model)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment