Commit ba77edf4 authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[SAFRAN] use same visualization colormap for margins & for safran display

parent fe1e74fa
No related merge requests found
Showing with 63 additions and 39 deletions
+63 -39
......@@ -17,6 +17,6 @@ def load_all_safran(only_first_one=False):
if __name__ == '__main__':
for safran in load_all_safran(only_first_one=True):
safran_visualizer = SafranVisualizer(safran)
# safran_visualizer.visualize_independent_margin_fits(threshold=[None, 20, 40, 60][1])
# safran_visualizer.visualize_smooth_margin_fit()
safran_visualizer.visualize_full_fit()
# safran_visualizer.visualize_independent_margin_fits(threshold=[None, 20, 40, 60][0])
safran_visualizer.visualize_smooth_margin_fit()
# safran_visualizer.visualize_full_fit()
......@@ -8,16 +8,15 @@ from mpl_toolkits.axes_grid1 import make_axes_locatable
from extreme_estimator.estimator.full_estimator.abstract_full_estimator import \
FullEstimatorInASingleStepWithSmoothMargin
from extreme_estimator.estimator.margin_estimator.abstract_margin_estimator import SmoothMarginEstimator
from extreme_estimator.estimator.max_stable_estimator.abstract_max_stable_estimator import MaxStableEstimator
from extreme_estimator.extreme_models.margin_model.smooth_margin_model import LinearAllParametersAllDimsMarginModel
from extreme_estimator.extreme_models.max_stable_model.max_stable_models import ExtremalT, Smith
from extreme_estimator.extreme_models.max_stable_model.max_stable_models import Smith
from extreme_estimator.margin_fits.extreme_params import ExtremeParams
from extreme_estimator.margin_fits.gev.gev_params import GevParams
from extreme_estimator.margin_fits.gev.gevmle_fit import GevMleFit
from extreme_estimator.margin_fits.gpd.gpd_params import GpdParams
from extreme_estimator.margin_fits.gpd.gpdmle_fit import GpdMleFit
from experiment.safran_study.safran import Safran
from experiment.safran_study.shifted_color_map import shiftedColorMap
from extreme_estimator.margin_fits.plot.create_shifted_cmap import plot_extreme_param, get_color_rbga
from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
......@@ -71,34 +70,14 @@ class SafranVisualizer(object):
massif_name_to_value = df.loc[gev_param_name, :].to_dict()
# Compute the middle point of the values for the color map
values = list(massif_name_to_value.values())
vmin, vmax = min(values), max(values)
try:
midpoint = 1 - vmax / (vmax + abs(vmin))
except ZeroDivisionError:
pass
# Load the shifted cmap to center on a middle point
cmap = [plt.cm.coolwarm, plt.cm.bwr, plt.cm.seismic][1]
if gev_param_name == ExtremeParams.SHAPE:
shifted_cmap = shiftedColorMap(cmap, midpoint=midpoint, name='shifted')
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
else:
shifted_cmap = shiftedColorMap(cmap, midpoint=0.0, name='shifted')
norm = mpl.colors.Normalize(vmin=vmin - 1, vmax=vmax)
m = cm.ScalarMappable(norm=norm, cmap=shifted_cmap)
massif_name_to_fill_kwargs = {massif_name: {'color': m.to_rgba(value)} for massif_name, value in
massif_name_to_value.items()}
colors = get_color_rbga(ax, gev_param_name, values)
massif_name_to_fill_kwargs = {massif_name: {'color': color} for massif_name, color in
zip(self.safran.safran_massif_names, colors)}
print(massif_name_to_fill_kwargs)
self.safran.visualize(ax=ax, massif_name_to_fill_kwargs=massif_name_to_fill_kwargs, show=False)
# Add colorbar
# plt.axis('off')
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.05)
cb = cbar.ColorbarBase(cax, cmap=shifted_cmap, norm=norm)
cb.set_label(gev_param_name)
if self.show:
plt.show()
......
......@@ -6,6 +6,7 @@ import numpy as np
import pandas as pd
from extreme_estimator.margin_fits.gev.gev_params import GevParams
from extreme_estimator.margin_fits.plot.create_shifted_cmap import plot_extreme_param
from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
from spatio_temporal_dataset.slicer.split import Split
......@@ -124,15 +125,19 @@ class AbstractMarginFunction(object):
# Visualization 2D
def visualize_2D(self, gev_value_name=GevParams.LOC, ax=None, show=True):
def visualize_2D(self, gev_param_name=GevParams.LOC, ax=None, show=True):
x = self.coordinates.x_coordinates
y = self.coordinates.y_coordinates
grid = self.grid_2D(x, y)
if ax is None:
ax = plt.gca()
imshow_method = ax.imshow
imshow_method(grid[gev_value_name], extent=(x.min(), x.max(), y.min(), y.max()),
interpolation='nearest', cmap=cm.gist_rainbow)
values = grid[gev_param_name]
norm, shifted_cmap = plot_extreme_param(ax, gev_param_name, values)
imshow_method(values, extent=(x.min(), x.max(), y.min(), y.max()),
interpolation='nearest', cmap=shifted_cmap)
# X axis
ax.set_xlabel('coordinate X')
plt.setp(ax.get_xticklabels(), visible=True)
......@@ -145,12 +150,8 @@ class AbstractMarginFunction(object):
if show:
plt.show()
def grid_2D(self, x, y):
# if self._grid_2D is None:
# self._grid_2D = self.get_grid_2D(x, y)
return self.get_grid_2D(x, y)
def get_grid_2D(self, x, y):
def grid_2D(self, x, y):
resolution = 100
grid = []
for i, xi in enumerate(np.linspace(x.min(), x.max(), resolution)):
......
from typing import Dict
import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.colorbar as cbar
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from mpl_toolkits.axes_grid1 import make_axes_locatable
from extreme_estimator.margin_fits.plot.shifted_color_map import shiftedColorMap
from extreme_estimator.margin_fits.extreme_params import ExtremeParams
from extreme_estimator.margin_fits.gev.gev_params import GevParams
from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
from spatio_temporal_dataset.slicer.split import Split
def plot_extreme_param(ax, gev_param_name, values):
# Load the shifted cmap to center on a middle point
vmin, vmax = np.min(values), np.max(values)
cmap = [plt.cm.coolwarm, plt.cm.bwr, plt.cm.seismic][1]
if gev_param_name == ExtremeParams.SHAPE and vmin < 0:
midpoint = 1 - vmax / (vmax + abs(vmin))
shifted_cmap = shiftedColorMap(cmap, midpoint=midpoint, name='shifted')
else:
shifted_cmap = shiftedColorMap(cmap, midpoint=0.0, name='shifted')
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.05)
cb = cbar.ColorbarBase(cax, cmap=shifted_cmap, norm=norm)
cb.set_label(gev_param_name)
return norm, shifted_cmap
def get_color_rbga(ax, gev_param_name, values):
"""
For some display it was necessary to transform dark blue values into white values
"""
norm, shifted_cmap = plot_extreme_param(ax, gev_param_name, values)
m = cm.ScalarMappable(norm=norm, cmap=shifted_cmap)
colors = [m.to_rgba(value) for value in values]
if gev_param_name != ExtremeParams.SHAPE:
colors = [color if color != (0, 0, 1, 1) else (1, 1, 1, 1) for color in colors]
return colors
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment