Commit 8f64edae authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[refactor] refactor one fold fit, enable to get sorted estimators with other...

[refactor] refactor one fold fit, enable to get sorted estimators with other selection criterion without more computation
parent e1713e88
No related merge requests found
Showing with 25 additions and 16 deletions
+25 -16
...@@ -181,7 +181,24 @@ class OneFoldFit(object): ...@@ -181,7 +181,24 @@ class OneFoldFit(object):
# Minimizing the AIC and some properties # Minimizing the AIC and some properties
@cached_property @cached_property
def sorted_estimators(self): def sorted_estimators_with_aic(self):
return self._sorted_estimators_with_method_name(method_name='aic')
def method_name_to_best_estimator(self, method_names):
return {self._sorted_estimators_with_method_name(method_name) for method_name in method_names}
def _sorted_estimators_with_method_name(self, method_name):
estimators = self.estimators_quality_checked
try:
sorted_estimators = sorted([estimator for estimator in estimators],
key=lambda e: e.__getattribute__(method_name))
except AssertionError as e:
print('Error for:\n', self.massif_name, self.altitude_group)
raise e
return sorted_estimators
@cached_property
def estimators_quality_checked(self):
estimators = list(self.model_class_to_estimator.values()) estimators = list(self.model_class_to_estimator.values())
if self.remove_physically_implausible_models: if self.remove_physically_implausible_models:
# Remove wrong shape # Remove wrong shape
...@@ -211,20 +228,12 @@ class OneFoldFit(object): ...@@ -211,20 +228,12 @@ class OneFoldFit(object):
if len(estimators) == 0: if len(estimators) == 0:
print(self.massif_name, " has only implausible models") print(self.massif_name, " has only implausible models")
try:
sorted_estimators = sorted([estimator for estimator in estimators], key=lambda e: e.aic)
except AssertionError as e:
print('Error for')
print(self.massif_name, self.altitude_group)
raise
# Apply the goodness of fit # Apply the goodness of fit
if self.only_models_that_pass_goodness_of_fit_test: if self.only_models_that_pass_goodness_of_fit_test:
return [e for e in sorted_estimators if self.goodness_of_fit_test(e)] estimators = [e for e in estimators if self.goodness_of_fit_test(e)]
if not (self.remove_physically_implausible_models or self.only_models_that_pass_goodness_of_fit_test): if not (self.remove_physically_implausible_models or self.only_models_that_pass_goodness_of_fit_test):
assert len(sorted_estimators) == len(self.models_classes) assert len(estimators) == len(self.models_classes)
return sorted_estimators return estimators
def get_coordinate(self, altitude, year): def get_coordinate(self, altitude, year):
if isinstance(self.altitude_group, DefaultAltitudeGroup): if isinstance(self.altitude_group, DefaultAltitudeGroup):
...@@ -244,16 +253,16 @@ class OneFoldFit(object): ...@@ -244,16 +253,16 @@ class OneFoldFit(object):
@property @property
def has_at_least_one_valid_model(self): def has_at_least_one_valid_model(self):
return len(self.sorted_estimators) > 0 return len(self.sorted_estimators_with_aic) > 0
@property @property
def model_class_to_estimator_with_finite_aic(self): def model_class_to_estimator_with_finite_aic(self):
return {type(estimator.margin_model): estimator for estimator in self.sorted_estimators} return {type(estimator.margin_model): estimator for estimator in self.sorted_estimators_with_aic}
@property @property
def best_estimator(self): def best_estimator(self):
if self.has_at_least_one_valid_model: if self.has_at_least_one_valid_model:
best_estimator = self.sorted_estimators[0] best_estimator = self.sorted_estimators_with_aic[0]
return best_estimator return best_estimator
else: else:
raise ValueError('This object should not have been called because ' raise ValueError('This object should not have been called because '
...@@ -289,7 +298,7 @@ class OneFoldFit(object): ...@@ -289,7 +298,7 @@ class OneFoldFit(object):
@property @property
def model_names(self): def model_names(self):
return [e.margin_model.name_str for e in self.sorted_estimators] return [e.margin_model.name_str for e in self.sorted_estimators_with_aic]
@property @property
def best_name(self): def best_name(self):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment