From a277fc27698bdb09cf824de86dfea819d2b6cea1 Mon Sep 17 00:00:00 2001 From: Le Roux Erwan <erwan.le-roux@irstea.fr> Date: Mon, 25 Feb 2019 17:39:18 +0100 Subject: [PATCH] [MARGIN MODEL] refactor by using class variables to avoid future potential issues --- extreme_estimator/estimator/abstract_estimator.py | 3 ++- .../margin_model/param_function/linear_coef.py | 7 +++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/extreme_estimator/estimator/abstract_estimator.py b/extreme_estimator/estimator/abstract_estimator.py index 71c687b7..3cd92206 100644 --- a/extreme_estimator/estimator/abstract_estimator.py +++ b/extreme_estimator/estimator/abstract_estimator.py @@ -3,6 +3,7 @@ import time from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \ AbstractMarginFunction from extreme_estimator.extreme_models.margin_model.margin_function.linear_margin_function import LinearMarginFunction +from extreme_estimator.extreme_models.margin_model.param_function.linear_coef import LinearCoef from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset @@ -49,7 +50,7 @@ class AbstractEstimator(object): return self._margin_function_fitted def extract_fitted_models_from_fitted_params(self, margin_function_to_fit, full_params_fitted): - coef_dict = {k: v for k, v in full_params_fitted.items() if 'Coeff' in k} + coef_dict = {k: v for k, v in full_params_fitted.items() if LinearCoef.COEFF_STR in k} self._margin_function_fitted = LinearMarginFunction.from_coef_dict(coordinates=self.dataset.coordinates, gev_param_name_to_linear_dims=margin_function_to_fit.gev_param_name_to_linear_dims, coef_dict=coef_dict) diff --git a/extreme_estimator/extreme_models/margin_model/param_function/linear_coef.py b/extreme_estimator/extreme_models/margin_model/param_function/linear_coef.py index 8954e5e0..1e67c94c 100644 --- a/extreme_estimator/extreme_models/margin_model/param_function/linear_coef.py +++ b/extreme_estimator/extreme_models/margin_model/param_function/linear_coef.py @@ -12,6 +12,7 @@ class LinearCoef(object): dim = 3 correspond to the third coordinate... """ INTERCEPT_NAME = 'intercept' + COEFF_STR = 'Coeff' def __init__(self, gev_param_name: str, dim_to_coef: Dict[int, float] = None, default_value: float = 0.0): self.gev_param_name = gev_param_name @@ -40,9 +41,11 @@ class LinearCoef(object): """ assert coefficient_name == cls.INTERCEPT_NAME or coefficient_name in AbstractCoordinates.COORDINATES_NAMES if coefficient_name == cls.INTERCEPT_NAME or coefficient_name in AbstractCoordinates.COORDINATE_SPATIAL_NAMES: - return gev_param_name + 'Coeff{}' + coef_template_str = gev_param_name + cls.COEFF_STR + '{}' else: - return 'tempCoeff' + gev_param_name.title() + '{}' + coef_template_str = 'temp' + cls.COEFF_STR + gev_param_name.title() + '{}' + assert cls.COEFF_STR in coef_template_str + return coef_template_str @staticmethod def has_dependence_in_spatial_coordinates(dim_to_coefficient_name): -- GitLab