Commit fcaa2d5e authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[SCM][NON STATIONARITY] improve visualization. add memoization. add mu1 parameter trend.

parent eea97371
No related merge requests found
Showing with 81 additions and 26 deletions
+81 -26
......@@ -18,7 +18,7 @@ class CrocusVariable(AbstractVariable):
class CrocusSweVariable(CrocusVariable):
NAME = 'Snow Water Equivalent'
UNIT = 'kg/m2 or mm'
UNIT = 'kg per m2 or mm'
def __init__(self, dataset, altitude):
super().__init__(dataset, altitude, 'SWE_1DY_ISBA')
......
......@@ -22,7 +22,7 @@ class SafranSnowfallVariable(AbstractVariable):
"""
NAME = 'Snowfall'
UNIT = 'kg/m2 or mm'
UNIT = 'kg per m2 or mm'
def __init__(self, dataset, altitude, nb_consecutive_days_of_snowfall=1, keyword='Snowf'):
super().__init__(dataset, altitude)
......
......@@ -12,14 +12,15 @@ SCM_EXTENDED_STUDIES = [ExtendedSafranSnowfall, ExtendedCrocusSwe, ExtendedCrocu
SCM_STUDY_TO_EXTENDED_STUDY = OrderedDict(zip(SCM_STUDIES, SCM_EXTENDED_STUDIES))
def study_iterator(study_class, only_first_one=False, both_altitude=False, verbose=True):
def study_iterator(study_class, only_first_one=False, both_altitude=False, verbose=True, altitudes=None):
all_studies = []
is_safran_study = study_class in [SafranSnowfall, ExtendedSafranSnowfall]
nb_days = [1] if is_safran_study else [1]
if verbose:
print('Loading studies....')
for nb_day in nb_days:
for alti in [1800]:
altis = [1800] if altitudes is None else altitudes
for alti in altis:
if verbose:
print('alti: {}, nb_day: {}'.format(alti, nb_day))
study = study_class(altitude=alti, nb_consecutive_days=nb_day) if is_safran_study \
......@@ -96,12 +97,12 @@ def complete_analysis(only_first_one=False):
def trend_analysis():
save_to_file = False
save_to_file = True
only_first_one = True
for study_class in [CrocusDepth, SafranSnowfall, SafranRainfall, SafranTemperature][1:2]:
for study in study_iterator(study_class, only_first_one=only_first_one):
for study_class in [CrocusSwe, CrocusDepth, SafranSnowfall, SafranRainfall, SafranTemperature][:3]:
for study in study_iterator(study_class, only_first_one=only_first_one, altitudes=[1800, 2100, 2400, 2700]):
study_visualizer = StudyVisualizer(study, save_to_file=save_to_file)
study_visualizer.visualize_temporal_trend_relevance(complete_analysis=False)
study_visualizer.visualize_temporal_trend_relevance(complete_analysis=not only_first_one)
if __name__ == '__main__':
# annual_mean_vizu_compare_durand_study(safran=True, take_mean_value=True, altitude=2100)
......
from typing import Union
from extreme_estimator.estimator.abstract_estimator import AbstractEstimator
from scipy.stats import chi2
from extreme_estimator.estimator.full_estimator.abstract_full_estimator import \
FullEstimatorInASingleStepWithSmoothMargin
from extreme_estimator.estimator.margin_estimator.abstract_margin_estimator import LinearMarginEstimator
FullEstimatorInASingleStepWithSmoothMargin, AbstractFullEstimator
from extreme_estimator.estimator.margin_estimator.abstract_margin_estimator import LinearMarginEstimator, \
AbstractMarginEstimator
from extreme_estimator.extreme_models.margin_model.linear_margin_model import \
LinearAllParametersTwoFirstCoordinatesMarginModel, LinearAllTwoStatialCoordinatesLocationLinearMarginModel, \
LinearStationaryMarginModel, LinearNonStationaryLocationMarginModel
from extreme_estimator.extreme_models.margin_model.margin_function.linear_margin_function import LinearMarginFunction
from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
from utils import get_display_name_from_object_type
class AbstractNonStationaryTrendTest(object):
......@@ -18,37 +23,73 @@ class AbstractNonStationaryTrendTest(object):
self.estimator_class = estimator_class
self.stationary_margin_model_class = stationary_margin_model_class
self.non_stationary_margin_model_class = non_stationary_margin_model_class
def load_estimator(self, margin_model) -> AbstractEstimator:
# Compute a dictionary that maps couple (margin model class, starting point)
# to the corresponding fitted estimator
self._margin_model_class_and_starting_point_to_estimator = {}
def get_estimator(self, margin_model_class, starting_point) -> Union[
AbstractFullEstimator, AbstractMarginEstimator]:
if (margin_model_class, starting_point) not in self._margin_model_class_and_starting_point_to_estimator:
margin_model = margin_model_class(coordinates=self.dataset.coordinates, starting_point=starting_point)
estimator = self._load_estimator(margin_model)
estimator.fit()
self._margin_model_class_and_starting_point_to_estimator[(margin_model_class, starting_point)] = estimator
return self._margin_model_class_and_starting_point_to_estimator[(margin_model_class, starting_point)]
def _load_estimator(self, margin_model) -> Union[AbstractFullEstimator, AbstractMarginEstimator]:
return self.estimator_class(self.dataset, margin_model)
def get_metric(self, margin_model_class, starting_point):
margin_model = margin_model_class(coordinates=self.dataset.coordinates, starting_point=starting_point)
estimator = self.load_estimator(margin_model) # type: AbstractEstimator
estimator.fit()
estimator = self.get_estimator(margin_model_class, starting_point)
metric = estimator.result_from_fit.__getattribute__(self.RESULT_ATTRIBUTE_METRIC)
assert isinstance(metric, float)
return metric
def get_mu1(self, starting_point):
# for the non stationary model gives the mu1 parameters that was fitted
estimator = self.get_estimator(self.non_stationary_margin_model_class, starting_point)
margin_function = estimator.margin_function_fitted # type: LinearMarginFunction
assert isinstance(margin_function, LinearMarginFunction)
return margin_function.mu1_temporal_trend
def visualize(self, ax, complete_analysis=True):
# Define the year_min and year_max for the starting point
if complete_analysis:
year_min, year_max, step = 1960, 1990, 1
else:
year_min, year_max, step = 1960, 1990, 10
# Fit the stationary model
stationary_metric = self.get_metric(self.stationary_margin_model_class, starting_point=None)
# Fit the non stationary model
years = list(range(year_min, year_max + 1, step))
# Plot differences
stationary_metric = self.get_metric(self.stationary_margin_model_class, starting_point=None)
non_stationary_metrics = [self.get_metric(self.non_stationary_margin_model_class, year) for year in years]
difference = [m - stationary_metric for m in non_stationary_metrics]
# Plot some lines
ax.axhline(y=0, xmin=year_min, xmax=year_max)
# Significative line
significative_deviance = chi2.ppf(q=0.95, df=1)
ax.axhline(y=significative_deviance, xmin=year_min, xmax=year_max)
# todo: plot the line corresponding to the significance 0.05
ax.plot(years, difference, 'ro-')
color_difference = 'b'
ax.plot(years, difference, color_difference + 'o-')
ax.set_ylabel(self.RESULT_ATTRIBUTE_METRIC + ' difference', color=color_difference)
# Plot the mu1 parameter
mu1_trends = [self.get_mu1(starting_point=year) for year in years]
ax2 = ax.twinx()
color_mu1 = 'c'
ax.plot(years, mu1_trends, color_mu1 + 'o-')
ax2.set_ylabel('mu1 parameter', color=color_mu1)
# Plot zero line
ax.plot(years, [0 for _ in years], 'k-', label='zero line')
# Plot significative line corresponding to 0.05 relevance
alpha = 0.05
significative_deviance = chi2.ppf(q=1 - alpha, df=1)
ax.plot(years, [significative_deviance for _ in years], 'g-', label='significative line')
# Add some informations about the graph
ax.set_xlabel('year')
ax.set_title(self.display_name)
ax.legend()
@property
def display_name(self):
raise NotImplementedError
class IndependenceLocationTrendTest(AbstractNonStationaryTrendTest):
......@@ -65,6 +106,10 @@ class ConditionalIndedendenceLocationTrendTest(AbstractNonStationaryTrendTest):
stationary_margin_model_class=LinearStationaryMarginModel,
non_stationary_margin_model_class=LinearNonStationaryLocationMarginModel)
@property
def display_name(self):
return get_display_name_from_object_type('conditional independence')
class MaxStableLocationTrendTest(AbstractNonStationaryTrendTest):
......@@ -75,5 +120,9 @@ class MaxStableLocationTrendTest(AbstractNonStationaryTrendTest):
non_stationary_margin_model_class=LinearNonStationaryLocationMarginModel)
self.max_stable_model = max_stable_model
def load_estimator(self, margin_model) -> AbstractEstimator:
def _load_estimator(self, margin_model) -> AbstractEstimator:
return self.estimator_class(self.dataset, margin_model, self.max_stable_model)
@property
def display_name(self):
return get_display_name_from_object_type(type(self.max_stable_model))
......@@ -6,6 +6,7 @@ from extreme_estimator.extreme_models.margin_model.param_function.abstract_coef
from extreme_estimator.extreme_models.margin_model.param_function.linear_coef import LinearCoef
from extreme_estimator.extreme_models.margin_model.param_function.param_function import AbstractParamFunction, \
LinearParamFunction
from extreme_estimator.margin_fits.extreme_params import ExtremeParams
from extreme_estimator.margin_fits.gev.gev_params import GevParams
from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
......@@ -58,6 +59,10 @@ class LinearMarginFunction(ParametricMarginFunction):
coef_dict.update(coef.coef_dict(dims, self.idx_to_coefficient_name(self.coordinates)))
return coef_dict
@property
def mu1_temporal_trend(self):
return self.coef_dict[LinearCoef.coef_template_str(ExtremeParams.LOC, AbstractCoordinates.COORDINATE_T).format(1)]
@property
def form_dict(self) -> Dict[str, str]:
form_dict = {}
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment