From 30f2543ff1bb5873a54c002c688545c258a35a81 Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Thu, 15 Nov 2018 17:22:05 +0100
Subject: [PATCH] [RETURN LEVEL PLOT] add return level plot and modify
 test_margin_estimators accordingly

---
 .../estimator/margin_estimator.py             | 11 ++++---
 .../abstract_margin_function.py               | 19 ++++++++----
 .../margin_model/smooth_margin_model.py       |  4 +--
 .../return_level_plot/spatial_2D_plot.py      | 30 +++++++++++++++++++
 .../test_estimator/test_full_estimators.py    |  4 +--
 .../test_estimator/test_margin_estimators.py  | 27 ++++++++++-------
 6 files changed, 70 insertions(+), 25 deletions(-)
 create mode 100644 extreme_estimator/return_level_plot/spatial_2D_plot.py

diff --git a/extreme_estimator/estimator/margin_estimator.py b/extreme_estimator/estimator/margin_estimator.py
index 30bb6836..82fdc39f 100644
--- a/extreme_estimator/estimator/margin_estimator.py
+++ b/extreme_estimator/estimator/margin_estimator.py
@@ -1,5 +1,8 @@
 from extreme_estimator.extreme_models.margin_model.abstract_margin_model import AbstractMarginModel
 from extreme_estimator.estimator.abstract_estimator import AbstractEstimator
+from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \
+    AbstractMarginFunction
+from extreme_estimator.extreme_models.margin_model.smooth_margin_model import LinearMarginModel
 from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
 
 
@@ -11,8 +14,8 @@ class AbstractMarginEstimator(AbstractEstimator):
         self._margin_function_fitted = None
 
     @property
-    def margin_function_fitted(self):
-        assert self._margin_function_fitted is not None, 'Error: estimator has not been not fitted yet'
+    def margin_function_fitted(self) -> AbstractMarginFunction:
+        assert self._margin_function_fitted is not None, 'Error: estimator has not been fitted'
         return self._margin_function_fitted
 
 
@@ -23,9 +26,9 @@ class PointWiseMarginEstimator(AbstractMarginEstimator):
 class SmoothMarginEstimator(AbstractMarginEstimator):
     """# with different type of marginals: cosntant, linear...."""
 
-    def __init__(self, dataset: AbstractDataset, margin_model: AbstractMarginModel):
+    def __init__(self, dataset: AbstractDataset, margin_model: LinearMarginModel):
         super().__init__(dataset)
-        assert isinstance(margin_model, AbstractMarginModel)
+        assert isinstance(margin_model, LinearMarginModel)
         self.margin_model = margin_model
 
     def _fit(self):
diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
index 6db45309..e5a593b4 100644
--- a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
+++ b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
@@ -18,16 +18,23 @@ class AbstractMarginFunction(object):
     def get_gev_params(self, coordinate: np.ndarray) -> GevParams:
         pass
 
-    def visualize_2D(self, gev_param_name=GevParams.GEV_LOC, show=False):
+    def visualize_2D(self, gev_param_name=GevParams.GEV_LOC, ax=None, show=False):
         x = self.spatial_coordinates.x_coordinates
         y = self.spatial_coordinates.y_coordinates
+        grid = self.get_grid_2D(x, y)
+        gev_param_idx = GevParams.GEV_PARAM_NAMES.index(gev_param_name)
+        if ax is None:
+            ax = plt.gca()
+        imshow_method = ax.imshow
+        imshow_method(grid[..., gev_param_idx], extent=(x.min(), x.max(), y.min(), y.max()),
+                      interpolation='nearest', cmap=cm.gist_rainbow)
+        if show:
+            plt.show()
+
+    def get_grid_2D(self, x, y):
         resolution = 100
         grid = np.zeros([resolution, resolution, 3])
         for i, xi in enumerate(np.linspace(x.min(), x.max(), resolution)):
             for j, yj in enumerate(np.linspace(y.min(), y.max(), resolution)):
                 grid[i, j] = self.get_gev_params(np.array([xi, yj])).to_array()
-        gev_param_idx = GevParams.GEV_PARAM_NAMES.index(gev_param_name)
-        plt.imshow(grid[..., gev_param_idx], extent=(x.min(), x.max(), y.min(), y.max()),
-                   interpolation='nearest', cmap=cm.gist_rainbow)
-        if show:
-            plt.show()
+        return grid
diff --git a/extreme_estimator/extreme_models/margin_model/smooth_margin_model.py b/extreme_estimator/extreme_models/margin_model/smooth_margin_model.py
index 8d9febab..6a33d939 100644
--- a/extreme_estimator/extreme_models/margin_model/smooth_margin_model.py
+++ b/extreme_estimator/extreme_models/margin_model/smooth_margin_model.py
@@ -33,8 +33,8 @@ class LinearShapeAxis0MarginModel(LinearMarginModel):
     def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_linear_axis=None):
         super().load_margin_functions({GevParams.GEV_SHAPE: 0})
 
-    # def fitmargin_from_maxima_gev(self, maxima_gev: np.ndarray, coordinates: np.ndarray) -> AbstractMarginFunction:
-    #     return self.margin_function_start_fit
+    def fitmargin_from_maxima_gev(self, maxima_gev: np.ndarray, coordinates: np.ndarray) -> AbstractMarginFunction:
+        return self.margin_function_start_fit
 
 
 if __name__ == '__main__':
diff --git a/extreme_estimator/return_level_plot/spatial_2D_plot.py b/extreme_estimator/return_level_plot/spatial_2D_plot.py
new file mode 100644
index 00000000..8347b0e9
--- /dev/null
+++ b/extreme_estimator/return_level_plot/spatial_2D_plot.py
@@ -0,0 +1,30 @@
+from itertools import product
+from typing import List, Dict
+
+import matplotlib.pyplot as plt
+
+from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \
+    AbstractMarginFunction
+from extreme_estimator.gev_params import GevParams
+
+plt.style.use('seaborn-white')
+
+
+class Spatial2DPlot(object):
+
+    def __init__(self, name_to_margin_function: Dict[str, AbstractMarginFunction]):
+        self.name_to_margin_function = name_to_margin_function # type: Dict[str, AbstractMarginFunction]
+        self.grid_columns = GevParams.GEV_PARAM_NAMES
+
+    def plot(self):
+        nb_grid_rows, nb_grid_columns = len(self.name_to_margin_function), len(self.grid_columns)
+        fig, axes = plt.subplots(nb_grid_rows, nb_grid_columns, sharex='col', sharey='row')
+        fig.subplots_adjust(hspace=0.4, wspace=0.4)
+        margin_function: AbstractMarginFunction
+        for i, (name, margin_function) in enumerate(self.name_to_margin_function.items()):
+            for j, param_name in enumerate(self.grid_columns):
+                ax = axes[i, j] if nb_grid_rows > 1 else axes[j]
+                margin_function.visualize_2D(gev_param_name=param_name, ax=ax)
+                ax.set_title("{} for {}".format(param_name, name))
+        fig.suptitle('Spatial2DPlot')
+        plt.show()
diff --git a/test/test_extreme_estimator/test_estimator/test_full_estimators.py b/test/test_extreme_estimator/test_estimator/test_full_estimators.py
index 5e03f8fd..49df744f 100644
--- a/test/test_extreme_estimator/test_estimator/test_full_estimators.py
+++ b/test/test_extreme_estimator/test_estimator/test_full_estimators.py
@@ -4,7 +4,7 @@ from itertools import product
 from extreme_estimator.estimator.full_estimator import SmoothMarginalsThenUnitaryMsp
 from spatio_temporal_dataset.dataset.simulation_dataset import FullSimulatedDataset
 from spatio_temporal_dataset.spatial_coordinates.generated_coordinates import CircleCoordinatesRadius1
