diff --git a/experiment/safran_study/safran.py b/experiment/safran_study/safran.py index 63185388e2ddd231eed05e60ee1b5fe6c5cb714b..374ebb460281eee0e9c87193a3487bef05ca0d59 100644 --- a/experiment/safran_study/safran.py +++ b/experiment/safran_study/safran.py @@ -108,7 +108,7 @@ class Safran(object): """ Visualization methods """ - def visualize(self, ax=None, massif_name_to_fill_kwargs=None, show=True): + def visualize(self, ax=None, massif_name_to_fill_kwargs=None, show=True, fill=True): if ax is None: ax = plt.gca() df_massif = pd.read_csv(op.join(self.map_full_path, 'massifsalpes.csv')) @@ -120,9 +120,10 @@ class Safran(object): l = [coords for idx, *coords in coord_tuples if idx == coordinate_id] l = list(zip(*l)) ax.plot(*l, color='black') - massif_name = self.coordinate_id_to_massif_name[coordinate_id] - fill_kwargs = massif_name_to_fill_kwargs[massif_name] if massif_name_to_fill_kwargs is not None else {} - ax.fill(*l, **fill_kwargs) + if fill: + massif_name = self.coordinate_id_to_massif_name[coordinate_id] + fill_kwargs = massif_name_to_fill_kwargs[massif_name] if massif_name_to_fill_kwargs is not None else {} + ax.fill(*l, **fill_kwargs) ax.scatter(self.massifs_coordinates.x_coordinates, self.massifs_coordinates.y_coordinates) ax.axis('off') diff --git a/experiment/safran_study/safran_visualizer.py b/experiment/safran_study/safran_visualizer.py index 31e4824e6cc5cebd42e364bbe3263fc0d0841e95..9abdbf52d2696040e4bc54f1ce1da71374e11e52 100644 --- a/experiment/safran_study/safran_visualizer.py +++ b/experiment/safran_study/safran_visualizer.py @@ -38,18 +38,23 @@ class SafranVisualizer(object): def dataset(self): return AbstractDataset(self.observations, self.coordinates) + def fit_and_visualize_estimator(self, estimator): + estimator.fit() + axes = estimator.margin_function_fitted.visualize(show=False) + for ax in axes: + self.safran.visualize(ax, fill=False, show=False) + plt.show() + def visualize_smooth_margin_fit(self): margin_model = LinearAllParametersAllDimsMarginModel(coordinates=self.coordinates) estimator = SmoothMarginEstimator(dataset=self.dataset, margin_model=margin_model) - estimator.fit() - estimator.margin_function_fitted.visualize(show=self.show) + self.fit_and_visualize_estimator(estimator) def visualize_full_fit(self): max_stable_model = Smith() margin_model = LinearAllParametersAllDimsMarginModel(coordinates=self.coordinates) estimator = FullEstimatorInASingleStepWithSmoothMargin(self.dataset, margin_model, max_stable_model) - estimator.fit() - estimator.margin_function_fitted.visualize(show=self.show) + self.fit_and_visualize_estimator(estimator) def visualize_independent_margin_fits(self, threshold=None, axes=None): if threshold is None: @@ -78,7 +83,6 @@ class SafranVisualizer(object): self.safran.visualize(ax=ax, massif_name_to_fill_kwargs=massif_name_to_fill_kwargs, show=False) - if self.show: plt.show() diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py index 8ca34294bc6cff51bf0b7d77afbbf93e5c63480f..14e0ca22037eae84145d42e02d8c93cbbc245029 100644 --- a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py +++ b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py @@ -69,6 +69,7 @@ class AbstractMarginFunction(object): ax.set_title(title_str) if show: plt.show() + return axes def visualize_single_param(self, gev_value_name=GevParams.LOC, ax=None, show=True): assert gev_value_name in GevParams.SUMMARY_NAMES