From e9edd532b1ecf77c685dbea652dffd7f63689cf9 Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Tue, 19 Feb 2019 10:40:02 +0100
Subject: [PATCH] [SCM] add shifted visualization for both crocus variables

---
 .../meteo_france_SCM_study/abstract_study.py  | 28 ++++++++--------
 .../meteo_france_SCM_study/crocus/crocus.py   | 27 ++++++++++++----
 .../meteo_france_SCM_study/main_visualize.py  | 30 ++++++++++-------
 experiment/meteo_france_SCM_study/massif.py   | 16 +++++++---
 .../safran/fit_safran.py                      |  2 +-
 .../meteo_france_SCM_study/safran/safran.py   | 22 ++-----------
 .../safran/safran_visualizer.py               | 32 +++++++++----------
 .../margin_fits/plot/create_shifted_cmap.py   | 15 +++++----
 8 files changed, 93 insertions(+), 79 deletions(-)

diff --git a/experiment/meteo_france_SCM_study/abstract_study.py b/experiment/meteo_france_SCM_study/abstract_study.py
index 60b88549..264d1fbb 100644
--- a/experiment/meteo_france_SCM_study/abstract_study.py
+++ b/experiment/meteo_france_SCM_study/abstract_study.py
@@ -1,7 +1,7 @@
 import os
 import os.path as op
 from collections import OrderedDict
-from typing import List
+from typing import List, Dict
 
 import matplotlib.pyplot as plt
 import pandas as pd
@@ -17,29 +17,30 @@ from utils import get_full_path, cached_property
 
 
 class AbstractStudy(object):
+    ALTITUDES = [1800, 2400]
 
-    def __init__(self, safran_altitude=1800):
-        assert safran_altitude in [1800, 2400]
-        self.safran_altitude = safran_altitude
+    def __init__(self, variable_class, altitude=1800):
+        assert altitude in self.ALTITUDES
+        self.altitude = altitude
         self.model_name = None
-        self.variable_class = None
+        self.variable_class = variable_class
 
     def write_to_file(self, df):
         if not op.exists(self.result_full_path):
             os.makedirs(self.result_full_path, exist_ok=True)
-        df.to_csv(op.join(self.result_full_path, 'merged_array_{}_altitude.csv'.format(self.safran_altitude)))
+        df.to_csv(op.join(self.result_full_path, 'merged_array_{}_altitude.csv'.format(self.altitude)))
 
     """ Data """
 
     @property
-    def df_all_snowfall_concatenated(self):
+    def df_all_snowfall_concatenated(self) -> pd.DataFrame:
         df_list = [pd.DataFrame(snowfall, columns=self.safran_massif_names) for snowfall in
                    self.year_to_daily_time_serie.values()]
         df_concatenated = pd.concat(df_list)
         return df_concatenated
 
     @property
-    def observations_annual_maxima(self):
+    def observations_annual_maxima(self) -> AnnualMaxima:
         return AnnualMaxima(df_maxima_gev=pd.DataFrame(self.year_to_annual_maxima, index=self.safran_massif_names))
 
     """ Load some attributes only once """
@@ -79,11 +80,11 @@ class AbstractStudy(object):
     @property
     def safran_massif_names(self) -> List[str]:
         # Load the names of the massif as defined by SAFRAN
-        return safran_massif_names_from_datasets(self.year_to_dataset_ordered_dict.values())
+        return safran_massif_names_from_datasets(list(self.year_to_dataset_ordered_dict.values()))
 
     @property
-    def safran_massif_id_to_massif_name(self):
-        return dict(enumerate(self.safran_massif_names))
+    def safran_massif_id_to_massif_name(self) -> Dict[int, str]:
+        return {massif_id: massif_name for massif_id, massif_name in enumerate(self.safran_massif_names)}
 
     @cached_property
     def massifs_coordinates(self) -> AbstractSpatialCoordinates:
@@ -103,13 +104,14 @@ class AbstractStudy(object):
         return df_centroid
 
     @property
-    def coordinate_id_to_massif_name(self) -> dict:
+    def coordinate_id_to_massif_name(self) -> Dict[int, str]:
         df_centroid = self.load_df_centroid()
         return dict(zip(df_centroid['id'], df_centroid.index))
 
     """ Visualization methods """
 
     def visualize(self, ax=None, massif_name_to_fill_kwargs=None, show=True, fill=True):
+        print("here")
         if ax is None:
             ax = plt.gca()
         df_massif = pd.read_csv(op.join(self.map_full_path, 'massifsalpes.csv'))
@@ -144,7 +146,7 @@ class AbstractStudy(object):
     @property
     def safran_full_path(self) -> str:
         assert self.model_name in ['Safran', 'Crocus']
-        return op.join(self.full_path, 'safran-crocus_{}'.format(self.safran_altitude), self.model_name)
+        return op.join(self.full_path, 'safran-crocus_{}'.format(self.altitude), self.model_name)
 
     @property
     def map_full_path(self) -> str:
diff --git a/experiment/meteo_france_SCM_study/crocus/crocus.py b/experiment/meteo_france_SCM_study/crocus/crocus.py
index 5a9c2177..27331668 100644
--- a/experiment/meteo_france_SCM_study/crocus/crocus.py
+++ b/experiment/meteo_france_SCM_study/crocus/crocus.py
@@ -3,17 +3,32 @@ from experiment.meteo_france_SCM_study.crocus.crocus_variables import CrocusSweV
 
 
 class Crocus(AbstractStudy):
+    """
+    In the Crocus data, there is no 'massifsList' variable, thus we assume massifs are ordered just like Safran data
+    """
 
-    def __init__(self, safran_altitude=1800, crocus_variable_class=CrocusSweVariable):
-        super().__init__(safran_altitude)
+    def __init__(self, variable_class, altitude=1800):
+        assert variable_class in [CrocusSweVariable, CrocusDepthVariable]
+        super().__init__(variable_class, altitude)
         self.model_name = 'Crocus'
-        assert crocus_variable_class in [CrocusSweVariable, CrocusDepthVariable]
-        self.variable_class = crocus_variable_class
+
+
+class CrocusSwe(Crocus):
+
+    def __init__(self, altitude=1800):
+        super().__init__(CrocusSweVariable, altitude)
+
+
+class CrocusDepth(Crocus):
+
+    def __init__(self, altitude=1800):
+        super().__init__(CrocusDepthVariable, altitude)
+
 
 if __name__ == '__main__':
     for variable_class in [CrocusSweVariable, CrocusDepthVariable]:
-        study = Crocus(crocus_variable_class=variable_class)
+        study = Crocus(variable_class=variable_class)
         # d = study.year_to_dataset_ordered_dict[1960]
         # print(d)
         a = study.year_to_daily_time_serie[1960]
-        print(a.shape)
\ No newline at end of file
+        print(a.shape)
diff --git a/experiment/meteo_france_SCM_study/main_visualize.py b/experiment/meteo_france_SCM_study/main_visualize.py
index f20d24f0..dc249ce2 100644
--- a/experiment/meteo_france_SCM_study/main_visualize.py
+++ b/experiment/meteo_france_SCM_study/main_visualize.py
@@ -1,22 +1,28 @@
+from experiment.meteo_france_SCM_study.abstract_study import AbstractStudy
+from experiment.meteo_france_SCM_study.crocus.crocus import CrocusDepth, CrocusSwe
 from experiment.meteo_france_SCM_study.safran.safran import Safran
 from itertools import product
 
-from experiment.meteo_france_SCM_study.safran.safran_visualizer import SafranVisualizer
+from experiment.meteo_france_SCM_study.safran.safran_visualizer import StudyVisualizer
 
 
-def load_all_safran(only_first_one=False):
-    all_safran = []
-    for safran_alti, nb_day in product([1800, 2400], [1, 3, 7]):
-        print('alti: {}, nb_day: {}'.format(safran_alti, nb_day))
-        all_safran.append(Safran(safran_alti, nb_day))
+def load_all_studies(study_class, only_first_one=False):
+    all_studies = []
+    is_safran_study = study_class == Safran
+    nb_days = [1, 5] if is_safran_study else [1]
+    for alti, nb_day in product(AbstractStudy.ALTITUDES, nb_days):
+        print('alti: {}, nb_day: {}'.format(alti, nb_day))
+        study = Safran(alti, nb_day) if is_safran_study else study_class(alti)
+        all_studies.append(study)
         if only_first_one:
             break
-    return all_safran
+    return all_studies
 
 
 if __name__ == '__main__':
-    for safran in load_all_safran(only_first_one=True):
-        safran_visualizer = SafranVisualizer(safran)
-        # safran_visualizer.visualize_independent_margin_fits(threshold=[None, 20, 40, 60][0])
-        safran_visualizer.visualize_smooth_margin_fit()
-        # safran_visualizer.visualize_full_fit()
+    for study_class in [Safran, CrocusSwe, CrocusDepth][:]:
+        for study in load_all_studies(study_class, only_first_one=True):
+            study_visualizer = StudyVisualizer(study)
+            # safran_visualizer.visualize_independent_margin_fits(threshold=[None, 20, 40, 60][0])
+            study_visualizer.visualize_smooth_margin_fit()
+            # safran_visualizer.visualize_full_fit()
diff --git a/experiment/meteo_france_SCM_study/massif.py b/experiment/meteo_france_SCM_study/massif.py
index fe535b14..eb65d7ba 100644
--- a/experiment/meteo_france_SCM_study/massif.py
+++ b/experiment/meteo_france_SCM_study/massif.py
@@ -1,5 +1,9 @@
 from utils import first
 
+MASSIF_NAMES = ['Pelvoux', 'Queyras', 'Mont-Blanc', 'Aravis', 'Haute-Tarentaise', 'Vercors', 'Alpes-Azur', 'Oisans',
+                'Mercantour', 'Chartreuse', 'Haute-Maurienne', 'Belledonne', 'Thabor', 'Parpaillon', 'Bauges',
+                'Chablais', 'Ubaye', 'Grandes-Rousses', 'Devoluy', 'Champsaur', 'Vanoise', 'Beaufortain', 'Maurienne']
+
 
 class Massif(object):
 
@@ -16,7 +20,11 @@ class Massif(object):
 
 
 def safran_massif_names_from_datasets(datasets):
-    # Assert the all the datasets have the same indexing for the massif
-    assert len(set([dataset.massifsList for dataset in datasets])) == 1
-    # List of the name of the massif used by all the SAFRAN datasets
-    return [Massif.from_str(massif_str).name for massif_str in first(datasets).massifsList.split('/')]
\ No newline at end of file
+    # Massifs names are extracted from SAFRAN dataset
+    if hasattr(datasets[0], 'massifsList'):
+        # Assert the all the datasets have the same indexing for the massif
+        assert len(set([dataset.massifsList for dataset in datasets])) == 1
+        # List of the name of the massif used by all the SAFRAN datasets
+        safran_names = [Massif.from_str(massif_str).name for massif_str in first(datasets).massifsList.split('/')]
+        assert MASSIF_NAMES == safran_names
+    return MASSIF_NAMES
diff --git a/experiment/meteo_france_SCM_study/safran/fit_safran.py b/experiment/meteo_france_SCM_study/safran/fit_safran.py
index 0e93a164..47fe572c 100644
--- a/experiment/meteo_france_SCM_study/safran/fit_safran.py
+++ b/experiment/meteo_france_SCM_study/safran/fit_safran.py
@@ -13,7 +13,7 @@ def fit_mle_gev_for_all_safran_and_different_days():
             # safran = Safran(safran_alti, nb_day)
             safran = ExtendedSafran(safran_alti, nb_day)
             df = safran.df_gev_mle_each_massif
-            df.index += ' Safran{} with {} days'.format(safran.safran_altitude, safran.nb_days_of_snowfall)
+            df.index += ' Safran{} with {} days'.format(safran.altitude, safran.nb_days_of_snowfall)
             dfs.append(df)
     df = pd.concat(dfs)
     path = r'/home/erwan/Documents/projects/spatiotemporalextremes/local/spatio_temporal_datasets/results/fit_mle_massif/fit_mle_gev_{}.csv'
diff --git a/experiment/meteo_france_SCM_study/safran/safran.py b/experiment/meteo_france_SCM_study/safran/safran.py
index a647d239..81b46344 100644
--- a/experiment/meteo_france_SCM_study/safran/safran.py
+++ b/experiment/meteo_france_SCM_study/safran/safran.py
@@ -1,32 +1,14 @@
-from typing import List
-
-import os.path as op
-from collections import OrderedDict
-
-import matplotlib.pyplot as plt
-import pandas as pd
-
 from experiment.meteo_france_SCM_study.abstract_study import AbstractStudy
 from experiment.meteo_france_SCM_study.abstract_variable import AbstractVariable
-from experiment.meteo_france_SCM_study.massif import safran_massif_names_from_datasets
 from experiment.meteo_france_SCM_study.safran.safran_snowfall_variable import SafranSnowfallVariable
-from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
-from spatio_temporal_dataset.coordinates.spatial_coordinates.abstract_spatial_coordinates import \
-    AbstractSpatialCoordinates
-from spatio_temporal_dataset.spatio_temporal_observations.annual_maxima_observations import AnnualMaxima
-from utils import cached_property
 
 
 class Safran(AbstractStudy):
 
-    def __init__(self, safran_altitude=1800, nb_days_of_snowfall=1):
-        super().__init__(safran_altitude)
+    def __init__(self, altitude=1800, nb_days_of_snowfall=1):
+        super().__init__(SafranSnowfallVariable, altitude)
         self.nb_days_of_snowfall = nb_days_of_snowfall
         self.model_name = 'Safran'
-        self.variable_class = SafranSnowfallVariable
-
 
     def instantiate_variable_object(self, dataset) -> AbstractVariable:
         return self.variable_class(dataset, self.nb_days_of_snowfall)
-
-
diff --git a/experiment/meteo_france_SCM_study/safran/safran_visualizer.py b/experiment/meteo_france_SCM_study/safran/safran_visualizer.py
index 8a8c79db..517ea10a 100644
--- a/experiment/meteo_france_SCM_study/safran/safran_visualizer.py
+++ b/experiment/meteo_france_SCM_study/safran/safran_visualizer.py
@@ -1,6 +1,7 @@
 import matplotlib.pyplot as plt
 import pandas as pd
 
+from experiment.meteo_france_SCM_study.abstract_study import AbstractStudy
 from extreme_estimator.estimator.full_estimator.abstract_full_estimator import \
     FullEstimatorInASingleStepWithSmoothMargin
 from extreme_estimator.estimator.margin_estimator.abstract_margin_estimator import SmoothMarginEstimator
@@ -15,19 +16,19 @@ from extreme_estimator.margin_fits.plot.create_shifted_cmap import get_color_rbg
 from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
 
 
-class SafranVisualizer(object):
+class StudyVisualizer(object):
 
-    def __init__(self, safran: Safran, show=True):
-        self.safran = safran
+    def __init__(self, study: AbstractStudy, show=True):
+        self.study = study
         self.show = show
 
     @property
     def observations(self):
-        return self.safran.observations_annual_maxima
+        return self.study.observations_annual_maxima
 
     @property
     def coordinates(self):
-        return self.safran.massifs_coordinates
+        return self.study.massifs_coordinates
 
     @property
     def dataset(self):
@@ -37,11 +38,10 @@ class SafranVisualizer(object):
         estimator.fit()
         axes = estimator.margin_function_fitted.visualize(show=False)
         for ax in axes:
-            self.safran.visualize(ax, fill=False, show=False)
+            self.study.visualize(ax, fill=False, show=False)
         plt.show()
 
     def visualize_smooth_margin_fit(self):
-        # todo: fix some blue points in the corner when we display the margin
         margin_model = LinearAllParametersAllDimsMarginModel(coordinates=self.coordinates)
         estimator = SmoothMarginEstimator(dataset=self.dataset, margin_model=margin_model)
         self.fit_and_visualize_estimator(estimator)
@@ -73,8 +73,8 @@ class SafranVisualizer(object):
             values = list(massif_name_to_value.values())
             colors = get_color_rbga_shifted(ax, gev_param_name, values)
             massif_name_to_fill_kwargs = {massif_name: {'color': color} for massif_name, color in
-                                          zip(self.safran.safran_massif_names, colors)}
-            self.safran.visualize(ax=ax, massif_name_to_fill_kwargs=massif_name_to_fill_kwargs, show=False)
+                                          zip(self.study.safran_massif_names, colors)}
+            self.study.visualize(ax=ax, massif_name_to_fill_kwargs=massif_name_to_fill_kwargs, show=False)
 
         if self.show:
             plt.show()
@@ -86,20 +86,20 @@ class SafranVisualizer(object):
         massif_name_to_fill_kwargs = {massif_name: {'color': orig_cmap(value)} for massif_name, value in
                                       massif_name_to_value.items()}
 
-        self.safran.visualize(massif_name_to_fill_kwargs=massif_name_to_fill_kwargs)
+        self.study.visualize(massif_name_to_fill_kwargs=massif_name_to_fill_kwargs)
 
     """ Statistics methods """
 
     @property
     def df_gev_mle_each_massif(self):
         # Fit a margin_fits on each massif
-        massif_to_gev_mle = {massif_name: GevMleFit(self.safran.observations_annual_maxima.loc[massif_name]).gev_params.summary_serie
-                             for massif_name in self.safran.safran_massif_names}
-        return pd.DataFrame(massif_to_gev_mle, columns=self.safran.safran_massif_names)
+        massif_to_gev_mle = {massif_name: GevMleFit(self.study.observations_annual_maxima.loc[massif_name]).gev_params.summary_serie
+                             for massif_name in self.study.safran_massif_names}
+        return pd.DataFrame(massif_to_gev_mle, columns=self.study.safran_massif_names)
 
     def df_gpd_mle_each_massif(self, threshold):
         # Fit a margin fit on each massif
-        massif_to_gev_mle = {massif_name: GpdMleFit(self.safran.df_all_snowfall_concatenated[massif_name],
+        massif_to_gev_mle = {massif_name: GpdMleFit(self.study.df_all_snowfall_concatenated[massif_name],
                                                     threshold=threshold).gpd_params.summary_serie
-                             for massif_name in self.safran.safran_massif_names}
-        return pd.DataFrame(massif_to_gev_mle, columns=self.safran.safran_massif_names)
+                             for massif_name in self.study.safran_massif_names}
+        return pd.DataFrame(massif_to_gev_mle, columns=self.study.safran_massif_names)
diff --git a/extreme_estimator/margin_fits/plot/create_shifted_cmap.py b/extreme_estimator/margin_fits/plot/create_shifted_cmap.py
index 9e83b278..b0fe67fc 100644
--- a/extreme_estimator/margin_fits/plot/create_shifted_cmap.py
+++ b/extreme_estimator/margin_fits/plot/create_shifted_cmap.py
@@ -18,12 +18,14 @@ from spatio_temporal_dataset.slicer.split import Split
 def plot_extreme_param(ax, gev_param_name, values):
     # Load the shifted cmap to center on a middle point
     vmin, vmax = np.min(values), np.max(values)
-    cmap = [plt.cm.coolwarm, plt.cm.bwr, plt.cm.seismic][1]
-    if gev_param_name == ExtremeParams.SHAPE and vmin < 0:
+    if vmin < 0 < vmax:
         midpoint = 1 - vmax / (vmax + abs(vmin))
-        shifted_cmap = shiftedColorMap(cmap, midpoint=midpoint, name='shifted')
-    else:
-        shifted_cmap = shiftedColorMap(cmap, midpoint=0.0, name='shifted')
+    elif vmin < 0 and vmax < 0:
+        midpoint = 1.0
+    elif vmin > 0 and vmax > 0:
+        midpoint = 0.0
+    cmap = [plt.cm.coolwarm, plt.cm.bwr, plt.cm.seismic][1]
+    shifted_cmap = shiftedColorMap(cmap, midpoint=midpoint, name='shifted')
     norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
     divider = make_axes_locatable(ax)
     cax = divider.append_axes('right', size='5%', pad=0.05)
@@ -50,10 +52,9 @@ def imshow_shifted(ax, gev_param_name, values, x, y):
 
     masked_array = values
     if gev_param_name != ExtremeParams.SHAPE:
-        epsilon = 1.0
+        epsilon = 1e-2 * (np.max(values) - np.min(values))
         value = np.min(values)
         # The right blue corner will be blue (but most of the time, another display will be on top)
         masked_array[-1, -1] = value - epsilon
-
     ax.imshow(masked_array, extent=(x.min(), x.max(), y.min(), y.max()), cmap=shifted_cmap)
 
-- 
GitLab