Commit 7833bcb4 authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[SCM] add visualization/comparison of multiple max_stable models for SCM

parent 07b7c7d1
No related merge requests found
Showing with 48 additions and 24 deletions
+48 -24
......@@ -20,11 +20,22 @@ def load_all_studies(study_class, only_first_one=False):
return all_studies
if __name__ == '__main__':
for study_class in [ExtendedSafran, ExtendedCrocusSwe, ExtendedCrocusDepth][:]:
def extended_visualization():
for study_class in [ExtendedSafran, ExtendedCrocusSwe, ExtendedCrocusDepth][:1]:
for study in load_all_studies(study_class, only_first_one=True):
study_visualizer = StudyVisualizer(study)
study_visualizer.visualize_all_kde_graphs()
def normal_visualization():
for study_class in [Safran, CrocusSwe, CrocusDepth][:1]:
for study in load_all_studies(study_class, only_first_one=True):
study_visualizer = StudyVisualizer(study)
# study_visualizer.visualize_independent_margin_fits(threshold=[None, 20, 40, 60][0])
# study_visualizer.visualize_smooth_margin_fit()
study_visualizer.visualize_all_kde_graphs()
# study_visualizer.visualize_full_fit()
study_visualizer.visualize_full_fit()
if __name__ == '__main__':
normal_visualization()
# extended_visualization()
\ No newline at end of file
......@@ -16,6 +16,8 @@ from extreme_estimator.margin_fits.gpd.gpd_params import GpdParams
from extreme_estimator.margin_fits.gpd.gpdmle_fit import GpdMleFit
from extreme_estimator.margin_fits.plot.create_shifted_cmap import get_color_rbga_shifted
from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
from test.test_utils import load_test_max_stable_models
from utils import get_display_name_from_object_type
class StudyVisualizer(object):
......@@ -89,14 +91,14 @@ class StudyVisualizer(object):
assert len(x) == len(y)
return x, y
def fit_and_visualize_estimator(self, estimator):
def fit_and_visualize_estimator(self, estimator, axes=None, show=True, title=None):
estimator.fit()
axes = estimator.margin_function_fitted.visualize(show=False)
axes = estimator.margin_function_fitted.visualize_function(show=False, axes=axes, title=title)
for ax in axes:
self.study.visualize(ax, fill=False, show=False)
plt.show()
if show:
plt.suptitle(self.study.title)
plt.show()
def visualize_smooth_margin_fit(self):
margin_model = LinearAllParametersAllDimsMarginModel(coordinates=self.coordinates)
......@@ -104,10 +106,19 @@ class StudyVisualizer(object):
self.fit_and_visualize_estimator(estimator)
def visualize_full_fit(self):
max_stable_model = Smith()
margin_model = LinearAllParametersAllDimsMarginModel(coordinates=self.coordinates)
estimator = FullEstimatorInASingleStepWithSmoothMargin(self.dataset, margin_model, max_stable_model)
self.fit_and_visualize_estimator(estimator)
max_stable_models = load_test_max_stable_models()
fig, axes = plt.subplots(len(max_stable_models), len(GevParams.SUMMARY_NAMES))
fig.subplots_adjust(hspace=1.0, wspace=1.0)
for i, max_stable_model in enumerate(max_stable_models):
margin_model = LinearAllParametersAllDimsMarginModel(coordinates=self.coordinates)
estimator = FullEstimatorInASingleStepWithSmoothMargin(self.dataset, margin_model, max_stable_model)
title = get_display_name_from_object_type(type(max_stable_model))
print(title)
self.fit_and_visualize_estimator(estimator, axes[i], show=False, title=title)
title = self.study.title
title += '\nMethod: Full Likelihood with Linear marginals and max stable dependency structure'
plt.suptitle(title)
plt.show()
def visualize_independent_margin_fits(self, threshold=None, axes=None):
if threshold is None:
......@@ -150,8 +161,9 @@ class StudyVisualizer(object):
@property
def df_gev_mle_each_massif(self):
# Fit a margin_fits on each massif
massif_to_gev_mle = {massif_name: GevMleFit(self.study.observations_annual_maxima.loc[massif_name]).gev_params.summary_serie
for massif_name in self.study.safran_massif_names}
massif_to_gev_mle = {
massif_name: GevMleFit(self.study.observations_annual_maxima.loc[massif_name]).gev_params.summary_serie
for massif_name in self.study.safran_massif_names}
return pd.DataFrame(massif_to_gev_mle, columns=self.study.safran_massif_names)
def df_gpd_mle_each_massif(self, threshold):
......
......@@ -50,14 +50,14 @@ for i in range(nb_estimator):
plt.show()
margin_function_sample = dataset.margin_model.margin_function_sample # type: LinearMarginFunction
margin_function_sample.visualize(show=False, axes=axes, dot_display=True)
margin_function_sample.visualize_function(show=False, axes=axes, dot_display=True)
axes = margin_function_sample.visualization_axes
# Estimation part
margin_model_for_estimator = margin_model_for_estimator_class(coordinates)
full_estimator = FullEstimatorInASingleStepWithSmoothMargin(dataset, margin_model_for_estimator, max_stable_model)
full_estimator.fit()
full_estimator.margin_function_fitted.visualize(axes=axes, show=False)
full_estimator.margin_function_fitted.visualize_function(axes=axes, show=False)
plt.show()
# Display all the margin on the same graph for comparison
......
......@@ -57,7 +57,7 @@ class AbstractMarginFunction(object):
self.filter = filter
self.color = color
def visualize(self, axes=None, show=True, dot_display=False):
def visualize_function(self, axes=None, show=True, dot_display=False, title=None):
self.datapoint_display = dot_display
if axes is None:
fig, axes = plt.subplots(1, len(GevParams.SUMMARY_NAMES))
......@@ -66,7 +66,7 @@ class AbstractMarginFunction(object):
for i, gev_value_name in enumerate(GevParams.SUMMARY_NAMES):
ax = axes[i]
self.visualize_single_param(gev_value_name, ax, show=False)
title_str = gev_value_name
title_str = gev_value_name if title is None else title
ax.set_title(title_str)
if show:
plt.show()
......
......@@ -26,8 +26,8 @@ class TestSmoothMarginEstimator(unittest.TestCase):
estimator.fit()
# Plot
if self.DISPLAY:
margin_model.margin_function_sample.visualize(show=True)
estimator.margin_function_fitted.visualize(show=True)
margin_model.margin_function_sample.visualize_function(show=True)
estimator.margin_function_fitted.visualize_function(show=True)
self.assertTrue(True)
......
......@@ -16,12 +16,12 @@ class VisualizationMarginModel(unittest.TestCase):
spatial_coordinates = CircleSpatialCoordinates.from_nb_points(nb_points=self.nb_points)
margin_model = self.margin_model(coordinates=spatial_coordinates)
if self.DISPLAY:
margin_model.margin_function_sample.visualize()
margin_model.margin_function_sample.visualize_function()
def test_example_visualization_1D(self):
coordinates = LinSpaceSpatialCoordinates.from_nb_points(nb_points=self.nb_points)
margin_model = self.margin_model(coordinates=coordinates, params_sample={(GevParams.SHAPE, 1): 0.02})
margin_model.margin_function_sample.visualize(show=self.DISPLAY)
margin_model.margin_function_sample.visualize_function(show=self.DISPLAY)
self.assertTrue(True)
......
......@@ -3,6 +3,7 @@ import os.path as op
VERSION = datetime.datetime.now()
def get_root_path() -> str:
return op.dirname(op.abspath(__file__))
......@@ -13,7 +14,7 @@ def get_full_path(relative_path: str) -> str:
def get_display_name_from_object_type(object_type):
# assert isinstance(object_type, type), object_type
return str(object_type).split('.')[-1].split("'")[0]
return str(object_type).split('.')[-1].split("'")[0].split(' ')[0]
def first(s):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment