Commit 7833bcb4 authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[SCM] add visualization/comparison of multiple max_stable models for SCM

parent 07b7c7d1
No related merge requests found
Showing with 48 additions and 24 deletions
+48 -24
...@@ -20,11 +20,22 @@ def load_all_studies(study_class, only_first_one=False): ...@@ -20,11 +20,22 @@ def load_all_studies(study_class, only_first_one=False):
return all_studies return all_studies
if __name__ == '__main__': def extended_visualization():
for study_class in [ExtendedSafran, ExtendedCrocusSwe, ExtendedCrocusDepth][:]: for study_class in [ExtendedSafran, ExtendedCrocusSwe, ExtendedCrocusDepth][:1]:
for study in load_all_studies(study_class, only_first_one=True):
study_visualizer = StudyVisualizer(study)
study_visualizer.visualize_all_kde_graphs()
def normal_visualization():
for study_class in [Safran, CrocusSwe, CrocusDepth][:1]:
for study in load_all_studies(study_class, only_first_one=True): for study in load_all_studies(study_class, only_first_one=True):
study_visualizer = StudyVisualizer(study) study_visualizer = StudyVisualizer(study)
# study_visualizer.visualize_independent_margin_fits(threshold=[None, 20, 40, 60][0]) # study_visualizer.visualize_independent_margin_fits(threshold=[None, 20, 40, 60][0])
# study_visualizer.visualize_smooth_margin_fit() # study_visualizer.visualize_smooth_margin_fit()
study_visualizer.visualize_all_kde_graphs() study_visualizer.visualize_full_fit()
# study_visualizer.visualize_full_fit()
if __name__ == '__main__':
normal_visualization()
# extended_visualization()
\ No newline at end of file
...@@ -16,6 +16,8 @@ from extreme_estimator.margin_fits.gpd.gpd_params import GpdParams ...@@ -16,6 +16,8 @@ from extreme_estimator.margin_fits.gpd.gpd_params import GpdParams
from extreme_estimator.margin_fits.gpd.gpdmle_fit import GpdMleFit from extreme_estimator.margin_fits.gpd.gpdmle_fit import GpdMleFit
from extreme_estimator.margin_fits.plot.create_shifted_cmap import get_color_rbga_shifted from extreme_estimator.margin_fits.plot.create_shifted_cmap import get_color_rbga_shifted
from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
from test.test_utils import load_test_max_stable_models
from utils import get_display_name_from_object_type
class StudyVisualizer(object): class StudyVisualizer(object):
...@@ -89,14 +91,14 @@ class StudyVisualizer(object): ...@@ -89,14 +91,14 @@ class StudyVisualizer(object):
assert len(x) == len(y) assert len(x) == len(y)
return x, y return x, y
def fit_and_visualize_estimator(self, estimator, axes=None, show=True, title=None):
def fit_and_visualize_estimator(self, estimator):
estimator.fit() estimator.fit()
axes = estimator.margin_function_fitted.visualize(show=False) axes = estimator.margin_function_fitted.visualize_function(show=False, axes=axes, title=title)
for ax in axes: for ax in axes:
self.study.visualize(ax, fill=False, show=False) self.study.visualize(ax, fill=False, show=False)
plt.show() if show:
plt.suptitle(self.study.title)
plt.show()
def visualize_smooth_margin_fit(self): def visualize_smooth_margin_fit(self):
margin_model = LinearAllParametersAllDimsMarginModel(coordinates=self.coordinates) margin_model = LinearAllParametersAllDimsMarginModel(coordinates=self.coordinates)
...@@ -104,10 +106,19 @@ class StudyVisualizer(object): ...@@ -104,10 +106,19 @@ class StudyVisualizer(object):
self.fit_and_visualize_estimator(estimator) self.fit_and_visualize_estimator(estimator)
def visualize_full_fit(self): def visualize_full_fit(self):
max_stable_model = Smith() max_stable_models = load_test_max_stable_models()
margin_model = LinearAllParametersAllDimsMarginModel(coordinates=self.coordinates) fig, axes = plt.subplots(len(max_stable_models), len(GevParams.SUMMARY_NAMES))
estimator = FullEstimatorInASingleStepWithSmoothMargin(self.dataset, margin_model, max_stable_model) fig.subplots_adjust(hspace=1.0, wspace=1.0)
self.fit_and_visualize_estimator(estimator) for i, max_stable_model in enumerate(max_stable_models):
margin_model = LinearAllParametersAllDimsMarginModel(coordinates=self.coordinates)
estimator = FullEstimatorInASingleStepWithSmoothMargin(self.dataset, margin_model, max_stable_model)
title = get_display_name_from_object_type(type(max_stable_model))
print(title)
self.fit_and_visualize_estimator(estimator, axes[i], show=False, title=title)
title = self.study.title
title += '\nMethod: Full Likelihood with Linear marginals and max stable dependency structure'
plt.suptitle(title)
plt.show()
def visualize_independent_margin_fits(self, threshold=None, axes=None): def visualize_independent_margin_fits(self, threshold=None, axes=None):
if threshold is None: if threshold is None:
...@@ -150,8 +161,9 @@ class StudyVisualizer(object): ...@@ -150,8 +161,9 @@ class StudyVisualizer(object):
@property @property
def df_gev_mle_each_massif(self): def df_gev_mle_each_massif(self):
# Fit a margin_fits on each massif # Fit a margin_fits on each massif
massif_to_gev_mle = {massif_name: GevMleFit(self.study.observations_annual_maxima.loc[massif_name]).gev_params.summary_serie massif_to_gev_mle = {
for massif_name in self.study.safran_massif_names} massif_name: GevMleFit(self.study.observations_annual_maxima.loc[massif_name]).gev_params.summary_serie
for massif_name in self.study.safran_massif_names}
return pd.DataFrame(massif_to_gev_mle, columns=self.study.safran_massif_names) return pd.DataFrame(massif_to_gev_mle, columns=self.study.safran_massif_names)
def df_gpd_mle_each_massif(self, threshold): def df_gpd_mle_each_massif(self, threshold):
......
...@@ -50,14 +50,14 @@ for i in range(nb_estimator): ...@@ -50,14 +50,14 @@ for i in range(nb_estimator):
plt.show() plt.show()
margin_function_sample = dataset.margin_model.margin_function_sample # type: LinearMarginFunction margin_function_sample = dataset.margin_model.margin_function_sample # type: LinearMarginFunction
margin_function_sample.visualize(show=False, axes=axes, dot_display=True) margin_function_sample.visualize_function(show=False, axes=axes, dot_display=True)
axes = margin_function_sample.visualization_axes axes = margin_function_sample.visualization_axes
# Estimation part # Estimation part
margin_model_for_estimator = margin_model_for_estimator_class(coordinates) margin_model_for_estimator = margin_model_for_estimator_class(coordinates)
full_estimator = FullEstimatorInASingleStepWithSmoothMargin(dataset, margin_model_for_estimator, max_stable_model) full_estimator = FullEstimatorInASingleStepWithSmoothMargin(dataset, margin_model_for_estimator, max_stable_model)
full_estimator.fit() full_estimator.fit()
full_estimator.margin_function_fitted.visualize(axes=axes, show=False) full_estimator.margin_function_fitted.visualize_function(axes=axes, show=False)
plt.show() plt.show()
# Display all the margin on the same graph for comparison # Display all the margin on the same graph for comparison
......
...@@ -57,7 +57,7 @@ class AbstractMarginFunction(object): ...@@ -57,7 +57,7 @@ class AbstractMarginFunction(object):
self.filter = filter self.filter = filter
self.color = color self.color = color
def visualize(self, axes=None, show=True, dot_display=False): def visualize_function(self, axes=None, show=True, dot_display=False, title=None):
self.datapoint_display = dot_display self.datapoint_display = dot_display
if axes is None: if axes is None:
fig, axes = plt.subplots(1, len(GevParams.SUMMARY_NAMES)) fig, axes = plt.subplots(1, len(GevParams.SUMMARY_NAMES))
...@@ -66,7 +66,7 @@ class AbstractMarginFunction(object): ...@@ -66,7 +66,7 @@ class AbstractMarginFunction(object):
for i, gev_value_name in enumerate(GevParams.SUMMARY_NAMES): for i, gev_value_name in enumerate(GevParams.SUMMARY_NAMES):
ax = axes[i] ax = axes[i]
self.visualize_single_param(gev_value_name, ax, show=False) self.visualize_single_param(gev_value_name, ax, show=False)
title_str = gev_value_name title_str = gev_value_name if title is None else title
ax.set_title(title_str) ax.set_title(title_str)
if show: if show:
plt.show() plt.show()
......
...@@ -26,8 +26,8 @@ class TestSmoothMarginEstimator(unittest.TestCase): ...@@ -26,8 +26,8 @@ class TestSmoothMarginEstimator(unittest.TestCase):
estimator.fit() estimator.fit()
# Plot # Plot
if self.DISPLAY: if self.DISPLAY:
margin_model.margin_function_sample.visualize(show=True) margin_model.margin_function_sample.visualize_function(show=True)
estimator.margin_function_fitted.visualize(show=True) estimator.margin_function_fitted.visualize_function(show=True)
self.assertTrue(True) self.assertTrue(True)
......
...@@ -16,12 +16,12 @@ class VisualizationMarginModel(unittest.TestCase): ...@@ -16,12 +16,12 @@ class VisualizationMarginModel(unittest.TestCase):
spatial_coordinates = CircleSpatialCoordinates.from_nb_points(nb_points=self.nb_points) spatial_coordinates = CircleSpatialCoordinates.from_nb_points(nb_points=self.nb_points)
margin_model = self.margin_model(coordinates=spatial_coordinates) margin_model = self.margin_model(coordinates=spatial_coordinates)
if self.DISPLAY: if self.DISPLAY:
margin_model.margin_function_sample.visualize() margin_model.margin_function_sample.visualize_function()
def test_example_visualization_1D(self): def test_example_visualization_1D(self):
coordinates = LinSpaceSpatialCoordinates.from_nb_points(nb_points=self.nb_points) coordinates = LinSpaceSpatialCoordinates.from_nb_points(nb_points=self.nb_points)
margin_model = self.margin_model(coordinates=coordinates, params_sample={(GevParams.SHAPE, 1): 0.02}) margin_model = self.margin_model(coordinates=coordinates, params_sample={(GevParams.SHAPE, 1): 0.02})
margin_model.margin_function_sample.visualize(show=self.DISPLAY) margin_model.margin_function_sample.visualize_function(show=self.DISPLAY)
self.assertTrue(True) self.assertTrue(True)
......
...@@ -3,6 +3,7 @@ import os.path as op ...@@ -3,6 +3,7 @@ import os.path as op
VERSION = datetime.datetime.now() VERSION = datetime.datetime.now()
def get_root_path() -> str: def get_root_path() -> str:
return op.dirname(op.abspath(__file__)) return op.dirname(op.abspath(__file__))
...@@ -13,7 +14,7 @@ def get_full_path(relative_path: str) -> str: ...@@ -13,7 +14,7 @@ def get_full_path(relative_path: str) -> str:
def get_display_name_from_object_type(object_type): def get_display_name_from_object_type(object_type):
# assert isinstance(object_type, type), object_type # assert isinstance(object_type, type), object_type
return str(object_type).split('.')[-1].split("'")[0] return str(object_type).split('.')[-1].split("'")[0].split(' ')[0]
def first(s): def first(s):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment