Commit 8d081557 authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[EXTREME FIT] After any Spatial Extreme fit, create an object ResultFromFit,...

[EXTREME FIT] After any Spatial Extreme fit, create an object ResultFromFit, refactor code accordingly
parent b9b99df8
No related merge requests found
Showing with 60 additions and 36 deletions
+60 -36
import time import time
from extreme_estimator.extreme_models.result_from_fit import ResultFromFit
from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \ from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \
AbstractMarginFunction AbstractMarginFunction
from extreme_estimator.extreme_models.margin_model.margin_function.linear_margin_function import LinearMarginFunction from extreme_estimator.extreme_models.margin_model.margin_function.linear_margin_function import LinearMarginFunction
...@@ -18,7 +19,7 @@ class AbstractEstimator(object): ...@@ -18,7 +19,7 @@ class AbstractEstimator(object):
def __init__(self, dataset: AbstractDataset): def __init__(self, dataset: AbstractDataset):
self.dataset = dataset # type: AbstractDataset self.dataset = dataset # type: AbstractDataset
self.additional_information = dict() self.additional_information = dict()
self._params_fitted = None self._result_from_fit = None # type: ResultFromFit
self._margin_function_fitted = None self._margin_function_fitted = None
self._max_stable_model_fitted = None self._max_stable_model_fitted = None
...@@ -34,9 +35,9 @@ class AbstractEstimator(object): ...@@ -34,9 +35,9 @@ class AbstractEstimator(object):
self.additional_information[self.DURATION] = int((te - ts) * 1000) self.additional_information[self.DURATION] = int((te - ts) * 1000)
@property @property
def params_fitted(self): def fitted_values(self):
assert self.is_fitted assert self.is_fitted
return self._params_fitted return self._result_from_fit.fitted_values
# @property # @property
# def max_stable_fitted(self) -> AbstractMarginFunction: # def max_stable_fitted(self) -> AbstractMarginFunction:
...@@ -57,7 +58,7 @@ class AbstractEstimator(object): ...@@ -57,7 +58,7 @@ class AbstractEstimator(object):
@property @property
def is_fitted(self): def is_fitted(self):
return self._params_fitted is not None return self._result_from_fit is not None
@property @property
def train_split(self): def train_split(self):
......
...@@ -55,7 +55,7 @@ class FullEstimatorInASingleStepWithSmoothMargin(AbstractFullEstimator): ...@@ -55,7 +55,7 @@ class FullEstimatorInASingleStepWithSmoothMargin(AbstractFullEstimator):
def _fit(self): def _fit(self):
# Estimate both the margin and the max-stable structure # Estimate both the margin and the max-stable structure
self._params_fitted = self.max_stable_model.fitmaxstab( self._result_from_fit = self.max_stable_model.fitmaxstab(
maxima_gev=self.dataset.maxima_gev(split=self.train_split), maxima_gev=self.dataset.maxima_gev(split=self.train_split),
df_coordinates=self.dataset.df_coordinates(split=self.train_split), df_coordinates=self.dataset.df_coordinates(split=self.train_split),
fit_marge=True, fit_marge=True,
...@@ -63,7 +63,7 @@ class FullEstimatorInASingleStepWithSmoothMargin(AbstractFullEstimator): ...@@ -63,7 +63,7 @@ class FullEstimatorInASingleStepWithSmoothMargin(AbstractFullEstimator):
margin_start_dict=self.linear_margin_function_to_fit.coef_dict margin_start_dict=self.linear_margin_function_to_fit.coef_dict
) )
# Create the fitted margin function # Create the fitted margin function
self.extract_fitted_models_from_fitted_params(self.linear_margin_function_to_fit, self._params_fitted) self.extract_fitted_models_from_fitted_params(self.linear_margin_function_to_fit, self.fitted_values)
class PointwiseAndThenUnitaryMsp(AbstractFullEstimator): class PointwiseAndThenUnitaryMsp(AbstractFullEstimator):
......
...@@ -37,8 +37,8 @@ class SmoothMarginEstimator(AbstractMarginEstimator): ...@@ -37,8 +37,8 @@ class SmoothMarginEstimator(AbstractMarginEstimator):
maxima_gev = self.dataset.maxima_gev(split=self.train_split) maxima_gev = self.dataset.maxima_gev(split=self.train_split)
df_coordinates_spatial = self.dataset.coordinates.df_spatial_coordinates(self.train_split) df_coordinates_spatial = self.dataset.coordinates.df_spatial_coordinates(self.train_split)
df_coordinates_temporal = self.dataset.coordinates.df_temporal_coordinates(self.train_split) df_coordinates_temporal = self.dataset.coordinates.df_temporal_coordinates(self.train_split)
self._params_fitted = self.margin_model.fitmargin_from_maxima_gev(maxima_gev=maxima_gev, self._result_from_fit = self.margin_model.fitmargin_from_maxima_gev(maxima_gev=maxima_gev,
df_coordinates_spatial=df_coordinates_spatial, df_coordinates_spatial=df_coordinates_spatial,
df_coordinates_temporal=df_coordinates_temporal) df_coordinates_temporal=df_coordinates_temporal)
self.extract_fitted_models_from_fitted_params(self.margin_model.margin_function_start_fit, self._params_fitted) self.extract_fitted_models_from_fitted_params(self.margin_model.margin_function_start_fit, self.fitted_values)
assert isinstance(self.margin_function_fitted, AbstractMarginFunction) assert isinstance(self.margin_function_fitted, AbstractMarginFunction)
...@@ -18,9 +18,10 @@ class MaxStableEstimator(AbstractMaxStableEstimator): ...@@ -18,9 +18,10 @@ class MaxStableEstimator(AbstractMaxStableEstimator):
def _fit(self): def _fit(self):
assert self.dataset.maxima_frech(split=self.train_split) is not None assert self.dataset.maxima_frech(split=self.train_split) is not None
self.max_stable_params_fitted = self.max_stable_model.fitmaxstab( self._result_from_fit = self.max_stable_model.fitmaxstab(
maxima_frech=self.dataset.maxima_frech(split=self.train_split), maxima_frech=self.dataset.maxima_frech(split=self.train_split),
df_coordinates=self.dataset.df_coordinates(split=self.train_split)) df_coordinates=self.dataset.df_coordinates(split=self.train_split))
self.max_stable_params_fitted = self.fitted_values
def _error(self, true_max_stable_params: dict): def _error(self, true_max_stable_params: dict):
absolute_errors = {param_name: np.abs(param_true_value - self.max_stable_params_fitted[param_name]) absolute_errors = {param_name: np.abs(param_true_value - self.max_stable_params_fitted[param_name])
......
from typing import Dict
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from extreme_estimator.extreme_models.result_from_fit import ResultFromFit
from extreme_estimator.extreme_models.margin_model.abstract_margin_model import AbstractMarginModel from extreme_estimator.extreme_models.margin_model.abstract_margin_model import AbstractMarginModel
from extreme_estimator.extreme_models.margin_model.margin_function.linear_margin_function import LinearMarginFunction from extreme_estimator.extreme_models.margin_model.margin_function.linear_margin_function import LinearMarginFunction
from extreme_estimator.extreme_models.margin_model.param_function.linear_coef import LinearCoef from extreme_estimator.extreme_models.margin_model.param_function.linear_coef import LinearCoef
from extreme_estimator.extreme_models.utils import safe_run_r_estimator, r, retrieve_fitted_values, get_coord, \ from extreme_estimator.extreme_models.utils import safe_run_r_estimator, r, get_coord, \
get_margin_formula get_margin_formula
from extreme_estimator.margin_fits.gev.gev_params import GevParams from extreme_estimator.margin_fits.gev.gev_params import GevParams
...@@ -59,7 +58,7 @@ class LinearMarginModel(AbstractMarginModel): ...@@ -59,7 +58,7 @@ class LinearMarginModel(AbstractMarginModel):
return cls(coordinates, params_sample=params, params_start_fit=params) return cls(coordinates, params_sample=params, params_start_fit=params)
def fitmargin_from_maxima_gev(self, maxima_gev: np.ndarray, df_coordinates_spatial: pd.DataFrame, def fitmargin_from_maxima_gev(self, maxima_gev: np.ndarray, df_coordinates_spatial: pd.DataFrame,
df_coordinates_temporal: pd.DataFrame) -> Dict[str, float]: df_coordinates_temporal: pd.DataFrame) -> ResultFromFit:
# The reshaping on the line below is only valid if we have a single observation per spatio-temporal point # The reshaping on the line below is only valid if we have a single observation per spatio-temporal point
if maxima_gev.shape[1] == 1: if maxima_gev.shape[1] == 1:
maxima_gev = maxima_gev.reshape([len(df_coordinates_temporal), len(df_coordinates_spatial)]) maxima_gev = maxima_gev.reshape([len(df_coordinates_temporal), len(df_coordinates_spatial)])
...@@ -75,9 +74,8 @@ class LinearMarginModel(AbstractMarginModel): ...@@ -75,9 +74,8 @@ class LinearMarginModel(AbstractMarginModel):
coef_dict = self.margin_function_start_fit.coef_dict coef_dict = self.margin_function_start_fit.coef_dict
fit_params['start'] = r.list(**coef_dict) fit_params['start'] = r.list(**coef_dict)
res = safe_run_r_estimator(function=r.fitspatgev, use_start=self.use_start_value, data=data, return safe_run_r_estimator(function=r.fitspatgev, use_start=self.use_start_value, data=data,
covariables=covariables, **fit_params) covariables=covariables, **fit_params)
return retrieve_fitted_values(res)
class ConstantMarginModel(LinearMarginModel): class ConstantMarginModel(LinearMarginModel):
......
...@@ -2,10 +2,10 @@ from enum import Enum ...@@ -2,10 +2,10 @@ from enum import Enum
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import rpy2.robjects as robjects
from extreme_estimator.extreme_models.abstract_model import AbstractModel from extreme_estimator.extreme_models.abstract_model import AbstractModel
from extreme_estimator.extreme_models.utils import r, safe_run_r_estimator, retrieve_fitted_values, get_coord, \ from extreme_estimator.extreme_models.result_from_fit import ResultFromFit
from extreme_estimator.extreme_models.utils import r, safe_run_r_estimator, get_coord, \
get_margin_formula get_margin_formula
from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
...@@ -21,8 +21,7 @@ class AbstractMaxStableModel(AbstractModel): ...@@ -21,8 +21,7 @@ class AbstractMaxStableModel(AbstractModel):
return {'cov.mod': self.cov_mod} return {'cov.mod': self.cov_mod}
def fitmaxstab(self, df_coordinates: pd.DataFrame, maxima_frech: np.ndarray = None, maxima_gev: np.ndarray = None, def fitmaxstab(self, df_coordinates: pd.DataFrame, maxima_frech: np.ndarray = None, maxima_gev: np.ndarray = None,
fit_marge=False, fit_marge=False, fit_marge_form_dict=None, margin_start_dict=None) -> ResultFromFit:
fit_marge_form_dict=None, margin_start_dict=None) -> dict:
assert isinstance(df_coordinates, pd.DataFrame) assert isinstance(df_coordinates, pd.DataFrame)
if fit_marge: if fit_marge:
assert fit_marge_form_dict is not None assert fit_marge_form_dict is not None
...@@ -63,8 +62,7 @@ class AbstractMaxStableModel(AbstractModel): ...@@ -63,8 +62,7 @@ class AbstractMaxStableModel(AbstractModel):
fit_params['fit.marge'] = fit_marge fit_params['fit.marge'] = fit_marge
# Run the fitmaxstab in R # Run the fitmaxstab in R
res = safe_run_r_estimator(function=r.fitmaxstab, use_start=self.use_start_value, data=data, coord=coord, **fit_params) return safe_run_r_estimator(function=r.fitmaxstab, use_start=self.use_start_value, data=data, coord=coord, **fit_params)
return retrieve_fitted_values(res)
def rmaxstab(self, nb_obs: int, coordinates_values: np.ndarray) -> np.ndarray: def rmaxstab(self, nb_obs: int, coordinates_values: np.ndarray) -> np.ndarray:
""" """
......
from typing import Dict
from rpy2 import robjects
class ResultFromFit(object):
"""
Handler from any result with the result of a fit functions from the package Spatial Extreme
"""
FITTED_VALUES_NAME = 'fitted.values'
CONVERGENCE_NAME = 'convergence'
def __init__(self, result_from_fit: robjects.ListVector) -> None:
if hasattr(result_from_fit, 'names'):
self.name_to_value = {name: result_from_fit.rx2(name) for name in result_from_fit.names}
else:
self.name_to_value = {}
@property
def names(self):
return self.name_to_value.keys()
@property
def convergence(self):
convergence_value = self.name_to_value[self.CONVERGENCE_NAME]
return convergence_value
@property
def fitted_values(self) -> Dict[str, float]:
fitted_values = self.name_to_value[self.FITTED_VALUES_NAME]
return {key: fitted_values.rx2(key)[0] for key in fitted_values.names}
...@@ -16,6 +16,8 @@ from rpy2.rinterface._rinterface import RRuntimeError ...@@ -16,6 +16,8 @@ from rpy2.rinterface._rinterface import RRuntimeError
from rpy2.robjects import numpy2ri from rpy2.robjects import numpy2ri
from rpy2.robjects import pandas2ri from rpy2.robjects import pandas2ri
from extreme_estimator.extreme_models.result_from_fit import ResultFromFit
r = ro.R() r = ro.R()
numpy2ri.activate() numpy2ri.activate()
pandas2ri.activate() pandas2ri.activate()
...@@ -41,7 +43,7 @@ class WarningMaximumAbsoluteValueTooHigh(Warning): ...@@ -41,7 +43,7 @@ class WarningMaximumAbsoluteValueTooHigh(Warning):
pass pass
def safe_run_r_estimator(function, data, use_start=False, threshold_max_abs_value=100, **parameters): def safe_run_r_estimator(function, data, use_start=False, threshold_max_abs_value=100, **parameters) -> ResultFromFit:
# Raise warning if the maximum absolute value is above a threshold # Raise warning if the maximum absolute value is above a threshold
assert isinstance(data, np.ndarray) assert isinstance(data, np.ndarray)
maximum_absolute_value = np.max(np.abs(data)) maximum_absolute_value = np.max(np.abs(data))
...@@ -54,6 +56,7 @@ def safe_run_r_estimator(function, data, use_start=False, threshold_max_abs_valu ...@@ -54,6 +56,7 @@ def safe_run_r_estimator(function, data, use_start=False, threshold_max_abs_valu
# First run without using start value # First run without using start value
# Then if it crashes, use start value # Then if it crashes, use start value
run_successful = False run_successful = False
res = None
while not run_successful: while not run_successful:
current_parameter = parameters.copy() current_parameter = parameters.copy()
if not use_start and 'start' in current_parameter: if not use_start and 'start' in current_parameter:
...@@ -70,15 +73,7 @@ def safe_run_r_estimator(function, data, use_start=False, threshold_max_abs_valu ...@@ -70,15 +73,7 @@ def safe_run_r_estimator(function, data, use_start=False, threshold_max_abs_valu
if isinstance(e, RRuntimeWarning): if isinstance(e, RRuntimeWarning):
print(e.__repr__()) print(e.__repr__())
print('WARNING') print('WARNING')
return res return ResultFromFit(res)
def retrieve_fitted_values(res: robjects.ListVector) -> Dict[str, float]:
# todo: maybe if the convergence was not successful I could try other starting point several times
# Retrieve the resulting fitted values
fitted_values = res.rx2('fitted.values')
fitted_values = {key: fitted_values.rx2(key)[0] for key in fitted_values.names}
return fitted_values
def get_coord(df_coordinates: pd.DataFrame): def get_coord(df_coordinates: pd.DataFrame):
......
...@@ -31,7 +31,7 @@ class TestMaxStableFitWithConstantMargin(TestUnitaryAbstract): ...@@ -31,7 +31,7 @@ class TestMaxStableFitWithConstantMargin(TestUnitaryAbstract):
full_estimator = FullEstimatorInASingleStepWithSmoothMargin(dataset, margin_model, full_estimator = FullEstimatorInASingleStepWithSmoothMargin(dataset, margin_model,
max_stable_model) max_stable_model)
full_estimator.fit() full_estimator.fit()
return full_estimator.params_fitted return full_estimator.fitted_values
def test_max_stable_fit_with_constant_margin(self): def test_max_stable_fit_with_constant_margin(self):
self.compare() self.compare()
...@@ -59,7 +59,7 @@ class TestMaxStableFitWithLinearMargin(TestUnitaryAbstract): ...@@ -59,7 +59,7 @@ class TestMaxStableFitWithLinearMargin(TestUnitaryAbstract):
full_estimator = FullEstimatorInASingleStepWithSmoothMargin(dataset, margin_model, full_estimator = FullEstimatorInASingleStepWithSmoothMargin(dataset, margin_model,
max_stable_model) max_stable_model)
full_estimator.fit() full_estimator.fit()
return full_estimator.params_fitted return full_estimator.fitted_values
def test_max_stable_fit_with_linear_margin(self): def test_max_stable_fit_with_linear_margin(self):
self.compare() self.compare()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment