Commit a277fc27 authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[MARGIN MODEL] refactor by using class variables to avoid future potential issues

parent a4f03d04
No related merge requests found
Showing with 7 additions and 3 deletions
+7 -3
......@@ -3,6 +3,7 @@ import time
from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \
AbstractMarginFunction
from extreme_estimator.extreme_models.margin_model.margin_function.linear_margin_function import LinearMarginFunction
from extreme_estimator.extreme_models.margin_model.param_function.linear_coef import LinearCoef
from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
......@@ -49,7 +50,7 @@ class AbstractEstimator(object):
return self._margin_function_fitted
def extract_fitted_models_from_fitted_params(self, margin_function_to_fit, full_params_fitted):
coef_dict = {k: v for k, v in full_params_fitted.items() if 'Coeff' in k}
coef_dict = {k: v for k, v in full_params_fitted.items() if LinearCoef.COEFF_STR in k}
self._margin_function_fitted = LinearMarginFunction.from_coef_dict(coordinates=self.dataset.coordinates,
gev_param_name_to_linear_dims=margin_function_to_fit.gev_param_name_to_linear_dims,
coef_dict=coef_dict)
......
......@@ -12,6 +12,7 @@ class LinearCoef(object):
dim = 3 correspond to the third coordinate...
"""
INTERCEPT_NAME = 'intercept'
COEFF_STR = 'Coeff'
def __init__(self, gev_param_name: str, dim_to_coef: Dict[int, float] = None, default_value: float = 0.0):
self.gev_param_name = gev_param_name
......@@ -40,9 +41,11 @@ class LinearCoef(object):
"""
assert coefficient_name == cls.INTERCEPT_NAME or coefficient_name in AbstractCoordinates.COORDINATES_NAMES
if coefficient_name == cls.INTERCEPT_NAME or coefficient_name in AbstractCoordinates.COORDINATE_SPATIAL_NAMES:
return gev_param_name + 'Coeff{}'
coef_template_str = gev_param_name + cls.COEFF_STR + '{}'
else:
return 'tempCoeff' + gev_param_name.title() + '{}'
coef_template_str = 'temp' + cls.COEFF_STR + gev_param_name.title() + '{}'
assert cls.COEFF_STR in coef_template_str
return coef_template_str
@staticmethod
def has_dependence_in_spatial_coordinates(dim_to_coefficient_name):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment