From e315dceae6a248eac8f0893903b48a5a0715ce47 Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Wed, 24 Feb 2021 13:05:38 +0100
Subject: [PATCH] [projections] add visualizer_for_sensitivity.py based on
 visualizer_for_projection_ensemble.py. add temperature_to_year.py file. Fix
 max abs recurrent issue (when it was equal to zero). add one test for one
 fold fit, that was that due to a scale parameter that was close to zero, and
 was causing instabilities in the nllh value. This is fixed by changing the
 check for implausible in one fold fit, and by saying that if the scale
 parameter is close to zero then the parameter is undefined.

---
 .../adamont_data/cmip5/temperature_to_year.py |  88 ++++++++++++++
 .../distribution/abstract_extreme_params.py   |   7 +-
 .../abstract_margin_estimator.py              |   6 +-
 .../one_fold_analysis/altitude_group.py       |   8 ++
 .../one_fold_analysis/one_fold_fit.py         |  25 +++-
 .../utils_altitude_studies_visualizer.py      |   3 +
 ...ation_temporal_for_projections_ensemble.py |  55 +++++----
 .../visualizer_for_projection_ensemble.py     |  36 ++++--
 .../visualizer_for_sensitivity.py             | 115 ++++++++++++++++++
 .../test_estimator/test_full_estimators.py    |   2 +
 .../test_one_fold_fit.py                      |  20 +++
 11 files changed, 328 insertions(+), 37 deletions(-)
 create mode 100644 extreme_data/meteo_france_data/adamont_data/cmip5/temperature_to_year.py
 create mode 100644 projects/projected_snowfall/elevation_temporal_model_for_projections/visualizer_for_sensitivity.py

diff --git a/extreme_data/meteo_france_data/adamont_data/cmip5/temperature_to_year.py b/extreme_data/meteo_france_data/adamont_data/cmip5/temperature_to_year.py
new file mode 100644
index 00000000..c911fa5e
--- /dev/null
+++ b/extreme_data/meteo_france_data/adamont_data/cmip5/temperature_to_year.py
@@ -0,0 +1,88 @@
+import matplotlib.pyplot as plt
+import numpy as np
+from matplotlib.lines import Line2D
+
+from extreme_data.meteo_france_data.adamont_data.adamont_gcm_rcm_couples import gcm_to_color
+from extreme_data.meteo_france_data.adamont_data.adamont_scenario import AdamontScenario, rcp_scenarios, get_gcm_list, \
+    get_linestyle_from_scenario, scenario_to_str, adamont_scenarios_real
+from extreme_data.meteo_france_data.adamont_data.cmip5.climate_explorer_cimp5 import years_and_global_mean_temps
+
+
+def temperature_minmax_to_year_minmax(gcm, scenario, temperature_min, temperature_max):
+    years, global_mean_temps = years_and_global_mean_temps(gcm, scenario, year_min=2005, year_max=2100,
+                                                           rolling=30, anomaly=True)
+    years, global_mean_temps = np.array(years), np.array(global_mean_temps)
+    ind = temperature_min < global_mean_temps
+    ind &= global_mean_temps < temperature_max
+    years_to_select = years[ind]
+    ind2 = years_to_select[:-1] == years_to_select[1:] - 1
+    if not all(ind2):
+        i = list(ind2).index(False)
+        years_to_select = years_to_select[:i + 1]
+    # A minimum of 30 years of data is needed to find a trend
+    if len(years_to_select) >= 30:
+        year_min, year_max = years_to_select[0], years_to_select[-1]
+        assert (year_max - year_min + 1) == len(years_to_select)
+        return year_min, year_max
+    else:
+        return None, None
+
+
+def get_nb_data(gcm, scenario, temperature_min, temperature_max):
+    year_min, year_max = temperature_minmax_to_year_minmax(gcm, scenario, temperature_min, temperature_max)
+    if year_min is None:
+        return 0
+    else:
+        return year_max - year_min + 1
+
+
+def plot_nb_data_one_line(ax, gcm, scenario, temp_min, temp_max):
+
+    nb_data = [get_nb_data(gcm, scenario, mi, ma) for mi, ma in zip(temp_min, temp_max)]
+    color = gcm_to_color[gcm]
+    linestyle = get_linestyle_from_scenario(scenario)
+
+    # Filter out the zero value
+    nb_data, temp_min = np.array(nb_data), np.array(temp_min)
+    ind = np.array(nb_data) > 0
+    nb_data, temp_min = nb_data[ind], temp_min[ind]
+
+    # For the legend
+    if scenario is AdamontScenario.rcp26:
+        ax.plot(temp_min[0], nb_data[0], color=color, linestyle='solid', label=gcm)
+
+    ax.plot(temp_min, nb_data, linestyle=linestyle, color=color)
+
+
+def plot_nb_data():
+    temp_max, temp_min = get_temp_min_and_temp_max()
+
+    ax = plt.gca()
+    for gcm in get_gcm_list(adamont_version=2)[:]:
+        for scenario in rcp_scenarios[:]:
+            plot_nb_data_one_line(ax, gcm, scenario, temp_min, temp_max)
+
+    ax.legend()
+    ticks_labels = ['{}-{}'.format(mi, ma) for mi, ma in zip(temp_min, temp_max)]
+    ax.set_xticks(temp_min)
+    ax.set_xticklabels(ticks_labels)
+    ax.set_xlabel('Temperature interval')
+    ax.set_ylabel('Nb of Data')
+    ax2 = ax.twinx()
+    legend_elements = [
+        Line2D([0], [0], color='k', lw=1, label=scenario_to_str(s),
+               linestyle=get_linestyle_from_scenario(s)) for s in adamont_scenarios_real
+    ]
+    ax2.legend(handles=legend_elements, loc='upper center')
+    ax2.set_yticks([])
+    plt.show()
+
+
+def get_temp_min_and_temp_max():
+    temp_min = np.arange(0.5, 2.5, 0.5)
+    temp_max = temp_min + 1.5
+    return temp_min, temp_max
+
+
+if __name__ == '__main__':
+    plot_nb_data()
diff --git a/extreme_fit/distribution/abstract_extreme_params.py b/extreme_fit/distribution/abstract_extreme_params.py
index 8c98f1b9..e892342d 100644
--- a/extreme_fit/distribution/abstract_extreme_params.py
+++ b/extreme_fit/distribution/abstract_extreme_params.py
@@ -1,7 +1,10 @@
+import numpy as np
+
 from extreme_fit.distribution.abstract_params import AbstractParams
 
 
 class AbstractExtremeParams(AbstractParams):
+    SMALL_SCALE_PARAMETERS_ARE_UNDEFINED = True
 
     def __init__(self, loc: float, scale: float, shape: float):
         self.location = loc
@@ -11,4 +14,6 @@ class AbstractExtremeParams(AbstractParams):
         # (sometimes it happens, when we want to find a quantile for every point of a 2D map
         # then it can happen that a corner point that was not used for fitting correspond to a negative scale,
         # in the case we set all the parameters as equal to np.nan, and we will not display those points)
-        self.has_undefined_parameters = self.scale <= 0
+        self.has_undefined_parameters = (self.scale <= 0)
+        if self.SMALL_SCALE_PARAMETERS_ARE_UNDEFINED:
+            self.has_undefined_parameters = self.has_undefined_parameters or (np.isclose(self.scale, 0))
\ No newline at end of file
diff --git a/extreme_fit/estimator/margin_estimator/abstract_margin_estimator.py b/extreme_fit/estimator/margin_estimator/abstract_margin_estimator.py
index f8afbdcd..6c71b8c5 100644
--- a/extreme_fit/estimator/margin_estimator/abstract_margin_estimator.py
+++ b/extreme_fit/estimator/margin_estimator/abstract_margin_estimator.py
@@ -61,11 +61,13 @@ class LinearMarginEstimator(AbstractMarginEstimator):
     def function_from_fit(self) -> LinearMarginFunction:
         return load_margin_function(self, self.margin_model)
 
+    def coordinates_for_nllh(self, split=Split.all):
+        return pd.concat([self.df_coordinates_spat(split=split), self.df_coordinates_temp(split=split)], axis=1).values
+
     def nllh(self, split=Split.all):
         nllh = 0
         maxima_values = self.dataset.maxima_gev(split=split)
-        df = pd.concat([self.df_coordinates_spat(split=split), self.df_coordinates_temp(split=split)], axis=1)
-        coordinate_values = df.values
+        coordinate_values = self.coordinates_for_nllh(split=split)
         for maximum, coordinate in zip(maxima_values, coordinate_values):
             assert len(maximum) == 1, \
                 'So far, only one observation for each coordinate, but code would be easy to change'
diff --git a/projects/altitude_spatial_model/altitudes_fit/one_fold_analysis/altitude_group.py b/projects/altitude_spatial_model/altitudes_fit/one_fold_analysis/altitude_group.py
index 30d0eb09..5a3f8a16 100644
--- a/projects/altitude_spatial_model/altitudes_fit/one_fold_analysis/altitude_group.py
+++ b/projects/altitude_spatial_model/altitudes_fit/one_fold_analysis/altitude_group.py
@@ -17,6 +17,10 @@ class AbstractAltitudeGroup(object):
     def name(self):
         raise NotImplementedError
 
+    @property
+    def altitudes(self):
+        return altitudes_for_groups[self.group_id-1]
+
     @property
     def reference_altitude(self):
         raise NotImplementedError
@@ -141,6 +145,10 @@ class DefaultAltitudeGroup(AbstractAltitudeGroup):
         return 500
 
 
+def get_altitude_class_from_altitudes(altitudes):
+    return type(get_altitude_group_from_altitudes(altitudes))
+
+
 def get_altitude_group_from_altitudes(altitudes):
     s = set(altitudes)
     if s == set(altitudes_for_groups[0]):
diff --git a/projects/altitude_spatial_model/altitudes_fit/one_fold_analysis/one_fold_fit.py b/projects/altitude_spatial_model/altitudes_fit/one_fold_analysis/one_fold_fit.py
index 65f900e2..a0451c44 100644
--- a/projects/altitude_spatial_model/altitudes_fit/one_fold_analysis/one_fold_fit.py
+++ b/projects/altitude_spatial_model/altitudes_fit/one_fold_analysis/one_fold_fit.py
@@ -180,12 +180,31 @@ class OneFoldFit(object):
             # Remove wrong shape
             estimators = [e for e in estimators if -0.5 < self._compute_shape_for_reference_altitude(e) < 0.5]
             # Remove models with undefined parameters for the coordinate of interest
-            coordinate = np.array([self.altitude_group.reference_altitude, self.last_year])
-            estimators = [e for e in estimators if not e.function_from_fit.get_params(coordinate).has_undefined_parameters]
+            well_defined_estimators = []
+            for e in estimators:
+                coordinate_values_for_the_fit = e.coordinates_for_nllh(Split.all)
+                coordinate_values_for_the_result = [np.array([self.altitude_group.reference_altitude, c])
+                                                             for c in self._covariate_before_and_after]
+                coordinate_values_to_check = list(coordinate_values_for_the_fit) + coordinate_values_for_the_result
+                has_undefined_parameters = False
+                for coordinate in coordinate_values_to_check:
+                    gev_params = e.function_from_fit.get_params(coordinate)
+                    if gev_params.has_undefined_parameters:
+                        has_undefined_parameters = True
+                        break
+                if not has_undefined_parameters:
+                    well_defined_estimators.append(e)
+            estimators = well_defined_estimators
+
             if len(estimators) == 0:
                 print(self.massif_name, " has only implausible models")
 
-        sorted_estimators = sorted([estimator for estimator in estimators], key=lambda e: e.aic())
+        try:
+            sorted_estimators = sorted([estimator for estimator in estimators], key=lambda e: e.aic())
+        except AssertionError as e:
+            print('Error for')
+            print(self.massif_name, self.altitude_group)
+            raise
         return sorted_estimators
 
     def _compute_shape_for_reference_altitude(self, estimator):
diff --git a/projects/altitude_spatial_model/altitudes_fit/utils_altitude_studies_visualizer.py b/projects/altitude_spatial_model/altitudes_fit/utils_altitude_studies_visualizer.py
index e0a556db..48463452 100644
--- a/projects/altitude_spatial_model/altitudes_fit/utils_altitude_studies_visualizer.py
+++ b/projects/altitude_spatial_model/altitudes_fit/utils_altitude_studies_visualizer.py
@@ -34,6 +34,9 @@ def compute_and_assign_max_abs(visualizer_list):
             max_abs = max([
                 max([abs(e) for e in v.method_name_and_order_to_d(method_name, order).values()
                      ]) for v in visualizer_list])
+            if max_abs == 0:
+                epsilon = 0.1
+                max_abs = epsilon
             method_name_and_order_to_max_abs[c] = max_abs
     # Assign the max abs dictionary
     for v in visualizer_list:
diff --git a/projects/projected_snowfall/elevation_temporal_model_for_projections/main_elevation_temporal_for_projections_ensemble.py b/projects/projected_snowfall/elevation_temporal_model_for_projections/main_elevation_temporal_for_projections_ensemble.py
index 268b38fc..23f51598 100644
--- a/projects/projected_snowfall/elevation_temporal_model_for_projections/main_elevation_temporal_for_projections_ensemble.py
+++ b/projects/projected_snowfall/elevation_temporal_model_for_projections/main_elevation_temporal_for_projections_ensemble.py
@@ -4,13 +4,16 @@ from typing import List
 
 import matplotlib as mpl
 
+from projects.projected_snowfall.elevation_temporal_model_for_projections.visualizer_for_sensitivity import \
+    VisualizerForSensivity
+
 mpl.rcParams['text.usetex'] = True
 mpl.rcParams['text.latex.preamble'] = [r'\usepackage{amsmath}']
 
 from extreme_fit.model.margin_model.polynomial_margin_model.utils import \
     ALTITUDINAL_GEV_MODELS_BASED_ON_POINTWISE_ANALYSIS
 from projects.projected_snowfall.elevation_temporal_model_for_projections.visualizer_for_projection_ensemble import \
-    MetaVisualizerForProjectionEnsemble
+    VisualizerForProjectionEnsemble
 import matplotlib
 from extreme_fit.model.utils import set_seed_for_test
 
@@ -34,62 +37,72 @@ from extreme_data.meteo_france_data.scm_models_data.utils import Season
 
 def main():
     start = time.time()
-    study_classes = [AdamontSnowfall][:1]
+    study_class = AdamontSnowfall
     ensemble_fit_class = [IndependentEnsembleFit]
     temporal_covariate_for_fit = [TimeTemporalCovariate, AnomalyTemperatureTemporalCovariate][1]
     set_seed_for_test()
     AbstractExtractEurocodeReturnLevel.ALPHA_CONFIDENCE_INTERVAL_UNCERTAINTY = 0.2
 
     fast = None
-    scenarios = rcp_scenarios[:1] if fast is False else [AdamontScenario.rcp26]
+    sensitivity_plot = True
+    scenarios = rcp_scenarios if fast is False else [AdamontScenario.rcp85]
 
     for scenario in scenarios:
         gcm_rcm_couples = get_gcm_rcm_couples(scenario)
         if fast is None:
             massif_names = None
-            gcm_rcm_couples = gcm_rcm_couples[:]
+            gcm_rcm_couples = None
             AbstractExtractEurocodeReturnLevel.NB_BOOTSTRAP = 10
-            altitudes_list = altitudes_for_groups[:1]
+            altitudes_list = altitudes_for_groups[3:]
         elif fast:
+            massif_names = ['Vanoise', 'Haute-Maurienne']
+            gcm_rcm_couples = gcm_rcm_couples[4:6]
             AbstractExtractEurocodeReturnLevel.NB_BOOTSTRAP = 10
-            massif_names = None
-            gcm_rcm_couples = [('EC-EARTH', 'RACMO22E')]
-            altitudes_list = altitudes_for_groups[1:2]
+            altitudes_list = altitudes_for_groups[2:]
         else:
             massif_names = None
             altitudes_list = altitudes_for_groups[:]
 
         assert isinstance(gcm_rcm_couples, list)
 
-        main_loop(gcm_rcm_couples, altitudes_list, massif_names, study_classes, ensemble_fit_class, scenario,
-                  temporal_covariate_for_fit)
+        main_loop(gcm_rcm_couples, altitudes_list, massif_names, study_class, ensemble_fit_class, scenario,
+                  temporal_covariate_for_fit, sensitivity_plot=sensitivity_plot)
 
     end = time.time()
     duration = str(datetime.timedelta(seconds=end - start))
     print('Total duration', duration)
 
 
-def main_loop(gcm_rcm_couples, altitudes_list, massif_names, study_classes, ensemble_fit_classes, scenario,
-              temporal_covariate_for_fit):
+def main_loop(gcm_rcm_couples, altitudes_list, massif_names, study_class, ensemble_fit_classes, scenario,
+              temporal_covariate_for_fit, sensitivity_plot=False):
     assert isinstance(altitudes_list, List)
     assert isinstance(altitudes_list[0], List)
     print('Covariate is {}'.format(temporal_covariate_for_fit))
-    for study_class in study_classes:
-        print('Inner loop', study_class)
-        model_classes = ALTITUDINAL_GEV_MODELS_BASED_ON_POINTWISE_ANALYSIS
-        assert scenario in rcp_scenarios
 
