From bee5048c0b2065dc2f5f1c9f0819bcc97b25fe18 Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Mon, 18 Mar 2019 18:59:48 +0100
Subject: [PATCH] [SCM] improve visualization of the margin. add a 2D optional
 mask & visualization bounds to visualize a 2D margin

---
 .../meteo_france_SCM_study/abstract_study.py  | 32 +++++++++++++----
 .../main_study_visualizer.py                  |  2 +-
 .../study_visualization/study_visualizer.py   | 35 ++++++++++++++-----
 .../abstract_margin_function.py               | 35 +++++++++++++++----
 .../param_function/param_function.py          |  4 ++-
 .../margin_fits/plot/create_shifted_cmap.py   |  9 +++--
 6 files changed, 91 insertions(+), 26 deletions(-)

diff --git a/experiment/meteo_france_SCM_study/abstract_study.py b/experiment/meteo_france_SCM_study/abstract_study.py
index e46e3f62..927bce4a 100644
--- a/experiment/meteo_france_SCM_study/abstract_study.py
+++ b/experiment/meteo_france_SCM_study/abstract_study.py
@@ -158,14 +158,10 @@ class AbstractStudy(object):
 
         if ax is None:
             ax = plt.gca()
-        df_massif = pd.read_csv(op.join(self.map_full_path, 'massifsalpes.csv'))
-        coord_tuples = [(row_massif['idx'], row_massif[AbstractCoordinates.COORDINATE_X],
-                         row_massif[AbstractCoordinates.COORDINATE_Y])
-                        for _, row_massif in df_massif.iterrows()]
 
-        for _, coordinate_id in enumerate(set([t[0] for t in coord_tuples])):
+        for coordinate_id, coords_list in self.idx_to_coords_list.items():
             # Retrieve the list of coords (x,y) that define the contour of the massif of id coordinate_id
-            coords_list = [coords for idx, *coords in coord_tuples if idx == coordinate_id]
+
             # if j == 0:
             #     mask_outside_polygon(poly_verts=l, ax=ax)
             # Plot the contour of the massif
@@ -200,6 +196,30 @@ class AbstractStudy(object):
         if show:
             plt.show()
 
+    @property
+    def idx_to_coords_list(self):
+        df_massif = pd.read_csv(op.join(self.map_full_path, 'massifsalpes.csv'))
+        coord_tuples = [(row_massif['idx'], row_massif[AbstractCoordinates.COORDINATE_X],
+                         row_massif[AbstractCoordinates.COORDINATE_Y])
+                        for _, row_massif in df_massif.iterrows()]
+        all_idxs = set([t[0] for t in coord_tuples])
+        return {idx: [coords for idx_loop, *coords in coord_tuples if idx == idx_loop] for idx in all_idxs}
+
+    @property
+    def all_coords_list(self):
+        all_values = []
+        for e in self.idx_to_coords_list.values():
+            all_values.extend(e)
+        return list(zip(*all_values))
+
+    @property
+    def visualization_x_limits(self):
+        return min(self.all_coords_list[0]), max(self.all_coords_list[0])
+
+    @property
+    def visualization_y_limits(self):
+        return min(self.all_coords_list[1]), max(self.all_coords_list[1])
+
     """ Some properties """
 
     @property
diff --git a/experiment/meteo_france_SCM_study/visualization/study_visualization/main_study_visualizer.py b/experiment/meteo_france_SCM_study/visualization/study_visualization/main_study_visualizer.py
index d03ff3df..5b9dbe4f 100644
--- a/experiment/meteo_france_SCM_study/visualization/study_visualization/main_study_visualizer.py
+++ b/experiment/meteo_france_SCM_study/visualization/study_visualization/main_study_visualizer.py
@@ -71,7 +71,7 @@ def normal_visualization():
             study_visualizer = StudyVisualizer(study, save_to_file=save_to_file)
             # study_visualizer.visualize_independent_margin_fits(threshold=[None, 20, 40, 60][0])
             # study_visualizer.visualize_annual_mean_values()
-            study_visualizer.visualize_linear_margin_fit(only_first_max_stable=only_first_one)
+            study_visualizer.visualize_linear_margin_fit(only_first_max_stable=None)
 
 
 def complete_analysis(only_first_one=False):
diff --git a/experiment/meteo_france_SCM_study/visualization/study_visualization/study_visualizer.py b/experiment/meteo_france_SCM_study/visualization/study_visualization/study_visualizer.py
index d39d41ab..b6f093e3 100644
--- a/experiment/meteo_france_SCM_study/visualization/study_visualization/study_visualizer.py
+++ b/experiment/meteo_france_SCM_study/visualization/study_visualization/study_visualizer.py
@@ -13,6 +13,7 @@ from experiment.utils import average_smoothing_with_sliding_window
 from extreme_estimator.estimator.full_estimator.abstract_full_estimator import \
     FullEstimatorInASingleStepWithSmoothMargin
 from extreme_estimator.estimator.margin_estimator.abstract_margin_estimator import SmoothMarginEstimator
+from extreme_estimator.extreme_models.margin_model.param_function.param_function import ParamFunction
 from extreme_estimator.extreme_models.margin_model.smooth_margin_model import LinearAllParametersAllDimsMarginModel
 from extreme_estimator.extreme_models.max_stable_model.abstract_max_stable_model import CovarianceFunction
 from extreme_estimator.margin_fits.abstract_params import AbstractParams
@@ -52,7 +53,10 @@ class StudyVisualizer(object):
         else:
             self.figsize = (16.0, 10.0)
         self.subplot_space = 0.5
-        self.coef_zoom_map = 0
+        self.coef_zoom_map = 1
+
+        # Remove some assert
+        ParamFunction.OUT_OF_BOUNDS_ASSERT = False
 
     @property
     def observations(self):
@@ -214,6 +218,8 @@ class StudyVisualizer(object):
         max_stable_models = load_test_max_stable_models(default_covariance_function=default_covariance_function)
         if only_first_max_stable:
             max_stable_models = max_stable_models[:1]
+        if only_first_max_stable is None:
+            max_stable_models = []
         fig, axes = plt.subplots(len(max_stable_models) + 2, len(GevParams.SUMMARY_NAMES), figsize=self.figsize)
         fig.subplots_adjust(hspace=self.subplot_space, wspace=self.subplot_space)
         margin_class = LinearAllParametersAllDimsMarginModel
@@ -238,20 +244,33 @@ class StudyVisualizer(object):
         estimator.fit()
 
         margin_fct = estimator.margin_function_fitted
+
+        # margin_fct.visualization_x_limits = self.study.
+        margin_fct._visualization_x_limits = self.study.visualization_x_limits
+        margin_fct._visualization_y_limits = self.study.visualization_y_limits
+        # Example of mask 2D
+        mask_2D = np.zeros([margin_fct.resolution, margin_fct.resolution], dtype=bool)
+        lim = 5
+        mask_2D[lim:-lim, lim:-lim] = True
+
+        margin_fct.mask_2D = mask_2D
         axes = margin_fct.visualize_function(show=False, axes=axes, title='')
 
+        self.visualize_contour_and_move_axes_limits(axes)
+        self.clean_axes_write_title_on_the_left(axes, title)
+
+    def visualize_contour_and_move_axes_limits(self, axes):
         def get_lim_array(ax_with_lim_to_measure):
             return np.array([np.array(ax_with_lim_to_measure.get_xlim()), np.array(ax_with_lim_to_measure.get_ylim())])
 
         for ax in axes:
-            old_lim = get_lim_array(ax)
+            # old_lim = get_lim_array(ax)
             self.study.visualize_study(ax, fill=False, show=False)
-            new_lim = get_lim_array(ax)
-            assert 0 <= self.coef_zoom_map <= 1
-            updated_lim = new_lim * self.coef_zoom_map + (1 - self.coef_zoom_map) * old_lim
-            for i, method in enumerate([ax.set_xlim, ax.set_ylim]):
-                method(updated_lim[i, 0], updated_lim[i, 1])
-        self.clean_axes_write_title_on_the_left(axes, title)
+            # new_lim = get_lim_array(ax)
+            # assert 0 <= self.coef_zoom_map <= 1
+            # updated_lim = new_lim * self.coef_zoom_map + (1 - self.coef_zoom_map) * old_lim
+            # for i, method in enumerate([ax.set_xlim, ax.set_ylim]):
+            #     method(updated_lim[i, 0], updated_lim[i, 1])
 
     @staticmethod
     def clean_axes_write_title_on_the_left(axes, title):
diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
index 3ae470f8..166e87b9 100644
--- a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
+++ b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
@@ -15,9 +15,10 @@ from utils import cached_property
 class AbstractMarginFunction(object):
     """ Class of function mapping points from a space S (could be 1D, 2D,...) to R^3 (the 3 parameters of the GEV)"""
 
-    def __init__(self, coordinates: AbstractCoordinates, resolution=100):
+    def __init__(self, coordinates: AbstractCoordinates):
         self.coordinates = coordinates
-        self.resolution = resolution
+        self.resolution = 100
+        self.mask_2D = None
 
         # Visualization parameters
         self.visualization_axes = None
@@ -31,6 +32,10 @@ class AbstractMarginFunction(object):
         self._grid_2D = None
         self._grid_1D = None
 
