Commit d449bbe3 authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[SCM] add mask for french alps vizu

parent bee5048c
No related merge requests found
Showing with 32 additions and 12 deletions
+32 -12
import os
import numpy as np
from PIL import Image, ImageDraw
import os.path as op
from collections import OrderedDict
from typing import List, Dict
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image
from netCDF4 import Dataset
from experiment.meteo_france_SCM_study.abstract_variable import AbstractVariable
from experiment.meteo_france_SCM_study.massif import safran_massif_names_from_datasets
from experiment.meteo_france_SCM_study.visualization.utils import get_km_formatter
from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \
AbstractMarginFunction
from extreme_estimator.margin_fits.plot.create_shifted_cmap import get_color_rbga_shifted
from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
from spatio_temporal_dataset.coordinates.spatial_coordinates.abstract_spatial_coordinates import \
......@@ -220,6 +225,24 @@ class AbstractStudy(object):
def visualization_y_limits(self):
return min(self.all_coords_list[1]), max(self.all_coords_list[1])
@cached_property
def mask_french_alps(self):
resolution = AbstractMarginFunction.VISUALIZATION_RESOLUTION
mask_french_alps = np.zeros([resolution, resolution])
for polygon in self.idx_to_coords_list.values():
xy_values = list(zip(*polygon))
normalized_polygon = []
for values, (minlim, max_lim) in zip(xy_values, [self.visualization_x_limits, self.visualization_y_limits]):
values -= minlim
values *= resolution / (max_lim - minlim)
normalized_polygon.append(values)
normalized_polygon = list(zip(*normalized_polygon))
img = Image.new('L', (resolution, resolution), 0)
ImageDraw.Draw(img).polygon(normalized_polygon, outline=1, fill=1)
mask_massif = np.array(img)
mask_french_alps += mask_massif
return ~np.array(mask_french_alps, dtype=bool)
""" Some properties """
@property
......
......@@ -243,17 +243,12 @@ class StudyVisualizer(object):
def fit_and_visualize_estimator(self, estimator, axes=None, title=None):
estimator.fit()
# Set visualization attributes for margin_fct
margin_fct = estimator.margin_function_fitted
# margin_fct.visualization_x_limits = self.study.
margin_fct._visualization_x_limits = self.study.visualization_x_limits
margin_fct._visualization_y_limits = self.study.visualization_y_limits
# Example of mask 2D
mask_2D = np.zeros([margin_fct.resolution, margin_fct.resolution], dtype=bool)
lim = 5
mask_2D[lim:-lim, lim:-lim] = True
margin_fct.mask_2D = self.study.mask_french_alps
margin_fct.mask_2D = mask_2D
axes = margin_fct.visualize_function(show=False, axes=axes, title='')
self.visualize_contour_and_move_axes_limits(axes)
......
......@@ -14,10 +14,10 @@ from utils import cached_property
class AbstractMarginFunction(object):
""" Class of function mapping points from a space S (could be 1D, 2D,...) to R^3 (the 3 parameters of the GEV)"""
VISUALIZATION_RESOLUTION = 100
def __init__(self, coordinates: AbstractCoordinates):
self.coordinates = coordinates
self.resolution = 100
self.mask_2D = None
# Visualization parameters
......@@ -185,10 +185,10 @@ class AbstractMarginFunction(object):
@cached_property
def grid_2D(self):
grid = []
for xi in np.linspace(*self.visualization_x_limits, self.resolution):
for yj in np.linspace(*self.visualization_y_limits, self.resolution):
for xi in np.linspace(*self.visualization_x_limits, self.VISUALIZATION_RESOLUTION):
for yj in np.linspace(*self.visualization_y_limits, self.VISUALIZATION_RESOLUTION):
grid.append(self.get_gev_params(np.array([xi, yj])).summary_dict)
grid = {value_name: np.array([g[value_name] for g in grid]).reshape([self.resolution, self.resolution])
grid = {value_name: np.array([g[value_name] for g in grid]).reshape([self.VISUALIZATION_RESOLUTION, self.VISUALIZATION_RESOLUTION])
for value_name in GevParams.SUMMARY_NAMES}
return grid
......
import numpy as np
import unittest
from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \
AbstractMarginFunction
from extreme_estimator.extreme_models.margin_model.margin_function.linear_margin_function import LinearMarginFunction
from extreme_estimator.margin_fits.gev.gev_params import GevParams
from extreme_estimator.extreme_models.margin_model.smooth_margin_model import LinearShapeDim1MarginModel, \
......@@ -29,7 +31,7 @@ class TestVisualizationMarginModel(unittest.TestCase):
spatial_coordinates = LinSpaceSpatial2DCoordinates.from_nb_points(nb_points=self.nb_points)
self.margin_model = self.margin_model_class(coordinates=spatial_coordinates)
# Assert that the grid correspond to what we expect in a simple case
self.margin_model.margin_function_sample.resolution = 2
AbstractMarginFunction.VISUALIZATION_RESOLUTION = 2
grid = self.margin_model.margin_function_sample.grid_2D['loc']
true_grid = np.array([[0.98, 1.0], [1.0, 1.02]])
self.assertTrue((grid == true_grid).all(), msg="\nexpected:\n{}, \nfound:\n{}".format(true_grid, grid))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment