From 75d1cdccbd319b5fdba65f3bbd673032537accf4 Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Wed, 11 Sep 2019 19:06:19 +0200
Subject: [PATCH] [POSTER EVAN] improve visualization of the starting years for
 the poster. some refactorign for the Alps graph and cmap, from now on, we
 will only remove the last 20 starting years

---
 .../scm_models_data/abstract_study.py         | 34 ++++++-----
 .../abstract_hypercube_visualizer.py          |  2 +-
 .../altitude_hypercube_visualizer.py          | 49 +++++++++++-----
 .../main_files/main_full_hypercube.py         |  2 +-
 .../poster_EVAN2019/main_poster_EVAN2019.py   | 57 +++++++++++++++++--
 .../abstract_univariate_test.py               | 10 ++++
 .../margin_fits/plot/create_shifted_cmap.py   | 30 +++++-----
 7 files changed, 134 insertions(+), 50 deletions(-)

diff --git a/experiment/meteo_france_data/scm_models_data/abstract_study.py b/experiment/meteo_france_data/scm_models_data/abstract_study.py
index 3c7a2911..248de3d6 100644
--- a/experiment/meteo_france_data/scm_models_data/abstract_study.py
+++ b/experiment/meteo_france_data/scm_models_data/abstract_study.py
@@ -25,7 +25,8 @@ from experiment.meteo_france_data.scm_models_data.scm_constants import ALTITUDES
 from experiment.meteo_france_data.scm_models_data.visualization.utils import get_km_formatter
 from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \
     AbstractMarginFunction
-from extreme_estimator.margin_fits.plot.create_shifted_cmap import get_color_rbga_shifted, create_colorbase_axis
+from extreme_estimator.margin_fits.plot.create_shifted_cmap import create_colorbase_axis, \
+    get_shifted_map, get_colors
 from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
 from spatio_temporal_dataset.coordinates.spatial_coordinates.abstract_spatial_coordinates import \
     AbstractSpatialCoordinates
@@ -50,6 +51,7 @@ class AbstractStudy(object):
     """
     REANALYSIS_FLAT_FOLDER = 'SAFRAN_montagne-CROCUS_2019/alp_flat/reanalysis'
     REANALYSIS_ALLSLOPES_FOLDER = 'SAFRAN_montagne-CROCUS_2019/alp_allslopes/reanalysis'
+
     # REANALYSIS_FOLDER = 'SAFRAN_montagne-CROCUS_2019/postes/reanalysis'
 
     def __init__(self, variable_class: type, altitude: int = 1800, year_min=1000, year_max=3000,
@@ -68,8 +70,6 @@ class AbstractStudy(object):
         self.orientation = orientation
         self.slope = slope
 
-
-
     """ Time """
 
     @cached_property
@@ -244,7 +244,7 @@ class AbstractStudy(object):
         slope_mask = np.array(ORDERED_ALLSLOPES_SLOPES) == self.slope
         allslopes_mask = altitude_mask & orientation_mask & slope_mask
         # Exclude all the data corresponding to the 24th massif
-        massif_24_mask =np.array(ORDERED_ALLSLOPES_MASSIFNUM) == 30
+        massif_24_mask = np.array(ORDERED_ALLSLOPES_MASSIFNUM) == 30
         return allslopes_mask & ~massif_24_mask
 
     @cached_property
@@ -283,7 +283,7 @@ class AbstractStudy(object):
     @classmethod
     def visualize_study(cls, ax=None, massif_name_to_value: Union[None, Dict[str, float]] = None, show=True, fill=True,
                         replace_blue_by_white=True,
-                        label=None, add_text=False, cmap=None, vmax=100, vmin=0,
+                        label=None, add_text=False, cmap=None, add_colorbar=False, vmax=100, vmin=0,
                         default_color_for_missing_massif='gainsboro',
                         default_color_for_nan_values='w',
                         massif_name_to_color=None,
@@ -291,20 +291,23 @@ class AbstractStudy(object):
                         scaled=False,
                         fontsize=7,
                         axis_off=False,
-                        massif_name_to_hatch_boolean_list=None
+                        massif_name_to_hatch_boolean_list=None,
+                        norm=None,
                         ):
         if ax is None:
             ax = plt.gca()
 
-        if massif_name_to_color is None:
+        if massif_name_to_value is not None:
             massif_names, values = list(zip(*massif_name_to_value.items()))
+        else:
+            massif_names, values = None, None
+
+        if massif_name_to_color is None:
+            # Load the colors
             if cmap is None:
-                colors = get_color_rbga_shifted(ax, replace_blue_by_white, values, label=label)
-            else:
-                norm = Normalize(vmin, vmax)
-                create_colorbase_axis(ax, label, cmap, norm)
-                m = cm.ScalarMappable(norm=norm, cmap=cmap)
-                colors = [m.to_rgba(value) if not np.isnan(value) else default_color_for_nan_values for value in values]
+                cmap = get_shifted_map(vmin, vmax)
+            norm = Normalize(vmin, vmax)
+            colors = get_colors(values, cmap, vmin, vmax, replace_blue_by_white)
             massif_name_to_color = dict(zip(massif_names, colors))
         massif_name_to_fill_kwargs = {massif_name: {'color': color} for massif_name, color in
                                       massif_name_to_color.items()}
@@ -366,6 +369,11 @@ class AbstractStudy(object):
 
         if scaled:
             plt.axis('scaled')
+
+        # create the colorbar only at the end
+        if add_colorbar:
+            if len(set(values)) > 1:
+                create_colorbase_axis(ax, label, cmap, norm)
         if axis_off:
             plt.axis('off')
 
diff --git a/experiment/meteo_france_data/scm_models_data/visualization/hypercube_visualization/abstract_hypercube_visualizer.py b/experiment/meteo_france_data/scm_models_data/visualization/hypercube_visualization/abstract_hypercube_visualizer.py
index d5e2e7c2..4cfb5d70 100644
--- a/experiment/meteo_france_data/scm_models_data/visualization/hypercube_visualization/abstract_hypercube_visualizer.py
+++ b/experiment/meteo_france_data/scm_models_data/visualization/hypercube_visualization/abstract_hypercube_visualizer.py
@@ -126,7 +126,7 @@ class AbstractHypercubeVisualizer(object):
             if not op.exists(dirname):
                 os.makedirs(dirname, exist_ok=True)
             if tight:
-                plt.savefig(filepath, bbox_inches='tight', pad_inches=-0.03, dpi=1000)
+                plt.savefig(filepath, bbox_inches='tight', pad_inches=+0.03, dpi=1000)
             elif dpi is not None:
                 plt.savefig(filepath, dpi=dpi)
             else:
diff --git a/experiment/meteo_france_data/scm_models_data/visualization/hypercube_visualization/altitude_hypercube_visualizer.py b/experiment/meteo_france_data/scm_models_data/visualization/hypercube_visualization/altitude_hypercube_visualizer.py
index 3e7249c7..15750592 100644
--- a/experiment/meteo_france_data/scm_models_data/visualization/hypercube_visualization/altitude_hypercube_visualizer.py
+++ b/experiment/meteo_france_data/scm_models_data/visualization/hypercube_visualization/altitude_hypercube_visualizer.py
@@ -289,7 +289,8 @@ class AltitudeHypercubeVisualizer(AbstractHypercubeVisualizer):
     def visualize_trend_test_repartition_poster(self, reduction_function, axes=None, subtitle='', isin_parameters=None,
                                                 plot_title=None,
                                                 poster_plot=False,
-                                                write_text_on_massif=True):
+                                                write_text_on_massif=True,
+                                                display_trend_color=True):
         trend_type_to_serie = {k: v[0].replace(0.0, np.nan) for k, v in
                                self.trend_type_to_series(reduction_function, isin_parameters).items()}
 
@@ -324,7 +325,7 @@ class AltitudeHypercubeVisualizer(AbstractHypercubeVisualizer):
                     else:
                         mean_idx, variance_idx = 1, 2
 
-                        massif_to_value_for_trend_type = {k: "$t_0=$" + str(int(v)) for k, v in
+                        massif_to_value_for_trend_type = {k: int(v) for k, v in
                                                           self.trend_type_to_series(reduction_function,
                                                                                     isin_parameters)[
                                                               display_trend_type][3].items()
@@ -356,19 +357,37 @@ class AltitudeHypercubeVisualizer(AbstractHypercubeVisualizer):
                 AbstractGevTrendTest.nb_years_for_quantile_evolution)
                 for m in massif_to_strength}
         else:
-            massif_name_to_value = massif_to_year
-        self.study.visualize_study(None, massif_name_to_color=massif_to_color, show=False,
-                                   show_label=False, scaled=True, add_text=write_text_on_massif,
-                                   massif_name_to_value=massif_name_to_value,
-                                   fontsize=4,
-                                   axis_off=True,
-                                   massif_name_to_hatch_boolean_list=massif_name_to_hatch_boolean_list)
+            massif_name_to_value = {k: "$t_0=$" + str(int(v)) if display_trend_color else v for k, v in massif_to_year.items()}
+
+        title = self.set_trend_test_reparition_title(subtitle, set=not poster_plot, first_title=display_trend_color)
+
+        if display_trend_color:
+            self.study.visualize_study(None, massif_name_to_color=massif_to_color, show=False,
+                                       show_label=False, scaled=True, add_text=write_text_on_massif,
+                                       massif_name_to_value=massif_name_to_value,
+                                       fontsize=4,
+                                       axis_off=True,
+                                       massif_name_to_hatch_boolean_list=massif_name_to_hatch_boolean_list,
+                                       )
+        else:
+            VMIN = 1957
+            VMAX = 1998
+            assert VMIN < self.first_starting_year
+            assert VMAX > self.last_starting_year
+            self.study.visualize_study(None, show=False,
+                                       show_label=False, scaled=True, add_text=False,
+                                       massif_name_to_value=massif_name_to_value,
+                                       cmap=plt.cm.GnBu,
+                                       add_colorbar=True,
+                                       vmin=VMIN,
+                                       vmax=VMAX)
+
+
 
-        title = self.set_trend_test_reparition_title(subtitle, set=not poster_plot)
 
         return title
 
-    def set_trend_test_reparition_title(self, subtitle, set=True):
+    def set_trend_test_reparition_title(self, subtitle, set=True, first_title=True):
         # Global information
         title = 'Repartition of {} trends'.format(subtitle)
         if self.study.has_orientation:
@@ -377,7 +396,7 @@ class AltitudeHypercubeVisualizer(AbstractHypercubeVisualizer):
         if len(self.starting_years) > 1:
             title += ' until starting_year={}'.format(self.last_starting_year)
         title += ' with {} test'.format(get_display_name_from_object_type(self.trend_test_class))
-        if self.reduce_strength_array:
+        if first_title:
             title += '\nEvolution of the quantile {} every {} years'.format(AbstractGevTrendTest.quantile_for_strength,
                                                                             AbstractGevTrendTest.nb_years_for_quantile_evolution)
         else:
@@ -492,7 +511,8 @@ class AltitudeHypercubeVisualizer(AbstractHypercubeVisualizer):
                                                  isin_parameters=None,
                                                  show_or_save_to_file=True,
                                                  poster_plot=False,
-                                                 write_text_on_massif=True):
+                                                 write_text_on_massif=True,
+                                                 display_trend_color=False):
         last_title = ''
         for subtitle, reduction_function in self.subtitle_to_reduction_function(self.index_reduction,
                                                                                 level=self.massif_index_level,
@@ -501,7 +521,8 @@ class AltitudeHypercubeVisualizer(AbstractHypercubeVisualizer):
                                                                       isin_parameters=isin_parameters,
                                                                       plot_title=plot_title,
                                                                       poster_plot=poster_plot,
-                                                                      write_text_on_massif=write_text_on_massif)
+                                                                      write_text_on_massif=write_text_on_massif,
+                                                                      display_trend_color=display_trend_color)
         if show_or_save_to_file:
             self.show_or_save_to_file(specific_title=last_title, dpi=1000, tight=poster_plot)
 
diff --git a/experiment/meteo_france_data/scm_models_data/visualization/hypercube_visualization/main_files/main_full_hypercube.py b/experiment/meteo_france_data/scm_models_data/visualization/hypercube_visualization/main_files/main_full_hypercube.py
index ffda9d8b..4ede274d 100644
--- a/experiment/meteo_france_data/scm_models_data/visualization/hypercube_visualization/main_files/main_full_hypercube.py
+++ b/experiment/meteo_france_data/scm_models_data/visualization/hypercube_visualization/main_files/main_full_hypercube.py
@@ -22,7 +22,7 @@ def get_full_parameters(altitude=None, offset_starting_year=10):
         altitudes = [altitude]
     else:
         altitudes = ALL_ALTITUDES[3:-6]
-    first_starting_year = 1958 + offset_starting_year
+    first_starting_year = 1958
     last_starting_year = 2017 - offset_starting_year
     trend_test_class = GevLocationTrendTest
     return altitudes, first_starting_year, last_starting_year, nb_data_reduced_for_speed, only_first_one, save_to_file, trend_test_class
diff --git a/experiment/paper1_steps/poster_EVAN2019/main_poster_EVAN2019.py b/experiment/paper1_steps/poster_EVAN2019/main_poster_EVAN2019.py
index 4d65f079..063783cc 100644
--- a/experiment/paper1_steps/poster_EVAN2019/main_poster_EVAN2019.py
+++ b/experiment/paper1_steps/poster_EVAN2019/main_poster_EVAN2019.py
@@ -29,14 +29,18 @@ def main_poster_A_non_stationary_model_choice():
 
 def main_poster_B_starting_years_analysis():
     nb = 3
-    for altitude in POSTER_ALTITUDES[2:]:
+    for altitude in POSTER_ALTITUDES[:nb]:
         for trend_test_class in [GevLocationAndScaleTrendTest]:
             # 1958 as starting year
             vizualiser = get_full_altitude_visualizer(Altitude_Hypercube_Year_Visualizer, altitude=altitude,
                                                       exact_starting_year=1958, reduce_strength_array=False,
                                                       trend_test_class=trend_test_class,
                                                       )
+            for d in [True, False]:
+                vizualiser.visualize_massif_trend_test_one_altitude(poster_plot=True, write_text_on_massif=False,
+                                                                    display_trend_color=d)
             # vizualiser.save_to_file = False
+
             vizualiser.visualize_massif_trend_test_one_altitude(poster_plot=True, write_text_on_massif=False)
             # Optimal common starting year
             vizualiser = get_full_altitude_visualizer(AltitudeHypercubeVisualizerWithoutTrendType, altitude=altitude,
@@ -48,13 +52,56 @@ def main_poster_B_starting_years_analysis():
             vizualiser = get_full_altitude_visualizer(Altitude_Hypercube_Year_Visualizer, altitude=altitude,
                                                       exact_starting_year=best_year, reduce_strength_array=False,
                                                       trend_test_class=trend_test_class)
-            vizualiser.visualize_massif_trend_test_one_altitude(poster_plot=True, write_text_on_massif=False)
+            for d in [True, False]:
+                vizualiser.visualize_massif_trend_test_one_altitude(poster_plot=True, write_text_on_massif=False,
+                                                                    display_trend_color=d)
             # Individual most likely starting year for each massif
             vizualiser = get_full_altitude_visualizer(Altitude_Hypercube_Year_Visualizer, altitude=altitude,
                                                       reduce_strength_array=False,
                                                       trend_test_class=trend_test_class,
                                                       offset_starting_year=20)
-            vizualiser.visualize_massif_trend_test_one_altitude(poster_plot=True, write_text_on_massif=True)
+            for d in [True, False]:
+                vizualiser.visualize_massif_trend_test_one_altitude(poster_plot=True, write_text_on_massif=False,
+                                                                    display_trend_color=d)
+
+# def main_poster_B_test():
+#     nb = 3
+#     for altitude in POSTER_ALTITUDES[:1]:
+#         for trend_test_class in [GevLocationAndScaleTrendTest]:
+#             # # 1958 as starting year
+#             vizualiser = get_full_altitude_visualizer(Altitude_Hypercube_Year_Visualizer, altitude=altitude,
+#                                                       exact_starting_year=1958, reduce_strength_array=False,
+#                                                       trend_test_class=trend_test_class,
+#                                                       )
+#             # vizualiser.save_to_file = False
+#             vizualiser.visualize_massif_trend_test_one_altitude(poster_plot=True, write_text_on_massif=False,
+#                                                                 display_trend_color=False)
+#             vizualiser.visualize_massif_trend_test_one_altitude(poster_plot=True, write_text_on_massif=False,
+#                                                                 display_trend_color=True)
+#             # # Optimal common starting year
+#             vizualiser = get_full_altitude_visualizer(AltitudeHypercubeVisualizerWithoutTrendType, altitude=altitude,
+#                                                       reduce_strength_array=True,
+#                                                       trend_test_class=trend_test_class,
+#                                                       offset_starting_year=20)
+#             res = vizualiser.visualize_year_trend_test(subtitle_specified='CrocusSwe3Days')
+#             best_year = res[0][1]
+#             vizualiser = get_full_altitude_visualizer(Altitude_Hypercube_Year_Visualizer, altitude=altitude,
+#                                                       exact_starting_year=best_year, reduce_strength_array=False,
+#                                                       trend_test_class=trend_test_class)
+#             vizualiser.visualize_massif_trend_test_one_altitude(poster_plot=True, write_text_on_massif=False,
+#                                                                 display_trend_color=False)
+#             vizualiser.visualize_massif_trend_test_one_altitude(poster_plot=True, write_text_on_massif=False,
+#                                                                 display_trend_color=True)
+#             # vizualiser.visualize_massif_trend_test_one_altitude(poster_plot=True, write_text_on_massif=False)
+#             # Individual most likely starting year for each massif
+#             # vizualiser = get_full_altitude_visualizer(Altitude_Hypercube_Year_Visualizer, altitude=altitude,
+#             #                                           reduce_strength_array=False,
+#             #                                           trend_test_class=trend_test_class,
+#             #                                           offset_starting_year=50)
+#             # # vizualiser.save_to_file = False
+#             # vizualiser.visualize_massif_trend_test_one_altitude(poster_plot=True, write_text_on_massif=True,
+#             #                                                     display_trend_color=False)
+
 
 
 def main_poster_C_orientation_analysis():
@@ -87,6 +134,6 @@ def main_poster_D_other_quantities_analysis():
 
 if __name__ == '__main__':
     # main_poster_A_non_stationary_model_choice()
-    # main_poster_B_starting_years_analysis()
-    main_poster_C_orientation_analysis()
+    main_poster_B_starting_years_analysis()
+    # main_poster_C_orientation_analysis()
     # main_poster_D_other_quantities_analysis()
diff --git a/experiment/trend_analysis/univariate_test/abstract_univariate_test.py b/experiment/trend_analysis/univariate_test/abstract_univariate_test.py
index 7cbbf78e..1376adde 100644
--- a/experiment/trend_analysis/univariate_test/abstract_univariate_test.py
+++ b/experiment/trend_analysis/univariate_test/abstract_univariate_test.py
@@ -6,6 +6,7 @@ from collections import OrderedDict
 
 import numpy as np
 from cached_property import cached_property
+from matplotlib import colors
 
 from experiment.trend_analysis.mann_kendall_test import mann_kendall_test
 from experiment.trend_analysis.abstract_score import MannKendall
@@ -53,6 +54,11 @@ class AbstractUnivariateTest(object):
     def three_main_trend_types(cls):
         return [cls.SIGNIFICATIVE_NEGATIVE_TREND, cls.NON_SIGNIFICATIVE_TREND, cls.SIGNIFICATIVE_POSITIVE_TREND]
 
+    @classmethod
+    def rgb_code_of_trend_colors(cls):
+        for name in ['lightgreen', 'lightcoral', 'darkgreen', 'darkred']:
+            print(name, colors.to_rgba(name)[:-1])
+
     @classmethod
     def display_trend_type_to_style(cls):
         d = OrderedDict()
@@ -143,3 +149,7 @@ class ExampleRandomTrendTest(AbstractUnivariateTest):
 
 class WarningScoreValue(Warning):
     pass
+
+
+if __name__ == '__main__':
+    AbstractUnivariateTest.rgb_code_of_trend_colors()
\ No newline at end of file
diff --git a/extreme_estimator/margin_fits/plot/create_shifted_cmap.py b/extreme_estimator/margin_fits/plot/create_shifted_cmap.py
index b605082f..6936b3fc 100644
--- a/extreme_estimator/margin_fits/plot/create_shifted_cmap.py
+++ b/extreme_estimator/margin_fits/plot/create_shifted_cmap.py
@@ -9,9 +9,8 @@ from extreme_estimator.margin_fits.extreme_params import ExtremeParams
 from extreme_estimator.margin_fits.plot.shifted_color_map import shiftedColorMap
 
 
-def plot_extreme_param(ax, label: str, values: np.ndarray):
+def get_shifted_map(vmin, vmax):
     # Load the shifted cmap to center on a middle point
-    vmin, vmax = np.min(values), np.max(values)
     if vmin < 0 < vmax:
         midpoint = 1 - vmax / (vmax + abs(vmin))
     elif vmin < 0 and vmax < 0:
@@ -22,29 +21,26 @@ def plot_extreme_param(ax, label: str, values: np.ndarray):
         raise ValueError('Unexpected values: vmin={}, vmax={}'.format(vmin, vmax))
     cmap = [plt.cm.coolwarm, plt.cm.bwr, plt.cm.seismic][1]
     shifted_cmap = shiftedColorMap(cmap, midpoint=midpoint, name='shifted')
-    norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
-    norm = create_colorbase_axis(ax, label, shifted_cmap, norm)
-    return norm, shifted_cmap
+    return shifted_cmap
 
 
 def create_colorbase_axis(ax, label, cmap, norm):
     divider = make_axes_locatable(ax)
-    cax = divider.append_axes('right', size='5%', pad=0.03)
+    cax = divider.append_axes('right', size='5%', pad=0.0)
     cb = cbar.ColorbarBase(cax, cmap=cmap, norm=norm)
     if isinstance(label, str):
         cb.set_label(label)
     return norm
 
 
-def get_color_rbga_shifted(ax, replace_blue_by_white: bool, values: np.ndarray, label=None):
-    """
-    For some display it was necessary to transform dark blue values into white values
-    """
-    norm, shifted_cmap = plot_extreme_param(ax, label, values)
-    m = cm.ScalarMappable(norm=norm, cmap=shifted_cmap)
+def get_norm(vmin, vmax):
+    return mpl.colors.Normalize(vmin=vmin, vmax=vmax)
+
+
+def get_colors(values, cmap, vmin, vmax, replace_blue_by_white=False):
+    norm = get_norm(vmin, vmax)
+    m = cm.ScalarMappable(norm=norm, cmap=cmap)
     colors = [m.to_rgba(value) for value in values]
-    # We do not want any blue values for parameters other than the Shape
-    # So when the value corresponding to the blue color is 1, then we set the color to white, i.e. (1,1,1,1)
     if replace_blue_by_white:
         colors = [color if color[2] != 1 else (1, 1, 1, 1) for color in colors]
     return colors
@@ -55,7 +51,10 @@ def imshow_shifted(ax, gev_param_name, values, visualization_extend, mask_2D=Non
     if mask_2D is not None:
         condition |= mask_2D
     masked_array = np.ma.masked_where(condition, values)
-    norm, shifted_cmap = plot_extreme_param(ax, gev_param_name, masked_array)
+    vmin, vmax = np.min(masked_array), np.max(masked_array)
+    shifted_cmap = get_shifted_map(vmin, vmax)
+    norm = get_norm(vmin, vmax)
+    create_colorbase_axis(ax, gev_param_name, shifted_cmap, norm)
     shifted_cmap.set_bad(color='white')
     if gev_param_name != ExtremeParams.SHAPE:
         epsilon = 1e-2 * (np.max(values) - np.min(values))
@@ -64,4 +63,3 @@ def imshow_shifted(ax, gev_param_name, values, visualization_extend, mask_2D=Non
         masked_array[-1, -1] = value - epsilon
     # IMPORTANT: Origin for all the plots is at the bottom left corner
     ax.imshow(masked_array, extent=visualization_extend, cmap=shifted_cmap, origin='lower')
-
-- 
GitLab