From ce5ac52b75635561df90f067efeb986775a3004c Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Wed, 18 Dec 2019 15:35:55 +0100
Subject: [PATCH] [Temporal fit] add gmle fit method in for
 temporal_linear_margin_model

---
 .../abstract_temporal_linear_margin_model.py  | 15 +++++++++----
 .../test_model/test_confidence_interval.py    | 21 ++++++++++++++++---
 2 files changed, 29 insertions(+), 7 deletions(-)

diff --git a/extreme_fit/model/margin_model/linear_margin_model/abstract_temporal_linear_margin_model.py b/extreme_fit/model/margin_model/linear_margin_model/abstract_temporal_linear_margin_model.py
index 08700dfa..94011d95 100644
--- a/extreme_fit/model/margin_model/linear_margin_model/abstract_temporal_linear_margin_model.py
+++ b/extreme_fit/model/margin_model/linear_margin_model/abstract_temporal_linear_margin_model.py
@@ -18,6 +18,7 @@ class TemporalMarginFitMethod(Enum):
     is_mev_gev_fit = 0
     extremes_fevd_bayesian = 1
     extremes_fevd_mle = 2
+    extremes_fevd_gmle = 3
 
 
 class AbstractTemporalLinearMarginModel(LinearMarginModel):
@@ -42,8 +43,8 @@ class AbstractTemporalLinearMarginModel(LinearMarginModel):
             return self.ismev_gev_fit(x, df_coordinates_temp)
         if self.fit_method == TemporalMarginFitMethod.extremes_fevd_bayesian:
             return self.extremes_fevd_bayesian_fit(x, df_coordinates_temp)
-        if self.fit_method == TemporalMarginFitMethod.extremes_fevd_mle:
-            return self.extremes_fevd_mle_fit(x, df_coordinates_temp)
+        if self.fit_method in [TemporalMarginFitMethod.extremes_fevd_mle, TemporalMarginFitMethod.extremes_fevd_gmle]:
+            return self.extremes_fevd_mle_related_fit(x, df_coordinates_temp)
 
     # Gev Fit with isMev package
 
@@ -56,13 +57,19 @@ class AbstractTemporalLinearMarginModel(LinearMarginModel):
 
     # Gev fit with extRemes package
 
-    def extremes_fevd_mle_fit(self, x, df_coordinates_temp) -> AbstractResultFromExtremes:
+    def extremes_fevd_mle_related_fit(self, x, df_coordinates_temp) -> AbstractResultFromExtremes:
         r_type_argument_kwargs, y = self.extreme_arguments(df_coordinates_temp)
+        if self.fit_method == TemporalMarginFitMethod.extremes_fevd_mle:
+            method = "MLE"
+        elif self.fit_method == TemporalMarginFitMethod.extremes_fevd_gmle:
+            method = "GMLE"
+        else:
+            raise ValueError('wrong method')
         res = safe_run_r_estimator(function=r('fevd_fixed'),
                                    x=x,
                                    data=y,
                                    type=self.type_for_mle,
-                                   method="MLE",
+                                   method=method,
                                    **r_type_argument_kwargs
                                    )
         return ResultFromMleExtremes(res, self.margin_function_start_fit.gev_param_name_to_dims,
diff --git a/test/test_extreme_fit/test_model/test_confidence_interval.py b/test/test_extreme_fit/test_model/test_confidence_interval.py
index 6922b113..0f52b44d 100644
--- a/test/test_extreme_fit/test_model/test_confidence_interval.py
+++ b/test/test_extreme_fit/test_model/test_confidence_interval.py
@@ -78,7 +78,7 @@ class TestConfidenceInterval(unittest.TestCase):
         self.ci_method = ConfidenceIntervalMethodFromExtremes.ci_bayes
         self.model_class_to_triplet = self.bayesian_ci
 
-    def test_ci_normal(self):
+    def test_ci_normal_mle(self):
         self.fit_method = TemporalMarginFitMethod.extremes_fevd_mle
         self.ci_method = ConfidenceIntervalMethodFromExtremes.ci_mle
         self.model_class_to_triplet = {
@@ -89,6 +89,17 @@ class TestConfidenceInterval(unittest.TestCase):
             NonStationaryLocationAndScaleGumbelModel: (6.0605675256893, 10.512751341145462, 14.964935156601623),
         }
 
+    def test_ci_normal_gmle(self):
+        self.fit_method = TemporalMarginFitMethod.extremes_fevd_gmle
+        self.ci_method = ConfidenceIntervalMethodFromExtremes.ci_mle
+        self.model_class_to_triplet = {
+            # Test only for the GEV cases (for the Gumbel cases results are just the same, since there is no shape parameter)
+            StationaryTemporalModel: (4.178088363735904, 15.27540259902303, 26.372716834310154),
+            NonStationaryLocationTemporalModel: (-6.716723409668982, 4.168288167650933, 15.053299744970847),
+            NonStationaryLocationAndScaleTemporalModel: (-12.226312466874123, 5.680769391219823, 23.58785124931377),
+        }
+
+
     def test_ci_boot(self):
         self.fit_method = TemporalMarginFitMethod.extremes_fevd_mle
         self.ci_method = ConfidenceIntervalMethodFromExtremes.ci_boot
@@ -113,7 +124,7 @@ class TestConfidenceInterval(unittest.TestCase):
             eurocode_ci = self.compute_eurocode_ci(model_class)
             found_triplet = eurocode_ci.triplet
             for a, b in zip(expected_triplet, found_triplet):
-                self.assertAlmostEqual(a, b, msg="{} \n{}".format(model_class, found_triplet))
+                self.assertAlmostEqual(a, b, msg="\n{} \nfound_triplet: {}".format(model_class, found_triplet))
 
 
 class TestConfidenceIntervalModifiedCoordinates(TestConfidenceInterval):
@@ -136,7 +147,11 @@ class TestConfidenceIntervalModifiedCoordinates(TestConfidenceInterval):
     def test_ci_bayes(self):
         super().test_ci_bayes()
 
-    def test_ci_normal(self):
+    def test_ci_normal_mle(self):
+        self.model_class_to_triplet = {}
+        self.assertTrue(True)
+
+    def test_ci_normal_gmle(self):
         self.model_class_to_triplet = {}
         self.assertTrue(True)
 
-- 
GitLab