From 98f7d0d4a54c99eb1fdc2bfc149777e9c5063b37 Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Thu, 4 Mar 2021 12:28:52 +0100
Subject: [PATCH] [projection snowfall] add together fit, and its usage into
 visualizer_for_sensitivity.py

---
 .../ensemble_fit/abstract_ensemble_fit.py     |  1 +
 .../visualizer_non_stationary_ensemble.py     | 28 ++++++++--
 .../visualizer_for_projection_ensemble.py     |  2 +
 .../visualizer_for_sensitivity.py             | 51 +++++++++++--------
 .../main_projections_ensemble.py              |  7 ++-
 .../main_sensitivity.py                       |  7 +--
 .../coordinates/abstract_coordinates.py       |  5 +-
 7 files changed, 68 insertions(+), 33 deletions(-)

diff --git a/extreme_trend/ensemble_fit/abstract_ensemble_fit.py b/extreme_trend/ensemble_fit/abstract_ensemble_fit.py
index 5d352ea8..115b0700 100644
--- a/extreme_trend/ensemble_fit/abstract_ensemble_fit.py
+++ b/extreme_trend/ensemble_fit/abstract_ensemble_fit.py
@@ -8,6 +8,7 @@ from extreme_trend.one_fold_fit.one_fold_fit import OneFoldFit
 class AbstractEnsembleFit(object):
     Median_merge = 'Median'
     Mean_merge = 'Mean'
+    Together_merge = 'Together'
 
     def __init__(self, massif_names, gcm_rcm_couple_to_altitude_studies: Dict[Tuple[str, str], AltitudesStudies],
                  models_classes,
diff --git a/extreme_trend/ensemble_fit/together_ensemble_fit/visualizer_non_stationary_ensemble.py b/extreme_trend/ensemble_fit/together_ensemble_fit/visualizer_non_stationary_ensemble.py
index 8781b583..4d8b1999 100644
--- a/extreme_trend/ensemble_fit/together_ensemble_fit/visualizer_non_stationary_ensemble.py
+++ b/extreme_trend/ensemble_fit/together_ensemble_fit/visualizer_non_stationary_ensemble.py
@@ -1,4 +1,5 @@
 from typing import List, Dict, Tuple
+import pandas as pd
 
 from extreme_data.meteo_france_data.scm_models_data.abstract_study import AbstractStudy
 from extreme_data.meteo_france_data.scm_models_data.altitudes_studies import AltitudesStudies
@@ -9,6 +10,10 @@ from extreme_trend.one_fold_fit.altitudes_studies_visualizer_for_non_stationary_
     AltitudesStudiesVisualizerForNonStationaryModels
 from extreme_trend.trend_test.visualizers.study_visualizer_for_non_stationary_trends import \
     StudyVisualizerForNonStationaryTrends
+from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
+from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
+from spatio_temporal_dataset.spatio_temporal_observations.abstract_spatio_temporal_observations import \
+    AbstractSpatioTemporalObservations
 
 
 class VisualizerNonStationaryEnsemble(AltitudesStudiesVisualizerForNonStationaryModels):
@@ -19,10 +24,23 @@ class VisualizerNonStationaryEnsemble(AltitudesStudiesVisualizerForNonStationary
         super().__init__(studies, *args, **kwargs)
 
     def get_massif_altitudes(self, massif_name):
-        # return self._get_massif_altitudes(massif_name, self.studies)
-        raise NotImplementedError
+        altitudes_before_intersection = []
+        for studies in self.gcm_rcm_couple_to_studies.values():
+            massif_altitudes = self._get_massif_altitudes(massif_name, studies)
+            altitudes_before_intersection.append(set(massif_altitudes))
+        altitudes_after_intersection = altitudes_before_intersection[0].intersection(*altitudes_before_intersection[1:])
+        altitudes_after_intersection = sorted(list(altitudes_after_intersection))
+        return altitudes_after_intersection
 
     def get_dataset(self, massif_altitudes, massif_name):
-        raise NotImplementedError
-        # dataset = self.studies.spatio_temporal_dataset(massif_name=massif_name, massif_altitudes=massif_altitudes)
-        # return dataset
+        df_coordinates_list = []
+        df_maxima_gev_list = []
+        for studies in self.gcm_rcm_couple_to_studies.values():
+            dataset = studies.spatio_temporal_dataset(massif_name=massif_name, massif_altitudes=massif_altitudes)
+            df_coordinates_list.append(dataset.coordinates.df_coordinates(add_climate_informations=True))
+            df_maxima_gev_list.append(dataset.observations.df_maxima_gev)
+        observations = AbstractSpatioTemporalObservations(df_maxima_gev=pd.concat(df_maxima_gev_list, axis=0))
+        coordinates = AbstractCoordinates(df=pd.concat(df_coordinates_list, axis=0),
+                                          slicer_class=type(dataset.slicer))
+        dataset = AbstractDataset(observations=observations, coordinates=coordinates)
+        return dataset
diff --git a/extreme_trend/ensemble_fit/visualizer_for_projection_ensemble.py b/extreme_trend/ensemble_fit/visualizer_for_projection_ensemble.py
index 665edb89..42b38f05 100644
--- a/extreme_trend/ensemble_fit/visualizer_for_projection_ensemble.py
+++ b/extreme_trend/ensemble_fit/visualizer_for_projection_ensemble.py
@@ -120,6 +120,8 @@ class VisualizerForProjectionEnsemble(object):
     def plot_together(self):
         visualizer_list = [together_ensemble_fit.visualizer
                            for together_ensemble_fit in self.ensemble_fits(TogetherEnsembleFit)]
+        for v in visualizer_list:
+            v.studies.study.gcm_rcm_couple = ("together", "merge")
         self.plot_for_visualizer_list(visualizer_list)
 
     def ensemble_fits(self, ensemble_class):
diff --git a/extreme_trend/ensemble_fit/visualizer_for_sensitivity.py b/extreme_trend/ensemble_fit/visualizer_for_sensitivity.py
index eeb69783..bd696188 100644
--- a/extreme_trend/ensemble_fit/visualizer_for_sensitivity.py
+++ b/extreme_trend/ensemble_fit/visualizer_for_sensitivity.py
@@ -10,6 +10,7 @@ from extreme_fit.model.margin_model.polynomial_margin_model.spatio_temporal_poly
 from extreme_fit.model.margin_model.utils import MarginFitMethod
 from extreme_trend.ensemble_fit.abstract_ensemble_fit import AbstractEnsembleFit
 from extreme_trend.ensemble_fit.independent_ensemble_fit.independent_ensemble_fit import IndependentEnsembleFit
+from extreme_trend.ensemble_fit.together_ensemble_fit.together_ensemble_fit import TogetherEnsembleFit
 from extreme_trend.ensemble_fit.visualizer_for_projection_ensemble import VisualizerForProjectionEnsemble
 from extreme_trend.one_fold_fit.altitude_group import get_altitude_class_from_altitudes, \
     get_linestyle_for_altitude_class
@@ -31,16 +32,16 @@ class VisualizerForSensivity(object):
                  is_temperature_interval=False,
                  is_shift_interval=False,
                  ):
+        self.ensemble_fit_classes = ensemble_fit_classes
         self.is_shift_interval = is_shift_interval
         self.temporal_covariate_for_fit = temporal_covariate_for_fit
         self.is_temperature_interval = is_temperature_interval
-        self.merge_visualizer_str_list = (AbstractEnsembleFit.Median_merge, AbstractEnsembleFit.Mean_merge)
         self.altitudes_list = altitudes_list
         self.massif_names = massif_names
         self.left_limits, self.right_limits = get_interval_limits(self.is_temperature_interval,
-                                                                   self.is_shift_interval)
+                                                                  self.is_shift_interval)
         self.left_limit_to_right_limit = OrderedDict(zip(self.left_limits, self.right_limits))
-        self.right_limit_to_visualizer = {} # type: Dict[float, VisualizerForProjectionEnsemble]
+        self.right_limit_to_visualizer = {}  # type: Dict[float, VisualizerForProjectionEnsemble]
 
         for left_limit, right_limit in zip(self.left_limits, self.right_limits):
             print("Interval is", left_limit, right_limit)
@@ -52,9 +53,9 @@ class VisualizerForSensivity(object):
                                                                   self.is_temperature_interval)
                 if year_min_and_year_max[0] is not None:
                     gcm_to_year_min_and_year_max[gcm] = year_min_and_year_max
-                
+
             visualizer = VisualizerForProjectionEnsemble(
-                altitudes_list, gcm_rcm_couples, study_class, Season.annual, scenario,
+                altitudes_list, gcm_rcm_couples, study_class, season, scenario,
                 model_classes=model_classes,
                 fit_method=fit_method,
                 ensemble_fit_classes=ensemble_fit_classes,
@@ -73,7 +74,12 @@ class VisualizerForSensivity(object):
         # , and not just the plots for the last t_min
         # for visualizer in self.temp_min_to_visualizer.values():
         #     visualizer.plot()
-        for merge_visualizer_str in self.merge_visualizer_str_list:
+        merge_visualizer_str_list = []
+        if IndependentEnsembleFit in self.ensemble_fit_classes:
+            merge_visualizer_str_list.extend([AbstractEnsembleFit.Median_merge, AbstractEnsembleFit.Mean_merge])
+        if TogetherEnsembleFit in self.ensemble_fit_classes:
+            merge_visualizer_str_list.append(AbstractEnsembleFit.Together_merge)
+        for merge_visualizer_str in merge_visualizer_str_list:
             self.sensitivity_plot(merge_visualizer_str)
 
     def sensitivity_plot(self, merge_visualizer_str):
@@ -89,7 +95,7 @@ class VisualizerForSensivity(object):
         ax.set_xticklabels(ticks_labels)
         ax.legend(prop={'size': 7}, loc='upper center', ncol=2)
         ax.set_ylim((0, 122))
-        ax.set_yticks([i*10 for i in range(11)])
+        ax.set_yticks([i * 10 for i in range(11)])
         merge_visualizer = self.first_merge_visualizer(merge_visualizer_str)
         temp_cov = self.temporal_covariate_for_fit is AnomalyTemperatureWithSplineTemporalCovariate
         merge_visualizer.plot_name = 'Sensitivity plot with ' \
@@ -99,19 +105,6 @@ class VisualizerForSensivity(object):
         merge_visualizer.show_or_save_to_file(no_title=True)
         plt.close()
 
-    def first_merge_visualizer(self, merge_visualizer_str):
-        altitude_class = get_altitude_class_from_altitudes(self.altitudes_list[0])
-        visualizer_projection = list(self.right_limit_to_visualizer.values())[0]
-        return self.get_merge_visualizer(altitude_class, visualizer_projection, merge_visualizer_str)
-
-    def get_merge_visualizer(self, altitude_class, visualizer_projection: VisualizerForProjectionEnsemble,
-                             merge_visualizer_str):
-        independent_ensemble_fit = visualizer_projection.altitude_class_to_ensemble_class_to_ensemble_fit[altitude_class][
-            IndependentEnsembleFit]
-        merge_visualizer = independent_ensemble_fit.merge_function_name_to_visualizer[merge_visualizer_str]
-        merge_visualizer.studies.study.gcm_rcm_couple = (merge_visualizer_str, "merge")
-        return merge_visualizer
-
     def interval_plot(self, ax, altitude_class, merge_visualizer_str):
         linestyle = get_linestyle_for_altitude_class(altitude_class)
         increasing_key = 'increasing'
@@ -137,3 +130,21 @@ class VisualizerForSensivity(object):
             color = label_to_color[label]
             ax.plot(self.right_limits, l, label=label_improved, color=color, linestyle=linestyle)
 
+    def get_merge_visualizer(self, altitude_class, visualizer_projection: VisualizerForProjectionEnsemble,
+                             merge_visualizer_str):
+        if merge_visualizer_str in [AbstractEnsembleFit.Median_merge, AbstractEnsembleFit.Mean_merge]:
+            independent_ensemble_fit = \
+            visualizer_projection.altitude_class_to_ensemble_class_to_ensemble_fit[altitude_class][
+                IndependentEnsembleFit]
+            merge_visualizer = independent_ensemble_fit.merge_function_name_to_visualizer[merge_visualizer_str]
+        else:
+            together_ensemble_fit = \
+            visualizer_projection.altitude_class_to_ensemble_class_to_ensemble_fit[altitude_class][TogetherEnsembleFit]
+            merge_visualizer = together_ensemble_fit.visualizer
+        merge_visualizer.studies.study.gcm_rcm_couple = (merge_visualizer_str, "merge")
+        return merge_visualizer
+
+    def first_merge_visualizer(self, merge_visualizer_str):
+        altitude_class = get_altitude_class_from_altitudes(self.altitudes_list[0])
+        visualizer_projection = list(self.right_limit_to_visualizer.values())[0]
+        return self.get_merge_visualizer(altitude_class, visualizer_projection, merge_visualizer_str)
diff --git a/projects/projected_extreme_snowfall/main_projections_ensemble.py b/projects/projected_extreme_snowfall/main_projections_ensemble.py
index 9cb4e279..3de38f49 100644
--- a/projects/projected_extreme_snowfall/main_projections_ensemble.py
+++ b/projects/projected_extreme_snowfall/main_projections_ensemble.py
@@ -45,7 +45,7 @@ def main():
     set_seed_for_test()
     AbstractExtractEurocodeReturnLevel.ALPHA_CONFIDENCE_INTERVAL_UNCERTAINTY = 0.2
 
-    fast = True
+    fast = False
     scenarios = rcp_scenarios[::-1] if fast is False else [AdamontScenario.rcp85]
 
     for scenario in scenarios:
@@ -57,7 +57,7 @@ def main():
             altitudes_list = altitudes_for_groups[3:]
         elif fast:
             massif_names = ['Vanoise', 'Haute-Maurienne']
-            gcm_rcm_couples = gcm_rcm_couples[4:6]
+            gcm_rcm_couples = gcm_rcm_couples[:]
             AbstractExtractEurocodeReturnLevel.NB_BOOTSTRAP = 10
             altitudes_list = altitudes_for_groups[:1]
         else:
@@ -73,7 +73,6 @@ def main():
 
         model_classes = ALTITUDINAL_GEV_MODELS_BASED_ON_POINTWISE_ANALYSIS
         assert scenario in rcp_scenarios
-        remove_physically_implausible_models = True
 
         visualizer = VisualizerForProjectionEnsemble(
             altitudes_list, gcm_rcm_couples, study_class, Season.annual, scenario,
@@ -81,7 +80,7 @@ def main():
             ensemble_fit_classes=ensemble_fit_classes,
             massif_names=massif_names,
             temporal_covariate_for_fit=temporal_covariate_for_fit,
-            remove_physically_implausible_models=remove_physically_implausible_models,
+            remove_physically_implausible_models=True,
             gcm_to_year_min_and_year_max=None,
         )
         visualizer.plot()
diff --git a/projects/projected_extreme_snowfall/main_sensitivity.py b/projects/projected_extreme_snowfall/main_sensitivity.py
index 8b88ffa7..39b2d243 100644
--- a/projects/projected_extreme_snowfall/main_sensitivity.py
+++ b/projects/projected_extreme_snowfall/main_sensitivity.py
@@ -4,6 +4,7 @@ from typing import List
 import matplotlib
 
 from extreme_trend.ensemble_fit.abstract_ensemble_fit import AbstractEnsembleFit
+from extreme_trend.ensemble_fit.together_ensemble_fit.together_ensemble_fit import TogetherEnsembleFit
 
 matplotlib.use('Agg')
 import matplotlib as mpl
@@ -40,7 +41,7 @@ from extreme_data.meteo_france_data.scm_models_data.utils import Season
 def main():
     start = time.time()
     study_class = AdamontSnowfall
-    ensemble_fit_classes = [IndependentEnsembleFit]
+    ensemble_fit_classes = [IndependentEnsembleFit, TogetherEnsembleFit][1:]
     temporal_covariate_for_fit = [TimeTemporalCovariate,
                                   AnomalyTemperatureWithSplineTemporalCovariate][0]
     set_seed_for_test()
@@ -53,9 +54,9 @@ def main():
         gcm_rcm_couples = get_gcm_rcm_couples(scenario)
         if fast is None:
             massif_names = None
-            gcm_rcm_couples = gcm_rcm_couples
+            gcm_rcm_couples = gcm_rcm_couples[4:6]
             AbstractExtractEurocodeReturnLevel.NB_BOOTSTRAP = 10
-            altitudes_list = altitudes_for_groups[1:2]
+            altitudes_list = altitudes_for_groups[3:]
         elif fast:
             massif_names = ['Vanoise', 'Haute-Maurienne']
             gcm_rcm_couples = gcm_rcm_couples[4:6]
diff --git a/spatio_temporal_dataset/coordinates/abstract_coordinates.py b/spatio_temporal_dataset/coordinates/abstract_coordinates.py
index d78bb441..83d077d6 100644
--- a/spatio_temporal_dataset/coordinates/abstract_coordinates.py
+++ b/spatio_temporal_dataset/coordinates/abstract_coordinates.py
@@ -154,11 +154,14 @@ class AbstractCoordinates(object):
 
     # Split
 
-    def df_coordinates(self, split: Split = Split.all, transformed=True) -> pd.DataFrame:
+    def df_coordinates(self, split: Split = Split.all, transformed=True, add_climate_informations=False) -> pd.DataFrame:
         if transformed:
             df_transformed_coordinates = self.transformation.transform_df(self.df_all_coordinates)
         else:
             df_transformed_coordinates = self.df_all_coordinates
+        if add_climate_informations:
+            df_transformed_coordinates = pd.concat([df_transformed_coordinates,
+                                                    self.df_coordinate_climate_model], axis=1)
         return df_sliced(df=df_transformed_coordinates, split=split, slicer=self.slicer)
 
     def coordinates_values(self, split: Split = Split.all, transformed=True) -> np.ndarray:
-- 
GitLab