From b37ea204e1fbc19ceaa20d438c91e5a127c39706 Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Thu, 18 Mar 2021 11:43:07 +0100
Subject: [PATCH] [projection swe] add crpss and train train test. improve the
 plot. fix bug for selection of plausible models.

---
 .../scm_models_data/utils_function.py         | 13 +++-
 .../main_model_as_truth.py                    | 73 +++++++++++--------
 .../model_as_truth.py                         | 61 +++++++++++-----
 .../weight_solver/abtract_weight_solver.py    | 18 +++--
 .../projected_swe/weight_solver/indicator.py  | 12 +++
 .../weight_solver/knutti_weight_solver.py     | 10 +--
 .../test_projected_swe/test_model_as_truth.py |  8 +-
 7 files changed, 130 insertions(+), 65 deletions(-)

diff --git a/extreme_data/meteo_france_data/scm_models_data/utils_function.py b/extreme_data/meteo_france_data/scm_models_data/utils_function.py
index d6a93399..ce25c874 100644
--- a/extreme_data/meteo_france_data/scm_models_data/utils_function.py
+++ b/extreme_data/meteo_france_data/scm_models_data/utils_function.py
@@ -70,19 +70,24 @@ class ReturnLevelBootstrap(object):
 
     def compute_all_return_level(self):
         idxs = list(range(self.nb_bootstrap))
+        multiprocess = self.multiprocess
+        if AbstractExtractEurocodeReturnLevel.NB_BOOTSTRAP <= 10:
+            multiprocess = False
 
-        if self.multiprocess is None:
+        if multiprocess is None:
 
             with Pool(NB_CORES) as p:
                 batchsize = math.ceil(AbstractExtractEurocodeReturnLevel.NB_BOOTSTRAP / NB_CORES)
                 list_return_level = p.map(self.compute_return_level_batch, batch(idxs, batchsize=batchsize))
                 return_level_list = list(chain.from_iterable(list_return_level))
 
-        elif self.multiprocess:
+        elif multiprocess:
+            f = self.compute_return_level_physically_plausible if self.only_physically_plausible_fits else self.compute_return_level
             with Pool(NB_CORES) as p:
-                return_level_list = p.map(self.compute_return_level, idxs)
+                return_level_list = p.map(f, idxs)
         else:
-            return_level_list = [self.compute_return_level(idx) for idx in idxs]
+            f = self.compute_return_level_physically_plausible if self.only_physically_plausible_fits else self.compute_return_level
+            return_level_list = [f(idx) for idx in idxs]
 
         return return_level_list
 
diff --git a/projects/projected_swe/model_as_truth_visualizer/main_model_as_truth.py b/projects/projected_swe/model_as_truth_visualizer/main_model_as_truth.py
index 86688952..b11d47a3 100644
--- a/projects/projected_swe/model_as_truth_visualizer/main_model_as_truth.py
+++ b/projects/projected_swe/model_as_truth_visualizer/main_model_as_truth.py
@@ -1,6 +1,7 @@
 from extreme_data.meteo_france_data.adamont_data.adamont.adamont_safran import AdamontSnowfall
 from extreme_data.meteo_france_data.adamont_data.adamont_scenario import AdamontScenario, get_gcm_rcm_couples
 from extreme_data.meteo_france_data.scm_models_data.safran.safran_max_snowf import SafranSnowfall2020
+from extreme_data.meteo_france_data.scm_models_data.utils_function import ReturnLevelBootstrap
 from extreme_fit.model.result_from_model_fit.result_from_extremes.abstract_extract_eurocode_return_level import \
     AbstractExtractEurocodeReturnLevel
 from projects.projected_swe.model_as_truth_visualizer.model_as_truth import ModelAsTruth
@@ -12,56 +13,64 @@ from projects.projected_swe.weight_solver.knutti_weight_solver_with_bootstrap im
 
 
 def main():
-    altitude = 900
+    # Set some parameters for the bootstrap
+    ReturnLevelBootstrap.only_physically_plausible_fits = True
     year_min_histo = 1982
     year_max_histo = 2011
-    year_min_projected = 2070
-    year_max_projected = 2099
     scenario = AdamontScenario.rcp85_extended
-    fast = None
+    fast = False
     gcm_rcm_couples = get_gcm_rcm_couples(adamont_scenario=scenario)
+    indicator_class = AnnualMaximaMeanIndicator
 
     if fast is None:
         AbstractExtractEurocodeReturnLevel.NB_BOOTSTRAP = 10
+        year_couples = [(1982, 2011), (2012, 2041), (2042, 2071)]
+        altitudes = [900, 1800]
+
         massif_names = None
-        knutti_weight_solver_classes = [EqualWeight,
-                                        KnuttiWeightSolver,
+        knutti_weight_solver_classes = [KnuttiWeightSolver,
                                         KnuttiWeightSolverWithBootstrapVersion1,
-                                        KnuttiWeightSolverWithBootstrapVersion2]
-        indicator_class = ReturnLevel30YearsIndicator
-        gcm_rcm_couples = gcm_rcm_couples[:3]
-        sigma_list = [10, 100, 1000, 10000]
+                                        KnuttiWeightSolverWithBootstrapVersion2][:1]
+        gcm_rcm_couples = gcm_rcm_couples[:8]
+        sigma_list = [6, 7, 8]
 
     elif fast:
+        altitudes = [900]
+        year_couples = [(1982, 2011)]
         AbstractExtractEurocodeReturnLevel.NB_BOOTSTRAP = 10
-        massif_names = ['Vercors']
-
-        knutti_weight_solver_classes = [EqualWeight, KnuttiWeightSolver,
-                                        KnuttiWeightSolverWithBootstrapVersion1]
-        indicator_class = ReturnLevel30YearsIndicator
+        massif_names = ['Chartreuse']
+        knutti_weight_solver_classes = [KnuttiWeightSolver]
         gcm_rcm_couples = gcm_rcm_couples[:3]
-        sigma_list = [100, 1000]
+        sigma_list = [10]
 
     else:
+        altitudes = [900, 1800, 2700, 3600][:2]
+        year_couples = [(1982, 2011), (2012, 2041), (2042, 2071), (2070, 2099)][:-1]
         massif_names = None
-        indicator_class = AnnualMaximaMeanIndicator
-        knutti_weight_solver_classes = [KnuttiWeightSolver,
-                                        KnuttiWeightSolverWithBootstrapVersion1,
-                                        KnuttiWeightSolverWithBootstrapVersion2]
+        AbstractExtractEurocodeReturnLevel.NB_BOOTSTRAP = 10
+        knutti_weight_solver_classes = [KnuttiWeightSolver][:]
+        gcm_rcm_couples = gcm_rcm_couples[:]
+        sigma_list = [i + 1 for i in range(10)]
 
-    observation_study = SafranSnowfall2020(altitude=altitude, year_min=year_min_histo, year_max=year_max_histo)
-    couple_to_historical_study = {c: AdamontSnowfall(altitude=altitude, scenario=scenario,
-                                                     year_min=year_min_histo, year_max=year_max_histo,
-                                                     gcm_rcm_couple=c) for c in gcm_rcm_couples}
-    couple_to_projected_study = {c: AdamontSnowfall(altitude=altitude, scenario=scenario,
-                                                    year_min=year_min_projected, year_max=year_max_projected,
-                                                    gcm_rcm_couple=c) for c in gcm_rcm_couples
-                                 }
+    for altitude in altitudes:
+        for year_couple in year_couples:
+            year_min_projected, year_max_projected = year_couple
+            observation_study = SafranSnowfall2020(altitude=altitude, year_min=year_min_histo, year_max=year_max_histo)
+            couple_to_historical_study = {c: AdamontSnowfall(altitude=altitude, scenario=scenario,
+                                                             year_min=year_min_histo, year_max=year_max_histo,
+                                                             gcm_rcm_couple=c) for c in gcm_rcm_couples}
+            if (year_min_projected, year_max_projected) == (year_min_histo, year_max_histo):
+                couple_to_projected_study = couple_to_historical_study
+            else:
+                couple_to_projected_study = {c: AdamontSnowfall(altitude=altitude, scenario=scenario,
+                                                                year_min=year_min_projected, year_max=year_max_projected,
+                                                                gcm_rcm_couple=c) for c in gcm_rcm_couples
+                                             }
 
-    model_as_truth = ModelAsTruth(observation_study, couple_to_projected_study, couple_to_historical_study,
-                                  indicator_class, knutti_weight_solver_classes, massif_names,
-                                  add_interdependence_weight=False)
-    model_as_truth.plot_against_sigma(sigma_list)
+            model_as_truth = ModelAsTruth(observation_study, couple_to_projected_study, couple_to_historical_study,
+                                          indicator_class, knutti_weight_solver_classes, massif_names,
+                                          add_interdependence_weight=False)
+            model_as_truth.plot_against_sigma(sigma_list)
 
 
 if __name__ == '__main__':
diff --git a/projects/projected_swe/model_as_truth_visualizer/model_as_truth.py b/projects/projected_swe/model_as_truth_visualizer/model_as_truth.py
index 4eb1267d..4acb56a6 100644
--- a/projects/projected_swe/model_as_truth_visualizer/model_as_truth.py
+++ b/projects/projected_swe/model_as_truth_visualizer/model_as_truth.py
@@ -5,10 +5,13 @@ from matplotlib.lines import Line2D
 from scipy.special import softmax
 import numpy as np
 
+from extreme_data.meteo_france_data.adamont_data.adamont_scenario import scenario_to_str
 from extreme_data.meteo_france_data.scm_models_data.abstract_study import AbstractStudy
+from extreme_data.meteo_france_data.scm_models_data.visualization.study_visualizer import StudyVisualizer
+from projects.projected_swe.old_weight_computer.utils import save_to_filepath
 from projects.projected_swe.weight_solver.abtract_weight_solver import AbstractWeightSolver
 from projects.projected_swe.weight_solver.default_weight_solver import EqualWeight
-from projects.projected_swe.weight_solver.indicator import AbstractIndicator
+from projects.projected_swe.weight_solver.indicator import AbstractIndicator, WeightComputationException
 from projects.projected_swe.weight_solver.knutti_weight_solver import KnuttiWeightSolver
 from projects.projected_swe.weight_solver.knutti_weight_solver_with_bootstrap import \
     KnuttiWeightSolverWithBootstrapVersion2, KnuttiWeightSolverWithBootstrapVersion1
@@ -61,6 +64,7 @@ class ModelAsTruth(object):
             assert len(x_list) == len(sigma_list)
             label = get_display_name_from_object_type(solver_class)
             color = self.solver_class_to_color[solver_class]
+            print(solver_class, score_list, np.array(score_list).mean(axis=1), np.median(np.array(score_list), axis=1))
             bplot = ax.boxplot(score_list, positions=x_list, widths=self.width, patch_artist=True, showmeans=True,
                                labels=[str(sigma) for sigma in sigma_list])
             for patch in bplot['boxes']:
@@ -71,10 +75,27 @@ class ModelAsTruth(object):
         custom_lines = [Line2D([0], [0], color=color, lw=4) for color in colors]
         ax.legend(custom_lines, labels, prop={'size': 8}, loc='upper left')
         ax.set_xlim(min(all_x) - self.width, max(all_x) + self.width)
-        _, max_y = ax.get_ylim()
-        ax.set_ylim((0, max_y * 1.1))
+        study_projected = list(self.couple_to_study_projected.values())[0]
+        title = 'crpss between a weighted forecast and an unweighted forecast \n' \
+                      'at {} m for {} of snowfall for {}-{} (%)'.format(self.observation_study.altitude,
+                                                                        self.indicator_class.str_indicator(),
+                                                                        study_projected.year_min,
+                                                                        study_projected.year_max)
+        ax2 = ax.twiny()
+        ax2.set_xlabel('{} for {} GCM/RCM couples'.format(scenario_to_str(study_projected.scenario), len(self.couple_to_study_projected)))
+        ax.set_xlabel('sigma skill parameter')
+        ax.set_ylabel(title)
 
-        plt.show()
+        # Plot a zero horizontal line
+        lim_left, lim_right = ax.get_xlim()
+        ax.hlines(0, xmin=lim_left, xmax=lim_right, linestyles='dashed')
+
+        # Save or show file
+        visualizer = StudyVisualizer(self.observation_study, show=False, save_to_file=True)
+        visualizer.plot_name = title.split('\n')[1]
+        visualizer.show_or_save_to_file(no_title=True)
+
+        plt.close()
 
     def get_x_list(self, j, sigma_list):
         shift = len(self.knutti_weight_solver_classes) + 1
@@ -95,19 +116,25 @@ class ModelAsTruth(object):
                                           c != gcm_rcm_couple}
             couple_to_study_projected = {c: s for c, s in self.couple_to_study_projected.items() if c != gcm_rcm_couple}
 
-            if issubclass(solver_class, KnuttiWeightSolver):
-                weight_solver = solver_class(sigma, None, historical_observation_study, couple_to_study_historical,
-                                             self.indicator_class, self.massif_names, self.add_interdependence_weight,
-                                             )  # type: AbstractWeightSolver
-            else:
-                weight_solver = solver_class(historical_observation_study, couple_to_study_historical,
-                                             self.indicator_class, self.massif_names, self.add_interdependence_weight,
-                                             )  # type: AbstractWeightSolver
-
-            print(solver_class, sigma, weight_solver.couple_to_weight.values())
-            mean_score = weight_solver.mean_prediction_score(self.massif_names, couple_to_study_projected,
-                                                             projected_observation_study)
-            score_list.append(mean_score)
+            try:
+                if issubclass(solver_class, KnuttiWeightSolver):
+                    weight_solver = solver_class(sigma, None, historical_observation_study, couple_to_study_historical,
+                                                 self.indicator_class, self.massif_names, self.add_interdependence_weight,
+                                                 )  # type: AbstractWeightSolver
+                else:
+                    weight_solver = solver_class(historical_observation_study, couple_to_study_historical,
+                                                 self.indicator_class, self.massif_names, self.add_interdependence_weight,
+                                                 )  # type: AbstractWeightSolver
+
+                print(solver_class, sigma, weight_solver.couple_to_weight.values())
+                mean_score = weight_solver.mean_prediction_score(self.massif_names, couple_to_study_projected,
+                                                                 projected_observation_study)
+                print(mean_score)
+                if mean_score < 1e4:
+                    score_list.append(mean_score)
+            except WeightComputationException:
+                pass
+        # print(solver_class, sigma, score_list)
         return np.array(score_list)
 
     def get_massif_names_subset_from_study_list(self, study_list: List[AbstractStudy]):
diff --git a/projects/projected_swe/weight_solver/abtract_weight_solver.py b/projects/projected_swe/weight_solver/abtract_weight_solver.py
index 68844411..85909226 100644
--- a/projects/projected_swe/weight_solver/abtract_weight_solver.py
+++ b/projects/projected_swe/weight_solver/abtract_weight_solver.py
@@ -5,7 +5,8 @@ from scipy.special import softmax
 import numpy as np
 
 from extreme_data.meteo_france_data.scm_models_data.abstract_study import AbstractStudy
-from projects.projected_swe.weight_solver.indicator import AbstractIndicator, ReturnLevelComputationException
+from projects.projected_swe.weight_solver.indicator import AbstractIndicator, ReturnLevelComputationException, \
+    ReturnLevel30YearsIndicator
 
 
 class AbstractWeightSolver(object):
@@ -47,7 +48,12 @@ class AbstractWeightSolver(object):
             couples, ensemble = zip(*list(couple_to_projected_indicator.items()))
             couple_to_weight = self.couple_to_weight
             weights = [couple_to_weight[c] for c in couples]
-            return ps.crps_ensemble(target, ensemble, weights=weights)
+            crps_weighted = ps.crps_ensemble(target, ensemble, weights=weights)
+            nb_weights = len(weights)
+            weights_unweighted = [1 / nb_weights for _ in range(nb_weights)]
+            crps_unweighted = ps.crps_ensemble(target, ensemble, weights=weights_unweighted)
+            crpss = 100 * (crps_weighted - crps_unweighted) / crps_unweighted
+            return crpss
         except ReturnLevelComputationException:
             return np.nan
 
@@ -56,14 +62,14 @@ class AbstractWeightSolver(object):
                   massif_name in massif_names]
         scores_filtered = [s for s in scores if not np.isnan(s)]
         assert len(scores_filtered) > 0
-        nb_massif_names_removed = len(scores) - len(scores_filtered)
-        if nb_massif_names_removed > 0:
-            print('{} massifs removed'.format(nb_massif_names_removed))
         return np.mean(scores_filtered)
 
     def target(self, massif_name, projected_observation_study):
         assert issubclass(self.indicator_class, AbstractIndicator)
-        return self.indicator_class.get_indicator(projected_observation_study, massif_name, bootstrap=True).mean()
+        if self.indicator_class is ReturnLevel30YearsIndicator:
+            return self.indicator_class.get_indicator(projected_observation_study, massif_name, bootstrap=True).mean()
+        else:
+            return self.indicator_class.get_indicator(projected_observation_study, massif_name)
 
     # Weight computation on the historical period
 
diff --git a/projects/projected_swe/weight_solver/indicator.py b/projects/projected_swe/weight_solver/indicator.py
index cc38a3ea..905a2079 100644
--- a/projects/projected_swe/weight_solver/indicator.py
+++ b/projects/projected_swe/weight_solver/indicator.py
@@ -19,6 +19,10 @@ class AbstractIndicator(object):
     def get_indicator(cls, study: AbstractStudy, massif_name, bootstrap=False):
         raise NotImplementedError
 
+    @classmethod
+    def str_indicator(cls):
+        raise NotImplementedError
+
 
 class AnnualMaximaMeanIndicator(AbstractIndicator):
 
@@ -29,6 +33,10 @@ class AnnualMaximaMeanIndicator(AbstractIndicator):
         else:
             return study.massif_name_to_annual_maxima[massif_name].mean()
 
+    @classmethod
+    def str_indicator(cls):
+        return 'Mean annual maxima'
+
 
 class ReturnLevel30YearsIndicator(AbstractIndicator):
 
@@ -41,3 +49,7 @@ class ReturnLevel30YearsIndicator(AbstractIndicator):
                 return study.massif_name_to_return_level(return_period=30)[massif_name]
         except KeyError:
             raise ReturnLevelComputationException
+
+    @classmethod
+    def str_indicator(cls, bootstrap):
+        return '30-year return level'
diff --git a/projects/projected_swe/weight_solver/knutti_weight_solver.py b/projects/projected_swe/weight_solver/knutti_weight_solver.py
index 9a6ae225..c15e7094 100644
--- a/projects/projected_swe/weight_solver/knutti_weight_solver.py
+++ b/projects/projected_swe/weight_solver/knutti_weight_solver.py
@@ -15,8 +15,6 @@ class KnuttiWeightSolver(AbstractWeightSolver):
         self.sigma_interdependence = sigma_interdependence
         if self.add_interdependence_weight:
             assert self.sigma_interdependence is not None
-        # Set some parameters for the bootstrap
-        ReturnLevelBootstrap.only_physically_plausible_fits = True
         # Set some parameters to speed up results (by caching some results)
         study_list = [self.observation_study] + list(self.couple_to_historical_study.values())
         for study in study_list:
@@ -29,10 +27,12 @@ class KnuttiWeightSolver(AbstractWeightSolver):
                 [self.compute_skill_one_massif(couple_study, massif_name) for couple_study in self.study_list]
                 if self.add_interdependence_weight:
                     [self.compute_interdependence_nllh_one_massif(couple_study, massif_name) for couple_study in self.study_list]
