From 2901ed6697b602cd13d13cbfd003b49c34a9f385 Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Thu, 7 Jan 2021 15:44:58 +0100
Subject: [PATCH] [contrasting] decompose the coherence plot into 4 subplots.

---
 .../altitudes_fit/main_altitudes_studies.py   |  10 +-
 ...es_visualizer_for_non_stationary_models.py |   2 +-
 .../one_fold_analysis/one_fold_fit.py         |   7 +-
 .../plots/plot_coherence_curves.py            | 130 +++++++++---------
 4 files changed, 77 insertions(+), 72 deletions(-)

diff --git a/projects/altitude_spatial_model/altitudes_fit/main_altitudes_studies.py b/projects/altitude_spatial_model/altitudes_fit/main_altitudes_studies.py
index 0532fa44..86732094 100644
--- a/projects/altitude_spatial_model/altitudes_fit/main_altitudes_studies.py
+++ b/projects/altitude_spatial_model/altitudes_fit/main_altitudes_studies.py
@@ -26,7 +26,7 @@ def main():
     study_classes = [SafranSnowfall1Day, SafranSnowfall3Days, SafranSnowfall5Days, SafranSnowfall7Days][:1]
     seasons = [Season.annual, Season.winter, Season.spring, Season.automn][:1]
 
-    fast = False
+    fast = True
     if fast is None:
         massif_names = None
         altitudes_list = altitudes_for_groups[1:2]
@@ -55,9 +55,9 @@ def main_loop(altitudes_list, massif_names, seasons, study_classes):
 
 
 def plot_visualizers(massif_names, visualizer_list):
-    plot_histogram_all_trends_against_altitudes(massif_names, visualizer_list)
-    for relative in [True, False]:
-        plot_shoe_plot_changes_against_altitude(massif_names, visualizer_list, relative=relative)
+    # plot_histogram_all_trends_against_altitudes(massif_names, visualizer_list)
+    # for relative in [True, False]:
+        # plot_shoe_plot_changes_against_altitude(massif_names, visualizer_list, relative=relative)
         # plot_shoe_plot_changes_against_altitude_for_maxima_and_total(massif_names, visualizer_list, relative=relative)
     # plot_coherence_curves(massif_names, visualizer_list)
     plot_coherence_curves(['Vanoise'], visualizer_list)
@@ -69,7 +69,7 @@ def plot_visualizer(massif_names, visualizer):
     # visualizer.studies.plot_maxima_time_series(['Vanoise'])
 
     # Plot the results for the model that minimizes the individual aic
-    plot_individual_aic(visualizer)
+    # plot_individual_aic(visualizer)
 
 
     # Plot the results for the model that minimizes the total aic
diff --git a/projects/altitude_spatial_model/altitudes_fit/one_fold_analysis/altitudes_studies_visualizer_for_non_stationary_models.py b/projects/altitude_spatial_model/altitudes_fit/one_fold_analysis/altitudes_studies_visualizer_for_non_stationary_models.py
index dc2a2beb..8b7f57f6 100644
--- a/projects/altitude_spatial_model/altitudes_fit/one_fold_analysis/altitudes_studies_visualizer_for_non_stationary_models.py
+++ b/projects/altitude_spatial_model/altitudes_fit/one_fold_analysis/altitudes_studies_visualizer_for_non_stationary_models.py
@@ -148,7 +148,7 @@ class AltitudesStudiesVisualizerForNonStationaryModels(StudyVisualizer):
         massif_name_to_text = self.massif_name_to_best_name
         if 'change' in method_name:
             plot_name = plot_name.replace(str_for_last_year, '')
-            plot_name += ' between {} and {}'.format(2019 - 50, 2019)
+            plot_name += ' between {} and {}'.format(2019 - OneFoldFit.nb_years, 2019)
             if 'relative' not in method_name:
                 # Put the relative score as text on the plot for the change.
                 massif_name_to_text = {m: ('+' if v > 0 else '') + str(int(v)) + '\%' for m, v in
diff --git a/projects/altitude_spatial_model/altitudes_fit/one_fold_analysis/one_fold_fit.py b/projects/altitude_spatial_model/altitudes_fit/one_fold_analysis/one_fold_fit.py
index 5f5880dd..35c892f1 100644
--- a/projects/altitude_spatial_model/altitudes_fit/one_fold_analysis/one_fold_fit.py
+++ b/projects/altitude_spatial_model/altitudes_fit/one_fold_analysis/one_fold_fit.py
@@ -33,6 +33,7 @@ class OneFoldFit(object):
     best_estimator_minimizes_total_aic = False
     return_period = 100
     quantile_level = 1 - (1 / return_period)
+    nb_years = 60
 
     def __init__(self, massif_name: str, dataset: AbstractDataset, models_classes,
                  fit_method=MarginFitMethod.extremes_fevd_mle, temporal_covariate_for_fit=None,
@@ -100,7 +101,7 @@ class OneFoldFit(object):
     def relative_change_in_return_level_for_reference_altitude(self) -> float:
         return self.relative_changes_of_moment(altitudes=[self.altitude_plot], order=None)[0]
 
-    def changes_of_moment(self, altitudes, year=2019, nb_years=60, order=1):
+    def changes_of_moment(self, altitudes, year=2019, nb_years=nb_years, order=1):
         changes = []
         for altitude in altitudes:
             mean_after = self.get_moment(altitude, year, order)
@@ -109,7 +110,7 @@ class OneFoldFit(object):
             changes.append(change)
         return changes
 
-    def relative_changes_of_moment(self, altitudes, year=2019, nb_years=60, order=1):
+    def relative_changes_of_moment(self, altitudes, year=2019, nb_years=nb_years, order=1):
         relative_changes = []
         for altitude in altitudes:
             mean_after = self.get_moment(altitude, year, order)
@@ -293,7 +294,7 @@ class OneFoldFit(object):
 
     def sign_of_change(self, estimator):
         return_levels = []
-        for year in [2019 - 60, 2019]:
+        for year in [2019 - self.nb_years, 2019]:
             coordinate = np.array([self.altitude_plot, year])
             return_level = estimator.function_from_fit.get_params(
                 coordinate=coordinate,
diff --git a/projects/altitude_spatial_model/altitudes_fit/plots/plot_coherence_curves.py b/projects/altitude_spatial_model/altitudes_fit/plots/plot_coherence_curves.py
index 87cdf9ae..91135f7a 100644
--- a/projects/altitude_spatial_model/altitudes_fit/plots/plot_coherence_curves.py
+++ b/projects/altitude_spatial_model/altitudes_fit/plots/plot_coherence_curves.py
@@ -9,86 +9,90 @@ from projects.altitude_spatial_model.altitudes_fit.one_fold_analysis.one_fold_fi
 
 def plot_coherence_curves(massif_names, visualizer_list: List[AltitudesStudiesVisualizerForNonStationaryModels]):
     folder = 'Coherence'
+    elevation_as_xaxis = True
     visualizer = visualizer_list[0]
     all_valid_names = set.union(*[v.get_valid_names(massif_names) for v in visualizer_list])
     for massif_name in all_valid_names:
 
         # For plotting the legend
         legend = False
-        if legend:
-            ax = plt.gca()
-            axes = [ax for _ in range(4)]
-        else:
-            _, axes = plt.subplots(2, 2)
-            axes = axes.flatten()
-        for i, ax in enumerate(axes):
-            if i % 2 == 1:
-                ax.set_yticks([])
-        axes = [ax if i % 2 == 0 else ax.twinx() for i, ax in enumerate(axes)]
         colors = ['blue', 'red', 'green']
         elevational_str = 'Piecewise elevational-temporal models in'
         labels = ['{} 1959'.format(elevational_str), '{} 2019'.format(elevational_str), 'Pointwise distributions']
         altitudinal_model = [True, True, False]
         years = [1959, 2019, None]
-        for color, global_label, boolean, year in list(zip(colors, labels, altitudinal_model, years))[:]:
-            plot_coherence_curve(axes, massif_name, visualizer_list, boolean, color, global_label, year, legend)
-        visualizer.plot_name = '{}/{}'.format(folder, massif_name.replace('_', '-'))
-        visualizer.show_or_save_to_file(add_classic_title=False, no_title=True, dpi=200)
-        plt.close()
+        for i in range(4):
+            # Load ax
+            ax = plt.gca()
+            if i % 2 == 0:
+                ax.set_yticks([])
+            ax = ax.twinx() if i % 2 == 0 else ax
+
 
+            for color, global_label, is_altitudinal, year in list(zip(colors, labels, altitudinal_model, years))[:]:
+                x_all_list, values_all_list, labels, all_bound_list = load_all_list(massif_name, visualizer_list,
+                                                                                    is_altitudinal,
+                                                                                    year)
+                label = labels[i]
+                plot_coherence_curve(ax, i, x_all_list, values_all_list, label, all_bound_list ,
+                                     is_altitudinal, color, global_label, year, legend,
+                                     elevation_as_xaxis)
+            visualizer.plot_name = '{}/{}_{}'.format(folder, massif_name.replace('_', '-'), label)
+            visualizer.show_or_save_to_file(add_classic_title=False, no_title=True, dpi=200)
+            plt.close()
 
-def plot_coherence_curve(axes, massif_name, visualizer_list: List[AltitudesStudiesVisualizerForNonStationaryModels],
-                         is_altitudinal, color, global_label, year, legend):
-    x_all_list, values_all_list, labels, all_bound_list = load_all_list(massif_name, visualizer_list, is_altitudinal,
-                                                                        year)
+
+def plot_coherence_curve(ax, i, x_all_list, values_all_list, label, all_bound_list,
+                         is_altitudinal, color, global_label, year, legend, elevation_as_xaxis):
 
     legend_line = True
-    for i, label in enumerate(labels):
-        if legend and i != 3:
-            continue
-        ax = axes[i]
-        # Plot with complete line
-        for j, (x_list, value_list) in enumerate(list(zip(x_all_list, values_all_list))):
-            value_list_i = value_list[i]
-            label_plot = global_label if j == 0 else None
-            if is_altitudinal:
-                if legend and legend_line:
-                    ax.plot(x_list, value_list_i, linestyle='solid', color=color, label=label_plot, linewidth=5)
-                else:
-                    ax.plot(x_list, value_list_i, linestyle='solid', color=color)
+    if legend and i != 3:
+        return
+    # Plot with complete line
+    for j, (x_list, value_list) in enumerate(list(zip(x_all_list, values_all_list))):
+        value_list_i = value_list[i]
+        label_plot = global_label if j == 0 else None
+        args = [x_list, value_list_i] if elevation_as_xaxis else [value_list_i, x_list]
+
+        if is_altitudinal:
+            if legend and legend_line:
+                ax.plot(*args, linestyle='solid', color=color, label=label_plot, linewidth=5)
             else:
-                if legend and legend_line:
-                    ax.plot(x_list, value_list_i, linestyle='None', color=color, label=label_plot, marker='o', markersize=10)
+                ax.plot(*args, linestyle='solid', color=color)
+        else:
+            if legend and legend_line:
+                ax.plot(*args, linestyle='None', color=color, label=label_plot, marker='o', markersize=10)
+            else:
+                ax.plot(*args, linestyle='None', color=color, marker='o')
+                ax.plot(*args, linestyle='dotted', color=color)
+
+    # Plot with dotted line
+    for x_list_before, value_list_before, x_list_after, value_list_after in zip(x_all_list, values_all_list,
+                                                                                x_all_list[1:],
+                                                                                values_all_list[1:]):
+        x_list = [x_list_before[-1], x_list_after[0]]
+        value_list_dotted = [value_list_before[i][-1], value_list_after[i][0]]
+        args = [x_list, value_list_dotted] if elevation_as_xaxis else [value_list_dotted, x_list]
+        ax.plot(*args, linestyle='dotted', color=color)
+
+    # Plot confidence interval
+    if i == 3 and year in [None, 2019]:
+        for j, (x_list, bounds) in enumerate(list(zip(x_all_list, all_bound_list))):
+            if len(bounds) > 0:
+                lower_bound, upper_bound = bounds
+                if legend and not legend_line:
+                    model_name = 'piecewise elevational-temporal models in 2019' if is_altitudinal else 'pointwise distributions'
+                    fill_label = "95\% confidence interval for the {}".format(model_name) if j == 0 else None
+                    ax.fill_between(x_list, lower_bound, upper_bound, color=color, alpha=0.2, label=fill_label)
                 else:
-                    ax.plot(x_list, value_list_i, linestyle='None', color=color, marker='o')
-                    ax.plot(x_list, value_list_i, linestyle='dotted', color=color)
-
-        # Plot with dotted line
-        for x_list_before, value_list_before, x_list_after, value_list_after in zip(x_all_list, values_all_list,
-                                                                                    x_all_list[1:],
-                                                                                    values_all_list[1:]):
-            x_list = [x_list_before[-1], x_list_after[0]]
-            value_list_dotted = [value_list_before[i][-1], value_list_after[i][0]]
-            ax.plot(x_list, value_list_dotted, linestyle='dotted', color=color)
-
-        # Plot confidence interval
-        if i == 3 and year in [None, 2019]:
-            for j, (x_list, bounds) in enumerate(list(zip(x_all_list, all_bound_list))):
-                if len(bounds) > 0:
-                    lower_bound, upper_bound = bounds
-                    if legend and not legend_line:
-                        model_name = 'piecewise elevational-temporal models in 2019' if is_altitudinal else 'pointwise distributions'
-                        fill_label = "95\% confidence interval for the {}".format(model_name) if j == 0 else None
-                        ax.fill_between(x_list, lower_bound, upper_bound, color=color, alpha=0.2, label=fill_label)
-                    else:
-                        ax.fill_between(x_list, lower_bound, upper_bound, color=color, alpha=0.2)
-
-            if legend:
-                min, max = ax.get_ylim()
-                ax.set_ylim([min, 2 * max])
-                size = 15 if legend_line else 11
-                ax.legend(prop={'size': size})
-        ax.set_ylabel(label)
+                    ax.fill_between(x_list, lower_bound, upper_bound, color=color, alpha=0.2)
+
+        if legend:
+            min, max = ax.get_ylim()
+            ax.set_ylim([min, 2 * max])
+            size = 15 if legend_line else 11
+            ax.legend(prop={'size': size})
+    ax.set_ylabel(label)
 
 
 def load_all_list(massif_name, visualizer_list, altitudinal_model=True, year=2019):
-- 
GitLab