From 9ff7812a3904c2d0cb9e5b44425cb91e041946e3 Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Wed, 13 Feb 2019 16:53:29 +0100
Subject: [PATCH] [SAFRAN] add map visualization on top of margin visualization

---
 experiment/safran_study/safran.py                  |  9 +++++----
 experiment/safran_study/safran_visualizer.py       | 14 +++++++++-----
 .../margin_function/abstract_margin_function.py    |  1 +
 3 files changed, 15 insertions(+), 9 deletions(-)

diff --git a/experiment/safran_study/safran.py b/experiment/safran_study/safran.py
index 63185388..374ebb46 100644
--- a/experiment/safran_study/safran.py
+++ b/experiment/safran_study/safran.py
@@ -108,7 +108,7 @@ class Safran(object):
 
     """ Visualization methods """
 
-    def visualize(self, ax=None, massif_name_to_fill_kwargs=None, show=True):
+    def visualize(self, ax=None, massif_name_to_fill_kwargs=None, show=True, fill=True):
         if ax is None:
             ax = plt.gca()
         df_massif = pd.read_csv(op.join(self.map_full_path, 'massifsalpes.csv'))
@@ -120,9 +120,10 @@ class Safran(object):
             l = [coords for idx, *coords in coord_tuples if idx == coordinate_id]
             l = list(zip(*l))
             ax.plot(*l, color='black')
-            massif_name = self.coordinate_id_to_massif_name[coordinate_id]
-            fill_kwargs = massif_name_to_fill_kwargs[massif_name] if massif_name_to_fill_kwargs is not None else {}
-            ax.fill(*l, **fill_kwargs)
+            if fill:
+                massif_name = self.coordinate_id_to_massif_name[coordinate_id]
+                fill_kwargs = massif_name_to_fill_kwargs[massif_name] if massif_name_to_fill_kwargs is not None else {}
+                ax.fill(*l, **fill_kwargs)
         ax.scatter(self.massifs_coordinates.x_coordinates, self.massifs_coordinates.y_coordinates)
         ax.axis('off')
 
diff --git a/experiment/safran_study/safran_visualizer.py b/experiment/safran_study/safran_visualizer.py
index 31e4824e..9abdbf52 100644
--- a/experiment/safran_study/safran_visualizer.py
+++ b/experiment/safran_study/safran_visualizer.py
@@ -38,18 +38,23 @@ class SafranVisualizer(object):
     def dataset(self):
         return AbstractDataset(self.observations, self.coordinates)
 
+    def fit_and_visualize_estimator(self, estimator):
+        estimator.fit()
+        axes = estimator.margin_function_fitted.visualize(show=False)
+        for ax in axes:
+            self.safran.visualize(ax, fill=False, show=False)
+        plt.show()
+
     def visualize_smooth_margin_fit(self):
         margin_model = LinearAllParametersAllDimsMarginModel(coordinates=self.coordinates)
         estimator = SmoothMarginEstimator(dataset=self.dataset, margin_model=margin_model)
-        estimator.fit()
-        estimator.margin_function_fitted.visualize(show=self.show)
+        self.fit_and_visualize_estimator(estimator)
 
     def visualize_full_fit(self):
         max_stable_model = Smith()
         margin_model = LinearAllParametersAllDimsMarginModel(coordinates=self.coordinates)
         estimator = FullEstimatorInASingleStepWithSmoothMargin(self.dataset, margin_model, max_stable_model)
-        estimator.fit()
-        estimator.margin_function_fitted.visualize(show=self.show)
+        self.fit_and_visualize_estimator(estimator)
 
     def visualize_independent_margin_fits(self, threshold=None, axes=None):
         if threshold is None:
@@ -78,7 +83,6 @@ class SafranVisualizer(object):
 
             self.safran.visualize(ax=ax, massif_name_to_fill_kwargs=massif_name_to_fill_kwargs, show=False)
 
-
         if self.show:
             plt.show()
 
diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
index 8ca34294..14e0ca22 100644
--- a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
+++ b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
@@ -69,6 +69,7 @@ class AbstractMarginFunction(object):
             ax.set_title(title_str)
         if show:
             plt.show()
+        return axes
 
     def visualize_single_param(self, gev_value_name=GevParams.LOC, ax=None, show=True):
         assert gev_value_name in GevParams.SUMMARY_NAMES
-- 
GitLab