From 3c816ae8dad1b29dfcd9fbcca074eef5278b09cf Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Tue, 21 May 2019 17:56:21 +0200
Subject: [PATCH] [SCM][HYPERCUBE] Add spatial trend visualization with
 hypercube

---
 .../meteo_france_SCM_study/abstract_study.py  | 22 ++++++--
 .../hypercube_visualizer.py                   | 50 +++++++++++++++----
 .../main_studies_visualizer.py                | 38 ++++++++++----
 .../study_visualization/study_visualizer.py   |  5 +-
 .../abstract_trend_test.py                    | 10 ++++
 .../abstract_margin_function.py               |  2 +-
 .../margin_fits/plot/create_shifted_cmap.py   | 12 +++--
 7 files changed, 107 insertions(+), 32 deletions(-)

diff --git a/experiment/meteo_france_SCM_study/abstract_study.py b/experiment/meteo_france_SCM_study/abstract_study.py
index 42973271..25dfd43c 100644
--- a/experiment/meteo_france_SCM_study/abstract_study.py
+++ b/experiment/meteo_france_SCM_study/abstract_study.py
@@ -11,6 +11,8 @@ import numpy as np
 import pandas as pd
 from PIL import Image
 from PIL import ImageDraw
+from matplotlib import cm
+from matplotlib.colors import Normalize
 from netCDF4 import Dataset
 
 from experiment.meteo_france_SCM_study.abstract_variable import AbstractVariable
@@ -18,7 +20,7 @@ from experiment.meteo_france_SCM_study.scm_constants import ALTITUDES, ZS_INT_23
 from experiment.meteo_france_SCM_study.visualization.utils import get_km_formatter
 from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \
     AbstractMarginFunction
-from extreme_estimator.margin_fits.plot.create_shifted_cmap import get_color_rbga_shifted
+from extreme_estimator.margin_fits.plot.create_shifted_cmap import get_color_rbga_shifted, create_colorbase_axis
 from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
 from spatio_temporal_dataset.coordinates.spatial_coordinates.abstract_spatial_coordinates import \
     AbstractSpatialCoordinates
@@ -235,12 +237,18 @@ class AbstractStudy(object):
     """ Visualization methods """
 
     def visualize_study(self, ax=None, massif_name_to_value=None, show=True, fill=True, replace_blue_by_white=True,
-                        label=None, add_text=False):
+                        label=None, add_text=False, cmap=None):
         if massif_name_to_value is None:
             massif_name_to_fill_kwargs = None
         else:
             massif_names, values = list(zip(*massif_name_to_value.items()))
-            colors = get_color_rbga_shifted(ax, replace_blue_by_white, values, label=label)
+            if cmap is None:
+                colors = get_color_rbga_shifted(ax, replace_blue_by_white, values, label=label)
+            else:
+                norm = Normalize(0, 100)
+                create_colorbase_axis(ax, label, cmap, norm)
+                m = cm.ScalarMappable(norm=norm, cmap=cmap)
+                colors = [m.to_rgba(value) for value in values]
             massif_name_to_fill_kwargs = {massif_name: {'color': color} for massif_name, color in
                                           zip(massif_names, colors)}
 
@@ -258,8 +266,12 @@ class AbstractStudy(object):
             # Potentially, fill the inside of the polygon with some color
             if fill and coordinate_id in self.coordinate_id_to_massif_name:
                 massif_name = self.coordinate_id_to_massif_name[coordinate_id]
-                fill_kwargs = massif_name_to_fill_kwargs[massif_name] if massif_name_to_fill_kwargs is not None else {}
-                ax.fill(*coords_list, **fill_kwargs)
+                if massif_name_to_fill_kwargs is not None and massif_name in massif_name_to_fill_kwargs:
+                    fill_kwargs = massif_name_to_fill_kwargs[massif_name]
+                    ax.fill(*coords_list, **fill_kwargs)
+                # else:
+                #     fill_kwargs = {}
+
                 # x , y = list(self.massifs_coordinates.df_all_coordinates.loc[massif_name])
                 # x , y= coords_list[0][0], coords_list[0][1]
                 # print(x, y)
diff --git a/experiment/meteo_france_SCM_study/visualization/studies_visualization/hypercube_visualizer.py b/experiment/meteo_france_SCM_study/visualization/studies_visualization/hypercube_visualizer.py
index 99d5c518..e8af0e1a 100644
--- a/experiment/meteo_france_SCM_study/visualization/studies_visualization/hypercube_visualizer.py
+++ b/experiment/meteo_france_SCM_study/visualization/studies_visualization/hypercube_visualizer.py
@@ -6,7 +6,7 @@ import matplotlib.pyplot as plt
 import pandas as pd
 
 from experiment.meteo_france_SCM_study.visualization.study_visualization.study_visualizer import StudyVisualizer
-from utils import cached_property, VERSION_TIME
+from utils import cached_property, VERSION_TIME, get_display_name_from_object_type
 
 
 class HypercubeVisualizer(object):
@@ -17,18 +17,19 @@ class HypercubeVisualizer(object):
     """
 
     def __init__(self, tuple_to_study_visualizer: Dict[Tuple, StudyVisualizer],
-                 trend_class,
+                 trend_test_class,
                  fast=False,
                  save_to_file=False):
-        self.nb_data_for_fast_mode = 2 if fast else None
+        self.nb_data_for_fast_mode = 7 if fast else None
         self.save_to_file = save_to_file
-        self.trend_class = trend_class
+        self.trend_test_class = trend_test_class
         self.tuple_to_study_visualizer = tuple_to_study_visualizer  # type: Dict[Tuple, StudyVisualizer]
 
     # Main attributes defining the hypercube
 
-    def tuple_to_massif_names(self, tuple):
-        return self.tuple_to_study_visualizer[tuple].study.study_massif_names
+    @property
+    def trend_test_name(self):
+        return get_display_name_from_object_type(self.trend_test_class)
 
     @cached_property
     def starting_years(self):
@@ -40,13 +41,13 @@ class HypercubeVisualizer(object):
     @cached_property
     def tuple_to_df_trend_type(self):
         df_spatio_temporal_trend_types = [
-            study_visualizer.df_trend_spatio_temporal(self.trend_class, self.starting_years,
+            study_visualizer.df_trend_spatio_temporal(self.trend_test_class, self.starting_years,
                                                       self.nb_data_for_fast_mode)
             for study_visualizer in self.tuple_to_study_visualizer.values()]
         return dict(zip(self.tuple_to_study_visualizer.keys(), df_spatio_temporal_trend_types))
 
     @cached_property
-    def df_hypercube(self):
+    def df_hypercube(self) -> pd.DataFrame:
         keys = list(self.tuple_to_df_trend_type.keys())
         values = list(self.tuple_to_df_trend_type.values())
         df = pd.concat(values, keys=keys, axis=0)
@@ -88,12 +89,12 @@ class AltitudeHypercubeVisualizer(HypercubeVisualizer):
     def altitudes(self):
         return list(self.tuple_to_study_visualizer.keys())
 
-    def visualize_trend_test(self, ax=None, marker='o'):
+    def visualize_altitude_trend_test(self, ax=None, marker='o'):
         if ax is None:
             fig, ax = plt.subplots(1, 1, figsize=self.study_visualizer.figsize)
 
         # Plot weighted percentages over the years
-        for trend_type, style in self.trend_class.trend_type_to_style().items():
+        for trend_type, style in self.trend_test_class.trend_type_to_style().items():
             altitude_percentages = (self.df_hypercube == trend_type)
             # Take the mean with respect to the years
             altitude_percentages = altitude_percentages.mean(axis=1)
@@ -116,10 +117,37 @@ class AltitudeHypercubeVisualizer(HypercubeVisualizer):
         ax.legend()
 
         variable_name = self.study.variable_class.NAME
-        title = 'Evolution of {} trends (significative or not) wrt to the altitude'.format(variable_name)
+        name = get_display_name_from_object_type(self.trend_test_class)
+        title = 'Evolution of {} trends (significative or not) wrt to the altitude with {}'.format(variable_name,name)
         ax.set_title(title)
         self.show_or_save_to_file(specific_title=title)
 
+    def visualize_spatial_trend_test(self, axes=None):
+        if axes is None:
+            nb_trend_type = len(self.trend_test_class.trend_type_to_style())
+            fig, axes = plt.subplots(1, nb_trend_type, figsize=self.study_visualizer.figsize)
+
+        # Plot weighted percentages over the years
+        for ax, (trend_type, style) in zip(axes, self.trend_test_class.trend_type_to_style().items()):
+            spatial_percentages = (self.df_hypercube == trend_type)
+            # Take the mean with respect to the years
+            spatial_percentages = spatial_percentages.mean(axis=1)
+            # Take the mean with respect the altitude
+            spatial_percentages = spatial_percentages.mean(axis=0, level=1) * 100
+            # Plot values
+            massif_to_value = dict(spatial_percentages)
+            cmap = self.trend_test_class.get_cmap_from_trend_type(trend_type)
+            self.study.visualize_study(ax, massif_to_value, show=False, cmap=cmap, label=None)
+            ax.set_title(trend_type)
+
+        # Global information
+        name = get_display_name_from_object_type(self.trend_test_class)
+        title = 'Repartition of trends (significative or not) with {}'.format(name)
+        title +=  '\n(in % averaged on altitudes & averaged on starting years)'
+        StudyVisualizer.clean_axes_write_title_on_the_left(axes, title, left_border=None)
+        plt.suptitle(title)
+        self.show_or_save_to_file(specific_title=title)
+
 
 class QuantitityAltitudeHypercubeVisualizer(HypercubeVisualizer):
     pass
diff --git a/experiment/meteo_france_SCM_study/visualization/studies_visualization/main_studies_visualizer.py b/experiment/meteo_france_SCM_study/visualization/studies_visualization/main_studies_visualizer.py
index ad30484b..62186a2b 100644
--- a/experiment/meteo_france_SCM_study/visualization/studies_visualization/main_studies_visualizer.py
+++ b/experiment/meteo_france_SCM_study/visualization/studies_visualization/main_studies_visualizer.py
@@ -59,28 +59,44 @@ def altitude_trends_significant():
         visualizer.trend_tests_percentage_evolution_with_altitude(trend_test_classes, starting_year_to_weights=None)
 
 
-def altitude_hypercube_test():
-    save_to_file = False
+def altitude_trend_with_hypercube():
+    save_to_file = True
     only_first_one = False
     fast = False
     altitudes = ALL_ALTITUDES[3:-6]
+    # altitudes = ALL_ALTITUDES[2:4]
+    for study_class in SCM_STUDIES[:]:
+        for trend_test_class in [MannKendallTrendTest, GevLocationTrendTest, GevScaleTrendTest, GevShapeTrendTest][:]:
+            visualizers = [StudyVisualizer(study, temporal_non_stationarity=True, verbose=False, multiprocessing=True)
+                           for study in study_iterator(study_class=study_class, only_first_one=only_first_one,
+                                                       altitudes=altitudes)]
+            altitude_to_visualizer = OrderedDict(zip(altitudes, visualizers))
+            visualizer = AltitudeHypercubeVisualizer(altitude_to_visualizer, save_to_file=save_to_file,
+                                                     trend_class=trend_test_class, fast=fast)
+            visualizer.visualize_altitude_trend_test()
+
+def spatial_trend_with_hypercube():
+    save_to_file = False
+    only_first_one = False
+    fast = True
+    # altitudes = ALL_ALTITUDES[3:-6]
     altitudes = ALL_ALTITUDES[2:4]
     for study_class in SCM_STUDIES[:1]:
-        trend_test_class = [MannKendallTrendTest, GevLocationTrendTest, GevScaleTrendTest, GevShapeTrendTest][0]
-        visualizers = [StudyVisualizer(study, temporal_non_stationarity=True, verbose=False, multiprocessing=True)
-                       for study in study_iterator(study_class=study_class, only_first_one=only_first_one,
+        for trend_test_class in [MannKendallTrendTest, GevLocationTrendTest, GevScaleTrendTest, GevShapeTrendTest][:1]:
+            visualizers = [StudyVisualizer(study, temporal_non_stationarity=True, verbose=False, multiprocessing=True)
+                           for study in study_iterator(study_class=study_class, only_first_one=only_first_one,
                                                        altitudes=altitudes)]
-        altitude_to_visualizer = OrderedDict(zip(altitudes, visualizers))
-        visualizer = AltitudeHypercubeVisualizer(altitude_to_visualizer, save_to_file=save_to_file,
-                                         trend_class=trend_test_class, fast=fast)
-        visualizer.visualize_trend_test()
-        # print(visualizer.df_hypercube)
+            altitude_to_visualizer = OrderedDict(zip(altitudes, visualizers))
+            visualizer = AltitudeHypercubeVisualizer(altitude_to_visualizer, save_to_file=save_to_file,
+                                                     trend_test_class=trend_test_class, fast=fast)
+            visualizer.visualize_spatial_trend_test()
+
 
 
 def main_run():
     # altitude_trends()
     # altitude_trends_significant()
-    altitude_hypercube_test()
+    spatial_trend_with_hypercube()
 
 
 if __name__ == '__main__':
diff --git a/experiment/meteo_france_SCM_study/visualization/study_visualization/study_visualizer.py b/experiment/meteo_france_SCM_study/visualization/study_visualization/study_visualizer.py
index 32b5898c..98f0aa6f 100644
--- a/experiment/meteo_france_SCM_study/visualization/study_visualization/study_visualizer.py
+++ b/experiment/meteo_france_SCM_study/visualization/study_visualization/study_visualizer.py
@@ -634,7 +634,10 @@ class StudyVisualizer(object):
 
     @staticmethod
     def clean_axes_write_title_on_the_left(axes, title, left_border=True):
-        if left_border:
+        if left_border is None:
+            clean_axes = axes
+            ax0 = axes[0]
+        elif left_border:
             ax0, *clean_axes = axes
         else:
             *clean_axes, ax0 = axes
diff --git a/experiment/trend_analysis/univariate_trend_test/abstract_trend_test.py b/experiment/trend_analysis/univariate_trend_test/abstract_trend_test.py
index c2afeb6e..eec14042 100644
--- a/experiment/trend_analysis/univariate_trend_test/abstract_trend_test.py
+++ b/experiment/trend_analysis/univariate_trend_test/abstract_trend_test.py
@@ -1,4 +1,5 @@
 import random
+import matplotlib.pyplot as plt
 from collections import OrderedDict
 
 import numpy as np
@@ -28,6 +29,15 @@ class AbstractTrendTest(object):
         d[cls.NO_TREND] = 'k--'
         return d
 
+    @classmethod
+    def get_cmap_from_trend_type(cls, trend_type):
+        if 'positive' in trend_type:
+            return plt.cm.Greens
+        elif 'negative' in trend_type:
+            return plt.cm.Reds
+        else:
+            return plt.cm.binary
+
     def __init__(self, years_after_change_point, maxima_after_change_point):
         self.years_after_change_point = years_after_change_point
         self.maxima_after_change_point = maxima_after_change_point
diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
index a2b5b3d0..34cc9b0d 100644
--- a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
+++ b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
@@ -7,7 +7,7 @@ import pandas as pd
 
 from experiment.meteo_france_SCM_study.visualization.utils import create_adjusted_axes
 from extreme_estimator.margin_fits.gev.gev_params import GevParams
-from extreme_estimator.margin_fits.plot.create_shifted_cmap import plot_extreme_param, imshow_shifted
+from extreme_estimator.margin_fits.plot.create_shifted_cmap import imshow_shifted
 from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
 from spatio_temporal_dataset.slicer.split import Split
 from utils import cached_property
diff --git a/extreme_estimator/margin_fits/plot/create_shifted_cmap.py b/extreme_estimator/margin_fits/plot/create_shifted_cmap.py
index 194466ec..b605082f 100644
--- a/extreme_estimator/margin_fits/plot/create_shifted_cmap.py
+++ b/extreme_estimator/margin_fits/plot/create_shifted_cmap.py
@@ -23,11 +23,17 @@ def plot_extreme_param(ax, label: str, values: np.ndarray):
     cmap = [plt.cm.coolwarm, plt.cm.bwr, plt.cm.seismic][1]
     shifted_cmap = shiftedColorMap(cmap, midpoint=midpoint, name='shifted')
     norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
+    norm = create_colorbase_axis(ax, label, shifted_cmap, norm)
+    return norm, shifted_cmap
+
+
+def create_colorbase_axis(ax, label, cmap, norm):
     divider = make_axes_locatable(ax)
     cax = divider.append_axes('right', size='5%', pad=0.03)
-    cb = cbar.ColorbarBase(cax, cmap=shifted_cmap, norm=norm)
-    cb.set_label(label)
-    return norm, shifted_cmap
+    cb = cbar.ColorbarBase(cax, cmap=cmap, norm=norm)
+    if isinstance(label, str):
+        cb.set_label(label)
+    return norm
 
 
 def get_color_rbga_shifted(ax, replace_blue_by_white: bool, values: np.ndarray, label=None):
-- 
GitLab