Commit 9341832a authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[R Model] transform gev2frech to a class method

parent 63ebb0e1
No related merge requests found
Showing with 8 additions and 6 deletions
+8 -6
...@@ -3,9 +3,10 @@ from extreme_estimator.R_model.utils import get_loaded_r ...@@ -3,9 +3,10 @@ from extreme_estimator.R_model.utils import get_loaded_r
class AbstractModel(object): class AbstractModel(object):
r = get_loaded_r()
def __init__(self, params_start_fit=None, params_sample=None): def __init__(self, params_start_fit=None, params_sample=None):
self.default_params_start_fit = None self.default_params_start_fit = None
self.default_params_sample = None self.default_params_sample = None
self.user_params_start_fit = params_start_fit self.user_params_start_fit = params_start_fit
self.user_params_sample = params_sample self.user_params_sample = params_sample
self.r = get_loaded_r() \ No newline at end of file
\ No newline at end of file
...@@ -49,13 +49,14 @@ class AbstractMarginModel(AbstractModel): ...@@ -49,13 +49,14 @@ class AbstractMarginModel(AbstractModel):
maxima_gev.append(x_gev) maxima_gev.append(x_gev)
return np.array(maxima_gev) return np.array(maxima_gev)
def gev2frech(self, maxima_gev: np.ndarray, df_gev_params: pd.DataFrame): @classmethod
def gev2frech(cls, maxima_gev: np.ndarray, df_gev_params: pd.DataFrame):
assert len(maxima_gev) == len(df_gev_params) assert len(maxima_gev) == len(df_gev_params)
maxima_frech = [] maxima_frech = []
for x_gev, (_, s_gev_params) in zip(maxima_gev, df_gev_params.iterrows()): for x_gev, (_, s_gev_params) in zip(maxima_gev, df_gev_params.iterrows()):
gev_params = dict(s_gev_params) gev_params = dict(s_gev_params)
gev2frech_param = {'emp': False} gev2frech_param = {'emp': False}
x_frech = self.r.gev2frech(x_gev, **gev_params, **gev2frech_param) x_frech = cls.r.gev2frech(x_gev, **gev_params, **gev2frech_param)
maxima_frech.append(x_frech) maxima_frech.append(x_frech)
return np.array(maxima_frech) return np.array(maxima_frech)
......
...@@ -23,8 +23,8 @@ class SmoothMarginalsThenUnitaryMsp(AbstractFullEstimator): ...@@ -23,8 +23,8 @@ class SmoothMarginalsThenUnitaryMsp(AbstractFullEstimator):
# Estimate the margin parameters # Estimate the margin parameters
self.margin_estimator.fit() self.margin_estimator.fit()
# Compute the maxima_frech # Compute the maxima_frech
maxima_frech = self.margin_estimator.margin_model.gev2frech(maxima_gev=self.dataset.maxima_gev, maxima_frech = AbstractMarginModel.gev2frech(maxima_gev=self.dataset.maxima_gev,
df_gev_params=self.margin_estimator.df_gev_params) df_gev_params=self.margin_estimator.df_gev_params)
# Update maxima frech field through the dataset object # Update maxima frech field through the dataset object
self.dataset.maxima_frech = maxima_frech self.dataset.maxima_frech = maxima_frech
# Estimate the max stable parameters # Estimate the max stable parameters
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment