From aafb602cd7666831963b663b9c2cf1013e1fd66d Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Wed, 3 Mar 2021 15:26:06 +0100
Subject: [PATCH] [projection snowfall] improve one_fold_fit_merge.py. improve
 interval for the sensitivity plot. Add is_temperature_interval and
 is_shift_interval to define the intervals.

---
 .../adamont_data/cmip5/temperature_to_year.py | 104 +++++++++++-------
 .../one_fold_fit_merge.py                     |  13 ++-
 .../visualizer_for_sensitivity.py             |  47 ++++----
 ...es_visualizer_for_non_stationary_models.py |  10 +-
 ...ation_temporal_for_projections_ensemble.py |   7 +-
 test/test_extreme_trend/test_one_fold_fit.py  |   6 +-
 6 files changed, 119 insertions(+), 68 deletions(-)

diff --git a/extreme_data/meteo_france_data/adamont_data/cmip5/temperature_to_year.py b/extreme_data/meteo_france_data/adamont_data/cmip5/temperature_to_year.py
index 6b73a512..8cd7e69d 100644
--- a/extreme_data/meteo_france_data/adamont_data/cmip5/temperature_to_year.py
+++ b/extreme_data/meteo_france_data/adamont_data/cmip5/temperature_to_year.py
@@ -8,65 +8,76 @@ from extreme_data.meteo_france_data.adamont_data.adamont_scenario import Adamont
 from extreme_data.meteo_france_data.adamont_data.cmip5.climate_explorer_cimp5 import years_and_global_mean_temps
 
 
-def temperature_minmax_to_year_minmax(gcm, scenario, temperature_min, temperature_max):
-    years, global_mean_temps = years_and_global_mean_temps(gcm, scenario, year_min=2005, year_max=2100, anomaly=True,
+def get_year_min_and_year_max(gcm, scenario, left_limit, right_limit, is_temperature_interval):
+    if is_temperature_interval:
+        years_to_select = _get_year_min_and_year_max_for_temperature_interval(gcm, left_limit,
+                                                                              right_limit, scenario)
+        if len(years_to_select) == 0:
+            return None, None
+        year_min, year_max = years_to_select[0], years_to_select[-1]
+    else:
+        year_min, year_max = left_limit, right_limit
+
+    # A minimum of 30 years of data is needed to find a trend
+    if year_max - year_min + 1 >= 30:
+        return year_min, year_max
+    else:
+        return None, None
+
+
+def _get_year_min_and_year_max_for_temperature_interval(gcm, left_limits, right_limits, scenario):
+    years, global_mean_temps = years_and_global_mean_temps(gcm, scenario, year_min=2006, year_max=2100, anomaly=True,
                                                            spline=True)
     years, global_mean_temps = np.array(years), np.array(global_mean_temps)
-    ind = temperature_min < global_mean_temps
-    ind &= global_mean_temps < temperature_max
+    ind = left_limits < global_mean_temps
+    ind &= global_mean_temps < right_limits
     years_to_select = years[ind]
     ind2 = years_to_select[:-1] == years_to_select[1:] - 1
     if not all(ind2):
         i = list(ind2).index(False)
         years_to_select = years_to_select[:i + 1]
-    # A minimum of 30 years of data is needed to find a trend
-    if len(years_to_select) >= 30:
-        year_min, year_max = years_to_select[0], years_to_select[-1]
-        assert (year_max - year_min + 1) == len(years_to_select)
-        return year_min, year_max
-    else:
-        return None, None
+    return years_to_select
 
 
-def get_nb_data(gcm, scenario, temperature_min, temperature_max):
-    year_min, year_max = temperature_minmax_to_year_minmax(gcm, scenario, temperature_min, temperature_max)
+def get_nb_data(gcm, scenario, temperature_min, temperature_max, is_temperature_interval):
+    year_min, year_max = get_year_min_and_year_max(gcm, scenario, temperature_min, temperature_max, is_temperature_interval)
     if year_min is None:
         return 0
     else:
         return year_max - year_min + 1
 
 
-def plot_nb_data_one_line(ax, gcm, scenario, temp_min, temp_max, first_scenario):
-
-    nb_data = [get_nb_data(gcm, scenario, mi, ma) for mi, ma in zip(temp_min, temp_max)]
+def plot_nb_data_one_line(ax, gcm, scenario, left_limits, right_limits, first_scenario, is_temperature_interval):
+    nb_data = [get_nb_data(gcm, scenario, left, right, is_temperature_interval) for left, right in zip(left_limits, right_limits)]
     color = gcm_to_color[gcm]
     linestyle = get_linestyle_from_scenario(scenario)
 
     # Filter out the zero value
-    nb_data, temp_min = np.array(nb_data), np.array(temp_min)
+    nb_data, left_limits = np.array(nb_data), np.array(left_limits)
     ind = np.array(nb_data) > 0
-    nb_data, temp_min = nb_data[ind], temp_min[ind]
+    nb_data, left_limits = nb_data[ind], left_limits[ind]
 
     # For the legend
     if (len(nb_data) > 0) and first_scenario:
-        ax.plot(temp_min[0], nb_data[0], color=color, linestyle='solid', label=gcm)
+        ax.plot(left_limits[0], nb_data[0], color=color, linestyle='solid', label=gcm)
 
-    ax.plot(temp_min, nb_data, linestyle=linestyle, color=color, marker='o')
+    ax.plot(left_limits, nb_data, linestyle=linestyle, color=color, marker='o')
 
 
-def plot_nb_data():
-    temp_min, temp_max = get_temp_min_and_temp_max()
+def plot_nb_data(is_temperature_interval, is_shift_interval):
+    left_limit, right_limit = get_interval_limits(is_temperature_interval, is_shift_interval)
 
     ax = plt.gca()
     for gcm in get_gcm_list(adamont_version=2)[:]:
         for i, scenario in enumerate(rcp_scenarios[:2]):
-            plot_nb_data_one_line(ax, gcm, scenario, temp_min, temp_max, first_scenario=i == 0)
+            plot_nb_data_one_line(ax, gcm, scenario, left_limit, right_limit,
+                                  i == 0, is_temperature_interval)
 
     ax.legend()
-    ticks_labels = get_ticks_labels_for_temp_min_and_temp_max()
-    ax.set_xticks(temp_min)
+    ticks_labels = get_ticks_labels_for_interval(is_temperature_interval, is_shift_interval)
+    ax.set_xticks(left_limit)
     ax.set_xticklabels(ticks_labels)
-    ax.set_xlabel('Temperature interval')
+    # ax.set_xlabel('Interval')
     ax.set_ylabel('Nb of Maxima')
     ax2 = ax.twinx()
     legend_elements = [
@@ -75,19 +86,38 @@ def plot_nb_data():
     ]
     ax2.legend(handles=legend_elements, loc='upper center')
     ax2.set_yticks([])
-    plt.show()
+    # plt.show()
 
 
-def get_temp_min_and_temp_max():
-    temp_min = np.arange(0, 3, 1)
-    temp_max = temp_min + 2
-    return temp_min, temp_max
+def get_interval_limits(is_temperature_interval, is_shift_interval):
+    if is_temperature_interval:
+        temp_min = np.arange(0, 3, 1)
+        temp_max = temp_min + 2
+        left_limit, right_limit = temp_min, temp_max
+    else:
+        shift = 25
+        nb = 3
+        year_min = [2006 + shift * i for i in range(nb)]
+        year_max = [2050 + shift * i for i in range(nb)]
+        left_limit, right_limit = year_min, year_max
+    if not is_shift_interval:
+        max_interval_right = max(right_limit)
+        right_limit = [max_interval_right for _ in left_limit]
+    return left_limit, right_limit
+
+
+def get_ticks_labels_for_interval(is_temperature_interval, is_shift_interval):
+    left_limits, right_limits = get_interval_limits(is_temperature_interval, is_shift_interval)
+    ticks_labels = [' +${}^o\mathrm{C}$ and +${}^o\mathrm{C}$'.format(left_limit, right_limit, **{'C': '{C}'})
+                    if is_temperature_interval else '{} and {}'.format(left_limit, right_limit)
+                    for left_limit, right_limit in zip(left_limits, right_limits)]
+    prefix = 'Maxima occured between \n'
+    ticks_labels = [prefix + l for l in ticks_labels]
+    return ticks_labels
 
-def get_ticks_labels_for_temp_min_and_temp_max():
-    temp_min, temp_max = get_temp_min_and_temp_max()
-    return ['Maxima occured between \n' \
-            ' +${}^o\mathrm{C}$ and +${}^o\mathrm{C}$'.format(mi, ma, **{'C': '{C}'})
-     for mi, ma in zip(temp_min, temp_max)]
 
 if __name__ == '__main__':
-    plot_nb_data()
+    for shift_interval in [False, True]:
+        for temp_interval in [True, False]:
+            print("shift = {}, temp_inteval = {}".format(shift_interval, temp_interval))
+            plot_nb_data(is_temperature_interval=temp_interval, is_shift_interval=shift_interval)
diff --git a/extreme_trend/ensemble_fit/independent_ensemble_fit/one_fold_fit_merge.py b/extreme_trend/ensemble_fit/independent_ensemble_fit/one_fold_fit_merge.py
index b1230df4..4db04016 100644
--- a/extreme_trend/ensemble_fit/independent_ensemble_fit/one_fold_fit_merge.py
+++ b/extreme_trend/ensemble_fit/independent_ensemble_fit/one_fold_fit_merge.py
@@ -19,5 +19,14 @@ class OneFoldFitMerge(OneFoldFit):
     def get_moment(self, altitude, temporal_covariate, order=1):
         return self.merge_function([o.get_moment(altitude, temporal_covariate, order) for o in self.one_fold_fit_list])
 
-
-
+    def changes_of_moment(self, altitudes, order=1):
+        all_changes = [o.changes_of_moment(altitudes, order) for o in self.one_fold_fit_list]
+        merged_changes = list(self.merge_function(np.array(all_changes), axis=0))
+        assert len(all_changes[0]) == len(merged_changes)
+        return merged_changes
+
+    def relative_changes_of_moment(self, altitudes, order=1):
+        all_relative_changes = [o.relative_changes_of_moment(altitudes, order) for o in self.one_fold_fit_list]
+        merged_relative_changes = list(self.merge_function(np.array(all_relative_changes), axis=0))
+        assert len(all_relative_changes[0]) == len(merged_relative_changes)
+        return merged_relative_changes
diff --git a/extreme_trend/ensemble_fit/visualizer_for_sensitivity.py b/extreme_trend/ensemble_fit/visualizer_for_sensitivity.py
index 4b9993bf..222fbc79 100644
--- a/extreme_trend/ensemble_fit/visualizer_for_sensitivity.py
+++ b/extreme_trend/ensemble_fit/visualizer_for_sensitivity.py
@@ -2,8 +2,8 @@ from collections import OrderedDict
 import matplotlib.pyplot as plt
 from typing import List, Dict
 
-from extreme_data.meteo_france_data.adamont_data.cmip5.temperature_to_year import get_temp_min_and_temp_max, \
-    temperature_minmax_to_year_minmax, get_ticks_labels_for_temp_min_and_temp_max
+from extreme_data.meteo_france_data.adamont_data.cmip5.temperature_to_year import get_interval_limits, \
+    get_year_min_and_year_max, get_ticks_labels_for_interval
 from extreme_data.meteo_france_data.scm_models_data.utils import Season
 from extreme_fit.model.margin_model.polynomial_margin_model.spatio_temporal_polynomial_model import \
     AbstractSpatioTemporalPolynomialModel
@@ -25,22 +25,28 @@ class VisualizerForSensivity(object):
                  display_only_model_that_pass_gof_test=False,
                  confidence_interval_based_on_delta_method=False,
                  remove_physically_implausible_models=False,
-                 merge_visualizer_str=IndependentEnsembleFit.Median_merge # if we choose the Mean merge, then it is almost impossible to obtain stationary trends
+                 merge_visualizer_str=IndependentEnsembleFit.Median_merge,  # if we choose the Mean merge, then it is almost impossible to obtain stationary trends
+                 is_temperature_interval=False,
+                 is_shift_interval=False,
                  ):
+        self.is_shift_interval = is_shift_interval
+        self.is_temperature_interval = is_temperature_interval
         self.merge_visualizer_str = merge_visualizer_str
         self.altitudes_list = altitudes_list
         self.massif_names = massif_names
-        self.temp_min, self.temp_max = get_temp_min_and_temp_max()
-        self.temp_min_to_temp_max = OrderedDict(zip(self.temp_min, self.temp_max))
-        self.temp_min_to_visualizer = {} # type: Dict[float, VisualizerForProjectionEnsemble]
+        self.left_limits, self.right_limits = get_interval_limits(self.is_temperature_interval,
+                                                                  self.is_shift_interval)
+        self.left_limit_to_right_limit = OrderedDict(zip(self.left_limits, self.right_limits))
+        self.left_limit_to_visualizer = {} # type: Dict[float, VisualizerForProjectionEnsemble]
 
-        for temp_min, temp_max in zip(self.temp_min, self.temp_max):
-            print("temp min and temp max", temp_min, temp_max)
-            # Build 
+        for left_limit, right_limit in zip(self.left_limits, self.right_limits):
+            print("Interval is", left_limit, right_limit)
+            # Build gcm_to_year_min_and_year_max
             gcm_to_year_min_and_year_max = {}
             gcm_list = list(set([g for g, r in gcm_rcm_couples]))
             for gcm in gcm_list:
