From d93bab182fa22aeafaf59178c5044e534b39dfd0 Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Tue, 16 Mar 2021 09:42:41 +0100
Subject: [PATCH] [projection snowfall] improve weight computation for
 knutti_weight_solver.py

---
 .../scm_models_data/abstract_study.py         |  4 +
 .../weight_solver/abtract_weight_solver.py    | 47 ++++++-----
 .../projected_swe/weight_solver/indicator.py  | 37 +++++++--
 .../weight_solver/knutti_weight_solver.py     | 77 ++++++++++++++-----
 .../test_projected_swe/__init__.py            |  0
 .../test_projected_swe/test_model_as_truth.py | 48 ++++++++++++
 6 files changed, 169 insertions(+), 44 deletions(-)
 create mode 100644 test/test_projects/test_projected_swe/__init__.py
 create mode 100644 test/test_projects/test_projected_swe/test_model_as_truth.py

diff --git a/extreme_data/meteo_france_data/scm_models_data/abstract_study.py b/extreme_data/meteo_france_data/scm_models_data/abstract_study.py
index 3cd4d53b..09aaa8f5 100644
--- a/extreme_data/meteo_france_data/scm_models_data/abstract_study.py
+++ b/extreme_data/meteo_france_data/scm_models_data/abstract_study.py
@@ -873,6 +873,10 @@ class AbstractStudy(object):
             mask_french_alps += mask_massif
         return ~np.array(mask_french_alps, dtype=bool)
 
+    def massif_name_to_return_level(self, return_period):
+        return {massif_name: gev_params.return_level(return_period=return_period)
+                for massif_name, gev_params in self.massif_name_to_stationary_gev_params.items()}
+
     @cached_property
     def massif_name_to_stationary_gev_params(self):
         d, _ = self._massif_name_to_stationary_gev_params_and_confidence(quantile_level=None)
diff --git a/projects/projected_swe/weight_solver/abtract_weight_solver.py b/projects/projected_swe/weight_solver/abtract_weight_solver.py
index 70bc47e3..fdde441a 100644
--- a/projects/projected_swe/weight_solver/abtract_weight_solver.py
+++ b/projects/projected_swe/weight_solver/abtract_weight_solver.py
@@ -1,20 +1,39 @@
+from typing import Dict, Tuple
+
 from scipy.special import softmax
 import numpy as np
 
+from extreme_data.meteo_france_data.scm_models_data.abstract_study import AbstractStudy
 from projects.projected_swe.weight_solver.indicator import AbstractIndicator
 
 
 class AbstractWeightSolver(object):
 
-    def __init__(self, observation_study, couple_to_study, indicator_class: type, add_interdependence_weight=False):
+    def __init__(self, observation_study: AbstractStudy,
+                 couple_to_study: Dict[Tuple[str, str], AbstractStudy],
+                 indicator_class: type,
+                 massif_names=None,
+                 add_interdependence_weight=False):
         self.observation_study = observation_study
         self.couple_to_study = couple_to_study
         self.indicator_class = indicator_class
         self.add_interdependence_weight = add_interdependence_weight
+        # Compute intersection massif names
+        sets = [set(study.study_massif_names) for study in self.study_list]
+        intersection_massif_names = sets[0].intersection(*sets[1:])
+        if massif_names is None:
+            self.massif_names = list(intersection_massif_names)
+        else:
+            assert set(massif_names).issubset(intersection_massif_names)
+            self.massif_names = massif_names
+
+    @property
+    def study_list(self):
+        return [self.observation_study] + list(self.couple_to_study.values())
 
     @property
     def couple_to_weight(self):
-        nllh_list, couple_list = zip(*list(self.couple_to_nllh.items()))
+        couple_list, nllh_list = zip(*list(self.couple_to_nllh.items()))
         weights = softmax(-np.array(nllh_list))
         return dict(zip(couple_list, weights))
 
@@ -28,24 +47,16 @@ class AbstractWeightSolver(object):
 
     @property
     def couple_to_nllh_skill(self):
-        couple_to_nllh_skill = {}
-        for couple, couple_study in self.couple_to_study.items():
-            skill = self.compute_skill(couple_study=couple_study)
-            nllh_skill = -np.log(skill)
-            couple_to_nllh_skill[couple] = nllh_skill
-        return couple_to_nllh_skill
-
-    def compute_skill(self, couple_study):
+        return {couple: self.compute_skill_nllh(couple_study=couple_study)
+                for couple, couple_study in self.couple_to_study.items()}
+
+    def compute_skill_nllh(self, couple_study):
         raise NotImplementedError
 
     @property
     def couple_to_nllh_interdependence(self):
-        couple_to_nllh_interdependence = {}
-        for couple, couple_study in self.couple_to_study.items():
-            interdependence = self.compute_interdependence(couple_study=couple_study)
-            nllh_interdependence = -np.log(interdependence)
-            couple_to_nllh_interdependence[couple] = nllh_interdependence
-        return couple_to_nllh_interdependence
-
-    def compute_interdependence(self, couple_study):
+        return {couple: self.compute_interdependence_nllh(couple_study=couple_study)
+                for couple, couple_study in self.couple_to_study.items()}
+
+    def compute_interdependence_nllh(self, couple_study):
         raise NotImplementedError
diff --git a/projects/projected_swe/weight_solver/indicator.py b/projects/projected_swe/weight_solver/indicator.py
index 1507f8a2..70311f18 100644
--- a/projects/projected_swe/weight_solver/indicator.py
+++ b/projects/projected_swe/weight_solver/indicator.py
@@ -1,19 +1,44 @@
+from extreme_data.meteo_france_data.scm_models_data.abstract_study import AbstractStudy
+
+
+class WeightComputationException(Exception):
+    pass
+
+
+class ReturnLevelComputationException(WeightComputationException):
+    pass
+
+
+class NllhComputationException(WeightComputationException):
+    pass
+
+
 class AbstractIndicator(object):
 
     @classmethod
-    def get_indicator(cls, study, bootstrap=False):
+    def get_indicator(cls, study: AbstractStudy, massif_name, bootstrap=False):
         raise NotImplementedError
 
 
 class AnnualMaximaMeanIndicator(AbstractIndicator):
 
     @classmethod
-    def get_indicator(cls, study, bootstrap=False):
-        pass
+    def get_indicator(cls, study: AbstractStudy, massif_name, bootstrap=False):
+        if bootstrap:
+            raise NotImplementedError
+        else:
+            return study.massif_name_to_annual_maxima[massif_name].mean()
 
 
-class ReturnLevelIndicator(AbstractIndicator):
+class ReturnLevel30YearsIndicator(AbstractIndicator):
 
     @classmethod
-    def get_indicator(cls, study, bootstrap=False):
-        pass
+    def get_indicator(cls, study: AbstractStudy, massif_name, bootstrap=False):
+        if bootstrap:
+            print(study.massif_name_to_return_level_list(return_period=30)[massif_name])
+            raise NotImplementedError
+        else:
+            try:
+                return study.massif_name_to_return_level(return_period=30)[massif_name]
+            except KeyError:
+                raise ReturnLevelComputationException
diff --git a/projects/projected_swe/weight_solver/knutti_weight_solver.py b/projects/projected_swe/weight_solver/knutti_weight_solver.py
index b1b2a651..6c1674c2 100644
--- a/projects/projected_swe/weight_solver/knutti_weight_solver.py
+++ b/projects/projected_swe/weight_solver/knutti_weight_solver.py
@@ -1,6 +1,9 @@
 import numpy as np
+from scipy.stats import norm
+
 from projects.projected_swe.weight_solver.abtract_weight_solver import AbstractWeightSolver
-from projects.projected_swe.weight_solver.indicator import AbstractIndicator
+from projects.projected_swe.weight_solver.indicator import AbstractIndicator, NllhComputationException, \
+    WeightComputationException
 
 
 class KnuttiWeightSolver(AbstractWeightSolver):
@@ -9,41 +12,75 @@ class KnuttiWeightSolver(AbstractWeightSolver):
         super().__init__(*args, **kwargs)
         self.sigma_skill = sigma_skill
         self.sigma_interdependence = sigma_interdependence
+        # Compute the subset of massif_names used for the computation
+        self.massif_names_for_computation = []
+        for massif_name in self.massif_names:
+            try:
+                [self.compute_skill_one_massif(couple_study, massif_name) for couple_study in self.study_list]
+            except WeightComputationException:
+                continue
+            self.massif_names_for_computation.append(massif_name)
+        assert len(self.massif_names_for_computation) > 0, 'Sigma values should be increased'
+
+    @property
+    def nb_massifs_for_computation(self):
+        return len(self.massif_names)
+
+    def compute_skill_nllh(self, couple_study):
+        return sum([self.compute_skill_one_massif(couple_study, massif_name)
+                    for massif_name in self.massif_names_for_computation])
+
+    def compute_interdependence_nllh(self, couple_study):
+        return sum([self.compute_interdependence_nllh_one_massif(couple_study, massif_name)
+                    for massif_name in self.massif_names_for_computation])
 
-    def compute_skill(self, couple_study):
-        raise self.compute_distance_between_two_study(self.observation_study, self.couple_to_study, self.sigma_skill)
+    def compute_skill_one_massif(self, couple_study, massif_name):
+        return self.compute_nllh_from_two_study(self.observation_study, couple_study, self.sigma_skill, massif_name)
 
-    def compute_interdependence(self, couple_study):
-        sum = 0
+    def compute_interdependence_nllh_one_massif(self, couple_study, massif_name):
+        sum_proba = 0
         for other_couple_study in self.couple_to_study.values():
             if other_couple_study is not couple_study:
-                sum += self.compute_distance_between_two_study(couple_study, other_couple_study, self.sigma_interdependence)
-        return 1 / (1 + sum)
+                nllh = self.compute_nllh_from_two_study(couple_study, other_couple_study,
+                                                        self.sigma_interdependence, massif_name)
+                proba = np.exp(-nllh)
+                sum_proba += proba
+        proba = 1 / (1 + sum_proba)
+        nllh = -np.log(proba)
+        return nllh
 
-    def compute_distance_between_two_study(self, study_1, study_2, sigma):
-        difference = self.sum_of_differences(study_1, study_2)
-        return np.exp(-np.power(difference, 2 * sigma))
+    def compute_nllh_from_two_study(self, study_1, study_2, sigma, massif_name):
+        differences = self.sum_of_differences(study_1, study_2, massif_name)
+        scale = np.sqrt(np.power(sigma, 2) * self.nb_massifs_for_computation / 2)
+        proba = norm.pdf(differences, 0, scale)
+        if not(0 < proba <= 1):
+            raise NllhComputationException
+        nllh = -np.log(proba)
+        return nllh.sum()
 
-    def sum_of_differences(self, study_1, study_2):
+    def sum_of_differences(self, study_1, study_2, massif_name):
         assert issubclass(self.indicator_class, AbstractIndicator)
-        return self.indicator_class.get_indicator(study_1) - self.indicator_class.get_indicator(study_2)
+        return np.array([self.indicator_class.get_indicator(study_1, massif_name)
+                         - self.indicator_class.get_indicator(study_2, massif_name)])
 
 
 class KnuttiWeightSolverWithBootstrapVersion1(KnuttiWeightSolver):
 
-    def sum_of_differences(self, study_1, study_2):
+    def sum_of_differences(self, study_1, study_2, massif_name):
         assert issubclass(self.indicator_class, AbstractIndicator)
-        bootstrap_study_1 = self.indicator_class.get_indicator(study_1, bootstrap=True)
-        bootstrap_study_2 = self.indicator_class.get_indicator(study_2, bootstrap=True)
+        bootstrap_study_1 = self.indicator_class.get_indicator(study_1, massif_name, bootstrap=True)
+        bootstrap_study_2 = self.indicator_class.get_indicator(study_2, massif_name, bootstrap=True)
         differences = bootstrap_study_1 - bootstrap_study_2
-        return differences.sum()
+        squared_difference = np.power(differences, 2)
+        return squared_difference.sum()
 
 
 class KnuttiWeightSolverWithBootstrapVersion2(KnuttiWeightSolver):
 
-    def sum_of_differences(self, study_1, study_2):
+    def sum_of_differences(self, study_1, study_2, massif_name):
         assert issubclass(self.indicator_class, AbstractIndicator)
-        bootstrap_study_1 = self.indicator_class.get_indicator(study_1, bootstrap=True)
-        bootstrap_study_2 = self.indicator_class.get_indicator(study_2, bootstrap=True)
+        bootstrap_study_1 = self.indicator_class.get_indicator(study_1, massif_name, bootstrap=True)
+        bootstrap_study_2 = self.indicator_class.get_indicator(study_2, massif_name, bootstrap=True)
         differences = np.subtract.outer(bootstrap_study_1, bootstrap_study_2)
-        return differences.sum()
+        squared_difference = np.power(differences, 2)
+        return squared_difference.sum()
diff --git a/test/test_projects/test_projected_swe/__init__.py b/test/test_projects/test_projected_swe/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/test/test_projects/test_projected_swe/test_model_as_truth.py b/test/test_projects/test_projected_swe/test_model_as_truth.py
new file mode 100644
index 00000000..02aec491
--- /dev/null
+++ b/test/test_projects/test_projected_swe/test_model_as_truth.py
@@ -0,0 +1,48 @@
+import unittest
+import numpy as np
+
+from extreme_data.meteo_france_data.adamont_data.adamont.adamont_safran import AdamontSnowfall
+from extreme_data.meteo_france_data.adamont_data.adamont_scenario import AdamontScenario, get_gcm_rcm_couples
+from extreme_data.meteo_france_data.scm_models_data.safran.safran import SafranSnowfall1Day
+from extreme_data.meteo_france_data.scm_models_data.safran.safran_max_snowf import SafranSnowfall2020
+from extreme_fit.model.result_from_model_fit.result_from_extremes.abstract_extract_eurocode_return_level import \
+    AbstractExtractEurocodeReturnLevel
+from projects.projected_swe.weight_solver.indicator import AnnualMaximaMeanIndicator, ReturnLevel30YearsIndicator
+from projects.projected_swe.weight_solver.knutti_weight_solver import KnuttiWeightSolver, \
+    KnuttiWeightSolverWithBootstrapVersion1, KnuttiWeightSolverWithBootstrapVersion2
+
+
+class TestModelAsTruth(unittest.TestCase):
+
+    def test_knutti_weight_solver(self):
+        altitude = 900
+        year_min = 1982
+        year_max = 2011
+        scenario = AdamontScenario.rcp85_extended
+        observation_study = SafranSnowfall2020(altitude=altitude, year_min=year_min, year_max=year_max)
+        couple_to_study = {c: AdamontSnowfall(altitude=altitude, scenario=scenario,
+                                              year_min=year_min, year_max=year_max,
+                                              gcm_rcm_couple=c) for c in get_gcm_rcm_couples(adamont_scenario=scenario)}
+        massif_names = None
+        AbstractExtractEurocodeReturnLevel.NB_BOOTSTRAP = 10
+        for knutti_weight_solver_class in [KnuttiWeightSolver,
+                                           KnuttiWeightSolverWithBootstrapVersion1,
+                                           KnuttiWeightSolverWithBootstrapVersion2][:1]:
+            for indicator_class in [AnnualMaximaMeanIndicator, ReturnLevel30YearsIndicator][:1]:
+                for add_interdependence_weight in [False, True]:
+                    knutti_weight = knutti_weight_solver_class(sigma_skill=10.0, sigma_interdependence=10.0,
+                                                               massif_names=massif_names,
+                                                               observation_study=observation_study,
+                                                               couple_to_study=couple_to_study,
+                                                               indicator_class=indicator_class,
+                                                               add_interdependence_weight=add_interdependence_weight
+                                                               )
+                    # print(knutti_weight.couple_to_weight)
+                    weight = knutti_weight.couple_to_weight[('CNRM-CM5', 'CCLM4-8-17')]
+                    self.assertFalse(np.isnan(weight))
+
+        self.assertTrue(True)
+
+
+if __name__ == '__main__':
+    unittest.main()
-- 
GitLab