Commit b37591cb authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[quantile regression project] final modification. attempt to add...

[quantile regression project]  final modification. attempt to add multiprocessing. but results are not conclusive
parent 6c753d0d
No related merge requests found
Showing with 23 additions and 8 deletions
+23 -8
from multiprocessing.dummy import Pool
from typing import Dict, List from typing import Dict, List
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from collections import OrderedDict from collections import OrderedDict
...@@ -11,7 +12,7 @@ from extreme_fit.estimator.quantile_estimator.quantile_estimator_from_regression ...@@ -11,7 +12,7 @@ from extreme_fit.estimator.quantile_estimator.quantile_estimator_from_regression
from extreme_fit.model.margin_model.linear_margin_model.abstract_temporal_linear_margin_model import \ from extreme_fit.model.margin_model.linear_margin_model.abstract_temporal_linear_margin_model import \
AbstractTemporalLinearMarginModel AbstractTemporalLinearMarginModel
from extreme_fit.model.quantile_model.quantile_regression_model import AbstractQuantileRegressionModel from extreme_fit.model.quantile_model.quantile_regression_model import AbstractQuantileRegressionModel
from root_utils import get_display_name_from_object_type from root_utils import get_display_name_from_object_type, NB_CORES
from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
from spatio_temporal_dataset.coordinates.temporal_coordinates.generated_temporal_coordinates import \ from spatio_temporal_dataset.coordinates.temporal_coordinates.generated_temporal_coordinates import \
ConsecutiveTemporalCoordinates ConsecutiveTemporalCoordinates
...@@ -64,14 +65,26 @@ class AbstractSimulation(object): ...@@ -64,14 +65,26 @@ class AbstractSimulation(object):
@cached_property @cached_property
def model_class_to_time_series_length_to_estimators(self): def model_class_to_time_series_length_to_estimators(self):
d = OrderedDict() d = OrderedDict()
for model_class in self.models_classes: for i, model_class in enumerate(self.models_classes, 1):
d_sub = OrderedDict() d_sub = OrderedDict()
for time_series_length, observation_list in self.time_series_length_to_observation_list.items(): for time_series_length, observation_list in self.time_series_length_to_observation_list.items():
print(model_class, '{}/{}'.format(i, len(self.models_classes)), time_series_length)
coordinates = self.time_series_length_to_coordinates[time_series_length] coordinates = self.time_series_length_to_coordinates[time_series_length]
estimators = []
for observations in observation_list: arguments = [
estimators.append(self.get_fitted_quantile_estimator(model_class, observations, coordinates, [model_class, observations, coordinates, self.quantile_estimator]
self.quantile_estimator)) for observations in observation_list
]
if self.multiprocessing:
raise NotImplementedError('The multiprocessing seems slow compared to the other,'
'maybe it would be best to call an external function rather than'
'a method, but this methods is override in other classes...')
# with Pool(NB_CORES) as p:
# estimators = p.starmap(self.get_fitted_quantile_estimator, arguments)
else:
estimators = []
for argument in arguments:
estimators.append(self.get_fitted_quantile_estimator(*argument))
d_sub[time_series_length] = estimators d_sub[time_series_length] = estimators
d[model_class] = d_sub d[model_class] = d_sub
return d return d
......
...@@ -12,7 +12,7 @@ from spatio_temporal_dataset.coordinates.transformed_coordinates.transformation. ...@@ -12,7 +12,7 @@ from spatio_temporal_dataset.coordinates.transformed_coordinates.transformation.
nb_time_series = 10 nb_time_series = 10
quantile = 0.98 quantile = 0.98
time_series_lengths = [50, 100, 200] time_series_lengths = [50, 100, 200]
transformation_class = [IdentityTransformation, CenteredScaledNormalization][1] transformation_class = [IdentityTransformation, CenteredScaledNormalization][0]
model_classes = [ model_classes = [
NonStationaryLocationTemporalModel, NonStationaryLocationTemporalModel,
TemporalCoordinatesQuantileRegressionModel, TemporalCoordinatesQuantileRegressionModel,
...@@ -27,5 +27,6 @@ simulation = simulation_class(nb_time_series=nb_time_series, ...@@ -27,5 +27,6 @@ simulation = simulation_class(nb_time_series=nb_time_series,
quantile=quantile, quantile=quantile,
time_series_lengths=time_series_lengths, time_series_lengths=time_series_lengths,
model_classes=model_classes, model_classes=model_classes,
transformation_class=transformation_class) transformation_class=transformation_class,
multiprocessing=False)
simulation.plot_error_for_last_year_quantile() simulation.plot_error_for_last_year_quantile()
...@@ -47,6 +47,7 @@ class TestExpSimulations(unittest.TestCase): ...@@ -47,6 +47,7 @@ class TestExpSimulations(unittest.TestCase):
class TestExpSimulationsDailyDataModels(unittest.TestCase): class TestExpSimulationsDailyDataModels(unittest.TestCase):
DISPLAY = False DISPLAY = False
# Warning this method is quite long...
def test_stationary_run_daily_data_quantile_regression_model(self): def test_stationary_run_daily_data_quantile_regression_model(self):
simulation = StationaryExpSimulation(nb_time_series=1, quantile=0.5, time_series_lengths=[50, 60], simulation = StationaryExpSimulation(nb_time_series=1, quantile=0.5, time_series_lengths=[50, 60],
model_classes=[ConstantQuantileRegressionModelOnDailyData]) model_classes=[ConstantQuantileRegressionModelOnDailyData])
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment