Commit 9ff7812a authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[SAFRAN] add map visualization on top of margin visualization

parent ba77edf4
No related merge requests found
Showing with 15 additions and 9 deletions
+15 -9
...@@ -108,7 +108,7 @@ class Safran(object): ...@@ -108,7 +108,7 @@ class Safran(object):
""" Visualization methods """ """ Visualization methods """
def visualize(self, ax=None, massif_name_to_fill_kwargs=None, show=True): def visualize(self, ax=None, massif_name_to_fill_kwargs=None, show=True, fill=True):
if ax is None: if ax is None:
ax = plt.gca() ax = plt.gca()
df_massif = pd.read_csv(op.join(self.map_full_path, 'massifsalpes.csv')) df_massif = pd.read_csv(op.join(self.map_full_path, 'massifsalpes.csv'))
...@@ -120,9 +120,10 @@ class Safran(object): ...@@ -120,9 +120,10 @@ class Safran(object):
l = [coords for idx, *coords in coord_tuples if idx == coordinate_id] l = [coords for idx, *coords in coord_tuples if idx == coordinate_id]
l = list(zip(*l)) l = list(zip(*l))
ax.plot(*l, color='black') ax.plot(*l, color='black')
massif_name = self.coordinate_id_to_massif_name[coordinate_id] if fill:
fill_kwargs = massif_name_to_fill_kwargs[massif_name] if massif_name_to_fill_kwargs is not None else {} massif_name = self.coordinate_id_to_massif_name[coordinate_id]
ax.fill(*l, **fill_kwargs) fill_kwargs = massif_name_to_fill_kwargs[massif_name] if massif_name_to_fill_kwargs is not None else {}
ax.fill(*l, **fill_kwargs)
ax.scatter(self.massifs_coordinates.x_coordinates, self.massifs_coordinates.y_coordinates) ax.scatter(self.massifs_coordinates.x_coordinates, self.massifs_coordinates.y_coordinates)
ax.axis('off') ax.axis('off')
......
...@@ -38,18 +38,23 @@ class SafranVisualizer(object): ...@@ -38,18 +38,23 @@ class SafranVisualizer(object):
def dataset(self): def dataset(self):
return AbstractDataset(self.observations, self.coordinates) return AbstractDataset(self.observations, self.coordinates)
def fit_and_visualize_estimator(self, estimator):
estimator.fit()
axes = estimator.margin_function_fitted.visualize(show=False)
for ax in axes:
self.safran.visualize(ax, fill=False, show=False)
plt.show()
def visualize_smooth_margin_fit(self): def visualize_smooth_margin_fit(self):
margin_model = LinearAllParametersAllDimsMarginModel(coordinates=self.coordinates) margin_model = LinearAllParametersAllDimsMarginModel(coordinates=self.coordinates)
estimator = SmoothMarginEstimator(dataset=self.dataset, margin_model=margin_model) estimator = SmoothMarginEstimator(dataset=self.dataset, margin_model=margin_model)
estimator.fit() self.fit_and_visualize_estimator(estimator)
estimator.margin_function_fitted.visualize(show=self.show)
def visualize_full_fit(self): def visualize_full_fit(self):
max_stable_model = Smith() max_stable_model = Smith()
margin_model = LinearAllParametersAllDimsMarginModel(coordinates=self.coordinates) margin_model = LinearAllParametersAllDimsMarginModel(coordinates=self.coordinates)
estimator = FullEstimatorInASingleStepWithSmoothMargin(self.dataset, margin_model, max_stable_model) estimator = FullEstimatorInASingleStepWithSmoothMargin(self.dataset, margin_model, max_stable_model)
estimator.fit() self.fit_and_visualize_estimator(estimator)
estimator.margin_function_fitted.visualize(show=self.show)
def visualize_independent_margin_fits(self, threshold=None, axes=None): def visualize_independent_margin_fits(self, threshold=None, axes=None):
if threshold is None: if threshold is None:
...@@ -78,7 +83,6 @@ class SafranVisualizer(object): ...@@ -78,7 +83,6 @@ class SafranVisualizer(object):
self.safran.visualize(ax=ax, massif_name_to_fill_kwargs=massif_name_to_fill_kwargs, show=False) self.safran.visualize(ax=ax, massif_name_to_fill_kwargs=massif_name_to_fill_kwargs, show=False)
if self.show: if self.show:
plt.show() plt.show()
......
...@@ -69,6 +69,7 @@ class AbstractMarginFunction(object): ...@@ -69,6 +69,7 @@ class AbstractMarginFunction(object):
ax.set_title(title_str) ax.set_title(title_str)
if show: if show:
plt.show() plt.show()
return axes
def visualize_single_param(self, gev_value_name=GevParams.LOC, ax=None, show=True): def visualize_single_param(self, gev_value_name=GevParams.LOC, ax=None, show=True):
assert gev_value_name in GevParams.SUMMARY_NAMES assert gev_value_name in GevParams.SUMMARY_NAMES
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment