From 61754e6efe5188def537fcf0fbc0056bddb48eca Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Thu, 7 Jan 2021 18:34:52 +0100
Subject: [PATCH] [contrasting] add coherence with elevation on the y axis

---
 .../altitudes_fit/main_altitudes_studies.py   |  4 +-
 .../plots/plot_coherence_curves.py            | 58 ++++++++++++-------
 2 files changed, 40 insertions(+), 22 deletions(-)

diff --git a/projects/altitude_spatial_model/altitudes_fit/main_altitudes_studies.py b/projects/altitude_spatial_model/altitudes_fit/main_altitudes_studies.py
index 86732094..be76e1c3 100644
--- a/projects/altitude_spatial_model/altitudes_fit/main_altitudes_studies.py
+++ b/projects/altitude_spatial_model/altitudes_fit/main_altitudes_studies.py
@@ -31,8 +31,8 @@ def main():
         massif_names = None
         altitudes_list = altitudes_for_groups[1:2]
     elif fast:
-        massif_names = ['Vanoise', 'Haute-Maurienne', 'Vercors'][:]
-        altitudes_list = altitudes_for_groups[2:]
+        massif_names = ['Vanoise', 'Haute-Maurienne', 'Vercors'][:1]
+        altitudes_list = altitudes_for_groups[:]
     else:
         massif_names = None
         altitudes_list = altitudes_for_groups[:]
diff --git a/projects/altitude_spatial_model/altitudes_fit/plots/plot_coherence_curves.py b/projects/altitude_spatial_model/altitudes_fit/plots/plot_coherence_curves.py
index 91135f7a..21692283 100644
--- a/projects/altitude_spatial_model/altitudes_fit/plots/plot_coherence_curves.py
+++ b/projects/altitude_spatial_model/altitudes_fit/plots/plot_coherence_curves.py
@@ -9,7 +9,7 @@ from projects.altitude_spatial_model.altitudes_fit.one_fold_analysis.one_fold_fi
 
 def plot_coherence_curves(massif_names, visualizer_list: List[AltitudesStudiesVisualizerForNonStationaryModels]):
     folder = 'Coherence'
-    elevation_as_xaxis = True
+    elevation_as_xaxis = False
     visualizer = visualizer_list[0]
     all_valid_names = set.union(*[v.get_valid_names(massif_names) for v in visualizer_list])
     for massif_name in all_valid_names:
@@ -21,20 +21,39 @@ def plot_coherence_curves(massif_names, visualizer_list: List[AltitudesStudiesVi
         labels = ['{} 1959'.format(elevational_str), '{} 2019'.format(elevational_str), 'Pointwise distributions']
         altitudinal_model = [True, True, False]
         years = [1959, 2019, None]
-        for i in range(4):
+        for i in list(range(4))[:]:
             # Load ax
             ax = plt.gca()
-            if i % 2 == 0:
+            if i % 2 == 1:
                 ax.set_yticks([])
-            ax = ax.twinx() if i % 2 == 0 else ax
-
+                ax2 = ax.twinx()
+            else:
+                ax2 = ax
+            if not elevation_as_xaxis and i < 2:
+                ax2.set_xticks([])
+                ax3 = ax2.twiny()
+            else:
+                ax3 = ax2
 
             for color, global_label, is_altitudinal, year in list(zip(colors, labels, altitudinal_model, years))[:]:
                 x_all_list, values_all_list, labels, all_bound_list = load_all_list(massif_name, visualizer_list,
                                                                                     is_altitudinal,
                                                                                     year)
                 label = labels[i]
-                plot_coherence_curve(ax, i, x_all_list, values_all_list, label, all_bound_list ,
+
+                # Set labels
+                fontsize_label = 15
+                if elevation_as_xaxis:
+                    ax3.set_xlabel('Elevation (m)', fontsize=fontsize_label)
+                    ax2.set_ylabel(label, fontsize=fontsize_label)
+                else:
+                    ax2.set_ylabel('Elevation(m)', fontsize=fontsize_label)
+                    if i == 3:
+                        ax.set_xlabel(label, fontsize=fontsize_label)
+                    else:
+                        ax3.set_xlabel(label, fontsize=fontsize_label)
+
+                plot_coherence_curve(ax3, i, x_all_list, values_all_list, all_bound_list,
                                      is_altitudinal, color, global_label, year, legend,
                                      elevation_as_xaxis)
             visualizer.plot_name = '{}/{}_{}'.format(folder, massif_name.replace('_', '-'), label)
@@ -42,12 +61,9 @@ def plot_coherence_curves(massif_names, visualizer_list: List[AltitudesStudiesVi
             plt.close()
 
 
-def plot_coherence_curve(ax, i, x_all_list, values_all_list, label, all_bound_list,
+def plot_coherence_curve(ax, i, x_all_list, values_all_list, all_bound_list,
                          is_altitudinal, color, global_label, year, legend, elevation_as_xaxis):
-
-    legend_line = True
-    if legend and i != 3:
-        return
+    legend_line = False
     # Plot with complete line
     for j, (x_list, value_list) in enumerate(list(zip(x_all_list, values_all_list))):
         value_list_i = value_list[i]
@@ -80,19 +96,21 @@ def plot_coherence_curve(ax, i, x_all_list, values_all_list, label, all_bound_li
         for j, (x_list, bounds) in enumerate(list(zip(x_all_list, all_bound_list))):
             if len(bounds) > 0:
                 lower_bound, upper_bound = bounds
+                f = ax.fill_between if elevation_as_xaxis else ax.fill_betweenx
                 if legend and not legend_line:
+                    print('here2')
                     model_name = 'piecewise elevational-temporal models in 2019' if is_altitudinal else 'pointwise distributions'
                     fill_label = "95\% confidence interval for the {}".format(model_name) if j == 0 else None
-                    ax.fill_between(x_list, lower_bound, upper_bound, color=color, alpha=0.2, label=fill_label)
+                    f(x_list, lower_bound, upper_bound, color=color, alpha=0.2, label=fill_label)
                 else:
-                    ax.fill_between(x_list, lower_bound, upper_bound, color=color, alpha=0.2)
-
-        if legend:
-            min, max = ax.get_ylim()
-            ax.set_ylim([min, 2 * max])
-            size = 15 if legend_line else 11
-            ax.legend(prop={'size': size})
-    ax.set_ylabel(label)
+                    f(x_list, lower_bound, upper_bound, color=color, alpha=0.2)
+
+    if legend:
+        print("here")
+        min, max = ax.get_ylim()
+        ax.set_ylim([min, 2 * max])
+        size = 15 if legend_line else 11
+        ax.legend(prop={'size': size})
 
 
 def load_all_list(massif_name, visualizer_list, altitudinal_model=True, year=2019):
-- 
GitLab