Commit 1114b27b authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[quantile regression project] add test_non_stationary_run for simulations

parent 2bc5363c
No related merge requests found
Showing with 35 additions and 9 deletions
+35 -9
......@@ -12,7 +12,7 @@ from extreme_fit.function.margin_function.abstract_margin_function import Abstra
from extreme_fit.function.param_function.linear_coef import LinearCoef
from extreme_fit.function.param_function.param_function import LinearParamFunction
from extreme_fit.model.margin_model.linear_margin_model.abstract_temporal_linear_margin_model import \
AbstractTemporalLinearMarginModel
AbstractTemporalLinearMarginModel, TemporalMarginFitMethod
from extreme_fit.model.margin_model.linear_margin_model.linear_margin_model import LinearMarginModel
from extreme_fit.model.quantile_model.quantile_regression_model import AbstractQuantileRegressionModel
from extreme_fit.model.result_from_model_fit.abstract_result_from_model_fit import AbstractResultFromModelFit
......@@ -23,7 +23,8 @@ from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
class QuantileEstimatorFromMargin(LinearMarginEstimator, AbstractQuantileEstimator):
def __init__(self, dataset: AbstractDataset, quantile, margin_model_class: type):
super().__init__(dataset=dataset, quantile=quantile, margin_model=margin_model_class(dataset.coordinates))
margin_model = margin_model_class(dataset.coordinates, fit_method=TemporalMarginFitMethod.extremes_fevd_mle)
super().__init__(dataset=dataset, quantile=quantile, margin_model=margin_model)
@cached_property
def function_from_fit(self) -> AbstractQuantileFunction:
......
......@@ -7,12 +7,12 @@ from extreme_fit.distribution.gev.gev_params import GevParams
class LinearMarginModel(ParametricMarginModel):
@classmethod
def from_coef_list(cls, coordinates, gev_param_name_to_coef_list):
def from_coef_list(cls, coordinates, gev_param_name_to_coef_list, **kwargs):
params = {}
for gev_param_name in GevParams.PARAM_NAMES:
for idx, coef in enumerate(gev_param_name_to_coef_list[gev_param_name], -1):
params[(gev_param_name, idx)] = coef
return cls(coordinates, params_sample=params, params_start_fit=params)
return cls(coordinates, params_sample=params, params_start_fit=params, **kwargs)
def load_margin_functions(self, gev_param_name_to_dims=None):
assert gev_param_name_to_dims is not None, 'LinearMarginModel cannot be used for sampling/fitting \n' \
......
......@@ -7,7 +7,10 @@ from cached_property import cached_property
from extreme_fit.distribution.gev.gev_params import GevParams
from extreme_fit.estimator.quantile_estimator.abstract_quantile_estimator import AbstractQuantileEstimator
from extreme_fit.model.margin_model.abstract_margin_model import AbstractMarginModel
from extreme_fit.model.margin_model.linear_margin_model.temporal_linear_margin_models import StationaryTemporalModel
from extreme_fit.model.margin_model.linear_margin_model.abstract_temporal_linear_margin_model import \
TemporalMarginFitMethod
from extreme_fit.model.margin_model.linear_margin_model.temporal_linear_margin_models import StationaryTemporalModel, \
NonStationaryLocationTemporalModel
from projects.quantile_regression_vs_evt.AbstractSimulation import AbstractSimulation
from spatio_temporal_dataset.spatio_temporal_observations.abstract_spatio_temporal_observations import \
AbstractSpatioTemporalObservations
......@@ -52,4 +55,17 @@ class StationarySimulation(GevSimulation):
GevParams.SHAPE: [0],
GevParams.SCALE: [1],
}
return StationaryTemporalModel.from_coef_list(coordinates, gev_param_name_to_coef_list)
return StationaryTemporalModel.from_coef_list(coordinates, gev_param_name_to_coef_list,
fit_method=TemporalMarginFitMethod.extremes_fevd_mle)
class NonStationaryLocationSimulation(GevSimulation):
def create_model(self, coordinates):
gev_param_name_to_coef_list = {
GevParams.LOC: [0, 1],
GevParams.SHAPE: [0],
GevParams.SCALE: [1],
}
return NonStationaryLocationTemporalModel.from_coef_list(coordinates, gev_param_name_to_coef_list,
fit_method=TemporalMarginFitMethod.extremes_fevd_mle)
import unittest
from extreme_fit.model.margin_model.linear_margin_model.temporal_linear_margin_models import StationaryTemporalModel
from extreme_fit.model.quantile_model.quantile_regression_model import ConstantQuantileRegressionModel
from projects.quantile_regression_vs_evt.GevSimulation import GevSimulation, StationarySimulation
from extreme_fit.model.margin_model.linear_margin_model.temporal_linear_margin_models import StationaryTemporalModel, \
NonStationaryLocationTemporalModel
from extreme_fit.model.quantile_model.quantile_regression_model import ConstantQuantileRegressionModel, \
TemporalCoordinatesQuantileRegressionModel
from projects.quantile_regression_vs_evt.GevSimulation import GevSimulation, StationarySimulation, \
NonStationaryLocationSimulation
class TestGevSimulations(unittest.TestCase):
......@@ -13,6 +16,12 @@ class TestGevSimulations(unittest.TestCase):
model_classes=[StationaryTemporalModel, ConstantQuantileRegressionModel])
simulation.plot_error_for_last_year_quantile(self.DISPLAY)
def test_non_stationary_run(self):
simulation = NonStationaryLocationSimulation(nb_time_series=1, quantile=0.5, time_series_lengths=[50, 60],
model_classes=[NonStationaryLocationTemporalModel,
TemporalCoordinatesQuantileRegressionModel])
simulation.plot_error_for_last_year_quantile(self.DISPLAY)
if __name__ == '__main__':
unittest.main()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment