From 809fd66289a5d7760859a48cf5a49abef3d011de Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Wed, 24 Feb 2021 17:07:36 +0100
Subject: [PATCH] [projections] fix and refactor temperature_to_year.py. Focus
 on the median merge function, rather than the mean merge function.

---
 .../adamont_data/cmip5/temperature_to_year.py | 25 ++++----
 .../one_fold_analysis/altitude_group.py       | 16 +++++-
 ...ation_temporal_for_projections_ensemble.py |  7 ++-
 .../visualizer_for_sensitivity.py             | 57 ++++++++++++-------
 4 files changed, 69 insertions(+), 36 deletions(-)

diff --git a/extreme_data/meteo_france_data/adamont_data/cmip5/temperature_to_year.py b/extreme_data/meteo_france_data/adamont_data/cmip5/temperature_to_year.py
index c911fa5e..30bc22ba 100644
--- a/extreme_data/meteo_france_data/adamont_data/cmip5/temperature_to_year.py
+++ b/extreme_data/meteo_france_data/adamont_data/cmip5/temperature_to_year.py
@@ -36,7 +36,7 @@ def get_nb_data(gcm, scenario, temperature_min, temperature_max):
         return year_max - year_min + 1
 
 
-def plot_nb_data_one_line(ax, gcm, scenario, temp_min, temp_max):
+def plot_nb_data_one_line(ax, gcm, scenario, temp_min, temp_max, first_scenario):
 
     nb_data = [get_nb_data(gcm, scenario, mi, ma) for mi, ma in zip(temp_min, temp_max)]
     color = gcm_to_color[gcm]
@@ -48,26 +48,26 @@ def plot_nb_data_one_line(ax, gcm, scenario, temp_min, temp_max):
     nb_data, temp_min = nb_data[ind], temp_min[ind]
 
     # For the legend
-    if scenario is AdamontScenario.rcp26:
+    if (len(nb_data) > 0) and first_scenario:
         ax.plot(temp_min[0], nb_data[0], color=color, linestyle='solid', label=gcm)
 
-    ax.plot(temp_min, nb_data, linestyle=linestyle, color=color)
+    ax.plot(temp_min, nb_data, linestyle=linestyle, color=color, marker='o')
 
 
 def plot_nb_data():
-    temp_max, temp_min = get_temp_min_and_temp_max()
+    temp_min, temp_max = get_temp_min_and_temp_max()
 
     ax = plt.gca()
     for gcm in get_gcm_list(adamont_version=2)[:]:
-        for scenario in rcp_scenarios[:]:
-            plot_nb_data_one_line(ax, gcm, scenario, temp_min, temp_max)
+        for i, scenario in enumerate(rcp_scenarios[:2]):
+            plot_nb_data_one_line(ax, gcm, scenario, temp_min, temp_max, first_scenario=i == 0)
 
     ax.legend()
-    ticks_labels = ['{}-{}'.format(mi, ma) for mi, ma in zip(temp_min, temp_max)]
+    ticks_labels = get_ticks_labels_for_temp_min_and_temp_max()
     ax.set_xticks(temp_min)
     ax.set_xticklabels(ticks_labels)
     ax.set_xlabel('Temperature interval')
