abstract_estimator.py 2.74 KiB
import time

from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \
    AbstractMarginFunction
from extreme_estimator.extreme_models.margin_model.margin_function.linear_margin_function import LinearMarginFunction
from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset


class AbstractEstimator(object):
    DURATION = 'Average duration'
    MAE_ERROR = 'Mean Average Error'

    # For each estimator, we shall define:
    # - The loss
    # - The optimization method for each part of the process

    def __init__(self, dataset: AbstractDataset):
        self.dataset = dataset  # type: AbstractDataset
        self.additional_information = dict()
        self._params_fitted = None
        self._margin_function_fitted = None
        self._max_stable_model_fitted = None

    @classmethod
    def from_dataset(cls, dataset: AbstractDataset):
        # raise NotImplementedError('from_dataset class constructor has not been implemented for this class')
        pass

    def fit(self):
        ts = time.time()
        self._fit()
        te = time.time()
        self.additional_information[self.DURATION] = int((te - ts) * 1000)

    @property
    def params_fitted(self):
        assert self.is_fitted
        return self._params_fitted

    # @property
    # def max_stable_fitted(self) -> AbstractMarginFunction:
    #     assert self._margin_function_fitted is not None, 'Error: estimator has not been fitted'
    #     return self._margin_function_fitted

    @property
    def margin_function_fitted(self) -> AbstractMarginFunction:
        assert self.is_fitted
        assert self._margin_function_fitted is not None, 'No margin function has been fitted'
        return self._margin_function_fitted

    def extract_fitted_models_from_fitted_params(self, margin_function_to_fit, full_params_fitted):
        coef_dict = {k: v for k, v in full_params_fitted.items() if 'Coeff' in k}
        self._margin_function_fitted = LinearMarginFunction.from_coef_dict(coordinates=self.dataset.coordinates,
                                                                           gev_param_name_to_linear_dims=margin_function_to_fit.gev_param_name_to_linear_dims,
                                                                           coef_dict=coef_dict)

    @property
    def is_fitted(self):
        return self._params_fitted is not None

    @property
    def train_split(self):
        return self.dataset.train_split

    def scalars(self, true_max_stable_params: dict):
        error = self._error(true_max_stable_params)
        return {**error, **self.additional_information}

    # Methods to override in the child class

    def _fit(self):
        pass

    def _error(self, true_max_stable_params: dict):
        pass