diff --git a/experiment/safran_study/safran.py b/experiment/safran_study/safran.py
index 63185388e2ddd231eed05e60ee1b5fe6c5cb714b..374ebb460281eee0e9c87193a3487bef05ca0d59 100644
--- a/experiment/safran_study/safran.py
+++ b/experiment/safran_study/safran.py
@@ -108,7 +108,7 @@ class Safran(object):
 
     """ Visualization methods """
 
-    def visualize(self, ax=None, massif_name_to_fill_kwargs=None, show=True):
+    def visualize(self, ax=None, massif_name_to_fill_kwargs=None, show=True, fill=True):
         if ax is None:
             ax = plt.gca()
         df_massif = pd.read_csv(op.join(self.map_full_path, 'massifsalpes.csv'))
@@ -120,9 +120,10 @@ class Safran(object):
             l = [coords for idx, *coords in coord_tuples if idx == coordinate_id]
             l = list(zip(*l))
             ax.plot(*l, color='black')
-            massif_name = self.coordinate_id_to_massif_name[coordinate_id]
-            fill_kwargs = massif_name_to_fill_kwargs[massif_name] if massif_name_to_fill_kwargs is not None else {}
-            ax.fill(*l, **fill_kwargs)
+            if fill:
+                massif_name = self.coordinate_id_to_massif_name[coordinate_id]
+                fill_kwargs = massif_name_to_fill_kwargs[massif_name] if massif_name_to_fill_kwargs is not None else {}
+                ax.fill(*l, **fill_kwargs)
         ax.scatter(self.massifs_coordinates.x_coordinates, self.massifs_coordinates.y_coordinates)
         ax.axis('off')
 
diff --git a/experiment/safran_study/safran_visualizer.py b/experiment/safran_study/safran_visualizer.py
index 31e4824e6cc5cebd42e364bbe3263fc0d0841e95..9abdbf52d2696040e4bc54f1ce1da71374e11e52 100644
--- a/experiment/safran_study/safran_visualizer.py
+++ b/experiment/safran_study/safran_visualizer.py
@@ -38,18 +38,23 @@ class SafranVisualizer(object):
     def dataset(self):
         return AbstractDataset(self.observations, self.coordinates)
 
+    def fit_and_visualize_estimator(self, estimator):
+        estimator.fit()
+        axes = estimator.margin_function_fitted.visualize(show=False)
+        for ax in axes:
+            self.safran.visualize(ax, fill=False, show=False)
+        plt.show()
+
     def visualize_smooth_margin_fit(self):
         margin_model = LinearAllParametersAllDimsMarginModel(coordinates=self.coordinates)
         estimator = SmoothMarginEstimator(dataset=self.dataset, margin_model=margin_model)
-        estimator.fit()
-        estimator.margin_function_fitted.visualize(show=self.show)
+        self.fit_and_visualize_estimator(estimator)
 
     def visualize_full_fit(self):
         max_stable_model = Smith()
         margin_model = LinearAllParametersAllDimsMarginModel(coordinates=self.coordinates)
         estimator = FullEstimatorInASingleStepWithSmoothMargin(self.dataset, margin_model, max_stable_model)
-        estimator.fit()
-        estimator.margin_function_fitted.visualize(show=self.show)
+        self.fit_and_visualize_estimator(estimator)
 
     def visualize_independent_margin_fits(self, threshold=None, axes=None):
         if threshold is None:
@@ -78,7 +83,6 @@ class SafranVisualizer(object):
 
             self.safran.visualize(ax=ax, massif_name_to_fill_kwargs=massif_name_to_fill_kwargs, show=False)
 
-
         if self.show:
             plt.show()
 
diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
index 8ca34294bc6cff51bf0b7d77afbbf93e5c63480f..14e0ca22037eae84145d42e02d8c93cbbc245029 100644
--- a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
+++ b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
@@ -69,6 +69,7 @@ class AbstractMarginFunction(object):
             ax.set_title(title_str)
         if show:
             plt.show()
+        return axes
 
     def visualize_single_param(self, gev_value_name=GevParams.LOC, ax=None, show=True):
         assert gev_value_name in GevParams.SUMMARY_NAMES