-    ax.set_ylabel('Nb of Data')
+    ax.set_ylabel('Nb of Maxima')
     ax2 = ax.twinx()
     legend_elements = [
         Line2D([0], [0], color='k', lw=1, label=scenario_to_str(s),
@@ -79,10 +79,15 @@ def plot_nb_data():
 
 
 def get_temp_min_and_temp_max():
-    temp_min = np.arange(0.5, 2.5, 0.5)
-    temp_max = temp_min + 1.5
+    temp_min = np.arange(0, 3, 1)
+    temp_max = temp_min + 2
     return temp_min, temp_max
 
+def get_ticks_labels_for_temp_min_and_temp_max():
+    temp_min, temp_max = get_temp_min_and_temp_max()
+    return ['Maxima occured between \n' \
+            ' +${}^o\mathrm{C}$ and +${}^o\mathrm{C}$'.format(mi, ma, **{'C': '{C}'})
+     for mi, ma in zip(temp_min, temp_max)]
 
 if __name__ == '__main__':
     plot_nb_data()
diff --git a/projects/altitude_spatial_model/altitudes_fit/one_fold_analysis/altitude_group.py b/projects/altitude_spatial_model/altitudes_fit/one_fold_analysis/altitude_group.py
index 5a3f8a16..60924af0 100644
--- a/projects/altitude_spatial_model/altitudes_fit/one_fold_analysis/altitude_group.py
+++ b/projects/altitude_spatial_model/altitudes_fit/one_fold_analysis/altitude_group.py
@@ -19,7 +19,7 @@ class AbstractAltitudeGroup(object):
 
     @property
     def altitudes(self):
-        return altitudes_for_groups[self.group_id-1]
+        return altitudes_for_groups[self.group_id - 1]
 
     @property
     def reference_altitude(self):
@@ -161,3 +161,17 @@ def get_altitude_group_from_altitudes(altitudes):
         return VeyHighAltitudeGroup()
     else:
         return DefaultAltitudeGroup()
+
+
+def get_linestyle_for_altitude_class(altitude_class):
+    assert issubclass(altitude_class, AbstractAltitudeGroup)
+    if altitude_class is LowAltitudeGroup:
+        return 'solid'
+    elif altitude_class is MidAltitudeGroup:
+        return 'dashed'
+    elif altitude_class is HighAltitudeGroup:
+        return 'dashdot'
+    elif altitude_class is VeyHighAltitudeGroup:
+        return 'dotted'
+    else:
+        raise NotImplementedError
diff --git a/projects/projected_snowfall/elevation_temporal_model_for_projections/main_elevation_temporal_for_projections_ensemble.py b/projects/projected_snowfall/elevation_temporal_model_for_projections/main_elevation_temporal_for_projections_ensemble.py
index 23f51598..798ce662 100644
--- a/projects/projected_snowfall/elevation_temporal_model_for_projections/main_elevation_temporal_for_projections_ensemble.py
+++ b/projects/projected_snowfall/elevation_temporal_model_for_projections/main_elevation_temporal_for_projections_ensemble.py
@@ -43,9 +43,9 @@ def main():
     set_seed_for_test()
     AbstractExtractEurocodeReturnLevel.ALPHA_CONFIDENCE_INTERVAL_UNCERTAINTY = 0.2
 
-    fast = None
+    fast = False
     sensitivity_plot = True
-    scenarios = rcp_scenarios if fast is False else [AdamontScenario.rcp85]
+    scenarios = rcp_scenarios[::-1] if fast is False else [AdamontScenario.rcp85]
 
     for scenario in scenarios:
         gcm_rcm_couples = get_gcm_rcm_couples(scenario)
@@ -58,7 +58,7 @@ def main():
             massif_names = ['Vanoise', 'Haute-Maurienne']
             gcm_rcm_couples = gcm_rcm_couples[4:6]
             AbstractExtractEurocodeReturnLevel.NB_BOOTSTRAP = 10
-            altitudes_list = altitudes_for_groups[2:]
+            altitudes_list = altitudes_for_groups[:]
         else:
             massif_names = None
             altitudes_list = altitudes_for_groups[:]
@@ -77,6 +77,7 @@ def main_loop(gcm_rcm_couples, altitudes_list, massif_names, study_class, ensemb
               temporal_covariate_for_fit, sensitivity_plot=False):
     assert isinstance(altitudes_list, List)
     assert isinstance(altitudes_list[0], List)
+    print('Scenario is', scenario)
     print('Covariate is {}'.format(temporal_covariate_for_fit))
 
     model_classes = ALTITUDINAL_GEV_MODELS_BASED_ON_POINTWISE_ANALYSIS
diff --git a/projects/projected_snowfall/elevation_temporal_model_for_projections/visualizer_for_sensitivity.py b/projects/projected_snowfall/elevation_temporal_model_for_projections/visualizer_for_sensitivity.py
index f3e40b69..72392ee5 100644
--- a/projects/projected_snowfall/elevation_temporal_model_for_projections/visualizer_for_sensitivity.py
+++ b/projects/projected_snowfall/elevation_temporal_model_for_projections/visualizer_for_sensitivity.py
@@ -3,14 +3,14 @@ import matplotlib.pyplot as plt
 from typing import List, Dict
 
 from extreme_data.meteo_france_data.adamont_data.cmip5.temperature_to_year import get_temp_min_and_temp_max, \
-    temperature_minmax_to_year_minmax
+    temperature_minmax_to_year_minmax, get_ticks_labels_for_temp_min_and_temp_max
 from extreme_data.meteo_france_data.scm_models_data.utils import Season
 from extreme_fit.model.margin_model.polynomial_margin_model.spatio_temporal_polynomial_model import \
     AbstractSpatioTemporalPolynomialModel
 from extreme_fit.model.margin_model.utils import MarginFitMethod
 from projects.altitude_spatial_model.altitudes_fit.altitudes_studies import AltitudesStudies
 from projects.altitude_spatial_model.altitudes_fit.one_fold_analysis.altitude_group import \
-    get_altitude_group_from_altitudes, get_altitude_class_from_altitudes
+    get_altitude_group_from_altitudes, get_altitude_class_from_altitudes, get_linestyle_for_altitude_class
 from projects.altitude_spatial_model.altitudes_fit.plots.plot_histogram_altitude_studies import \
     plot_histogram_all_trends_against_altitudes, plot_shoe_plot_changes_against_altitude
 from projects.altitude_spatial_model.altitudes_fit.utils_altitude_studies_visualizer import compute_and_assign_max_abs
@@ -31,7 +31,9 @@ class VisualizerForSensivity(object):
                  display_only_model_that_pass_gof_test=False,
                  confidence_interval_based_on_delta_method=False,
                  remove_physically_implausible_models=False,
+                 merge_visualizer_str=IndependentEnsembleFit.Median_merge # if we choose the Mean merge, then it is almost impossible to obtain stationary trends
                  ):
+        self.merge_visualizer_str = merge_visualizer_str
         self.altitudes_list = altitudes_list
         self.massif_names = massif_names
         self.temp_min, self.temp_max = get_temp_min_and_temp_max()
@@ -64,7 +66,8 @@ class VisualizerForSensivity(object):
 
     def plot(self):
         # todo: before reactivating the subplot, i should ensure that we can modify the prefix
-        # so that we can have all the subplot, and not juste for the last t_min
+        # so that we can have all the subplot for the merge visualizer
+        # , and not just the plots for the last t_min
         # for visualizer in self.temp_min_to_visualizer.values():
         #     visualizer.plot()
         self.sensitivity_plot()
@@ -75,41 +78,51 @@ class VisualizerForSensivity(object):
             altitude_class = get_altitude_class_from_altitudes(altitudes)
             self.temperature_interval_plot(ax, altitude_class)
 
-        ticks_labels = ['+{} to +{}'.format(mi, ma) for mi, ma in zip(self.temp_min, self.temp_max)]
-        ax.set_ylabel('Percentages of massifs (%)')
+        ticks_labels = get_ticks_labels_for_temp_min_and_temp_max()
+        ax.set_ylabel('Percentages of massifs (\%)')
         ax.set_xlabel('Range of temperatures used to compute the trends ')
         ax.set_xticks(self.temp_min)
         ax.set_xticklabels(ticks_labels)
-        ax.legend()
-        mean_visualizer = self.first_mean_visualizer
-        mean_visualizer.plot_name = 'Sensitivity plot'
-        mean_visualizer.show_or_save_to_file(no_title=True)
+        ax.legend(prop={'size': 7}, loc='upper center', ncol=2)
+        ax.set_ylim((0, 122))
+        ax.set_yticks([i*10 for i in range(11)])
+        merge_visualizer = self.first_merge_visualizer
+        merge_visualizer.plot_name = 'Sensitivity plot'
+        merge_visualizer.show_or_save_to_file(no_title=True)
 
     @property
-    def first_mean_visualizer(self):
+    def first_merge_visualizer(self):
         altitude_class = get_altitude_class_from_altitudes(self.altitudes_list[0])
         visualizer_projection = list(self.temp_min_to_visualizer.values())[0]
-        return self.get_mean_visualizer(altitude_class, visualizer_projection)
+        return self.get_merge_visualizer(altitude_class, visualizer_projection)
 
-    def get_mean_visualizer(self, altitude_class, visualizer_projection: VisualizerForProjectionEnsemble):
+    def get_merge_visualizer(self, altitude_class, visualizer_projection: VisualizerForProjectionEnsemble):
         independent_ensemble_fit = visualizer_projection.altitude_class_to_ensemble_class_to_ensemble_fit[altitude_class][
             IndependentEnsembleFit]
-        mean_visualizer = independent_ensemble_fit.merge_function_name_to_visualizer[IndependentEnsembleFit.Mean_merge]
-        mean_visualizer.studies.study.gcm_rcm_couple = (IndependentEnsembleFit.Mean_merge, "merge")
-        return mean_visualizer
+        merge_visualizer = independent_ensemble_fit.merge_function_name_to_visualizer[self.merge_visualizer_str]
+        merge_visualizer.studies.study.gcm_rcm_couple = (self.merge_visualizer_str, "merge")
+        return merge_visualizer
 
     def temperature_interval_plot(self, ax, altitude_class):
+        linestyle = get_linestyle_for_altitude_class(altitude_class)
+        increasing_key = 'increasing'
+        decreasing_key = 'decreasing'
         label_to_l = {
-            'increasing': [],
-            'decreasing': []
+            increasing_key: [],
+            decreasing_key: []
+        }
+        label_to_color = {
+            increasing_key: 'red',
+            decreasing_key: 'blue'
         }
         for v in self.temp_min_to_visualizer.values():
-            mean_visualizer = self.get_mean_visualizer(altitude_class, v)
-            _, *trends = mean_visualizer.all_trends(self.massif_names, with_significance=False)
-            label_to_l['decreasing'].append(trends[0])
-            label_to_l['increasing'].append(trends[2])
+            merge_visualizer = self.get_merge_visualizer(altitude_class, v)
+            _, *trends = merge_visualizer.all_trends(self.massif_names, with_significance=False)
+            label_to_l[decreasing_key].append(trends[0])
+            label_to_l[increasing_key].append(trends[2])
         altitude_str = altitude_class().formula
         for label, l in label_to_l.items():
             label_improved = 'with {} trends {}'.format(label, altitude_str)
-            ax.plot(self.temp_min, l, label=label_improved)
+            color = label_to_color[label]
+            ax.plot(self.temp_min, l, label=label_improved, color=color, linestyle=linestyle)
 
-- 
GitLab