From 9f24f3a1771306783f76851028f54e0d68adf022 Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Mon, 11 Feb 2019 17:48:54 +0100
Subject: [PATCH] [SAFRAN] fix range of colorbar

---
 safran_study/safran.py | 46 +++++++++++++++++-------------------------
 1 file changed, 19 insertions(+), 27 deletions(-)

diff --git a/safran_study/safran.py b/safran_study/safran.py
index 64210348..c1d80673 100644
--- a/safran_study/safran.py
+++ b/safran_study/safran.py
@@ -7,6 +7,7 @@ from collections import OrderedDict
 import matplotlib
 import matplotlib.pyplot as plt
 import pandas as pd
+from matplotlib import cm
 from mpl_toolkits.axes_grid1 import AxesGrid, make_axes_locatable
 from netCDF4 import Dataset
 
@@ -50,11 +51,10 @@ class Safran(object):
             massif_name = self.coordinate_id_to_massif_name[coordinate_id]
             fill_kwargs = massif_name_to_fill_kwargs[massif_name] if massif_name_to_fill_kwargs is not None else {}
             ax.fill(*l, **fill_kwargs)
-        cax = ax.scatter(self.massifs_coordinates.x_coordinates, self.massifs_coordinates.y_coordinates)
+        ax.scatter(self.massifs_coordinates.x_coordinates, self.massifs_coordinates.y_coordinates)
 
         if show:
             plt.show()
-        return cax
 
     def visualize_gev_fit_with_cmap(self, show=True, axes=None):
         if axes is None:
@@ -67,45 +67,37 @@ class Safran(object):
             #                 cbar_location="right", cbar_mode="each",
             #                 cbar_size="7%", cbar_pad="2%")
 
-        for i, gev_param_name in enumerate(GevParams.GEV_PARAM_NAMES[-1:]):
+        for i, gev_param_name in enumerate(GevParams.GEV_PARAM_NAMES[:]):
             massif_name_to_value = self.df_gev_mle_each_massif.loc[gev_param_name, :].to_dict()
             # Compute the middle point of the values for the color map
             values = list(massif_name_to_value.values())
             vmin, vmax = min(values), max(values)
             midpoint = 1 - vmax / (vmax + abs(vmin))
-            scaling_factor = 2 * max(vmax, -vmin)
+            maxmax = max(vmax, -vmin)
+            scaling_factor = 2 * maxmax
             # print(gev_param_name, midpoint, vmin, vmax, scaling_factor)
             # Load the shifted cmap to center on a middle point
-            cmap = [plt.cm.coolwarm, plt.cm.bwr, plt.cm.seismic][0]
-            shifted_cmap = shiftedColorMap(cmap, midpoint=0.5, name='shifted')
-            for massif_name, value in massif_name_to_value.items():
-                if value < 0:
-                    print(massif_name, value)
-            # massif_name_to_fill_kwargs = {massif_name: {'color': shifted_cmap(0.5 + value / scaling_factor)} for massif_name, value in
-            #                               massif_name_to_value.items()}
-            massif_name_to_fill_kwargs = {massif_name: {'color': shifted_cmap(0.5 + value / scaling_factor)} for massif_name, value in
+
+            cmap = [plt.cm.coolwarm, plt.cm.bwr, plt.cm.seismic][1]
+            if gev_param_name == GevParams.GEV_SHAPE:
+                shifted_cmap = shiftedColorMap(cmap, midpoint=midpoint, name='shifted')
+                norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
+            else:
+                shifted_cmap = shiftedColorMap(cmap, midpoint=0.0, name='shifted')
+                norm = mpl.colors.Normalize(vmin=vmin-1, vmax=vmax)
+
+            m = cm.ScalarMappable(norm=norm, cmap=shifted_cmap)
+
+            massif_name_to_fill_kwargs = {massif_name: {'color': m.to_rgba(value)} for massif_name, value in
                                           massif_name_to_value.items()}
             ax = axes[i]
-            cax = self.visualize(ax=ax, massif_name_to_fill_kwargs=massif_name_to_fill_kwargs, show=False)
+            self.visualize(ax=ax, massif_name_to_fill_kwargs=massif_name_to_fill_kwargs, show=False)
 
             divider = make_axes_locatable(ax)
             cax = divider.append_axes('right', size='5%', pad=0.05)
 
-            # cax, _ = cbar.make_axes(ax)
-            # todo: the good shape values are not the one displayed
-            norm = mpl.colors.Normalize(vmin=-1, vmax=1)
             cb = cbar.ColorbarBase(cax, cmap=shifted_cmap, norm=norm)
-            cb.set_label('Some Units')
-
-            # fig.colorbar(shifted_cmap, cax=cax, orientation='vertical')
-
-
-            # cbar = fig.colorbar(cax, ticks=[-1, 0, 1], orientation='horizontal')
-            # cbar.ax.set_xticklabels(['Low', 'Medium', 'High'])  # horizontal colorbar
-            # cbar = fig.colorbar(cax, ticks=[-1, 0, 1])
-            # cbar.ax.set_yticklabels(['< -1', '0', '> 1'])  # vertically oriented colorbar
-            title_str = gev_param_name
-            ax.set_title(title_str)
+            cb.set_label(gev_param_name)
 
         if show:
             plt.show()
-- 
GitLab