-                year_min_and_year_max = temperature_minmax_to_year_minmax(gcm, scenario, temp_min, temp_max)
+                year_min_and_year_max = get_year_min_and_year_max(gcm, scenario, left_limit, right_limit,
+                                                                  self.is_temperature_interval)
                 if year_min_and_year_max[0] is not None:
                     gcm_to_year_min_and_year_max[gcm] = year_min_and_year_max
                 
@@ -56,7 +62,7 @@ class VisualizerForSensivity(object):
                 remove_physically_implausible_models=remove_physically_implausible_models,
                 gcm_to_year_min_and_year_max=gcm_to_year_min_and_year_max
             )
-            self.temp_min_to_visualizer[temp_min] = visualizer
+            self.left_limit_to_visualizer[left_limit] = visualizer
 
     def plot(self):
         # todo: before reactivating the subplot, i should ensure that we can modify the prefix
@@ -70,12 +76,12 @@ class VisualizerForSensivity(object):
         ax = plt.gca()
         for altitudes in self.altitudes_list:
             altitude_class = get_altitude_class_from_altitudes(altitudes)
-            self.temperature_interval_plot(ax, altitude_class)
+            self.interval_plot(ax, altitude_class)
 
-        ticks_labels = get_ticks_labels_for_temp_min_and_temp_max()
+        ticks_labels = get_ticks_labels_for_interval(self.is_temperature_interval, self.is_shift_interval)
         ax.set_ylabel('Percentages of massifs (\%)')
-        ax.set_xlabel('Range of temperatures used to compute the trends ')
-        ax.set_xticks(self.temp_min)
+        ax.set_xlabel('Interval used to compute the trends ')
+        ax.set_xticks(self.left_limits)
         ax.set_xticklabels(ticks_labels)
         ax.legend(prop={'size': 7}, loc='upper center', ncol=2)
         ax.set_ylim((0, 122))
@@ -87,7 +93,7 @@ class VisualizerForSensivity(object):
     @property
     def first_merge_visualizer(self):
         altitude_class = get_altitude_class_from_altitudes(self.altitudes_list[0])
-        visualizer_projection = list(self.temp_min_to_visualizer.values())[0]
+        visualizer_projection = list(self.left_limit_to_visualizer.values())[0]
         return self.get_merge_visualizer(altitude_class, visualizer_projection)
 
     def get_merge_visualizer(self, altitude_class, visualizer_projection: VisualizerForProjectionEnsemble):
@@ -97,7 +103,7 @@ class VisualizerForSensivity(object):
         merge_visualizer.studies.study.gcm_rcm_couple = (self.merge_visualizer_str, "merge")
         return merge_visualizer
 
-    def temperature_interval_plot(self, ax, altitude_class):
+    def interval_plot(self, ax, altitude_class):
         linestyle = get_linestyle_for_altitude_class(altitude_class)
         increasing_key = 'increasing'
         decreasing_key = 'decreasing'
@@ -109,14 +115,15 @@ class VisualizerForSensivity(object):
             increasing_key: 'red',
             decreasing_key: 'blue'
         }
-        for v in self.temp_min_to_visualizer.values():
+        for v in self.left_limit_to_visualizer.values():
             merge_visualizer = self.get_merge_visualizer(altitude_class, v)
-            _, *trends = merge_visualizer.all_trends(self.massif_names, with_significance=False)
+            _, *trends = merge_visualizer.all_trends(self.massif_names, with_significance=False,
+                                                     with_relative_change=True)
             label_to_l[decreasing_key].append(trends[0])
             label_to_l[increasing_key].append(trends[2])
         altitude_str = altitude_class().formula
         for label, l in label_to_l.items():
             label_improved = 'with {} trends {}'.format(label, altitude_str)
             color = label_to_color[label]
-            ax.plot(self.temp_min, l, label=label_improved, color=color, linestyle=linestyle)
+            ax.plot(self.left_limits, l, label=label_improved, color=color, linestyle=linestyle)
 
diff --git a/extreme_trend/one_fold_fit/altitudes_studies_visualizer_for_non_stationary_models.py b/extreme_trend/one_fold_fit/altitudes_studies_visualizer_for_non_stationary_models.py
index 995470b7..6263daed 100644
--- a/extreme_trend/one_fold_fit/altitudes_studies_visualizer_for_non_stationary_models.py
+++ b/extreme_trend/one_fold_fit/altitudes_studies_visualizer_for_non_stationary_models.py
@@ -455,7 +455,7 @@ class AltitudesStudiesVisualizerForNonStationaryModels(StudyVisualizer):
         self.studies.show_or_save_to_file(plot_name=plot_name, show=self.show)
         plt.close()
 
-    def all_trends(self, massif_names, with_significance=True):
+    def all_trends(self, massif_names, with_significance=True, with_relative_change=False):
         """return percents which contain decrease, significant decrease, increase, significant increase percentages"""
         valid_massif_names = self.get_valid_names(massif_names)
 
@@ -464,10 +464,14 @@ class AltitudesStudiesVisualizerForNonStationaryModels(StudyVisualizer):
         for one_fold in [one_fold for m, one_fold in self.massif_name_to_one_fold_fit.items()
                          if m in valid_massif_names]:
             # Compute nb of non stationary models
-            if one_fold.change_in_return_level_for_reference_altitude == 0:
+            if with_relative_change:
+                change_value = one_fold.relative_change_in_return_level_for_reference_altitude
+            else:
+                change_value = one_fold.change_in_return_level_for_reference_altitude
+            if change_value == 0:
                 continue
             # Compute nbs
-            idx = 0 if one_fold.change_in_return_level_for_reference_altitude < 0 else 2
+            idx = 0 if change_value < 0 else 2
             nbs[idx] += 1
             if with_significance and one_fold.is_significant:
                 nbs[idx + 1] += 1