-from test.test_extreme_estimator.test_estimator.test_margin_estimators import TestMarginEstimators
+from test.test_extreme_estimator.test_estimator.test_margin_estimators import TestSmoothMarginEstimator
 from test.test_extreme_estimator.test_estimator.test_max_stable_estimators import TestMaxStableEstimators
 
 
@@ -16,7 +16,7 @@ class TestFullEstimators(unittest.TestCase):
         super().setUp()
         self.spatial_coordinates = CircleCoordinatesRadius1.from_nb_points(nb_points=5, max_radius=1)
         self.max_stable_models = TestMaxStableEstimators.load_max_stable_models()
-        self.margin_models = TestMarginEstimators.load_margin_models(spatial_coordinates=self.spatial_coordinates)
+        self.margin_models = TestSmoothMarginEstimator.load_margin_models(spatial_coordinates=self.spatial_coordinates)
 
     def test_full_estimators(self):
         for margin_model, max_stable_model in product(self.margin_models, self.max_stable_models):
diff --git a/test/test_extreme_estimator/test_estimator/test_margin_estimators.py b/test/test_extreme_estimator/test_estimator/test_margin_estimators.py
index 480ebba2..9dcea3dc 100644
--- a/test/test_extreme_estimator/test_estimator/test_margin_estimators.py
+++ b/test/test_extreme_estimator/test_estimator/test_margin_estimators.py
@@ -1,15 +1,17 @@
 import unittest
 
 from extreme_estimator.extreme_models.margin_model.abstract_margin_model import AbstractMarginModel
-from extreme_estimator.extreme_models.margin_model.smooth_margin_model import ConstantMarginModel
+from extreme_estimator.extreme_models.margin_model.smooth_margin_model import ConstantMarginModel, \
+    LinearShapeAxis0MarginModel
 from extreme_estimator.estimator.margin_estimator import SmoothMarginEstimator
+from extreme_estimator.return_level_plot.spatial_2D_plot import Spatial2DPlot
 from spatio_temporal_dataset.dataset.simulation_dataset import MarginDataset
 from spatio_temporal_dataset.spatial_coordinates.generated_coordinates import CircleCoordinatesRadius1
 
 
-class TestMarginEstimators(unittest.TestCase):
+class TestSmoothMarginEstimator(unittest.TestCase):
     DISPLAY = False
-    MARGIN_TYPES = [ConstantMarginModel]
+    MARGIN_TYPES = [ConstantMarginModel, LinearShapeAxis0MarginModel][1:]
     MARGIN_ESTIMATORS = [SmoothMarginEstimator]
 
     def setUp(self):
@@ -25,14 +27,17 @@ class TestMarginEstimators(unittest.TestCase):
         for margin_model in self.margin_models:
             dataset = MarginDataset.from_sampling(nb_obs=10, margin_model=margin_model,
                                                   spatial_coordinates=self.spatial_coordinates)
-
-            for estimator_class in self.MARGIN_ESTIMATORS:
-                estimator = estimator_class(dataset=dataset, margin_model=margin_model)
-                estimator.fit()
-                if self.DISPLAY:
-                    print(type(margin_model))
-                    print(dataset.df_dataset.head())
-                    print(estimator.additional_information)
+            # Fit estimator
+            estimator = SmoothMarginEstimator(dataset=dataset, margin_model=margin_model)
+            estimator.fit()
+            # Map name to their margin functions
+            name_to_margin_function = {
+                'Ground truth margin function': dataset.margin_model.margin_function_sample,
+                # 'Estimated margin function': estimator.margin_function_fitted,
+            }
+            # Spatial Plot
+            if self.DISPLAY:
+                Spatial2DPlot(name_to_margin_function=name_to_margin_function).plot()
             self.assertTrue(True)
 
 
-- 
GitLab