import unittest

import numpy as np

from extreme_estimator.R_fit.gev_fit.abstract_margin_model import ConstantMarginModel
from extreme_estimator.R_fit.gev_fit.gev_mle_fit import GevMleFit
from extreme_estimator.R_fit.utils import get_loaded_r
from extreme_estimator.estimator.margin_estimator import SmoothMarginEstimator
from spatio_temporal_dataset.dataset.simulation_dataset import MarginDataset
from spatio_temporal_dataset.spatial_coordinates.generated_coordinates import CircleCoordinatesRadius1


class TestMarginEstimators(unittest.TestCase):
    DISPLAY = False
    MARGIN_TYPES = [ConstantMarginModel]
    MARGIN_ESTIMATORS = [SmoothMarginEstimator]

    def test_unitary_mle_gev_fit(self):
        r = get_loaded_r()
        r("""
        set.seed(42)
        N <- 50
        loc = 0; scale = 1; shape <- 1
        x_gev <- rgev(N, loc = loc, scale = scale, shape = shape)
        start_loc = 0; start_scale = 1; start_shape = 1
        """)
        # Get the MLE estimator
        estimator = GevMleFit(x_gev=np.array(r['x_gev']),
                              start_loc=np.float(r['start_loc'][0]),
                              start_scale=np.float(r['start_scale'][0]),
                              start_shape=np.float(r['start_shape'][0]))
        # Compare the MLE estimated parameters to the reference
        mle_params_estimated = estimator.mle_params
        mle_params_ref = {'loc': 0.0219, 'scale': 1.0347, 'shape': 0.8290}
        for key in mle_params_ref.keys():
            self.assertAlmostEqual(mle_params_ref[key], mle_params_estimated[key], places=3)

    def setUp(self):
        super().setUp()
        self.spatial_coord = CircleCoordinatesRadius1.from_nb_points(nb_points=5, max_radius=1)
        self.margin_models = self.load_margin_models()

    @classmethod
    def load_margin_models(cls):
        return [margin_class() for margin_class in cls.MARGIN_TYPES]

    def test_dependency_estimators(self):
        for margin_model in self.margin_models:
            dataset = MarginDataset.from_sampling(nb_obs=10, margin_model=margin_model,
                                                  spatial_coordinates=self.spatial_coord)

            for estimator_class in self.MARGIN_ESTIMATORS:
                estimator = estimator_class(dataset=dataset, margin_model=margin_model)
                estimator.fit()
                if self.DISPLAY:
                    print(type(margin_model))
                    print(dataset.df_dataset.head())
                    print(estimator.additional_information)
            self.assertTrue(True)


if __name__ == '__main__':
    unittest.main()