diff --git a/extreme_estimator/estimator/abstract_estimator.py b/extreme_estimator/estimator/abstract_estimator.py index 71c687b72e338fdc61eca81a2b66141c80d9a03b..3cd9220683c5c2c5d6a69ac3fa1d5ca9bd5eeb19 100644 --- a/extreme_estimator/estimator/abstract_estimator.py +++ b/extreme_estimator/estimator/abstract_estimator.py @@ -3,6 +3,7 @@ import time from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \ AbstractMarginFunction from extreme_estimator.extreme_models.margin_model.margin_function.linear_margin_function import LinearMarginFunction +from extreme_estimator.extreme_models.margin_model.param_function.linear_coef import LinearCoef from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset @@ -49,7 +50,7 @@ class AbstractEstimator(object): return self._margin_function_fitted def extract_fitted_models_from_fitted_params(self, margin_function_to_fit, full_params_fitted): - coef_dict = {k: v for k, v in full_params_fitted.items() if 'Coeff' in k} + coef_dict = {k: v for k, v in full_params_fitted.items() if LinearCoef.COEFF_STR in k} self._margin_function_fitted = LinearMarginFunction.from_coef_dict(coordinates=self.dataset.coordinates, gev_param_name_to_linear_dims=margin_function_to_fit.gev_param_name_to_linear_dims, coef_dict=coef_dict) diff --git a/extreme_estimator/extreme_models/margin_model/param_function/linear_coef.py b/extreme_estimator/extreme_models/margin_model/param_function/linear_coef.py index 8954e5e0833e3d3fc4145627bd89f7c78ea71213..1e67c94c17d2ab5a35cd6f946bfe222297c44d67 100644 --- a/extreme_estimator/extreme_models/margin_model/param_function/linear_coef.py +++ b/extreme_estimator/extreme_models/margin_model/param_function/linear_coef.py @@ -12,6 +12,7 @@ class LinearCoef(object): dim = 3 correspond to the third coordinate... """ INTERCEPT_NAME = 'intercept' + COEFF_STR = 'Coeff' def __init__(self, gev_param_name: str, dim_to_coef: Dict[int, float] = None, default_value: float = 0.0): self.gev_param_name = gev_param_name @@ -40,9 +41,11 @@ class LinearCoef(object): """ assert coefficient_name == cls.INTERCEPT_NAME or coefficient_name in AbstractCoordinates.COORDINATES_NAMES if coefficient_name == cls.INTERCEPT_NAME or coefficient_name in AbstractCoordinates.COORDINATE_SPATIAL_NAMES: - return gev_param_name + 'Coeff{}' + coef_template_str = gev_param_name + cls.COEFF_STR + '{}' else: - return 'tempCoeff' + gev_param_name.title() + '{}' + coef_template_str = 'temp' + cls.COEFF_STR + gev_param_name.title() + '{}' + assert cls.COEFF_STR in coef_template_str + return coef_template_str @staticmethod def has_dependence_in_spatial_coordinates(dim_to_coefficient_name):