diff --git a/projects/projected_extreme_snowfall/main_elevation_temporal_for_projections_ensemble.py b/projects/projected_extreme_snowfall/main_elevation_temporal_for_projections_ensemble.py
index 09f9bdb3..5e1eac74 100644
--- a/projects/projected_extreme_snowfall/main_elevation_temporal_for_projections_ensemble.py
+++ b/projects/projected_extreme_snowfall/main_elevation_temporal_for_projections_ensemble.py
@@ -38,11 +38,12 @@ def main():
     start = time.time()
     study_class = AdamontSnowfall
     ensemble_fit_class = [IndependentEnsembleFit]
-    temporal_covariate_for_fit = [TimeTemporalCovariate, AnomalyTemperatureWithSplineTemporalCovariate][1]
+    temporal_covariate_for_fit = [TimeTemporalCovariate,
+                                  AnomalyTemperatureWithSplineTemporalCovariate][0]
     set_seed_for_test()
     AbstractExtractEurocodeReturnLevel.ALPHA_CONFIDENCE_INTERVAL_UNCERTAINTY = 0.2
 
-    fast = False
+    fast = True
     sensitivity_plot = True
     scenarios = rcp_scenarios[::-1] if fast is False else [AdamontScenario.rcp85]
 
@@ -57,7 +58,7 @@ def main():
             massif_names = ['Vanoise', 'Haute-Maurienne']
             gcm_rcm_couples = gcm_rcm_couples[4:6]
             AbstractExtractEurocodeReturnLevel.NB_BOOTSTRAP = 10
-            altitudes_list = altitudes_for_groups[:]
+            altitudes_list = altitudes_for_groups[:1]
         else:
             massif_names = None
             altitudes_list = altitudes_for_groups[:]
diff --git a/test/test_extreme_trend/test_one_fold_fit.py b/test/test_extreme_trend/test_one_fold_fit.py
index 31da639f..30829f58 100644
--- a/test/test_extreme_trend/test_one_fold_fit.py
+++ b/test/test_extreme_trend/test_one_fold_fit.py
@@ -2,7 +2,7 @@ import unittest
 
 from extreme_data.meteo_france_data.adamont_data.adamont.adamont_safran import AdamontSnowfall
 from extreme_data.meteo_france_data.adamont_data.adamont_scenario import AdamontScenario
-from extreme_data.meteo_france_data.adamont_data.cmip5.temperature_to_year import temperature_minmax_to_year_minmax
+from extreme_data.meteo_france_data.adamont_data.cmip5.temperature_to_year import get_year_min_and_year_max
 from extreme_data.meteo_france_data.scm_models_data.safran.safran import SafranSnowfall1Day
 from extreme_fit.model.margin_model.polynomial_margin_model.gev_altitudinal_models import StationaryAltitudinal
 from extreme_fit.model.margin_model.polynomial_margin_model.models_based_on_pariwise_analysis.gev_with_constant_shape_wrt_altitude import \
@@ -84,8 +84,8 @@ class TestOneFoldFit(unittest.TestCase):
         self.altitudes = [3000, 3300, 3600]
         gcm_rcm_couple = ('HadGEM2-ES', 'RegCM4-6')
         scenario = AdamontScenario.rcp85
-        year_min, year_max = temperature_minmax_to_year_minmax(gcm_rcm_couple[0], scenario, temperature_min=1.0,
-                                                               temperature_max=2.5)
+        year_min, year_max = get_year_min_and_year_max(gcm_rcm_couple[0], scenario, left_limit=1.0,
+                                                       right_limit=2.5, is_temperature_interval=True)
         dataset = self.load_dataset(AdamontSnowfall,
                                     scenario=scenario, gcm_rcm_couple=gcm_rcm_couple,
                                     year_min=year_min, year_max=year_max)
-- 
GitLab