-        visualizer = MetaVisualizerForProjectionEnsemble(
+    model_classes = ALTITUDINAL_GEV_MODELS_BASED_ON_POINTWISE_ANALYSIS
+    assert scenario in rcp_scenarios
+    remove_physically_implausible_models = True
+
+    if sensitivity_plot:
+        visualizer = VisualizerForSensivity(
+            altitudes_list, gcm_rcm_couples, study_class, Season.annual, scenario,
+            model_classes=model_classes,
+            ensemble_fit_classes=ensemble_fit_classes,
+            massif_names=massif_names,
+            temporal_covariate_for_fit=temporal_covariate_for_fit,
+            remove_physically_implausible_models=remove_physically_implausible_models,
+        )
+    else:
+        visualizer = VisualizerForProjectionEnsemble(
             altitudes_list, gcm_rcm_couples, study_class, Season.annual, scenario,
             model_classes=model_classes,
             ensemble_fit_classes=ensemble_fit_classes,
             massif_names=massif_names,
             temporal_covariate_for_fit=temporal_covariate_for_fit,
-            remove_physically_implausible_models=True,
+            remove_physically_implausible_models=remove_physically_implausible_models,
+            gcm_to_year_min_and_year_max=None,
         )
-        visualizer.plot()
-        del visualizer
-        time.sleep(2)
+    visualizer.plot()
 
 
 if __name__ == '__main__':
diff --git a/projects/projected_snowfall/elevation_temporal_model_for_projections/visualizer_for_projection_ensemble.py b/projects/projected_snowfall/elevation_temporal_model_for_projections/visualizer_for_projection_ensemble.py
index 9d51602b..e7d9e38e 100644
--- a/projects/projected_snowfall/elevation_temporal_model_for_projections/visualizer_for_projection_ensemble.py
+++ b/projects/projected_snowfall/elevation_temporal_model_for_projections/visualizer_for_projection_ensemble.py
@@ -5,7 +5,7 @@ from extreme_fit.model.margin_model.polynomial_margin_model.spatio_temporal_poly
 from extreme_fit.model.margin_model.utils import MarginFitMethod
 from projects.altitude_spatial_model.altitudes_fit.altitudes_studies import AltitudesStudies
 from projects.altitude_spatial_model.altitudes_fit.one_fold_analysis.altitude_group import \
-    get_altitude_group_from_altitudes
+    get_altitude_group_from_altitudes, get_altitude_class_from_altitudes
 from projects.altitude_spatial_model.altitudes_fit.plots.plot_histogram_altitude_studies import \
     plot_histogram_all_trends_against_altitudes, plot_shoe_plot_changes_against_altitude
 from projects.altitude_spatial_model.altitudes_fit.utils_altitude_studies_visualizer import compute_and_assign_max_abs
@@ -13,7 +13,7 @@ from projects.projected_snowfall.elevation_temporal_model_for_projections.indepe
     IndependentEnsembleFit
 
 
-class MetaVisualizerForProjectionEnsemble(object):
+class VisualizerForProjectionEnsemble(object):
 
     def __init__(self, altitudes_list, gcm_rcm_couples, study_class, season, scenario,
                  model_classes: List[AbstractSpatioTemporalPolynomialModel],
@@ -24,25 +24,39 @@ class MetaVisualizerForProjectionEnsemble(object):
                  display_only_model_that_pass_gof_test=False,
                  confidence_interval_based_on_delta_method=False,
                  remove_physically_implausible_models=False,
+                 gcm_to_year_min_and_year_max=None,
                  ):
         self.gcm_rcm_couples = gcm_rcm_couples
         self.massif_names = massif_names
         self.ensemble_fit_classes = ensemble_fit_classes
 
         # Load all studies
-        altitude_group_to_gcm_couple_to_studies = {}
+        altitude_class_to_gcm_couple_to_studies = {}
         for altitudes in altitudes_list:
-            altitude_group = get_altitude_group_from_altitudes(altitudes)
+            altitude_class = get_altitude_class_from_altitudes(altitudes)
             gcm_rcm_couple_to_studies = {}
             for gcm_rcm_couple in gcm_rcm_couples:
+                if gcm_to_year_min_and_year_max is None:
+                    kwargs_study = {}
+                else:
+                    gcm = gcm_rcm_couple[0]
+                    if gcm not in gcm_to_year_min_and_year_max:
+                        # It means that for this gcm and scenario,
+                        # there is not enough data (less than 30 years) for the fit
+                        continue
+                    year_min, year_max = gcm_to_year_min_and_year_max[gcm]
+                    kwargs_study = {'year_min': year_min, 'year_max': year_max}
                 studies = AltitudesStudies(study_class, altitudes, season=season,
-                                           scenario=scenario, gcm_rcm_couple=gcm_rcm_couple)
+                                           scenario=scenario, gcm_rcm_couple=gcm_rcm_couple,
+                                           **kwargs_study)
                 gcm_rcm_couple_to_studies[gcm_rcm_couple] = studies
-            altitude_group_to_gcm_couple_to_studies[altitude_group] = gcm_rcm_couple_to_studies
+            if len(gcm_rcm_couple_to_studies) == 0:
+                print('No valid studies for the following couples:', self.gcm_rcm_couples)
+            altitude_class_to_gcm_couple_to_studies[altitude_class] = gcm_rcm_couple_to_studies
 
         # Load ensemble fit
-        self.altitude_group_to_ensemble_class_to_ensemble_fit = {}
-        for altitude_group, gcm_rcm_couple_to_studies in altitude_group_to_gcm_couple_to_studies.items():
+        self.altitude_class_to_ensemble_class_to_ensemble_fit = {}
+        for altitude_class, gcm_rcm_couple_to_studies in altitude_class_to_gcm_couple_to_studies.items():
             ensemble_class_to_ensemble_fit = {}
             for ensemble_fit_class in ensemble_fit_classes:
                 ensemble_fit = ensemble_fit_class(massif_names, gcm_rcm_couple_to_studies, model_classes,
@@ -51,7 +65,7 @@ class MetaVisualizerForProjectionEnsemble(object):
                                                   confidence_interval_based_on_delta_method,
                                                   remove_physically_implausible_models)
                 ensemble_class_to_ensemble_fit[ensemble_fit_class] = ensemble_fit
-            self.altitude_group_to_ensemble_class_to_ensemble_fit[altitude_group] = ensemble_class_to_ensemble_fit
+            self.altitude_class_to_ensemble_class_to_ensemble_fit[altitude_class] = ensemble_class_to_ensemble_fit
 
     def plot(self):
         if IndependentEnsembleFit in self.ensemble_fit_classes:
@@ -74,6 +88,8 @@ class MetaVisualizerForProjectionEnsemble(object):
         # Aggregated at gcm_rcm_level plots
         merge_keys = [IndependentEnsembleFit.Median_merge, IndependentEnsembleFit.Mean_merge]
         keys = self.gcm_rcm_couples + merge_keys
+        # Only plot Mean for speed
+        keys = [IndependentEnsembleFit.Mean_merge]
         for key in keys:
             visualizer_list = [independent_ensemble_fit.gcm_rcm_couple_to_visualizer[key]
                                if key in self.gcm_rcm_couples
@@ -92,7 +108,7 @@ class MetaVisualizerForProjectionEnsemble(object):
     def ensemble_fits(self, ensemble_class):
         return [ensemble_class_to_ensemble_fit[ensemble_class]
                 for ensemble_class_to_ensemble_fit
-                in self.altitude_group_to_ensemble_class_to_ensemble_fit.values()]
+                in self.altitude_class_to_ensemble_class_to_ensemble_fit.values()]
 
     def plot_together(self):
         pass
diff --git a/projects/projected_snowfall/elevation_temporal_model_for_projections/visualizer_for_sensitivity.py b/projects/projected_snowfall/elevation_temporal_model_for_projections/visualizer_for_sensitivity.py
new file mode 100644
index 00000000..f3e40b69
--- /dev/null
+++ b/projects/projected_snowfall/elevation_temporal_model_for_projections/visualizer_for_sensitivity.py
@@ -0,0 +1,115 @@
+from collections import OrderedDict
+import matplotlib.pyplot as plt
+from typing import List, Dict
+
+from extreme_data.meteo_france_data.adamont_data.cmip5.temperature_to_year import get_temp_min_and_temp_max, \
+    temperature_minmax_to_year_minmax
+from extreme_data.meteo_france_data.scm_models_data.utils import Season
+from extreme_fit.model.margin_model.polynomial_margin_model.spatio_temporal_polynomial_model import \
+    AbstractSpatioTemporalPolynomialModel
+from extreme_fit.model.margin_model.utils import MarginFitMethod
+from projects.altitude_spatial_model.altitudes_fit.altitudes_studies import AltitudesStudies
+from projects.altitude_spatial_model.altitudes_fit.one_fold_analysis.altitude_group import \
+    get_altitude_group_from_altitudes, get_altitude_class_from_altitudes
+from projects.altitude_spatial_model.altitudes_fit.plots.plot_histogram_altitude_studies import \
+    plot_histogram_all_trends_against_altitudes, plot_shoe_plot_changes_against_altitude
+from projects.altitude_spatial_model.altitudes_fit.utils_altitude_studies_visualizer import compute_and_assign_max_abs
+from projects.projected_snowfall.elevation_temporal_model_for_projections.independent_ensemble_fit.independent_ensemble_fit import \
+    IndependentEnsembleFit
+from projects.projected_snowfall.elevation_temporal_model_for_projections.visualizer_for_projection_ensemble import \
+    VisualizerForProjectionEnsemble
+
+
+class VisualizerForSensivity(object):
+
+    def __init__(self, altitudes_list, gcm_rcm_couples, study_class, season, scenario,
+                 model_classes: List[AbstractSpatioTemporalPolynomialModel],
+                 ensemble_fit_classes=None,
+                 massif_names=None,
+                 fit_method=MarginFitMethod.extremes_fevd_mle,
+                 temporal_covariate_for_fit=None,
+                 display_only_model_that_pass_gof_test=False,
+                 confidence_interval_based_on_delta_method=False,
+                 remove_physically_implausible_models=False,
+                 ):
+        self.altitudes_list = altitudes_list
+        self.massif_names = massif_names
+        self.temp_min, self.temp_max = get_temp_min_and_temp_max()
+        self.temp_min_to_temp_max = OrderedDict(zip(self.temp_min, self.temp_max))
+        self.temp_min_to_visualizer = {} # type: Dict[float, VisualizerForProjectionEnsemble]
+
+        for temp_min, temp_max in zip(self.temp_min, self.temp_max):
+            print(temp_min, temp_max)
+            # Build 
+            gcm_to_year_min_and_year_max = {}
+            gcm_list = list(set([g for g, r in gcm_rcm_couples]))
+            for gcm in gcm_list:
+                year_min_and_year_max = temperature_minmax_to_year_minmax(gcm, scenario, temp_min, temp_max)
+                if year_min_and_year_max[0] is not None:
+                    gcm_to_year_min_and_year_max[gcm] = year_min_and_year_max
+                
+            visualizer = VisualizerForProjectionEnsemble(
+                altitudes_list, gcm_rcm_couples, study_class, Season.annual, scenario,
+                model_classes=model_classes,
+                fit_method=fit_method,
+                ensemble_fit_classes=ensemble_fit_classes,
+                display_only_model_that_pass_gof_test=display_only_model_that_pass_gof_test,
+                confidence_interval_based_on_delta_method=confidence_interval_based_on_delta_method,
+                massif_names=massif_names,
+                temporal_covariate_for_fit=temporal_covariate_for_fit,
+                remove_physically_implausible_models=remove_physically_implausible_models,
+                gcm_to_year_min_and_year_max=gcm_to_year_min_and_year_max
+            )
+            self.temp_min_to_visualizer[temp_min] = visualizer
+
+    def plot(self):
+        # todo: before reactivating the subplot, i should ensure that we can modify the prefix
+        # so that we can have all the subplot, and not juste for the last t_min
+        # for visualizer in self.temp_min_to_visualizer.values():
+        #     visualizer.plot()
+        self.sensitivity_plot()
+
+    def sensitivity_plot(self):
+        ax = plt.gca()
+        for altitudes in self.altitudes_list:
+            altitude_class = get_altitude_class_from_altitudes(altitudes)
+            self.temperature_interval_plot(ax, altitude_class)
+
+        ticks_labels = ['+{} to +{}'.format(mi, ma) for mi, ma in zip(self.temp_min, self.temp_max)]
+        ax.set_ylabel('Percentages of massifs (%)')
+        ax.set_xlabel('Range of temperatures used to compute the trends ')
+        ax.set_xticks(self.temp_min)
+        ax.set_xticklabels(ticks_labels)
+        ax.legend()
+        mean_visualizer = self.first_mean_visualizer
+        mean_visualizer.plot_name = 'Sensitivity plot'
+        mean_visualizer.show_or_save_to_file(no_title=True)
+
+    @property
+    def first_mean_visualizer(self):
+        altitude_class = get_altitude_class_from_altitudes(self.altitudes_list[0])
+        visualizer_projection = list(self.temp_min_to_visualizer.values())[0]
+        return self.get_mean_visualizer(altitude_class, visualizer_projection)
+
+    def get_mean_visualizer(self, altitude_class, visualizer_projection: VisualizerForProjectionEnsemble):
+        independent_ensemble_fit = visualizer_projection.altitude_class_to_ensemble_class_to_ensemble_fit[altitude_class][
+            IndependentEnsembleFit]
+        mean_visualizer = independent_ensemble_fit.merge_function_name_to_visualizer[IndependentEnsembleFit.Mean_merge]
+        mean_visualizer.studies.study.gcm_rcm_couple = (IndependentEnsembleFit.Mean_merge, "merge")
+        return mean_visualizer
+
+    def temperature_interval_plot(self, ax, altitude_class):
+        label_to_l = {
+            'increasing': [],
+            'decreasing': []
+        }
+        for v in self.temp_min_to_visualizer.values():
+            mean_visualizer = self.get_mean_visualizer(altitude_class, v)
+            _, *trends = mean_visualizer.all_trends(self.massif_names, with_significance=False)
+            label_to_l['decreasing'].append(trends[0])
+            label_to_l['increasing'].append(trends[2])
+        altitude_str = altitude_class().formula
+        for label, l in label_to_l.items():
+            label_improved = 'with {} trends {}'.format(label, altitude_str)
+            ax.plot(self.temp_min, l, label=label_improved)
+
diff --git a/test/test_extreme_fit/test_estimator/test_full_estimators.py b/test/test_extreme_fit/test_estimator/test_full_estimators.py
index 9322f034..abbb240b 100644
--- a/test/test_extreme_fit/test_estimator/test_full_estimators.py
+++ b/test/test_extreme_fit/test_estimator/test_full_estimators.py
@@ -1,6 +1,7 @@
 import unittest
 from itertools import product
 
+from extreme_fit.distribution.abstract_extreme_params import AbstractExtremeParams
 from spatio_temporal_dataset.dataset.simulation_dataset import FullSimulatedDataset
 from test.test_utils import load_test_max_stable_models, load_smooth_margin_models, load_test_1D_and_2D_spatial_coordinates, \
     load_test_full_estimators
@@ -17,6 +18,7 @@ class TestFullEstimators(unittest.TestCase):
         self.max_stable_models = load_test_max_stable_models()
 
     def test_full_estimators(self):
+        AbstractExtremeParams.SMALL_SCALE_PARAMETERS_ARE_UNDEFINED = False
         for coordinates in self.spatial_coordinates:
             smooth_margin_models = load_smooth_margin_models(coordinates=coordinates)
             for margin_model, max_stable_model in product(smooth_margin_models, self.max_stable_models):
diff --git a/test/test_projects/test_altitude_spatial/test_one_fold_fit.py b/test/test_projects/test_altitude_spatial/test_one_fold_fit.py
index 33daf73f..f9397686 100644
--- a/test/test_projects/test_altitude_spatial/test_one_fold_fit.py
+++ b/test/test_projects/test_altitude_spatial/test_one_fold_fit.py
@@ -2,6 +2,7 @@ import unittest
 
 from extreme_data.meteo_france_data.adamont_data.adamont.adamont_snowfall import AdamontSnowfall
 from extreme_data.meteo_france_data.adamont_data.adamont_scenario import AdamontScenario
+from extreme_data.meteo_france_data.adamont_data.cmip5.temperature_to_year import temperature_minmax_to_year_minmax
 from extreme_data.meteo_france_data.scm_models_data.safran.safran import SafranSnowfall1Day
 from extreme_fit.model.margin_model.linear_margin_model.temporal_linear_margin_models import StationaryTemporalModel
 from extreme_fit.model.margin_model.polynomial_margin_model.gev_altitudinal_models import StationaryAltitudinal
@@ -11,6 +12,7 @@ from extreme_fit.model.margin_model.polynomial_margin_model.models_based_on_pari
 from extreme_fit.model.margin_model.polynomial_margin_model.utils import \
     ALTITUDINAL_GEV_MODELS_BASED_ON_POINTWISE_ANALYSIS
 from projects.altitude_spatial_model.altitudes_fit.altitudes_studies import AltitudesStudies
+from projects.altitude_spatial_model.altitudes_fit.one_fold_analysis.altitude_group import VeyHighAltitudeGroup
 from projects.altitude_spatial_model.altitudes_fit.one_fold_analysis.one_fold_fit import OneFoldFit
 from spatio_temporal_dataset.coordinates.temporal_coordinates.abstract_temporal_covariate_for_fit import \
     TimeTemporalCovariate, AnomalyTemperatureTemporalCovariate
@@ -75,6 +77,24 @@ class TestOneFoldFit(unittest.TestCase):
                                   remove_physically_implausible_models=True)
         self.assertFalse(one_fold_fit.has_at_least_one_valid_model)
 
+    def test_assertion_error_for_a_specific_case(self):
+        self.massif_name = "Thabor"
+        self.model_classes = ALTITUDINAL_GEV_MODELS_BASED_ON_POINTWISE_ANALYSIS[:]
+        self.altitudes = [3000, 3300, 3600]
+        gcm_rcm_couple = ('HadGEM2-ES', 'RegCM4-6')
+        scenario = AdamontScenario.rcp85
+        year_min, year_max = temperature_minmax_to_year_minmax(gcm_rcm_couple[0], scenario, temperature_min=1.0,
+                                                               temperature_max=2.5)
+        dataset = self.load_dataset(AdamontSnowfall,
+                                    scenario=scenario, gcm_rcm_couple=gcm_rcm_couple,
+                                    year_min=year_min, year_max=year_max)
+        one_fold_fit = OneFoldFit(self.massif_name, dataset,
+                                  models_classes=self.model_classes,
+                                  temporal_covariate_for_fit=AnomalyTemperatureTemporalCovariate,
+                                  altitude_class=VeyHighAltitudeGroup,
+                                  only_models_that_pass_goodness_of_fit_test=False,
+                                  remove_physically_implausible_models=True)
+        self.assertTrue(one_fold_fit.has_at_least_one_valid_model)
 
 if __name__ == '__main__':
     unittest.main()
-- 
GitLab