+        # Visualization limits
+        self._visualization_x_limits = None
+        self._visualization_y_limits = None
+
     @property
     def x(self):
         return self.coordinates.x_coordinates
@@ -145,7 +150,7 @@ class AbstractMarginFunction(object):
             ax = plt.gca()
 
         # Special display
-        imshow_shifted(ax, gev_param_name, self.grid_2D[gev_param_name], self.x, self.y)
+        imshow_shifted(ax, gev_param_name, self.grid_2D[gev_param_name], self.visualization_extend, self.mask_2D)
 
         # X axis
         ax.set_xlabel('coordinate X')
@@ -159,13 +164,29 @@ class AbstractMarginFunction(object):
         if show:
             plt.show()
 
+    @property
+    def visualization_x_limits(self):
+        if self._visualization_x_limits is None:
+            return self.x.min(), self.x.max()
+        else:
+            return self._visualization_x_limits
+
+    @property
+    def visualization_y_limits(self):
+        if self._visualization_y_limits is None:
+            return self.y.min(), self.y.max()
+        else:
+            return self._visualization_y_limits
+
+    @property
+    def visualization_extend(self):
+        return self.visualization_x_limits + self.visualization_y_limits
+
     @cached_property
     def grid_2D(self):
-        x = self.x
-        y = self.y
         grid = []
-        for i, xi in enumerate(np.linspace(x.min(), x.max(), self.resolution)):
-            for j, yj in enumerate(np.linspace(y.min(), y.max(), self.resolution)):
+        for xi in np.linspace(*self.visualization_x_limits, self.resolution):
+            for yj in np.linspace(*self.visualization_y_limits, self.resolution):
                 grid.append(self.get_gev_params(np.array([xi, yj])).summary_dict)
         grid = {value_name: np.array([g[value_name] for g in grid]).reshape([self.resolution, self.resolution])
                 for value_name in GevParams.SUMMARY_NAMES}
diff --git a/extreme_estimator/extreme_models/margin_model/param_function/param_function.py b/extreme_estimator/extreme_models/margin_model/param_function/param_function.py
index ad8555ef..f8b882d6 100644
--- a/extreme_estimator/extreme_models/margin_model/param_function/param_function.py
+++ b/extreme_estimator/extreme_models/margin_model/param_function/param_function.py
@@ -4,6 +4,7 @@ from extreme_estimator.extreme_models.margin_model.param_function.linear_coef im
 
 
 class ParamFunction(object):
+    OUT_OF_BOUNDS_ASSERT = True
 
     def get_gev_param_value(self, coordinate: np.ndarray) -> float:
         pass
@@ -31,7 +32,8 @@ class LinearOneAxisParamFunction(ParamFunction):
 
     def get_gev_param_value(self, coordinate: np.ndarray) -> float:
         t = coordinate[self.linear_axis]
-        assert self.t_min <= t <= self.t_max, 'Out of bounds'
+        if self.OUT_OF_BOUNDS_ASSERT:
+            assert self.t_min <= t <= self.t_max, 'Out of bounds'
         return t * self.coef
 
 
diff --git a/extreme_estimator/margin_fits/plot/create_shifted_cmap.py b/extreme_estimator/margin_fits/plot/create_shifted_cmap.py
index 128954fa..194466ec 100644
--- a/extreme_estimator/margin_fits/plot/create_shifted_cmap.py
+++ b/extreme_estimator/margin_fits/plot/create_shifted_cmap.py
@@ -44,8 +44,11 @@ def get_color_rbga_shifted(ax, replace_blue_by_white: bool, values: np.ndarray,
     return colors
 
 
-def imshow_shifted(ax, gev_param_name, values, x, y):
-    masked_array = np.ma.masked_where(np.isnan(values), values)
+def imshow_shifted(ax, gev_param_name, values, visualization_extend, mask_2D=None):
+    condition = np.isnan(values)
+    if mask_2D is not None:
+        condition |= mask_2D
+    masked_array = np.ma.masked_where(condition, values)
     norm, shifted_cmap = plot_extreme_param(ax, gev_param_name, masked_array)
     shifted_cmap.set_bad(color='white')
     if gev_param_name != ExtremeParams.SHAPE:
@@ -54,5 +57,5 @@ def imshow_shifted(ax, gev_param_name, values, x, y):
         # The right blue corner will be blue (but most of the time, another display will be on top)
         masked_array[-1, -1] = value - epsilon
     # IMPORTANT: Origin for all the plots is at the bottom left corner
-    ax.imshow(masked_array, extent=(x.min(), x.max(), y.min(), y.max()), cmap=shifted_cmap, origin='lower')
+    ax.imshow(masked_array, extent=visualization_extend, cmap=shifted_cmap, origin='lower')
 
-- 
GitLab