-            except WeightComputationException:
+            except WeightComputationException as e:
                 continue
             self.massif_names_for_computation.append(massif_name)
-        assert len(self.massif_names_for_computation) > 0, 'Sigma values should be increased'
+        if len(self.massif_names_for_computation) == 0:
+            print('Sigma values should be increased')
+            raise WeightComputationException
 
     @property
     def nb_massifs_for_computation(self):
@@ -64,7 +64,7 @@ class KnuttiWeightSolver(AbstractWeightSolver):
     def compute_nllh_from_two_study(self, study_1, study_2, sigma, massif_name):
         differences = self.differences(study_1, study_2, massif_name)
         scale = np.sqrt(np.power(sigma, 2) * self.nb_massifs_for_computation / 2)
-        proba = norm.pdf(differences, 0, scale)
+        proba = norm.pdf(x=differences, loc=0, scale=scale)
         proba_positive = (proba > 0).all()
         proba_lower_than_one = (proba <= 1).all()
         if not (proba_positive and proba_lower_than_one):
diff --git a/test/test_projects/test_projected_swe/test_model_as_truth.py b/test/test_projects/test_projected_swe/test_model_as_truth.py
index 95f0a251..c19e45c2 100644
--- a/test/test_projects/test_projected_swe/test_model_as_truth.py
+++ b/test/test_projects/test_projected_swe/test_model_as_truth.py
@@ -5,8 +5,10 @@ from extreme_data.meteo_france_data.adamont_data.adamont.adamont_safran import A
 from extreme_data.meteo_france_data.adamont_data.adamont_scenario import AdamontScenario, get_gcm_rcm_couples
 from extreme_data.meteo_france_data.scm_models_data.safran.safran import SafranSnowfall1Day
 from extreme_data.meteo_france_data.scm_models_data.safran.safran_max_snowf import SafranSnowfall2020
+from extreme_data.meteo_france_data.scm_models_data.utils_function import ReturnLevelBootstrap
 from extreme_fit.model.result_from_model_fit.result_from_extremes.abstract_extract_eurocode_return_level import \
     AbstractExtractEurocodeReturnLevel
+from extreme_fit.model.utils import set_seed_for_test
 from projects.projected_swe.weight_solver.indicator import AnnualMaximaMeanIndicator, ReturnLevel30YearsIndicator
 from projects.projected_swe.weight_solver.knutti_weight_solver import KnuttiWeightSolver
 from projects.projected_swe.weight_solver.knutti_weight_solver_with_bootstrap import \
@@ -16,6 +18,8 @@ from projects.projected_swe.weight_solver.knutti_weight_solver_with_bootstrap im
 class TestModelAsTruth(unittest.TestCase):
 
     def test_knutti_weight_solver(self):
+        set_seed_for_test()
+        ReturnLevelBootstrap.only_physically_plausible_fits = True
         altitude = 900
         year_min = 1982
         year_max = 2011
@@ -31,11 +35,13 @@ class TestModelAsTruth(unittest.TestCase):
                                            KnuttiWeightSolverWithBootstrapVersion2][:]:
             if knutti_weight_solver_class in [KnuttiWeightSolverWithBootstrapVersion1, KnuttiWeightSolverWithBootstrapVersion2]:
                 idx = 1
+                sigma = 1000
             else:
+                sigma = 10
                 idx = 0
             for indicator_class in [AnnualMaximaMeanIndicator, ReturnLevel30YearsIndicator][idx:]:
                 for add_interdependence_weight in [False, True]:
-                    knutti_weight = knutti_weight_solver_class(sigma_skill=100.0, sigma_interdependence=100.0,
+                    knutti_weight = knutti_weight_solver_class(sigma_skill=sigma, sigma_interdependence=sigma,
                                                                massif_names=massif_names,
                                                                observation_study=observation_study,
                                                                couple_to_historical_study=couple_to_study,
-- 
GitLab