Commit ba77edf4 authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[SAFRAN] use same visualization colormap for margins & for safran display

parent fe1e74fa
No related merge requests found
Showing with 63 additions and 39 deletions
+63 -39
...@@ -17,6 +17,6 @@ def load_all_safran(only_first_one=False): ...@@ -17,6 +17,6 @@ def load_all_safran(only_first_one=False):
if __name__ == '__main__': if __name__ == '__main__':
for safran in load_all_safran(only_first_one=True): for safran in load_all_safran(only_first_one=True):
safran_visualizer = SafranVisualizer(safran) safran_visualizer = SafranVisualizer(safran)
# safran_visualizer.visualize_independent_margin_fits(threshold=[None, 20, 40, 60][1]) # safran_visualizer.visualize_independent_margin_fits(threshold=[None, 20, 40, 60][0])
# safran_visualizer.visualize_smooth_margin_fit() safran_visualizer.visualize_smooth_margin_fit()
safran_visualizer.visualize_full_fit() # safran_visualizer.visualize_full_fit()
...@@ -8,16 +8,15 @@ from mpl_toolkits.axes_grid1 import make_axes_locatable ...@@ -8,16 +8,15 @@ from mpl_toolkits.axes_grid1 import make_axes_locatable
from extreme_estimator.estimator.full_estimator.abstract_full_estimator import \ from extreme_estimator.estimator.full_estimator.abstract_full_estimator import \
FullEstimatorInASingleStepWithSmoothMargin FullEstimatorInASingleStepWithSmoothMargin
from extreme_estimator.estimator.margin_estimator.abstract_margin_estimator import SmoothMarginEstimator from extreme_estimator.estimator.margin_estimator.abstract_margin_estimator import SmoothMarginEstimator
from extreme_estimator.estimator.max_stable_estimator.abstract_max_stable_estimator import MaxStableEstimator
from extreme_estimator.extreme_models.margin_model.smooth_margin_model import LinearAllParametersAllDimsMarginModel from extreme_estimator.extreme_models.margin_model.smooth_margin_model import LinearAllParametersAllDimsMarginModel
from extreme_estimator.extreme_models.max_stable_model.max_stable_models import ExtremalT, Smith from extreme_estimator.extreme_models.max_stable_model.max_stable_models import Smith
from extreme_estimator.margin_fits.extreme_params import ExtremeParams from extreme_estimator.margin_fits.extreme_params import ExtremeParams
from extreme_estimator.margin_fits.gev.gev_params import GevParams from extreme_estimator.margin_fits.gev.gev_params import GevParams
from extreme_estimator.margin_fits.gev.gevmle_fit import GevMleFit from extreme_estimator.margin_fits.gev.gevmle_fit import GevMleFit
from extreme_estimator.margin_fits.gpd.gpd_params import GpdParams from extreme_estimator.margin_fits.gpd.gpd_params import GpdParams
from extreme_estimator.margin_fits.gpd.gpdmle_fit import GpdMleFit from extreme_estimator.margin_fits.gpd.gpdmle_fit import GpdMleFit
from experiment.safran_study.safran import Safran from experiment.safran_study.safran import Safran
from experiment.safran_study.shifted_color_map import shiftedColorMap from extreme_estimator.margin_fits.plot.create_shifted_cmap import plot_extreme_param, get_color_rbga
from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
...@@ -71,34 +70,14 @@ class SafranVisualizer(object): ...@@ -71,34 +70,14 @@ class SafranVisualizer(object):
massif_name_to_value = df.loc[gev_param_name, :].to_dict() massif_name_to_value = df.loc[gev_param_name, :].to_dict()
# Compute the middle point of the values for the color map # Compute the middle point of the values for the color map
values = list(massif_name_to_value.values()) values = list(massif_name_to_value.values())
vmin, vmax = min(values), max(values) colors = get_color_rbga(ax, gev_param_name, values)
try: massif_name_to_fill_kwargs = {massif_name: {'color': color} for massif_name, color in
midpoint = 1 - vmax / (vmax + abs(vmin)) zip(self.safran.safran_massif_names, colors)}
except ZeroDivisionError:
pass print(massif_name_to_fill_kwargs)
# Load the shifted cmap to center on a middle point
cmap = [plt.cm.coolwarm, plt.cm.bwr, plt.cm.seismic][1]
if gev_param_name == ExtremeParams.SHAPE:
shifted_cmap = shiftedColorMap(cmap, midpoint=midpoint, name='shifted')
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
else:
shifted_cmap = shiftedColorMap(cmap, midpoint=0.0, name='shifted')
norm = mpl.colors.Normalize(vmin=vmin - 1, vmax=vmax)
m = cm.ScalarMappable(norm=norm, cmap=shifted_cmap)
massif_name_to_fill_kwargs = {massif_name: {'color': m.to_rgba(value)} for massif_name, value in
massif_name_to_value.items()}
self.safran.visualize(ax=ax, massif_name_to_fill_kwargs=massif_name_to_fill_kwargs, show=False) self.safran.visualize(ax=ax, massif_name_to_fill_kwargs=massif_name_to_fill_kwargs, show=False)
# Add colorbar
# plt.axis('off')
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.05)
cb = cbar.ColorbarBase(cax, cmap=shifted_cmap, norm=norm)
cb.set_label(gev_param_name)
if self.show: if self.show:
plt.show() plt.show()
......
...@@ -6,6 +6,7 @@ import numpy as np ...@@ -6,6 +6,7 @@ import numpy as np
import pandas as pd import pandas as pd
from extreme_estimator.margin_fits.gev.gev_params import GevParams from extreme_estimator.margin_fits.gev.gev_params import GevParams
from extreme_estimator.margin_fits.plot.create_shifted_cmap import plot_extreme_param
from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
from spatio_temporal_dataset.slicer.split import Split from spatio_temporal_dataset.slicer.split import Split
...@@ -124,15 +125,19 @@ class AbstractMarginFunction(object): ...@@ -124,15 +125,19 @@ class AbstractMarginFunction(object):
# Visualization 2D # Visualization 2D
def visualize_2D(self, gev_value_name=GevParams.LOC, ax=None, show=True): def visualize_2D(self, gev_param_name=GevParams.LOC, ax=None, show=True):
x = self.coordinates.x_coordinates x = self.coordinates.x_coordinates
y = self.coordinates.y_coordinates y = self.coordinates.y_coordinates
grid = self.grid_2D(x, y) grid = self.grid_2D(x, y)
if ax is None: if ax is None:
ax = plt.gca() ax = plt.gca()
imshow_method = ax.imshow imshow_method = ax.imshow
imshow_method(grid[gev_value_name], extent=(x.min(), x.max(), y.min(), y.max()), values = grid[gev_param_name]
interpolation='nearest', cmap=cm.gist_rainbow)
norm, shifted_cmap = plot_extreme_param(ax, gev_param_name, values)
imshow_method(values, extent=(x.min(), x.max(), y.min(), y.max()),
interpolation='nearest', cmap=shifted_cmap)
# X axis # X axis
ax.set_xlabel('coordinate X') ax.set_xlabel('coordinate X')
plt.setp(ax.get_xticklabels(), visible=True) plt.setp(ax.get_xticklabels(), visible=True)
...@@ -145,12 +150,8 @@ class AbstractMarginFunction(object): ...@@ -145,12 +150,8 @@ class AbstractMarginFunction(object):
if show: if show:
plt.show() plt.show()
def grid_2D(self, x, y):
# if self._grid_2D is None:
# self._grid_2D = self.get_grid_2D(x, y)
return self.get_grid_2D(x, y)
def get_grid_2D(self, x, y): def grid_2D(self, x, y):
resolution = 100 resolution = 100
grid = [] grid = []
for i, xi in enumerate(np.linspace(x.min(), x.max(), resolution)): for i, xi in enumerate(np.linspace(x.min(), x.max(), resolution)):
......
from typing import Dict
import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.colorbar as cbar
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from mpl_toolkits.axes_grid1 import make_axes_locatable
from extreme_estimator.margin_fits.plot.shifted_color_map import shiftedColorMap
from extreme_estimator.margin_fits.extreme_params import ExtremeParams
from extreme_estimator.margin_fits.gev.gev_params import GevParams
from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
from spatio_temporal_dataset.slicer.split import Split
def plot_extreme_param(ax, gev_param_name, values):
# Load the shifted cmap to center on a middle point
vmin, vmax = np.min(values), np.max(values)
cmap = [plt.cm.coolwarm, plt.cm.bwr, plt.cm.seismic][1]
if gev_param_name == ExtremeParams.SHAPE and vmin < 0:
midpoint = 1 - vmax / (vmax + abs(vmin))
shifted_cmap = shiftedColorMap(cmap, midpoint=midpoint, name='shifted')
else:
shifted_cmap = shiftedColorMap(cmap, midpoint=0.0, name='shifted')
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.05)
cb = cbar.ColorbarBase(cax, cmap=shifted_cmap, norm=norm)
cb.set_label(gev_param_name)
return norm, shifted_cmap
def get_color_rbga(ax, gev_param_name, values):
"""
For some display it was necessary to transform dark blue values into white values
"""
norm, shifted_cmap = plot_extreme_param(ax, gev_param_name, values)
m = cm.ScalarMappable(norm=norm, cmap=shifted_cmap)
colors = [m.to_rgba(value) for value in values]
if gev_param_name != ExtremeParams.SHAPE:
colors = [color if color != (0, 0, 1, 1) else (1, 1, 1, 1) for color in colors]
return colors
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment