From a277fc27698bdb09cf824de86dfea819d2b6cea1 Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Mon, 25 Feb 2019 17:39:18 +0100
Subject: [PATCH] [MARGIN MODEL] refactor by using class variables to avoid
 future potential issues

---
 extreme_estimator/estimator/abstract_estimator.py          | 3 ++-
 .../margin_model/param_function/linear_coef.py             | 7 +++++--
 2 files changed, 7 insertions(+), 3 deletions(-)

diff --git a/extreme_estimator/estimator/abstract_estimator.py b/extreme_estimator/estimator/abstract_estimator.py
index 71c687b7..3cd92206 100644
--- a/extreme_estimator/estimator/abstract_estimator.py
+++ b/extreme_estimator/estimator/abstract_estimator.py
@@ -3,6 +3,7 @@ import time
 from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \
     AbstractMarginFunction
 from extreme_estimator.extreme_models.margin_model.margin_function.linear_margin_function import LinearMarginFunction
+from extreme_estimator.extreme_models.margin_model.param_function.linear_coef import LinearCoef
 from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
 
 
@@ -49,7 +50,7 @@ class AbstractEstimator(object):
         return self._margin_function_fitted
 
     def extract_fitted_models_from_fitted_params(self, margin_function_to_fit, full_params_fitted):
-        coef_dict = {k: v for k, v in full_params_fitted.items() if 'Coeff' in k}
+        coef_dict = {k: v for k, v in full_params_fitted.items() if LinearCoef.COEFF_STR in k}
         self._margin_function_fitted = LinearMarginFunction.from_coef_dict(coordinates=self.dataset.coordinates,
                                                                            gev_param_name_to_linear_dims=margin_function_to_fit.gev_param_name_to_linear_dims,
                                                                            coef_dict=coef_dict)
diff --git a/extreme_estimator/extreme_models/margin_model/param_function/linear_coef.py b/extreme_estimator/extreme_models/margin_model/param_function/linear_coef.py
index 8954e5e0..1e67c94c 100644
--- a/extreme_estimator/extreme_models/margin_model/param_function/linear_coef.py
+++ b/extreme_estimator/extreme_models/margin_model/param_function/linear_coef.py
@@ -12,6 +12,7 @@ class LinearCoef(object):
         dim = 3 correspond to the third coordinate...
     """
     INTERCEPT_NAME = 'intercept'
+    COEFF_STR = 'Coeff'
 
     def __init__(self, gev_param_name: str, dim_to_coef: Dict[int, float] = None, default_value: float = 0.0):
         self.gev_param_name = gev_param_name
@@ -40,9 +41,11 @@ class LinearCoef(object):
         """
         assert coefficient_name == cls.INTERCEPT_NAME or coefficient_name in AbstractCoordinates.COORDINATES_NAMES
         if coefficient_name == cls.INTERCEPT_NAME or coefficient_name in AbstractCoordinates.COORDINATE_SPATIAL_NAMES:
-            return gev_param_name + 'Coeff{}'
+            coef_template_str = gev_param_name + cls.COEFF_STR + '{}'
         else:
-            return 'tempCoeff' + gev_param_name.title() + '{}'
+            coef_template_str = 'temp' + cls.COEFF_STR + gev_param_name.title() + '{}'
+        assert cls.COEFF_STR in coef_template_str
+        return coef_template_str
 
     @staticmethod
     def has_dependence_in_spatial_coordinates(dim_to_coefficient_name):
-- 
GitLab