Commit b4c10545 authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[projection snowfall] modify interval limits. plot both merge & mean aggregation.

parent 51e32a62
No related merge requests found
Showing with 45 additions and 35 deletions
+45 -35
...@@ -53,15 +53,15 @@ def plot_nb_data_one_line(ax, gcm, scenario, left_limits, right_limits, first_sc ...@@ -53,15 +53,15 @@ def plot_nb_data_one_line(ax, gcm, scenario, left_limits, right_limits, first_sc
linestyle = get_linestyle_from_scenario(scenario) linestyle = get_linestyle_from_scenario(scenario)
# Filter out the zero value # Filter out the zero value
nb_data, left_limits = np.array(nb_data), np.array(left_limits) nb_data, right_limits = np.array(nb_data), np.array(right_limits)
ind = np.array(nb_data) > 0 ind = np.array(nb_data) > 0
nb_data, left_limits = nb_data[ind], left_limits[ind] nb_data, right_limits = nb_data[ind], right_limits[ind]
# For the legend # For the legend
if (len(nb_data) > 0) and first_scenario: if (len(nb_data) > 0) and first_scenario:
ax.plot(left_limits[0], nb_data[0], color=color, linestyle='solid', label=gcm) ax.plot(right_limits[0], nb_data[0], color=color, linestyle='solid', label=gcm)
ax.plot(left_limits, nb_data, linestyle=linestyle, color=color, marker='o') ax.plot(right_limits, nb_data, linestyle=linestyle, color=color, marker='o')
def plot_nb_data(is_temperature_interval, is_shift_interval): def plot_nb_data(is_temperature_interval, is_shift_interval):
...@@ -75,7 +75,7 @@ def plot_nb_data(is_temperature_interval, is_shift_interval): ...@@ -75,7 +75,7 @@ def plot_nb_data(is_temperature_interval, is_shift_interval):
ax.legend() ax.legend()
ticks_labels = get_ticks_labels_for_interval(is_temperature_interval, is_shift_interval) ticks_labels = get_ticks_labels_for_interval(is_temperature_interval, is_shift_interval)
ax.set_xticks(left_limit) ax.set_xticks(right_limit)
ax.set_xticklabels(ticks_labels) ax.set_xticklabels(ticks_labels)
# ax.set_xlabel('Interval') # ax.set_xlabel('Interval')
ax.set_ylabel('Nb of Maxima') ax.set_ylabel('Nb of Maxima')
...@@ -101,8 +101,8 @@ def get_interval_limits(is_temperature_interval, is_shift_interval): ...@@ -101,8 +101,8 @@ def get_interval_limits(is_temperature_interval, is_shift_interval):
year_max = [2050 + shift * i for i in range(nb)] year_max = [2050 + shift * i for i in range(nb)]
left_limit, right_limit = year_min, year_max left_limit, right_limit = year_min, year_max
if not is_shift_interval: if not is_shift_interval:
max_interval_right = max(right_limit) min_interval_left = min(left_limit)
right_limit = [max_interval_right for _ in left_limit] left_limit = [min_interval_left for _ in right_limit]
return left_limit, right_limit return left_limit, right_limit
......
...@@ -23,6 +23,7 @@ class VisualizerMerge(AltitudesStudiesVisualizerForNonStationaryModels): ...@@ -23,6 +23,7 @@ class VisualizerMerge(AltitudesStudiesVisualizerForNonStationaryModels):
merge_function=np.median): merge_function=np.median):
self.merge_function = merge_function self.merge_function = merge_function
self.visualizers = visualizers self.visualizers = visualizers
assert len(visualizers) > 0
super().__init__(studies=visualizers[0].studies, model_classes=model_classes, show=show, massif_names=massif_names, super().__init__(studies=visualizers[0].studies, model_classes=model_classes, show=show, massif_names=massif_names,
fit_method=fit_method, temporal_covariate_for_fit=temporal_covariate_for_fit, fit_method=fit_method, temporal_covariate_for_fit=temporal_covariate_for_fit,
display_only_model_that_pass_anderson_test=display_only_model_that_pass_anderson_test, display_only_model_that_pass_anderson_test=display_only_model_that_pass_anderson_test,
......
...@@ -13,6 +13,8 @@ from extreme_trend.ensemble_fit.independent_ensemble_fit.independent_ensemble_fi ...@@ -13,6 +13,8 @@ from extreme_trend.ensemble_fit.independent_ensemble_fit.independent_ensemble_fi
from extreme_trend.ensemble_fit.visualizer_for_projection_ensemble import VisualizerForProjectionEnsemble from extreme_trend.ensemble_fit.visualizer_for_projection_ensemble import VisualizerForProjectionEnsemble
from extreme_trend.one_fold_fit.altitude_group import get_altitude_class_from_altitudes, \ from extreme_trend.one_fold_fit.altitude_group import get_altitude_class_from_altitudes, \
get_linestyle_for_altitude_class get_linestyle_for_altitude_class
from spatio_temporal_dataset.coordinates.temporal_coordinates.temperature_covariate import \
AnomalyTemperatureWithSplineTemporalCovariate
class VisualizerForSensivity(object): class VisualizerForSensivity(object):
...@@ -26,19 +28,19 @@ class VisualizerForSensivity(object): ...@@ -26,19 +28,19 @@ class VisualizerForSensivity(object):
display_only_model_that_pass_gof_test=False, display_only_model_that_pass_gof_test=False,
confidence_interval_based_on_delta_method=False, confidence_interval_based_on_delta_method=False,
remove_physically_implausible_models=False, remove_physically_implausible_models=False,
merge_visualizer_str=AbstractEnsembleFit.Median_merge, # if we choose the Mean merge, then it is almost impossible to obtain stationary trends
is_temperature_interval=False, is_temperature_interval=False,
is_shift_interval=False, is_shift_interval=False,
): ):
self.is_shift_interval = is_shift_interval self.is_shift_interval = is_shift_interval
self.temporal_covariate_for_fit = temporal_covariate_for_fit
self.is_temperature_interval = is_temperature_interval self.is_temperature_interval = is_temperature_interval
self.merge_visualizer_str = merge_visualizer_str self.merge_visualizer_str_list = (AbstractEnsembleFit.Median_merge, AbstractEnsembleFit.Mean_merge)
self.altitudes_list = altitudes_list self.altitudes_list = altitudes_list
self.massif_names = massif_names self.massif_names = massif_names
self.left_limits, self.right_limits = get_interval_limits(self.is_temperature_interval, self.left_limits, self.right_limits = get_interval_limits(self.is_temperature_interval,
self.is_shift_interval) self.is_shift_interval)
self.left_limit_to_right_limit = OrderedDict(zip(self.left_limits, self.right_limits)) self.left_limit_to_right_limit = OrderedDict(zip(self.left_limits, self.right_limits))
self.left_limit_to_visualizer = {} # type: Dict[float, VisualizerForProjectionEnsemble] self.right_limit_to_visualizer = {} # type: Dict[float, VisualizerForProjectionEnsemble]
for left_limit, right_limit in zip(self.left_limits, self.right_limits): for left_limit, right_limit in zip(self.left_limits, self.right_limits):
print("Interval is", left_limit, right_limit) print("Interval is", left_limit, right_limit)
...@@ -63,7 +65,7 @@ class VisualizerForSensivity(object): ...@@ -63,7 +65,7 @@ class VisualizerForSensivity(object):
remove_physically_implausible_models=remove_physically_implausible_models, remove_physically_implausible_models=remove_physically_implausible_models,
gcm_to_year_min_and_year_max=gcm_to_year_min_and_year_max gcm_to_year_min_and_year_max=gcm_to_year_min_and_year_max
) )
self.left_limit_to_visualizer[left_limit] = visualizer self.right_limit_to_visualizer[right_limit] = visualizer
def plot(self): def plot(self):
# todo: before reactivating the subplot, i should ensure that we can modify the prefix # todo: before reactivating the subplot, i should ensure that we can modify the prefix
...@@ -71,53 +73,60 @@ class VisualizerForSensivity(object): ...@@ -71,53 +73,60 @@ class VisualizerForSensivity(object):
# , and not just the plots for the last t_min # , and not just the plots for the last t_min
# for visualizer in self.temp_min_to_visualizer.values(): # for visualizer in self.temp_min_to_visualizer.values():
# visualizer.plot() # visualizer.plot()
self.sensitivity_plot() for merge_visualizer_str in self.merge_visualizer_str_list:
self.sensitivity_plot(merge_visualizer_str)
def sensitivity_plot(self): def sensitivity_plot(self, merge_visualizer_str):
ax = plt.gca() ax = plt.gca()
for altitudes in self.altitudes_list: for altitudes in self.altitudes_list:
altitude_class = get_altitude_class_from_altitudes(altitudes) altitude_class = get_altitude_class_from_altitudes(altitudes)
self.interval_plot(ax, altitude_class) self.interval_plot(ax, altitude_class, merge_visualizer_str)
ticks_labels = get_ticks_labels_for_interval(self.is_temperature_interval, self.is_shift_interval) ticks_labels = get_ticks_labels_for_interval(self.is_temperature_interval, self.is_shift_interval)
ax.set_ylabel('Percentages of massifs (\%)') ax.set_ylabel('Percentages of massifs (\%)')
ax.set_xlabel('Interval used to compute the trends ') ax.set_xlabel('Interval used to compute the trends ')
ax.set_xticks(self.left_limits) ax.set_xticks(self.right_limits)
ax.set_xticklabels(ticks_labels) ax.set_xticklabels(ticks_labels)
ax.legend(prop={'size': 7}, loc='upper center', ncol=2) ax.legend(prop={'size': 7}, loc='upper center', ncol=2)
ax.set_ylim((0, 122)) ax.set_ylim((0, 122))
ax.set_yticks([i*10 for i in range(11)]) ax.set_yticks([i*10 for i in range(11)])
merge_visualizer = self.first_merge_visualizer merge_visualizer = self.first_merge_visualizer(merge_visualizer_str)
merge_visualizer.plot_name = 'Sensitivity plot' temp_cov = self.temporal_covariate_for_fit is AnomalyTemperatureWithSplineTemporalCovariate
merge_visualizer.plot_name = 'Sensitivity plot with ' \
'shift={} temp_interval={}, temp_cov={}'.format(self.is_shift_interval,
self.is_temperature_interval,
temp_cov)
merge_visualizer.show_or_save_to_file(no_title=True) merge_visualizer.show_or_save_to_file(no_title=True)
plt.close()
@property def first_merge_visualizer(self, merge_visualizer_str):
def first_merge_visualizer(self):
altitude_class = get_altitude_class_from_altitudes(self.altitudes_list[0]) altitude_class = get_altitude_class_from_altitudes(self.altitudes_list[0])
visualizer_projection = list(self.left_limit_to_visualizer.values())[0] visualizer_projection = list(self.right_limit_to_visualizer.values())[0]
return self.get_merge_visualizer(altitude_class, visualizer_projection) return self.get_merge_visualizer(altitude_class, visualizer_projection, merge_visualizer_str)
def get_merge_visualizer(self, altitude_class, visualizer_projection: VisualizerForProjectionEnsemble): def get_merge_visualizer(self, altitude_class, visualizer_projection: VisualizerForProjectionEnsemble,
merge_visualizer_str):
independent_ensemble_fit = visualizer_projection.altitude_class_to_ensemble_class_to_ensemble_fit[altitude_class][ independent_ensemble_fit = visualizer_projection.altitude_class_to_ensemble_class_to_ensemble_fit[altitude_class][
IndependentEnsembleFit] IndependentEnsembleFit]
merge_visualizer = independent_ensemble_fit.merge_function_name_to_visualizer[self.merge_visualizer_str] merge_visualizer = independent_ensemble_fit.merge_function_name_to_visualizer[merge_visualizer_str]
merge_visualizer.studies.study.gcm_rcm_couple = (self.merge_visualizer_str, "merge") merge_visualizer.studies.study.gcm_rcm_couple = (merge_visualizer_str, "merge")
return merge_visualizer return merge_visualizer
def interval_plot(self, ax, altitude_class): def interval_plot(self, ax, altitude_class, merge_visualizer_str):
linestyle = get_linestyle_for_altitude_class(altitude_class) linestyle = get_linestyle_for_altitude_class(altitude_class)
increasing_key = 'increasing' increasing_key = 'increasing'
decreasing_key = 'decreasing' decreasing_key = 'decreasing'
label_to_l = {
increasing_key: [],
decreasing_key: []
}
label_to_color = { label_to_color = {
increasing_key: 'red', increasing_key: 'red',
decreasing_key: 'blue' decreasing_key: 'blue'
} }
for v in self.left_limit_to_visualizer.values(): label_to_l = {
merge_visualizer = self.get_merge_visualizer(altitude_class, v) increasing_key: [],
decreasing_key: []
}
for v in self.right_limit_to_visualizer.values():
merge_visualizer = self.get_merge_visualizer(altitude_class, v, merge_visualizer_str)
_, *trends = merge_visualizer.all_trends(self.massif_names, with_significance=False, _, *trends = merge_visualizer.all_trends(self.massif_names, with_significance=False,
with_relative_change=True) with_relative_change=True)
label_to_l[decreasing_key].append(trends[0]) label_to_l[decreasing_key].append(trends[0])
...@@ -126,5 +135,5 @@ class VisualizerForSensivity(object): ...@@ -126,5 +135,5 @@ class VisualizerForSensivity(object):
for label, l in label_to_l.items(): for label, l in label_to_l.items():
label_improved = 'with {} trends {}'.format(label, altitude_str) label_improved = 'with {} trends {}'.format(label, altitude_str)
color = label_to_color[label] color = label_to_color[label]
ax.plot(self.left_limits, l, label=label_improved, color=color, linestyle=linestyle) ax.plot(self.right_limits, l, label=label_improved, color=color, linestyle=linestyle)
...@@ -83,9 +83,8 @@ def main(): ...@@ -83,9 +83,8 @@ def main():
massif_names=massif_names, massif_names=massif_names,
temporal_covariate_for_fit=temporal_covariate_for_fit, temporal_covariate_for_fit=temporal_covariate_for_fit,
remove_physically_implausible_models=remove_physically_implausible_models, remove_physically_implausible_models=remove_physically_implausible_models,
merge_visualizer_str=AbstractEnsembleFit.Median_merge,
is_temperature_interval=False, is_temperature_interval=False,
is_shift_interval=True, is_shift_interval=False,
) )
visualizer.plot() visualizer.plot()
......
...@@ -44,6 +44,7 @@ class MeanAlpsTemperatureCovariate(AbstractTemperatureCovariate): ...@@ -44,6 +44,7 @@ class MeanAlpsTemperatureCovariate(AbstractTemperatureCovariate):
def load_year_to_temperature_covariate(cls): def load_year_to_temperature_covariate(cls):
return load_year_to_mean_alps_temperatures() return load_year_to_mean_alps_temperatures()
class AnomalyTemperatureWithSplineTemporalCovariate(AbstractTemporalCovariateForFit): class AnomalyTemperatureWithSplineTemporalCovariate(AbstractTemporalCovariateForFit):
gcm_and_scenario_to_d = {} gcm_and_scenario_to_d = {}
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment