Commit 2db1da7c authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[contrasting] implement aic and bic

parent 3914f93e
No related merge requests found
Showing with 18 additions and 8 deletions
+18 -8
...@@ -60,3 +60,11 @@ class LinearMarginEstimator(AbstractMarginEstimator): ...@@ -60,3 +60,11 @@ class LinearMarginEstimator(AbstractMarginEstimator):
nllh -= np.log(p) nllh -= np.log(p)
assert not np.isinf(nllh) assert not np.isinf(nllh)
return nllh return nllh
def aic(self, split=Split.all):
return 2 * self.margin_model.nb_params + 2 * self.nllh(split=split)
def bic(self, split=Split.all):
n = len(self.dataset.maxima_gev(split=split))
return np.log(n) * self.margin_model.nb_params + 2 * self.nllh(split=split)
...@@ -36,6 +36,10 @@ class LinearMarginModel(ParametricMarginModel): ...@@ -36,6 +36,10 @@ class LinearMarginModel(ParametricMarginModel):
param_name_and_dim_to_coef[(param_name, dim)] = default_slope param_name_and_dim_to_coef[(param_name, dim)] = default_slope
return param_name_and_dim_to_coef return param_name_and_dim_to_coef
@property
def nb_params(self):
return len(self.margin_function.coef_dict)
def param_name_to_linear_coef(self, param_name_and_dim_to_coef): def param_name_to_linear_coef(self, param_name_and_dim_to_coef):
param_name_to_linear_coef = {} param_name_to_linear_coef = {}
param_names = list(set([e[0] for e in param_name_and_dim_to_coef.keys()])) param_names = list(set([e[0] for e in param_name_and_dim_to_coef.keys()]))
......
...@@ -47,6 +47,8 @@ class TestGevTemporalExtremesMle(unittest.TestCase): ...@@ -47,6 +47,8 @@ class TestGevTemporalExtremesMle(unittest.TestCase):
for key in ref.keys(): for key in ref.keys():
self.assertAlmostEqual(ref[key], mle_params_estimated[key], places=3) self.assertAlmostEqual(ref[key], mle_params_estimated[key], places=3)
self.assertAlmostEqual(estimator.result_from_model_fit.nllh, estimator.nllh()) self.assertAlmostEqual(estimator.result_from_model_fit.nllh, estimator.nllh())
self.assertAlmostEqual(estimator.result_from_model_fit.aic, estimator.aic())
self.assertAlmostEqual(estimator.result_from_model_fit.bic, estimator.bic())
def test_gev_temporal_margin_fit_non_stationary_location(self): def test_gev_temporal_margin_fit_non_stationary_location(self):
# Create estimator # Create estimator
...@@ -58,6 +60,8 @@ class TestGevTemporalExtremesMle(unittest.TestCase): ...@@ -58,6 +60,8 @@ class TestGevTemporalExtremesMle(unittest.TestCase):
mle_params_estimated_year3 = estimator.function_from_fit.get_params(np.array([3])).to_dict() mle_params_estimated_year3 = estimator.function_from_fit.get_params(np.array([3])).to_dict()
self.assertNotEqual(mle_params_estimated_year1, mle_params_estimated_year3) self.assertNotEqual(mle_params_estimated_year1, mle_params_estimated_year3)
self.assertAlmostEqual(estimator.result_from_model_fit.nllh, estimator.nllh()) self.assertAlmostEqual(estimator.result_from_model_fit.nllh, estimator.nllh())
self.assertAlmostEqual(estimator.result_from_model_fit.aic, estimator.aic())
self.assertAlmostEqual(estimator.result_from_model_fit.bic, estimator.bic())
def test_gev_temporal_margin_fit_non_stationary_location_and_scale(self): def test_gev_temporal_margin_fit_non_stationary_location_and_scale(self):
# Create estimator # Create estimator
...@@ -70,14 +74,8 @@ class TestGevTemporalExtremesMle(unittest.TestCase): ...@@ -70,14 +74,8 @@ class TestGevTemporalExtremesMle(unittest.TestCase):
mle_params_estimated_year3 = estimator.function_from_fit.get_params(np.array([3])).to_dict() mle_params_estimated_year3 = estimator.function_from_fit.get_params(np.array([3])).to_dict()
self.assertNotEqual(mle_params_estimated_year1, mle_params_estimated_year3) self.assertNotEqual(mle_params_estimated_year1, mle_params_estimated_year3)
self.assertAlmostEqual(estimator.result_from_model_fit.nllh, estimator.nllh()) self.assertAlmostEqual(estimator.result_from_model_fit.nllh, estimator.nllh())
# self.assertAlmostEqual(estimator.result_from_model_fit.aic, estimator.aic()) self.assertAlmostEqual(estimator.result_from_model_fit.aic, estimator.aic())
# self.assertAlmostEqual(estimator.result_from_model_fit.bic, estimator.bic()) self.assertAlmostEqual(estimator.result_from_model_fit.bic, estimator.bic())
# print(estimator.result_from_model_fit.summary_name_to_value)
# for k, v in estimator.result_from_model_fit.results.items():
# print(k, np.array(v)[0])
self.assertAlmostEqual(estimator.result_from_model_fit.aic, 215.59675857481045)
self.assertAlmostEqual(estimator.result_from_model_fit.bic, 225.1568736019512)
if __name__ == '__main__': if __name__ == '__main__':
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment