Commit 3914f93e authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[refactor] modify margin function

parent 00b74982
No related merge requests found
Showing with 101 additions and 96 deletions
+101 -96
...@@ -91,3 +91,7 @@ class ParametricMarginFunction(IndependentMarginFunction): ...@@ -91,3 +91,7 @@ class ParametricMarginFunction(IndependentMarginFunction):
@property @property
def form_dict(self) -> Dict[str, str]: def form_dict(self) -> Dict[str, str]:
raise NotImplementedError raise NotImplementedError
@property
def coef_dict(self) -> Dict[str, str]:
raise NotImplementedError
class AbstractModel(object): class AbstractModel(object):
def __init__(self, params_user=None): def __init__(self, params_user=None):
self.user_params_sample = params_user self.params_user = params_user
@property @property
def default_params(self): def default_params(self):
...@@ -9,14 +9,14 @@ class AbstractModel(object): ...@@ -9,14 +9,14 @@ class AbstractModel(object):
@property @property
def params_sample(self) -> dict: def params_sample(self) -> dict:
return self.merge_params(default_params=self.default_params, input_params=self.user_params_sample) return self.merge_params(default_params=self.default_params, params_user=self.params_user)
@staticmethod @staticmethod
def merge_params(default_params, input_params): def merge_params(default_params, params_user):
assert default_params is not None, 'some default_params need to be specified' assert default_params is not None, 'some default_params need to be specified'
merged_params = default_params.copy() merged_params = default_params.copy()
if input_params is not None: if params_user is not None:
assert isinstance(default_params, dict) and isinstance(input_params, dict) assert isinstance(default_params, dict) and isinstance(params_user, dict)
assert set(input_params.keys()).issubset(set(default_params.keys())) assert set(params_user.keys()).issubset(set(default_params.keys()))
merged_params.update(input_params) merged_params.update(params_user)
return merged_params return merged_params
...@@ -2,6 +2,7 @@ from abc import ABC ...@@ -2,6 +2,7 @@ from abc import ABC
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from cached_property import cached_property
from extreme_fit.model.abstract_model import AbstractModel from extreme_fit.model.abstract_model import AbstractModel
from extreme_fit.function.margin_function.abstract_margin_function \ from extreme_fit.function.margin_function.abstract_margin_function \
...@@ -18,23 +19,20 @@ class AbstractMarginModel(AbstractModel, ABC): ...@@ -18,23 +19,20 @@ class AbstractMarginModel(AbstractModel, ABC):
-margin_function -margin_function
""" """
def __init__(self, coordinates: AbstractCoordinates, params_user=None, def __init__(self, coordinates: AbstractCoordinates, params_user=None, params_class=GevParams):
params_class=GevParams):
super().__init__(params_user) super().__init__(params_user)
assert isinstance(coordinates, AbstractCoordinates), type(coordinates) assert isinstance(coordinates, AbstractCoordinates), type(coordinates)
self.coordinates = coordinates self.coordinates = coordinates
self.margin_function = None # type: AbstractMarginFunction
self.margin_function = None # type: AbstractMarginFunction
self.params_class = params_class self.params_class = params_class
self.load_margin_functions()
def load_margin_functions(self): @cached_property
raise NotImplementedError def margin_function(self) -> AbstractMarginFunction:
margin_function = self.load_margin_function()
assert margin_function is not None
return margin_function
def default_load_margin_functions(self, margin_function_class): def load_margin_function(self):
# todo: check it i could remove these attributes raise NotImplementedError
self.margin_function = margin_function_class(coordinates=self.coordinates,
default_params=self.params_class.from_dict(self.params_sample))
# Conversion class methods # Conversion class methods
......
...@@ -26,8 +26,7 @@ class AbstractTemporalLinearMarginModel(LinearMarginModel): ...@@ -26,8 +26,7 @@ class AbstractTemporalLinearMarginModel(LinearMarginModel):
params_initial_fit_bayesian=None, params_initial_fit_bayesian=None,
type_for_MLE="GEV", type_for_MLE="GEV",
params_class=GevParams): params_class=GevParams):
super().__init__(coordinates, params_user, starting_point, super().__init__(coordinates, params_user, starting_point, params_class)
params_class)
self.type_for_mle = type_for_MLE self.type_for_mle = type_for_MLE
self.params_initial_fit_bayesian = params_initial_fit_bayesian self.params_initial_fit_bayesian = params_initial_fit_bayesian
self.nb_iterations_for_bayesian_fit = nb_iterations_for_bayesian_fit self.nb_iterations_for_bayesian_fit = nb_iterations_for_bayesian_fit
......
...@@ -14,16 +14,16 @@ class LinearMarginModel(ParametricMarginModel): ...@@ -14,16 +14,16 @@ class LinearMarginModel(ParametricMarginModel):
params[(param_name, idx)] = coef params[(param_name, idx)] = coef
return cls(coordinates, params_user=params, params_class=params_class, **kwargs) return cls(coordinates, params_user=params, params_class=params_class, **kwargs)
def load_margin_functions(self, param_name_to_dims=None): def load_margin_function(self, param_name_to_dims=None):
assert param_name_to_dims is not None, 'LinearMarginModel cannot be used for sampling/fitting \n' \ assert param_name_to_dims is not None, 'LinearMarginModel cannot be used for sampling/fitting \n' \
'load_margin_functions needs to be implemented in child class' 'load_margin_functions needs to be implemented in child class'
# Load sample coef # Load sample coef
coef_sample = self.param_name_to_linear_coef(param_name_and_dim_to_coef=self.params_sample) coef_sample = self.param_name_to_linear_coef(param_name_and_dim_to_coef=self.params_sample)
self.margin_function = LinearMarginFunction(coordinates=self.coordinates, return LinearMarginFunction(coordinates=self.coordinates,
param_name_to_coef=coef_sample, param_name_to_coef=coef_sample,
param_name_to_dims=param_name_to_dims, param_name_to_dims=param_name_to_dims,
starting_point=self.starting_point, starting_point=self.starting_point,
params_class=self.params_class) params_class=self.params_class)
@property @property
def default_params(self) -> dict: def default_params(self) -> dict:
...@@ -49,75 +49,75 @@ class LinearMarginModel(ParametricMarginModel): ...@@ -49,75 +49,75 @@ class LinearMarginModel(ParametricMarginModel):
class ConstantMarginModel(LinearMarginModel): class ConstantMarginModel(LinearMarginModel):
def load_margin_functions(self, param_name_to_dims=None): def load_margin_function(self, param_name_to_dims=None):
super().load_margin_functions({}) return super().load_margin_function({})
class LinearShapeDim0MarginModel(LinearMarginModel): class LinearShapeDim0MarginModel(LinearMarginModel):
def load_margin_functions(self, margin_function_class: type = None, param_name_to_dims=None): def load_margin_function(self, margin_function_class: type = None, param_name_to_dims=None):
super().load_margin_functions({GevParams.SHAPE: [0]}) return super().load_margin_function({GevParams.SHAPE: [0]})
class LinearScaleDim0MarginModel(LinearMarginModel): class LinearScaleDim0MarginModel(LinearMarginModel):
def load_margin_functions(self, margin_function_class: type = None, param_name_to_dims=None): def load_margin_function(self, margin_function_class: type = None, param_name_to_dims=None):
super().load_margin_functions({GevParams.SCALE: [0]}) return super().load_margin_function({GevParams.SCALE: [0]})
class LinearShapeDim0and1MarginModel(LinearMarginModel): class LinearShapeDim0and1MarginModel(LinearMarginModel):
def load_margin_functions(self, margin_function_class: type = None, param_name_to_dims=None): def load_margin_function(self, margin_function_class: type = None, param_name_to_dims=None):
super().load_margin_functions({GevParams.SHAPE: [0, 1]}) return super().load_margin_function({GevParams.SHAPE: [0, 1]})
class LinearAllParametersDim0MarginModel(LinearMarginModel): class LinearAllParametersDim0MarginModel(LinearMarginModel):
def load_margin_functions(self, margin_function_class: type = None, param_name_to_dims=None): def load_margin_function(self, margin_function_class: type = None, param_name_to_dims=None):
super().load_margin_functions({GevParams.SHAPE: [0], return super().load_margin_function({GevParams.SHAPE: [0],
GevParams.LOC: [0], GevParams.LOC: [0],
GevParams.SCALE: [0]}) GevParams.SCALE: [0]})
class LinearMarginModelExample(LinearMarginModel): class LinearMarginModelExample(LinearMarginModel):
def load_margin_functions(self, margin_function_class: type = None, param_name_to_dims=None): def load_margin_function(self, margin_function_class: type = None, param_name_to_dims=None):
super().load_margin_functions({GevParams.SHAPE: [0], return super().load_margin_function({GevParams.SHAPE: [0],
GevParams.LOC: [1], GevParams.LOC: [1],
GevParams.SCALE: [0]}) GevParams.SCALE: [0]})
class LinearLocationAllDimsMarginModel(LinearMarginModel): class LinearLocationAllDimsMarginModel(LinearMarginModel):
def load_margin_functions(self, margin_function_class: type = None, param_name_to_dims=None): def load_margin_function(self, margin_function_class: type = None, param_name_to_dims=None):
super().load_margin_functions({GevParams.LOC: self.coordinates.coordinates_dims}) return super().load_margin_function({GevParams.LOC: self.coordinates.coordinates_dims})
class LinearShapeAllDimsMarginModel(LinearMarginModel): class LinearShapeAllDimsMarginModel(LinearMarginModel):
def load_margin_functions(self, margin_function_class: type = None, param_name_to_dims=None): def load_margin_function(self, margin_function_class: type = None, param_name_to_dims=None):
super().load_margin_functions({GevParams.SHAPE: self.coordinates.coordinates_dims}) return super().load_margin_function({GevParams.SHAPE: self.coordinates.coordinates_dims})
class LinearAllParametersAllDimsMarginModel(LinearMarginModel): class LinearAllParametersAllDimsMarginModel(LinearMarginModel):
def load_margin_functions(self, margin_function_class: type = None, param_name_to_dims=None): def load_margin_function(self, margin_function_class: type = None, param_name_to_dims=None):
super().load_margin_functions({GevParams.SHAPE: self.coordinates.coordinates_dims, return super().load_margin_function({GevParams.SHAPE: self.coordinates.coordinates_dims,
GevParams.LOC: self.coordinates.coordinates_dims, GevParams.LOC: self.coordinates.coordinates_dims,
GevParams.SCALE: self.coordinates.coordinates_dims}) GevParams.SCALE: self.coordinates.coordinates_dims})
class LinearStationaryMarginModel(LinearMarginModel): class LinearStationaryMarginModel(LinearMarginModel):
def load_margin_functions(self, margin_function_class: type = None, param_name_to_dims=None): def load_margin_function(self, margin_function_class: type = None, param_name_to_dims=None):
super().load_margin_functions({GevParams.SHAPE: self.coordinates.spatial_coordinates_dims, return super().load_margin_function({GevParams.SHAPE: self.coordinates.spatial_coordinates_dims,
GevParams.LOC: self.coordinates.spatial_coordinates_dims, GevParams.LOC: self.coordinates.spatial_coordinates_dims,
GevParams.SCALE: self.coordinates.spatial_coordinates_dims}) GevParams.SCALE: self.coordinates.spatial_coordinates_dims})
class LinearNonStationaryLocationMarginModel(LinearMarginModel): class LinearNonStationaryLocationMarginModel(LinearMarginModel):
def load_margin_functions(self, margin_function_class: type = None, param_name_to_dims=None): def load_margin_function(self, margin_function_class: type = None, param_name_to_dims=None):
super().load_margin_functions({GevParams.SHAPE: self.coordinates.spatial_coordinates_dims, return super().load_margin_function({GevParams.SHAPE: self.coordinates.spatial_coordinates_dims,
GevParams.LOC: self.coordinates.coordinates_dims, GevParams.LOC: self.coordinates.coordinates_dims,
GevParams.SCALE: self.coordinates.spatial_coordinates_dims}) GevParams.SCALE: self.coordinates.spatial_coordinates_dims})
...@@ -11,5 +11,5 @@ class NonStationaryRateTemporalModel(AbstractTemporalLinearMarginModel, Abstract ...@@ -11,5 +11,5 @@ class NonStationaryRateTemporalModel(AbstractTemporalLinearMarginModel, Abstract
super().__init__(*arg, **kwargs) super().__init__(*arg, **kwargs)
self.drop_duplicates = False self.drop_duplicates = False
def load_margin_functions(self, param_name_to_dims=None): def load_margin_function(self, param_name_to_dims=None):
super().load_margin_functions({ExpParams.RATE: [self.coordinates.idx_temporal_coordinates]}) return super().load_margin_function({ExpParams.RATE: [self.coordinates.idx_temporal_coordinates]})
...@@ -9,14 +9,14 @@ from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoo ...@@ -9,14 +9,14 @@ from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoo
class StationaryTemporalModel(AbstractTemporalLinearMarginModel): class StationaryTemporalModel(AbstractTemporalLinearMarginModel):
def load_margin_functions(self, param_name_to_dims=None): def load_margin_function(self, param_name_to_dims=None):
super().load_margin_functions({}) return super().load_margin_function({})
class NonStationaryLocationTemporalModel(AbstractTemporalLinearMarginModel): class NonStationaryLocationTemporalModel(AbstractTemporalLinearMarginModel):
def load_margin_functions(self, param_name_to_dims=None): def load_margin_function(self, param_name_to_dims=None):
super().load_margin_functions({GevParams.LOC: [self.coordinates.idx_temporal_coordinates]}) return super().load_margin_function({GevParams.LOC: [self.coordinates.idx_temporal_coordinates]})
@property @property
def mul(self): def mul(self):
...@@ -25,8 +25,8 @@ class NonStationaryLocationTemporalModel(AbstractTemporalLinearMarginModel): ...@@ -25,8 +25,8 @@ class NonStationaryLocationTemporalModel(AbstractTemporalLinearMarginModel):
class NonStationaryScaleTemporalModel(AbstractTemporalLinearMarginModel): class NonStationaryScaleTemporalModel(AbstractTemporalLinearMarginModel):
def load_margin_functions(self, param_name_to_dims=None): def load_margin_function(self, param_name_to_dims=None):
super().load_margin_functions({GevParams.SCALE: [self.coordinates.idx_temporal_coordinates]}) return super().load_margin_function({GevParams.SCALE: [self.coordinates.idx_temporal_coordinates]})
@property @property
def sigl(self): def sigl(self):
...@@ -35,8 +35,8 @@ class NonStationaryScaleTemporalModel(AbstractTemporalLinearMarginModel): ...@@ -35,8 +35,8 @@ class NonStationaryScaleTemporalModel(AbstractTemporalLinearMarginModel):
class NonStationaryLogScaleTemporalModel(NonStationaryScaleTemporalModel): class NonStationaryLogScaleTemporalModel(NonStationaryScaleTemporalModel):
def load_margin_functions(self, param_name_to_dims=None): def load_margin_function(self, param_name_to_dims=None):
super().load_margin_functions({GevParams.SCALE: [self.coordinates.idx_temporal_coordinates]}) return super().load_margin_function({GevParams.SCALE: [self.coordinates.idx_temporal_coordinates]})
@property @property
def siglink(self): def siglink(self):
...@@ -45,8 +45,8 @@ class NonStationaryLogScaleTemporalModel(NonStationaryScaleTemporalModel): ...@@ -45,8 +45,8 @@ class NonStationaryLogScaleTemporalModel(NonStationaryScaleTemporalModel):
class NonStationaryShapeTemporalModel(AbstractTemporalLinearMarginModel): class NonStationaryShapeTemporalModel(AbstractTemporalLinearMarginModel):
def load_margin_functions(self, param_name_to_dims=None): def load_margin_function(self, param_name_to_dims=None):
super().load_margin_functions({GevParams.SHAPE: [self.coordinates.idx_temporal_coordinates]}) return super().load_margin_function({GevParams.SHAPE: [self.coordinates.idx_temporal_coordinates]})
@property @property
def shl(self): def shl(self):
...@@ -55,9 +55,9 @@ class NonStationaryShapeTemporalModel(AbstractTemporalLinearMarginModel): ...@@ -55,9 +55,9 @@ class NonStationaryShapeTemporalModel(AbstractTemporalLinearMarginModel):
class NonStationaryLocationAndScaleTemporalModel(AbstractTemporalLinearMarginModel): class NonStationaryLocationAndScaleTemporalModel(AbstractTemporalLinearMarginModel):
def load_margin_functions(self, param_name_to_dims=None): def load_margin_function(self, param_name_to_dims=None):
super().load_margin_functions({GevParams.LOC: [self.coordinates.idx_temporal_coordinates], return super().load_margin_function({GevParams.LOC: [self.coordinates.idx_temporal_coordinates],
GevParams.SCALE: [self.coordinates.idx_temporal_coordinates]}) GevParams.SCALE: [self.coordinates.idx_temporal_coordinates]})
@property @property
def mul(self): def mul(self):
......
...@@ -2,8 +2,10 @@ from abc import ABC ...@@ -2,8 +2,10 @@ from abc import ABC
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from cached_property import cached_property
from extreme_fit.distribution.gev.gev_params import GevParams from extreme_fit.distribution.gev.gev_params import GevParams
from extreme_fit.function.margin_function.abstract_margin_function import AbstractMarginFunction
from extreme_fit.function.margin_function.parametric_margin_function import \ from extreme_fit.function.margin_function.parametric_margin_function import \
ParametricMarginFunction ParametricMarginFunction
from extreme_fit.model.margin_model.abstract_margin_model import AbstractMarginModel from extreme_fit.model.margin_model.abstract_margin_model import AbstractMarginModel
...@@ -22,11 +24,16 @@ class ParametricMarginModel(AbstractMarginModel, ABC): ...@@ -22,11 +24,16 @@ class ParametricMarginModel(AbstractMarginModel, ABC):
""" """
:param starting_point: starting coordinate for the temporal trend :param starting_point: starting coordinate for the temporal trend
""" """
super().__init__(coordinates, params_user, params_class)
self.fit_method = fit_method self.fit_method = fit_method
self.starting_point = starting_point self.starting_point = starting_point
self.margin_function = None # type: ParametricMarginFunction
self.drop_duplicates = True self.drop_duplicates = True
super().__init__(coordinates, params_user, params_class)
@cached_property
def margin_function(self) -> ParametricMarginFunction:
margin_function = super().margin_function
assert isinstance(margin_function, ParametricMarginFunction)
return margin_function
def fitmargin_from_maxima_gev(self, data: np.ndarray, df_coordinates_spat: pd.DataFrame, def fitmargin_from_maxima_gev(self, data: np.ndarray, df_coordinates_spat: pd.DataFrame,
df_coordinates_temp: pd.DataFrame) -> ResultFromSpatialExtreme: df_coordinates_temp: pd.DataFrame) -> ResultFromSpatialExtreme:
......
...@@ -15,10 +15,10 @@ class SplineMarginModel(ParametricMarginModel): ...@@ -15,10 +15,10 @@ class SplineMarginModel(ParametricMarginModel):
params_user=None): params_user=None):
super().__init__(coordinates, params_user) super().__init__(coordinates, params_user)
def load_margin_functions(self, param_name_to_dims: Dict[str, List[int]] = None, def load_margin_function(self, param_name_to_dims: Dict[str, List[int]] = None,
param_name_to_coef: Dict[str, AbstractCoef] = None, param_name_to_coef: Dict[str, AbstractCoef] = None,
param_name_to_nb_knots: Dict[str, int] = None, param_name_to_nb_knots: Dict[str, int] = None,
degree=3): degree=3):
# Default parameters # Default parameters
# todo: for the default parameters: take inspiration from the linear_margin_model # todo: for the default parameters: take inspiration from the linear_margin_model
# also implement the class method thing # also implement the class method thing
...@@ -36,27 +36,24 @@ class SplineMarginModel(ParametricMarginModel): ...@@ -36,27 +36,24 @@ class SplineMarginModel(ParametricMarginModel):
if param_name_to_nb_knots is None: if param_name_to_nb_knots is None:
param_name_to_nb_knots = {param_name: 2 for param_name in GevParams.PARAM_NAMES} param_name_to_nb_knots = {param_name: 2 for param_name in GevParams.PARAM_NAMES}
# Load sample coef return SplineMarginFunction(coordinates=self.coordinates,
self.margin_function = SplineMarginFunction(coordinates=self.coordinates, param_name_to_dims=param_name_to_dims,
param_name_to_dims=param_name_to_dims, param_name_to_coef=param_name_to_coef,
param_name_to_coef=param_name_to_coef, param_name_to_nb_knots=param_name_to_nb_knots,
param_name_to_nb_knots=param_name_to_nb_knots, degree=degree)
degree=degree)
class ConstantSplineMarginModel(SplineMarginModel): class ConstantSplineMarginModel(SplineMarginModel):
def load_margin_functions(self, param_name_to_dims: Dict[str, List[int]] = None, def load_margin_function(self, param_name_to_dims: Dict[str, List[int]] = None,
param_name_to_coef: Dict[str, AbstractCoef] = None, param_name_to_coef: Dict[str, AbstractCoef] = None,
param_name_to_nb_knots: Dict[str, int] = None, degree=3): param_name_to_nb_knots: Dict[str, int] = None, degree=3):
super().load_margin_functions({}, param_name_to_coef, param_name_to_nb_knots, return super().load_margin_function({}, param_name_to_coef, param_name_to_nb_knots, degree)
degree)
class Degree1SplineMarginModel(SplineMarginModel): class Degree1SplineMarginModel(SplineMarginModel):
def load_margin_functions(self, param_name_to_dims: Dict[str, List[int]] = None, def load_margin_function(self, param_name_to_dims: Dict[str, List[int]] = None,
param_name_to_coef: Dict[str, AbstractCoef] = None, param_name_to_coef: Dict[str, AbstractCoef] = None,
param_name_to_nb_knots: Dict[str, int] = None, degree=3): param_name_to_nb_knots: Dict[str, int] = None, degree=3):
super().load_margin_functions(param_name_to_dims, param_name_to_coef, param_name_to_nb_knots, return super().load_margin_function(param_name_to_dims, param_name_to_coef, param_name_to_nb_knots, degree=1)
degree=1)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment