diff --git a/extreme_trend/one_fold_fit/one_fold_fit.py b/extreme_trend/one_fold_fit/one_fold_fit.py index b061dd407db8ce462298135af694ba68c571615c..0865828a5ae75cf6b1cb85ab1a4d477691ba95ac 100644 --- a/extreme_trend/one_fold_fit/one_fold_fit.py +++ b/extreme_trend/one_fold_fit/one_fold_fit.py @@ -181,7 +181,24 @@ class OneFoldFit(object): # Minimizing the AIC and some properties @cached_property - def sorted_estimators(self): + def sorted_estimators_with_aic(self): + return self._sorted_estimators_with_method_name(method_name='aic') + + def method_name_to_best_estimator(self, method_names): + return {self._sorted_estimators_with_method_name(method_name) for method_name in method_names} + + def _sorted_estimators_with_method_name(self, method_name): + estimators = self.estimators_quality_checked + try: + sorted_estimators = sorted([estimator for estimator in estimators], + key=lambda e: e.__getattribute__(method_name)) + except AssertionError as e: + print('Error for:\n', self.massif_name, self.altitude_group) + raise e + return sorted_estimators + + @cached_property + def estimators_quality_checked(self): estimators = list(self.model_class_to_estimator.values()) if self.remove_physically_implausible_models: # Remove wrong shape @@ -211,20 +228,12 @@ class OneFoldFit(object): if len(estimators) == 0: print(self.massif_name, " has only implausible models") - - try: - sorted_estimators = sorted([estimator for estimator in estimators], key=lambda e: e.aic) - except AssertionError as e: - print('Error for') - print(self.massif_name, self.altitude_group) - raise - # Apply the goodness of fit if self.only_models_that_pass_goodness_of_fit_test: - return [e for e in sorted_estimators if self.goodness_of_fit_test(e)] + estimators = [e for e in estimators if self.goodness_of_fit_test(e)] if not (self.remove_physically_implausible_models or self.only_models_that_pass_goodness_of_fit_test): - assert len(sorted_estimators) == len(self.models_classes) - return sorted_estimators + assert len(estimators) == len(self.models_classes) + return estimators def get_coordinate(self, altitude, year): if isinstance(self.altitude_group, DefaultAltitudeGroup): @@ -244,16 +253,16 @@ class OneFoldFit(object): @property def has_at_least_one_valid_model(self): - return len(self.sorted_estimators) > 0 + return len(self.sorted_estimators_with_aic) > 0 @property def model_class_to_estimator_with_finite_aic(self): - return {type(estimator.margin_model): estimator for estimator in self.sorted_estimators} + return {type(estimator.margin_model): estimator for estimator in self.sorted_estimators_with_aic} @property def best_estimator(self): if self.has_at_least_one_valid_model: - best_estimator = self.sorted_estimators[0] + best_estimator = self.sorted_estimators_with_aic[0] return best_estimator else: raise ValueError('This object should not have been called because ' @@ -289,7 +298,7 @@ class OneFoldFit(object): @property def model_names(self): - return [e.margin_model.name_str for e in self.sorted_estimators] + return [e.margin_model.name_str for e in self.sorted_estimators_with_aic] @property def best_name(self):