From 7833bcb46c6376c64882e50f5d0c799e44dab3d1 Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Wed, 20 Feb 2019 15:53:49 +0100
Subject: [PATCH] [SCM] add visualization/comparison of multiple max_stable
 models for SCM

---
 .../meteo_france_SCM_study/main_visualize.py  | 19 ++++++++---
 .../safran/safran_visualizer.py               | 34 +++++++++++++------
 .../regression_margin/regression_margin.py    |  4 +--
 .../abstract_margin_function.py               |  4 +--
 .../test_estimator/test_margin_estimators.py  |  4 +--
 .../test_extreme_models/test_margin_model.py  |  4 +--
 utils.py                                      |  3 +-
 7 files changed, 48 insertions(+), 24 deletions(-)

diff --git a/experiment/meteo_france_SCM_study/main_visualize.py b/experiment/meteo_france_SCM_study/main_visualize.py
index 48eb5f40..07d8a010 100644
--- a/experiment/meteo_france_SCM_study/main_visualize.py
+++ b/experiment/meteo_france_SCM_study/main_visualize.py
@@ -20,11 +20,22 @@ def load_all_studies(study_class, only_first_one=False):
     return all_studies
 
 
-if __name__ == '__main__':
-    for study_class in [ExtendedSafran, ExtendedCrocusSwe, ExtendedCrocusDepth][:]:
+def extended_visualization():
+    for study_class in [ExtendedSafran, ExtendedCrocusSwe, ExtendedCrocusDepth][:1]:
+        for study in load_all_studies(study_class, only_first_one=True):
+            study_visualizer = StudyVisualizer(study)
+            study_visualizer.visualize_all_kde_graphs()
+
+
+def normal_visualization():
+    for study_class in [Safran, CrocusSwe, CrocusDepth][:1]:
         for study in load_all_studies(study_class, only_first_one=True):
             study_visualizer = StudyVisualizer(study)
             # study_visualizer.visualize_independent_margin_fits(threshold=[None, 20, 40, 60][0])
             # study_visualizer.visualize_smooth_margin_fit()
-            study_visualizer.visualize_all_kde_graphs()
-            # study_visualizer.visualize_full_fit()
+            study_visualizer.visualize_full_fit()
+
+
+if __name__ == '__main__':
+    normal_visualization()
+    # extended_visualization()
\ No newline at end of file
diff --git a/experiment/meteo_france_SCM_study/safran/safran_visualizer.py b/experiment/meteo_france_SCM_study/safran/safran_visualizer.py
index cbc03594..d211fb38 100644
--- a/experiment/meteo_france_SCM_study/safran/safran_visualizer.py
+++ b/experiment/meteo_france_SCM_study/safran/safran_visualizer.py
@@ -16,6 +16,8 @@ from extreme_estimator.margin_fits.gpd.gpd_params import GpdParams
 from extreme_estimator.margin_fits.gpd.gpdmle_fit import GpdMleFit
 from extreme_estimator.margin_fits.plot.create_shifted_cmap import get_color_rbga_shifted
 from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
+from test.test_utils import load_test_max_stable_models
+from utils import get_display_name_from_object_type
 
 
 class StudyVisualizer(object):
@@ -89,14 +91,14 @@ class StudyVisualizer(object):
         assert len(x) == len(y)
         return x, y
 
-
-
-    def fit_and_visualize_estimator(self, estimator):
+    def fit_and_visualize_estimator(self, estimator, axes=None, show=True, title=None):
         estimator.fit()
-        axes = estimator.margin_function_fitted.visualize(show=False)
+        axes = estimator.margin_function_fitted.visualize_function(show=False, axes=axes, title=title)
         for ax in axes:
             self.study.visualize(ax, fill=False, show=False)
-        plt.show()
+        if show:
+            plt.suptitle(self.study.title)
+            plt.show()
 
     def visualize_smooth_margin_fit(self):
         margin_model = LinearAllParametersAllDimsMarginModel(coordinates=self.coordinates)
@@ -104,10 +106,19 @@ class StudyVisualizer(object):
         self.fit_and_visualize_estimator(estimator)
 
     def visualize_full_fit(self):
-        max_stable_model = Smith()
-        margin_model = LinearAllParametersAllDimsMarginModel(coordinates=self.coordinates)
-        estimator = FullEstimatorInASingleStepWithSmoothMargin(self.dataset, margin_model, max_stable_model)
-        self.fit_and_visualize_estimator(estimator)
+        max_stable_models = load_test_max_stable_models()
+        fig, axes = plt.subplots(len(max_stable_models), len(GevParams.SUMMARY_NAMES))
+        fig.subplots_adjust(hspace=1.0, wspace=1.0)
+        for i, max_stable_model in enumerate(max_stable_models):
+            margin_model = LinearAllParametersAllDimsMarginModel(coordinates=self.coordinates)
+            estimator = FullEstimatorInASingleStepWithSmoothMargin(self.dataset, margin_model, max_stable_model)
+            title = get_display_name_from_object_type(type(max_stable_model))
+            print(title)
+            self.fit_and_visualize_estimator(estimator, axes[i], show=False, title=title)
+        title = self.study.title
+        title += '\nMethod: Full Likelihood with Linear marginals and max stable dependency structure'
+        plt.suptitle(title)
+        plt.show()
 
     def visualize_independent_margin_fits(self, threshold=None, axes=None):
         if threshold is None:
@@ -150,8 +161,9 @@ class StudyVisualizer(object):
     @property
     def df_gev_mle_each_massif(self):
         # Fit a margin_fits on each massif
