diff --git a/extreme_estimator/R_model/abstract_model.py b/extreme_estimator/R_model/abstract_model.py index 00673d987a285040720d9d635cf7576c8ea30fc6..304ca70880064a47f74d24220c3e1e407bd240e1 100644 --- a/extreme_estimator/R_model/abstract_model.py +++ b/extreme_estimator/R_model/abstract_model.py @@ -3,9 +3,10 @@ from extreme_estimator.R_model.utils import get_loaded_r class AbstractModel(object): + r = get_loaded_r() + def __init__(self, params_start_fit=None, params_sample=None): self.default_params_start_fit = None self.default_params_sample = None self.user_params_start_fit = params_start_fit - self.user_params_sample = params_sample - self.r = get_loaded_r() \ No newline at end of file + self.user_params_sample = params_sample \ No newline at end of file diff --git a/extreme_estimator/R_model/margin_model/abstract_margin_model.py b/extreme_estimator/R_model/margin_model/abstract_margin_model.py index 644ef37d79c631c265a01b7227ac9d6fd801cc65..25e59524214fe7f7c7a1b284f19e573cc95a29ce 100644 --- a/extreme_estimator/R_model/margin_model/abstract_margin_model.py +++ b/extreme_estimator/R_model/margin_model/abstract_margin_model.py @@ -49,13 +49,14 @@ class AbstractMarginModel(AbstractModel): maxima_gev.append(x_gev) return np.array(maxima_gev) - def gev2frech(self, maxima_gev: np.ndarray, df_gev_params: pd.DataFrame): + @classmethod + def gev2frech(cls, maxima_gev: np.ndarray, df_gev_params: pd.DataFrame): assert len(maxima_gev) == len(df_gev_params) maxima_frech = [] for x_gev, (_, s_gev_params) in zip(maxima_gev, df_gev_params.iterrows()): gev_params = dict(s_gev_params) gev2frech_param = {'emp': False} - x_frech = self.r.gev2frech(x_gev, **gev_params, **gev2frech_param) + x_frech = cls.r.gev2frech(x_gev, **gev_params, **gev2frech_param) maxima_frech.append(x_frech) return np.array(maxima_frech) diff --git a/extreme_estimator/estimator/full_estimator.py b/extreme_estimator/estimator/full_estimator.py index cec7a3cee020736507f95dc2739acdf28e6f8081..3d31d97f58e8b9aa98d627790ae40b772be5a74e 100644 --- a/extreme_estimator/estimator/full_estimator.py +++ b/extreme_estimator/estimator/full_estimator.py @@ -23,8 +23,8 @@ class SmoothMarginalsThenUnitaryMsp(AbstractFullEstimator): # Estimate the margin parameters self.margin_estimator.fit() # Compute the maxima_frech - maxima_frech = self.margin_estimator.margin_model.gev2frech(maxima_gev=self.dataset.maxima_gev, - df_gev_params=self.margin_estimator.df_gev_params) + maxima_frech = AbstractMarginModel.gev2frech(maxima_gev=self.dataset.maxima_gev, + df_gev_params=self.margin_estimator.df_gev_params) # Update maxima frech field through the dataset object self.dataset.maxima_frech = maxima_frech # Estimate the max stable parameters