diff --git a/extreme_estimator/extreme_models/margin_model/linear_margin_model.py b/extreme_estimator/extreme_models/margin_model/linear_margin_model.py index 27cc3c6fa9314744415b91b40712b8dcc20cfbfa..9408fb72316c5365560ad874ad16e2190fe89928 100644 --- a/extreme_estimator/extreme_models/margin_model/linear_margin_model.py +++ b/extreme_estimator/extreme_models/margin_model/linear_margin_model.py @@ -61,25 +61,25 @@ class ConstantMarginModel(LinearMarginModel): super().load_margin_functions({}) -class LinearShapeDim1MarginModel(LinearMarginModel): +class LinearShapeDim0MarginModel(LinearMarginModel): def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_dims=None): super().load_margin_functions({GevParams.SHAPE: [0]}) -class LinearScaleDim1MarginModel(LinearMarginModel): +class LinearScaleDim0MarginModel(LinearMarginModel): def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_dims=None): super().load_margin_functions({GevParams.SCALE: [0]}) -class LinearShapeDim1and2MarginModel(LinearMarginModel): +class LinearShapeDim0and1MarginModel(LinearMarginModel): def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_dims=None): super().load_margin_functions({GevParams.SHAPE: [0, 1]}) -class LinearAllParametersDim1MarginModel(LinearMarginModel): +class LinearAllParametersDim0MarginModel(LinearMarginModel): def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_dims=None): super().load_margin_functions({GevParams.SHAPE: [0], diff --git a/extreme_estimator/extreme_models/margin_model/spline_margin_model.py b/extreme_estimator/extreme_models/margin_model/spline_margin_model.py index d0f3b3f001e4a9ab98b44a100c310544bb594176..e7777b6ff14dcee1d67fd84a85b763efd7c2ecae 100644 --- a/extreme_estimator/extreme_models/margin_model/spline_margin_model.py +++ b/extreme_estimator/extreme_models/margin_model/spline_margin_model.py @@ -1,21 +1,11 @@ -import numpy as np from typing import Dict, List -import numpy as np - -from extreme_estimator.extreme_models.margin_model.margin_function.parametric_margin_function import \ - ParametricMarginFunction +from extreme_estimator.extreme_models.margin_model.margin_function.spline_margin_function import SplineMarginFunction from extreme_estimator.extreme_models.margin_model.param_function.abstract_coef import AbstractCoef -from extreme_estimator.extreme_models.margin_model.param_function.param_function import AbstractParamFunction, \ - SplineParamFunction from extreme_estimator.extreme_models.margin_model.param_function.spline_coef import SplineCoef, KnotCoef, \ PolynomialCoef -from extreme_estimator.margin_fits.gev.gev_params import GevParams -from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates - -from extreme_estimator.extreme_models.margin_model.abstract_margin_model import AbstractMarginModel -from extreme_estimator.extreme_models.margin_model.margin_function.spline_margin_function import SplineMarginFunction from extreme_estimator.extreme_models.margin_model.parametric_margin_model import ParametricMarginModel +from extreme_estimator.margin_fits.gev.gev_params import GevParams from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates @@ -30,6 +20,8 @@ class SplineMarginModel(ParametricMarginModel): gev_param_name_to_nb_knots: Dict[str, int] = None, degree=3): # Default parameters + # todo: for the default parameters: take inspiration from the linear_margin_model + # also implement the class method thing if gev_param_name_to_dims is None: gev_param_name_to_dims = {gev_param_name: self.coordinates.coordinates_dims for gev_param_name in GevParams.PARAM_NAMES}