-        massif_to_gev_mle = {massif_name: GevMleFit(self.study.observations_annual_maxima.loc[massif_name]).gev_params.summary_serie
-                             for massif_name in self.study.safran_massif_names}
+        massif_to_gev_mle = {
+        massif_name: GevMleFit(self.study.observations_annual_maxima.loc[massif_name]).gev_params.summary_serie
+        for massif_name in self.study.safran_massif_names}
         return pd.DataFrame(massif_to_gev_mle, columns=self.study.safran_massif_names)
 
     def df_gpd_mle_each_massif(self, threshold):
diff --git a/experiment/regression_margin/regression_margin.py b/experiment/regression_margin/regression_margin.py
index e4f0e6ea..b38308a7 100644
--- a/experiment/regression_margin/regression_margin.py
+++ b/experiment/regression_margin/regression_margin.py
@@ -50,14 +50,14 @@ for i in range(nb_estimator):
         plt.show()
 
     margin_function_sample = dataset.margin_model.margin_function_sample # type: LinearMarginFunction
-    margin_function_sample.visualize(show=False, axes=axes, dot_display=True)
+    margin_function_sample.visualize_function(show=False, axes=axes, dot_display=True)
     axes = margin_function_sample.visualization_axes
 
     # Estimation part
     margin_model_for_estimator = margin_model_for_estimator_class(coordinates)
     full_estimator = FullEstimatorInASingleStepWithSmoothMargin(dataset, margin_model_for_estimator, max_stable_model)
     full_estimator.fit()
-    full_estimator.margin_function_fitted.visualize(axes=axes, show=False)
+    full_estimator.margin_function_fitted.visualize_function(axes=axes, show=False)
 plt.show()
 
 # Display all the margin on the same graph for comparison
diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
index 9fec7d59..09866a6a 100644
--- a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
+++ b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
@@ -57,7 +57,7 @@ class AbstractMarginFunction(object):
         self.filter = filter
         self.color = color
 
-    def visualize(self, axes=None, show=True, dot_display=False):
+    def visualize_function(self, axes=None, show=True, dot_display=False, title=None):
         self.datapoint_display = dot_display
         if axes is None:
             fig, axes = plt.subplots(1, len(GevParams.SUMMARY_NAMES))
@@ -66,7 +66,7 @@ class AbstractMarginFunction(object):
         for i, gev_value_name in enumerate(GevParams.SUMMARY_NAMES):
             ax = axes[i]
             self.visualize_single_param(gev_value_name, ax, show=False)
-            title_str = gev_value_name
+            title_str = gev_value_name if title is None else title
             ax.set_title(title_str)
         if show:
             plt.show()
diff --git a/test/test_extreme_estimator/test_estimator/test_margin_estimators.py b/test/test_extreme_estimator/test_estimator/test_margin_estimators.py
index aa19eb3f..e3e1593a 100644
--- a/test/test_extreme_estimator/test_estimator/test_margin_estimators.py
+++ b/test/test_extreme_estimator/test_estimator/test_margin_estimators.py
@@ -26,8 +26,8 @@ class TestSmoothMarginEstimator(unittest.TestCase):
                 estimator.fit()
                 # Plot
                 if self.DISPLAY:
-                    margin_model.margin_function_sample.visualize(show=True)
-                    estimator.margin_function_fitted.visualize(show=True)
+                    margin_model.margin_function_sample.visualize_function(show=True)
+                    estimator.margin_function_fitted.visualize_function(show=True)
         self.assertTrue(True)
 
 
diff --git a/test/test_extreme_estimator/test_extreme_models/test_margin_model.py b/test/test_extreme_estimator/test_extreme_models/test_margin_model.py
index 638e321f..7611f024 100644
--- a/test/test_extreme_estimator/test_extreme_models/test_margin_model.py
+++ b/test/test_extreme_estimator/test_extreme_models/test_margin_model.py
@@ -16,12 +16,12 @@ class VisualizationMarginModel(unittest.TestCase):
         spatial_coordinates = CircleSpatialCoordinates.from_nb_points(nb_points=self.nb_points)
         margin_model = self.margin_model(coordinates=spatial_coordinates)
         if self.DISPLAY:
-            margin_model.margin_function_sample.visualize()
+            margin_model.margin_function_sample.visualize_function()
 
     def test_example_visualization_1D(self):
         coordinates = LinSpaceSpatialCoordinates.from_nb_points(nb_points=self.nb_points)
         margin_model = self.margin_model(coordinates=coordinates, params_sample={(GevParams.SHAPE, 1): 0.02})
-        margin_model.margin_function_sample.visualize(show=self.DISPLAY)
+        margin_model.margin_function_sample.visualize_function(show=self.DISPLAY)
         self.assertTrue(True)
 
 
diff --git a/utils.py b/utils.py
index 44b396f6..d3e9071e 100644
--- a/utils.py
+++ b/utils.py
@@ -3,6 +3,7 @@ import os.path as op
 
 VERSION = datetime.datetime.now()
 
+
 def get_root_path() -> str:
     return op.dirname(op.abspath(__file__))
 
@@ -13,7 +14,7 @@ def get_full_path(relative_path: str) -> str:
 
 def get_display_name_from_object_type(object_type):
     # assert isinstance(object_type, type), object_type
-    return str(object_type).split('.')[-1].split("'")[0]
+    return str(object_type).split('.')[-1].split("'")[0].split(' ')[0]
 
 
 def first(s):
-- 
GitLab