From 2db1da7cf09d2361d63ad660f76c3d6273c716f2 Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Sun, 19 Apr 2020 19:24:15 +0200
Subject: [PATCH] [contrasting] implement aic and bic

---
 .../margin_estimator/abstract_margin_estimator.py  |  8 ++++++++
 .../linear_margin_model/linear_margin_model.py     |  4 ++++
 .../test_gev/test_gev_temporal_extremes_mle.py     | 14 ++++++--------
 3 files changed, 18 insertions(+), 8 deletions(-)

diff --git a/extreme_fit/estimator/margin_estimator/abstract_margin_estimator.py b/extreme_fit/estimator/margin_estimator/abstract_margin_estimator.py
index 0c140c6e..bf20379c 100644
--- a/extreme_fit/estimator/margin_estimator/abstract_margin_estimator.py
+++ b/extreme_fit/estimator/margin_estimator/abstract_margin_estimator.py
@@ -60,3 +60,11 @@ class LinearMarginEstimator(AbstractMarginEstimator):
             nllh -= np.log(p)
             assert not np.isinf(nllh)
         return nllh
+
+    def aic(self, split=Split.all):
+        return 2 * self.margin_model.nb_params + 2 * self.nllh(split=split)
+
+    def bic(self, split=Split.all):
+        n = len(self.dataset.maxima_gev(split=split))
+        return np.log(n) * self.margin_model.nb_params + 2 * self.nllh(split=split)
+
diff --git a/extreme_fit/model/margin_model/linear_margin_model/linear_margin_model.py b/extreme_fit/model/margin_model/linear_margin_model/linear_margin_model.py
index 07d6e8cf..538fa816 100644
--- a/extreme_fit/model/margin_model/linear_margin_model/linear_margin_model.py
+++ b/extreme_fit/model/margin_model/linear_margin_model/linear_margin_model.py
@@ -36,6 +36,10 @@ class LinearMarginModel(ParametricMarginModel):
                 param_name_and_dim_to_coef[(param_name, dim)] = default_slope
         return param_name_and_dim_to_coef
 
+    @property
+    def nb_params(self):
+        return len(self.margin_function.coef_dict)
+
     def param_name_to_linear_coef(self, param_name_and_dim_to_coef):
         param_name_to_linear_coef = {}
         param_names = list(set([e[0] for e in param_name_and_dim_to_coef.keys()]))
diff --git a/test/test_extreme_fit/test_distribution/test_gev/test_gev_temporal_extremes_mle.py b/test/test_extreme_fit/test_distribution/test_gev/test_gev_temporal_extremes_mle.py
index 425d4110..0a2796f0 100644
--- a/test/test_extreme_fit/test_distribution/test_gev/test_gev_temporal_extremes_mle.py
+++ b/test/test_extreme_fit/test_distribution/test_gev/test_gev_temporal_extremes_mle.py
@@ -47,6 +47,8 @@ class TestGevTemporalExtremesMle(unittest.TestCase):
             for key in ref.keys():
                 self.assertAlmostEqual(ref[key], mle_params_estimated[key], places=3)
             self.assertAlmostEqual(estimator.result_from_model_fit.nllh, estimator.nllh())
+            self.assertAlmostEqual(estimator.result_from_model_fit.aic, estimator.aic())
+            self.assertAlmostEqual(estimator.result_from_model_fit.bic, estimator.bic())
 
     def test_gev_temporal_margin_fit_non_stationary_location(self):
         # Create estimator
@@ -58,6 +60,8 @@ class TestGevTemporalExtremesMle(unittest.TestCase):
         mle_params_estimated_year3 = estimator.function_from_fit.get_params(np.array([3])).to_dict()
         self.assertNotEqual(mle_params_estimated_year1, mle_params_estimated_year3)
         self.assertAlmostEqual(estimator.result_from_model_fit.nllh, estimator.nllh())
+        self.assertAlmostEqual(estimator.result_from_model_fit.aic, estimator.aic())
+        self.assertAlmostEqual(estimator.result_from_model_fit.bic, estimator.bic())
 
     def test_gev_temporal_margin_fit_non_stationary_location_and_scale(self):
         # Create estimator
@@ -70,14 +74,8 @@ class TestGevTemporalExtremesMle(unittest.TestCase):
         mle_params_estimated_year3 = estimator.function_from_fit.get_params(np.array([3])).to_dict()
         self.assertNotEqual(mle_params_estimated_year1, mle_params_estimated_year3)
         self.assertAlmostEqual(estimator.result_from_model_fit.nllh, estimator.nllh())
-        # self.assertAlmostEqual(estimator.result_from_model_fit.aic, estimator.aic())
-        # self.assertAlmostEqual(estimator.result_from_model_fit.bic, estimator.bic())
-        # print(estimator.result_from_model_fit.summary_name_to_value)
-        # for k, v in estimator.result_from_model_fit.results.items():
-        #     print(k, np.array(v)[0])
-        self.assertAlmostEqual(estimator.result_from_model_fit.aic, 215.59675857481045)
-        self.assertAlmostEqual(estimator.result_from_model_fit.bic, 225.1568736019512)
-
+        self.assertAlmostEqual(estimator.result_from_model_fit.aic, estimator.aic())
+        self.assertAlmostEqual(estimator.result_from_model_fit.bic, estimator.bic())
 
 
 if __name__ == '__main__':
-- 
GitLab