From ba77edf4afc3da2bf1bf6acd81e016e43983248e Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Wed, 13 Feb 2019 16:18:09 +0100
Subject: [PATCH] [SAFRAN] use same visualization colormap for margins & for
 safran display

---
 .../safran_study/main_visualize_safran.py     |  6 +--
 experiment/safran_study/safran_visualizer.py  | 35 +++------------
 .../abstract_margin_function.py               | 17 +++----
 .../margin_fits/plot/__init__.py              |  0
 .../margin_fits/plot/create_shifted_cmap.py   | 44 +++++++++++++++++++
 .../margin_fits/plot}/shifted_color_map.py    |  0
 6 files changed, 63 insertions(+), 39 deletions(-)
 create mode 100644 extreme_estimator/margin_fits/plot/__init__.py
 create mode 100644 extreme_estimator/margin_fits/plot/create_shifted_cmap.py
 rename {experiment/safran_study => extreme_estimator/margin_fits/plot}/shifted_color_map.py (100%)

diff --git a/experiment/safran_study/main_visualize_safran.py b/experiment/safran_study/main_visualize_safran.py
index dd0e9ea5..2f5bc83b 100644
--- a/experiment/safran_study/main_visualize_safran.py
+++ b/experiment/safran_study/main_visualize_safran.py
@@ -17,6 +17,6 @@ def load_all_safran(only_first_one=False):
 if __name__ == '__main__':
     for safran in load_all_safran(only_first_one=True):
         safran_visualizer = SafranVisualizer(safran)
-        # safran_visualizer.visualize_independent_margin_fits(threshold=[None, 20, 40, 60][1])
-        # safran_visualizer.visualize_smooth_margin_fit()
-        safran_visualizer.visualize_full_fit()
+        # safran_visualizer.visualize_independent_margin_fits(threshold=[None, 20, 40, 60][0])
+        safran_visualizer.visualize_smooth_margin_fit()
+        # safran_visualizer.visualize_full_fit()
diff --git a/experiment/safran_study/safran_visualizer.py b/experiment/safran_study/safran_visualizer.py
index 447160a7..31e4824e 100644
--- a/experiment/safran_study/safran_visualizer.py
+++ b/experiment/safran_study/safran_visualizer.py
@@ -8,16 +8,15 @@ from mpl_toolkits.axes_grid1 import make_axes_locatable
 from extreme_estimator.estimator.full_estimator.abstract_full_estimator import \
     FullEstimatorInASingleStepWithSmoothMargin
 from extreme_estimator.estimator.margin_estimator.abstract_margin_estimator import SmoothMarginEstimator
-from extreme_estimator.estimator.max_stable_estimator.abstract_max_stable_estimator import MaxStableEstimator
 from extreme_estimator.extreme_models.margin_model.smooth_margin_model import LinearAllParametersAllDimsMarginModel
-from extreme_estimator.extreme_models.max_stable_model.max_stable_models import ExtremalT, Smith
+from extreme_estimator.extreme_models.max_stable_model.max_stable_models import Smith
 from extreme_estimator.margin_fits.extreme_params import ExtremeParams
 from extreme_estimator.margin_fits.gev.gev_params import GevParams
 from extreme_estimator.margin_fits.gev.gevmle_fit import GevMleFit
 from extreme_estimator.margin_fits.gpd.gpd_params import GpdParams
 from extreme_estimator.margin_fits.gpd.gpdmle_fit import GpdMleFit
 from experiment.safran_study.safran import Safran
-from experiment.safran_study.shifted_color_map import shiftedColorMap
+from extreme_estimator.margin_fits.plot.create_shifted_cmap import plot_extreme_param, get_color_rbga
 from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
 
 
@@ -71,34 +70,14 @@ class SafranVisualizer(object):
             massif_name_to_value = df.loc[gev_param_name, :].to_dict()
             # Compute the middle point of the values for the color map
             values = list(massif_name_to_value.values())
-            vmin, vmax = min(values), max(values)
-            try:
-                midpoint = 1 - vmax / (vmax + abs(vmin))
-            except ZeroDivisionError:
-                pass
-            # Load the shifted cmap to center on a middle point
-
-            cmap = [plt.cm.coolwarm, plt.cm.bwr, plt.cm.seismic][1]
-            if gev_param_name == ExtremeParams.SHAPE:
-                shifted_cmap = shiftedColorMap(cmap, midpoint=midpoint, name='shifted')
-                norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
-            else:
-                shifted_cmap = shiftedColorMap(cmap, midpoint=0.0, name='shifted')
-                norm = mpl.colors.Normalize(vmin=vmin - 1, vmax=vmax)
-
-            m = cm.ScalarMappable(norm=norm, cmap=shifted_cmap)
-
-            massif_name_to_fill_kwargs = {massif_name: {'color': m.to_rgba(value)} for massif_name, value in
-                                          massif_name_to_value.items()}
+            colors = get_color_rbga(ax, gev_param_name, values)
+            massif_name_to_fill_kwargs = {massif_name: {'color': color} for massif_name, color in
+                                          zip(self.safran.safran_massif_names, colors)}
+
+            print(massif_name_to_fill_kwargs)
 
             self.safran.visualize(ax=ax, massif_name_to_fill_kwargs=massif_name_to_fill_kwargs, show=False)
 
-            # Add colorbar
-            # plt.axis('off')
-            divider = make_axes_locatable(ax)
-            cax = divider.append_axes('right', size='5%', pad=0.05)
-            cb = cbar.ColorbarBase(cax, cmap=shifted_cmap, norm=norm)
-            cb.set_label(gev_param_name)
 
         if self.show:
             plt.show()
diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
index 85f1e7db..8ca34294 100644
--- a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
+++ b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
@@ -6,6 +6,7 @@ import numpy as np
 import pandas as pd
 
 from extreme_estimator.margin_fits.gev.gev_params import GevParams
+from extreme_estimator.margin_fits.plot.create_shifted_cmap import plot_extreme_param
 from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
 from spatio_temporal_dataset.slicer.split import Split
 
@@ -124,15 +125,19 @@ class AbstractMarginFunction(object):
 
     # Visualization 2D
 
-    def visualize_2D(self, gev_value_name=GevParams.LOC, ax=None, show=True):
+    def visualize_2D(self, gev_param_name=GevParams.LOC, ax=None, show=True):
         x = self.coordinates.x_coordinates
         y = self.coordinates.y_coordinates
         grid = self.grid_2D(x, y)
         if ax is None:
             ax = plt.gca()
         imshow_method = ax.imshow
-        imshow_method(grid[gev_value_name], extent=(x.min(), x.max(), y.min(), y.max()),
-                      interpolation='nearest', cmap=cm.gist_rainbow)
+        values = grid[gev_param_name]
+
+        norm, shifted_cmap = plot_extreme_param(ax, gev_param_name, values)
+
+        imshow_method(values, extent=(x.min(), x.max(), y.min(), y.max()),
+                      interpolation='nearest', cmap=shifted_cmap)
         # X axis
         ax.set_xlabel('coordinate X')
         plt.setp(ax.get_xticklabels(), visible=True)
@@ -145,12 +150,8 @@ class AbstractMarginFunction(object):
         if show:
             plt.show()
 
-    def grid_2D(self, x, y):
-        # if self._grid_2D is None:
-        #     self._grid_2D = self.get_grid_2D(x, y)
-        return self.get_grid_2D(x, y)
 
-    def get_grid_2D(self, x, y):
+    def grid_2D(self, x, y):
         resolution = 100
         grid = []
         for i, xi in enumerate(np.linspace(x.min(), x.max(), resolution)):
diff --git a/extreme_estimator/margin_fits/plot/__init__.py b/extreme_estimator/margin_fits/plot/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/extreme_estimator/margin_fits/plot/create_shifted_cmap.py b/extreme_estimator/margin_fits/plot/create_shifted_cmap.py
new file mode 100644
index 00000000..dc72f36e
--- /dev/null
+++ b/extreme_estimator/margin_fits/plot/create_shifted_cmap.py
@@ -0,0 +1,44 @@
+from typing import Dict
+
+import matplotlib as mpl
+import matplotlib.cm as cm
+import matplotlib.colorbar as cbar
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+from mpl_toolkits.axes_grid1 import make_axes_locatable
+
+from extreme_estimator.margin_fits.plot.shifted_color_map import shiftedColorMap
+from extreme_estimator.margin_fits.extreme_params import ExtremeParams
+from extreme_estimator.margin_fits.gev.gev_params import GevParams
+from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
+from spatio_temporal_dataset.slicer.split import Split
+
+
+def plot_extreme_param(ax, gev_param_name, values):
+    # Load the shifted cmap to center on a middle point
+    vmin, vmax = np.min(values), np.max(values)
+    cmap = [plt.cm.coolwarm, plt.cm.bwr, plt.cm.seismic][1]
+    if gev_param_name == ExtremeParams.SHAPE and vmin < 0:
+        midpoint = 1 - vmax / (vmax + abs(vmin))
+        shifted_cmap = shiftedColorMap(cmap, midpoint=midpoint, name='shifted')
+    else:
+        shifted_cmap = shiftedColorMap(cmap, midpoint=0.0, name='shifted')
+    norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
+    divider = make_axes_locatable(ax)
+    cax = divider.append_axes('right', size='5%', pad=0.05)
+    cb = cbar.ColorbarBase(cax, cmap=shifted_cmap, norm=norm)
+    cb.set_label(gev_param_name)
+    return norm, shifted_cmap
+
+
+def get_color_rbga(ax, gev_param_name, values):
+    """
+    For some display it was necessary to transform dark blue values into white values
+    """
+    norm, shifted_cmap = plot_extreme_param(ax, gev_param_name, values)
+    m = cm.ScalarMappable(norm=norm, cmap=shifted_cmap)
+    colors = [m.to_rgba(value) for value in values]
+    if gev_param_name != ExtremeParams.SHAPE:
+        colors = [color if color != (0, 0, 1, 1) else (1, 1, 1, 1) for color in colors]
+    return colors
diff --git a/experiment/safran_study/shifted_color_map.py b/extreme_estimator/margin_fits/plot/shifted_color_map.py
similarity index 100%
rename from experiment/safran_study/shifted_color_map.py
rename to extreme_estimator/margin_fits/plot/shifted_color_map.py
-- 
GitLab