from collections import OrderedDict import matplotlib.pyplot as plt from typing import List, Dict from extreme_data.meteo_france_data.adamont_data.cmip5.temperature_to_year import get_interval_limits, \ get_year_min_and_year_max, get_ticks_labels_for_interval from extreme_data.meteo_france_data.scm_models_data.utils import Season from extreme_fit.model.margin_model.polynomial_margin_model.spatio_temporal_polynomial_model import \ AbstractSpatioTemporalPolynomialModel from extreme_fit.model.margin_model.utils import MarginFitMethod from extreme_trend.ensemble_fit.abstract_ensemble_fit import AbstractEnsembleFit from extreme_trend.ensemble_fit.independent_ensemble_fit.independent_ensemble_fit import IndependentEnsembleFit from extreme_trend.ensemble_fit.visualizer_for_projection_ensemble import VisualizerForProjectionEnsemble from extreme_trend.one_fold_fit.altitude_group import get_altitude_class_from_altitudes, \ get_linestyle_for_altitude_class from spatio_temporal_dataset.coordinates.temporal_coordinates.temperature_covariate import \ AnomalyTemperatureWithSplineTemporalCovariate class VisualizerForSensivity(object): def __init__(self, altitudes_list, gcm_rcm_couples, study_class, season, scenario, model_classes: List[AbstractSpatioTemporalPolynomialModel], ensemble_fit_classes=None, massif_names=None, fit_method=MarginFitMethod.extremes_fevd_mle, temporal_covariate_for_fit=None, display_only_model_that_pass_gof_test=False, confidence_interval_based_on_delta_method=False, remove_physically_implausible_models=False, is_temperature_interval=False, is_shift_interval=False, ): self.is_shift_interval = is_shift_interval self.temporal_covariate_for_fit = temporal_covariate_for_fit self.is_temperature_interval = is_temperature_interval self.merge_visualizer_str_list = (AbstractEnsembleFit.Median_merge, AbstractEnsembleFit.Mean_merge) self.altitudes_list = altitudes_list self.massif_names = massif_names self.left_limits, self.right_limits = get_interval_limits(self.is_temperature_interval, self.is_shift_interval) self.left_limit_to_right_limit = OrderedDict(zip(self.left_limits, self.right_limits)) self.right_limit_to_visualizer = {} # type: Dict[float, VisualizerForProjectionEnsemble] for left_limit, right_limit in zip(self.left_limits, self.right_limits): print("Interval is", left_limit, right_limit) # Build gcm_to_year_min_and_year_max gcm_to_year_min_and_year_max = {} gcm_list = list(set([g for g, r in gcm_rcm_couples])) for gcm in gcm_list: year_min_and_year_max = get_year_min_and_year_max(gcm, scenario, left_limit, right_limit, self.is_temperature_interval) if year_min_and_year_max[0] is not None: gcm_to_year_min_and_year_max[gcm] = year_min_and_year_max visualizer = VisualizerForProjectionEnsemble( altitudes_list, gcm_rcm_couples, study_class, Season.annual, scenario, model_classes=model_classes, fit_method=fit_method, ensemble_fit_classes=ensemble_fit_classes, display_only_model_that_pass_gof_test=display_only_model_that_pass_gof_test, confidence_interval_based_on_delta_method=confidence_interval_based_on_delta_method, massif_names=massif_names, temporal_covariate_for_fit=temporal_covariate_for_fit, remove_physically_implausible_models=remove_physically_implausible_models, gcm_to_year_min_and_year_max=gcm_to_year_min_and_year_max ) self.right_limit_to_visualizer[right_limit] = visualizer def plot(self): # todo: before reactivating the subplot, i should ensure that we can modify the prefix # todo: so that we can have all the subplot for the merge visualizer # , and not just the plots for the last t_min # for visualizer in self.temp_min_to_visualizer.values(): # visualizer.plot() for merge_visualizer_str in self.merge_visualizer_str_list: self.sensitivity_plot(merge_visualizer_str) def sensitivity_plot(self, merge_visualizer_str): ax = plt.gca() for altitudes in self.altitudes_list: altitude_class = get_altitude_class_from_altitudes(altitudes) self.interval_plot(ax, altitude_class, merge_visualizer_str) ticks_labels = get_ticks_labels_for_interval(self.is_temperature_interval, self.is_shift_interval) ax.set_ylabel('Percentages of massifs (\%)') ax.set_xlabel('Interval used to compute the trends ') ax.set_xticks(self.right_limits) ax.set_xticklabels(ticks_labels) ax.legend(prop={'size': 7}, loc='upper center', ncol=2) ax.set_ylim((0, 122)) ax.set_yticks([i*10 for i in range(11)]) merge_visualizer = self.first_merge_visualizer(merge_visualizer_str) temp_cov = self.temporal_covariate_for_fit is AnomalyTemperatureWithSplineTemporalCovariate merge_visualizer.plot_name = 'Sensitivity plot with ' \ 'shift={} temp_interval={}, temp_cov={}'.format(self.is_shift_interval, self.is_temperature_interval, temp_cov) merge_visualizer.show_or_save_to_file(no_title=True) plt.close() def first_merge_visualizer(self, merge_visualizer_str): altitude_class = get_altitude_class_from_altitudes(self.altitudes_list[0]) visualizer_projection = list(self.right_limit_to_visualizer.values())[0] return self.get_merge_visualizer(altitude_class, visualizer_projection, merge_visualizer_str) def get_merge_visualizer(self, altitude_class, visualizer_projection: VisualizerForProjectionEnsemble, merge_visualizer_str): independent_ensemble_fit = visualizer_projection.altitude_class_to_ensemble_class_to_ensemble_fit[altitude_class][ IndependentEnsembleFit] merge_visualizer = independent_ensemble_fit.merge_function_name_to_visualizer[merge_visualizer_str] merge_visualizer.studies.study.gcm_rcm_couple = (merge_visualizer_str, "merge") return merge_visualizer def interval_plot(self, ax, altitude_class, merge_visualizer_str): linestyle = get_linestyle_for_altitude_class(altitude_class) increasing_key = 'increasing' decreasing_key = 'decreasing' label_to_color = { increasing_key: 'red', decreasing_key: 'blue' } label_to_l = { increasing_key: [], decreasing_key: [] } for v in self.right_limit_to_visualizer.values(): merge_visualizer = self.get_merge_visualizer(altitude_class, v, merge_visualizer_str) _, *trends = merge_visualizer.all_trends(self.massif_names, with_significance=False, with_relative_change=True) label_to_l[decreasing_key].append(trends[0]) label_to_l[increasing_key].append(trends[2]) altitude_str = altitude_class().formula for label, l in label_to_l.items(): label_improved = 'with {} trends {}'.format(label, altitude_str) color = label_to_color[label] ax.plot(self.right_limits, l, label=label_improved, color=color, linestyle=linestyle)