Commit 927b62f6 authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[quantile regression project] refactor simulations. add transformed coordinate...

[quantile regression project] refactor simulations. add transformed coordinate to the quantile functions
parent ec20ef1a
No related merge requests found
Showing with 27 additions and 24 deletions
+27 -24
import numpy as np
from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
...@@ -6,3 +8,5 @@ class AbstractFunction(object): ...@@ -6,3 +8,5 @@ class AbstractFunction(object):
def __init__(self, coordinates: AbstractCoordinates): def __init__(self, coordinates: AbstractCoordinates):
self.coordinates = coordinates self.coordinates = coordinates
def transform(self, coordinate: np.ndarray) -> np.ndarray:
return self.coordinates.transformation.transform_array(coordinate)
...@@ -10,7 +10,11 @@ from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoo ...@@ -10,7 +10,11 @@ from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoo
class AbstractQuantileFunction(AbstractFunction): class AbstractQuantileFunction(AbstractFunction):
def get_quantile(self, coordinate: np.ndarray) -> float: def get_quantile(self, coordinate: np.ndarray, is_transformed: bool = True) -> float:
transformed_coordinate = coordinate if is_transformed else self.transform(coordinate)
return self._get_quantile(transformed_coordinate)
def _get_quantile(self, coordinate: np.ndarray):
raise NotImplementedError raise NotImplementedError
def visualize(self, show=True): def visualize(self, show=True):
...@@ -43,7 +47,7 @@ class QuantileFunctionFromParamFunction(AbstractQuantileFunction): ...@@ -43,7 +47,7 @@ class QuantileFunctionFromParamFunction(AbstractQuantileFunction):
super().__init__(coordinates) super().__init__(coordinates)
self.param_function = param_function self.param_function = param_function
def get_quantile(self, coordinate: np.ndarray) -> float: def _get_quantile(self, coordinate: np.ndarray) -> float:
return self.param_function.get_param_value(coordinate) return self.param_function.get_param_value(coordinate)
...@@ -54,6 +58,6 @@ class QuantileFunctionFromMarginFunction(AbstractQuantileFunction): ...@@ -54,6 +58,6 @@ class QuantileFunctionFromMarginFunction(AbstractQuantileFunction):
self.margin_function = margin_function self.margin_function = margin_function
self.quantile = quantile self.quantile = quantile
def get_quantile(self, coordinate: np.ndarray) -> float: def _get_quantile(self, coordinate: np.ndarray) -> float:
gev_params = self.margin_function.get_gev_params(coordinate) gev_params = self.margin_function.get_gev_params(coordinate)
return gev_params.quantile(self.quantile) return gev_params.quantile(self.quantile)
...@@ -32,8 +32,7 @@ class IndependentMarginFunction(AbstractMarginFunction): ...@@ -32,8 +32,7 @@ class IndependentMarginFunction(AbstractMarginFunction):
gev_params[gev_param_name] = param_function.get_param_value(transformed_coordinate) gev_params[gev_param_name] = param_function.get_param_value(transformed_coordinate)
return GevParams.from_dict(gev_params) return GevParams.from_dict(gev_params)
def transform(self, coordinate: np.ndarray) -> np.ndarray:
return self.coordinates.transformation.transform_array(coordinate)
...@@ -16,16 +16,12 @@ from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoo ...@@ -16,16 +16,12 @@ from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoo
from spatio_temporal_dataset.coordinates.temporal_coordinates.generated_temporal_coordinates import \ from spatio_temporal_dataset.coordinates.temporal_coordinates.generated_temporal_coordinates import \
ConsecutiveTemporalCoordinates ConsecutiveTemporalCoordinates
from spatio_temporal_dataset.coordinates.transformed_coordinates.transformation.abstract_transformation import \ from spatio_temporal_dataset.coordinates.transformed_coordinates.transformation.abstract_transformation import \
CenteredScaledNormalization, IdentityTransformation CenteredScaledNormalization
from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
from spatio_temporal_dataset.spatio_temporal_observations.abstract_spatio_temporal_observations import \ from spatio_temporal_dataset.spatio_temporal_observations.abstract_spatio_temporal_observations import \
AbstractSpatioTemporalObservations AbstractSpatioTemporalObservations
class Coordinates(object):
pass
class AbstractSimulation(object): class AbstractSimulation(object):
def __init__(self, nb_time_series, quantile, time_series_lengths=None, multiprocessing=False, def __init__(self, nb_time_series, quantile, time_series_lengths=None, multiprocessing=False,
...@@ -41,31 +37,31 @@ class AbstractSimulation(object): ...@@ -41,31 +37,31 @@ class AbstractSimulation(object):
raise NotImplementedError raise NotImplementedError
@cached_property @cached_property
def time_serie_length_to_observation_list(self) -> Dict[int, List[AbstractSpatioTemporalObservations]]: def time_series_length_to_observation_list(self) -> Dict[int, List[AbstractSpatioTemporalObservations]]:
d = OrderedDict() d = OrderedDict()
for length in self.time_series_lengths: for length in self.time_series_lengths:
d[length] = self.generate_all_observation(self.nb_time_series, length) d[length] = self.generate_all_observation(self.nb_time_series, length)
return d return d
@cached_property @cached_property
def time_serie_length_to_coordinates(self) -> Dict[int, AbstractCoordinates]: def time_series_length_to_coordinates(self) -> Dict[int, AbstractCoordinates]:
d = OrderedDict() d = OrderedDict()
for length in self.time_series_lengths: for length in self.time_series_lengths:
d[length] = ConsecutiveTemporalCoordinates.from_nb_temporal_steps(length, d[length] = ConsecutiveTemporalCoordinates.\
transformation_class=self.transformation_class) from_nb_temporal_steps(length, transformation_class=self.transformation_class)
return d return d
@cached_property @cached_property
def model_class_to_time_serie_length_to_estimator_fitted(self): def model_class_to_time_series_length_to_estimators(self):
d = OrderedDict() d = OrderedDict()
for model_class in self.models_classes: for model_class in self.models_classes:
d_sub = OrderedDict() d_sub = OrderedDict()
for time_serie_length, observation_list in self.time_serie_length_to_observation_list.items(): for time_series_length, observation_list in self.time_series_length_to_observation_list.items():
coordinates = self.time_serie_length_to_coordinates[time_serie_length] coordinates = self.time_series_length_to_coordinates[time_series_length]
estimators_fitted = [] estimators = []
for observations in observation_list: for observations in observation_list:
estimators_fitted.append(self.get_fitted_quantile_estimator(model_class, observations, coordinates)) estimators.append(self.get_fitted_quantile_estimator(model_class, observations, coordinates))
d_sub[time_serie_length] = estimators_fitted d_sub[time_series_length] = estimators
d[model_class] = d_sub d[model_class] = d_sub
return d return d
...@@ -83,7 +79,7 @@ class AbstractSimulation(object): ...@@ -83,7 +79,7 @@ class AbstractSimulation(object):
@cached_property @cached_property
def model_class_to_error_last_year_quantile(self): def model_class_to_error_last_year_quantile(self):
d = OrderedDict() d = OrderedDict()
for model_class, d_sub in self.model_class_to_time_serie_length_to_estimator_fitted.items(): for model_class, d_sub in self.model_class_to_time_series_length_to_estimators.items():
length_to_error_values = OrderedDict() length_to_error_values = OrderedDict()
for length, estimators_fitted in d_sub.items(): for length, estimators_fitted in d_sub.items():
errors = self.compute_errors(length, estimators_fitted) errors = self.compute_errors(length, estimators_fitted)
......
...@@ -23,7 +23,7 @@ class GevSimulation(AbstractSimulation): ...@@ -23,7 +23,7 @@ class GevSimulation(AbstractSimulation):
def time_series_lengths_to_margin_model(self) -> Dict[int, AbstractMarginModel]: def time_series_lengths_to_margin_model(self) -> Dict[int, AbstractMarginModel]:
d = OrderedDict() d = OrderedDict()
for length in self.time_series_lengths: for length in self.time_series_lengths:
coordinates = self.time_serie_length_to_coordinates[length] coordinates = self.time_series_length_to_coordinates[length]
d[length] = self.create_model(coordinates) d[length] = self.create_model(coordinates)
return d return d
...@@ -31,13 +31,13 @@ class GevSimulation(AbstractSimulation): ...@@ -31,13 +31,13 @@ class GevSimulation(AbstractSimulation):
raise NotImplementedError raise NotImplementedError
def generate_all_observation(self, nb_time_series, length) -> List[AbstractSpatioTemporalObservations]: def generate_all_observation(self, nb_time_series, length) -> List[AbstractSpatioTemporalObservations]:
coordinates = self.time_serie_length_to_coordinates[length] coordinates = self.time_series_length_to_coordinates[length]
margin_model = self.time_series_lengths_to_margin_model[length] margin_model = self.time_series_lengths_to_margin_model[length]
return [MarginAnnualMaxima.from_sampling(nb_obs=1, coordinates=coordinates, margin_model=margin_model) return [MarginAnnualMaxima.from_sampling(nb_obs=1, coordinates=coordinates, margin_model=margin_model)
for _ in range(nb_time_series)] for _ in range(nb_time_series)]
def compute_errors(self, length: int, estimators: List[AbstractQuantileEstimator]): def compute_errors(self, length: int, estimators: List[AbstractQuantileEstimator]):
coordinates = self.time_serie_length_to_coordinates[length] coordinates = self.time_series_length_to_coordinates[length]
last_coordinate = coordinates.coordinates_values()[-1] last_coordinate = coordinates.coordinates_values()[-1]
# Compute true value # Compute true value
margin_model = self.time_series_lengths_to_margin_model[length] margin_model = self.time_series_lengths_to_margin_model[length]
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment