Commit d9c8629c authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[MARGIN PLOT] fix plot issue for the margin. add test. add spatial coordinates 2D

parent 6d112d0f
No related merge requests found
Showing with 61 additions and 22 deletions
+61 -22
...@@ -9,13 +9,15 @@ from extreme_estimator.margin_fits.gev.gev_params import GevParams ...@@ -9,13 +9,15 @@ from extreme_estimator.margin_fits.gev.gev_params import GevParams
from extreme_estimator.margin_fits.plot.create_shifted_cmap import plot_extreme_param, imshow_shifted from extreme_estimator.margin_fits.plot.create_shifted_cmap import plot_extreme_param, imshow_shifted
from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
from spatio_temporal_dataset.slicer.split import Split from spatio_temporal_dataset.slicer.split import Split
from utils import cached_property
class AbstractMarginFunction(object): class AbstractMarginFunction(object):
""" Class of function mapping points from a space S (could be 1D, 2D,...) to R^3 (the 3 parameters of the GEV)""" """ Class of function mapping points from a space S (could be 1D, 2D,...) to R^3 (the 3 parameters of the GEV)"""
def __init__(self, coordinates: AbstractCoordinates): def __init__(self, coordinates: AbstractCoordinates, resolution=100):
self.coordinates = coordinates self.coordinates = coordinates
self.resolution = resolution
# Visualization parameters # Visualization parameters
self.visualization_axes = None self.visualization_axes = None
...@@ -29,6 +31,14 @@ class AbstractMarginFunction(object): ...@@ -29,6 +31,14 @@ class AbstractMarginFunction(object):
self._grid_2D = None self._grid_2D = None
self._grid_1D = None self._grid_1D = None
@property
def x(self):
return self.coordinates.x_coordinates
@property
def y(self):
return self.coordinates.y_coordinates
def get_gev_params(self, coordinate: np.ndarray) -> GevParams: def get_gev_params(self, coordinate: np.ndarray) -> GevParams:
"""Main method that maps each coordinate to its GEV parameters""" """Main method that maps each coordinate to its GEV parameters"""
pass pass
...@@ -131,14 +141,11 @@ class AbstractMarginFunction(object): ...@@ -131,14 +141,11 @@ class AbstractMarginFunction(object):
# Visualization 2D # Visualization 2D
def visualize_2D(self, gev_param_name=GevParams.LOC, ax=None, show=True): def visualize_2D(self, gev_param_name=GevParams.LOC, ax=None, show=True):
x = self.coordinates.x_coordinates
y = self.coordinates.y_coordinates
grid = self.grid_2D(x, y)
if ax is None: if ax is None:
ax = plt.gca() ax = plt.gca()
# Special display # Special display
imshow_shifted(ax, gev_param_name, grid[gev_param_name], x, y) imshow_shifted(ax, gev_param_name, self.grid_2D[gev_param_name], self.x, self.y)
# X axis # X axis
ax.set_xlabel('coordinate X') ax.set_xlabel('coordinate X')
...@@ -152,13 +159,15 @@ class AbstractMarginFunction(object): ...@@ -152,13 +159,15 @@ class AbstractMarginFunction(object):
if show: if show:
plt.show() plt.show()
def grid_2D(self, x, y): @cached_property
resolution = 100 def grid_2D(self):
x = self.x
y = self.y
grid = [] grid = []
for i, xi in enumerate(np.linspace(x.min(), x.max(), resolution)): for i, xi in enumerate(np.linspace(x.min(), x.max(), self.resolution)):
for j, yj in enumerate(np.linspace(y.min(), y.max(), resolution)): for j, yj in enumerate(np.linspace(y.min(), y.max(), self.resolution)):
grid.append(self.get_gev_params(np.array([xi, yj])).summary_dict) grid.append(self.get_gev_params(np.array([xi, yj])).summary_dict)
grid = {value_name: np.array([g[value_name] for g in grid]).reshape([resolution, resolution]) grid = {value_name: np.array([g[value_name] for g in grid]).reshape([self.resolution, self.resolution])
for value_name in GevParams.SUMMARY_NAMES} for value_name in GevParams.SUMMARY_NAMES}
return grid return grid
......
...@@ -53,5 +53,6 @@ def imshow_shifted(ax, gev_param_name, values, x, y): ...@@ -53,5 +53,6 @@ def imshow_shifted(ax, gev_param_name, values, x, y):
value = np.min(values) value = np.min(values)
# The right blue corner will be blue (but most of the time, another display will be on top) # The right blue corner will be blue (but most of the time, another display will be on top)
masked_array[-1, -1] = value - epsilon masked_array[-1, -1] = value - epsilon
ax.imshow(masked_array, extent=(x.min(), x.max(), y.min(), y.max()), cmap=shifted_cmap) # IMPORTANT: Origin for all the plots is at the bottom left corner
ax.imshow(masked_array, extent=(x.min(), x.max(), y.min(), y.max()), cmap=shifted_cmap, origin='lower')
import numpy as np
import pandas as pd
from rpy2.robjects import r
from spatio_temporal_dataset.coordinates.spatial_coordinates.abstract_spatial_coordinates import \
AbstractSpatialCoordinates
from spatio_temporal_dataset.coordinates.spatial_coordinates.coordinates_1D import LinSpaceSpatialCoordinates
class AbstractBiDimensionalSpatialCoordinates(AbstractSpatialCoordinates):
pass
class LinSpaceSpatial2DCoordinates(AbstractBiDimensionalSpatialCoordinates):
@classmethod
def from_nb_points(cls, nb_points, train_split_ratio: float = None, start=-1.0, end=1.0):
axis_coordinates = np.linspace(start, end, nb_points)
df = pd.DataFrame.from_dict({cls.COORDINATE_X: axis_coordinates,
cls.COORDINATE_Y: axis_coordinates})
return cls.from_df(df, train_split_ratio)
...@@ -12,17 +12,17 @@ from spatio_temporal_dataset.coordinates.spatial_coordinates.abstract_spatial_co ...@@ -12,17 +12,17 @@ from spatio_temporal_dataset.coordinates.spatial_coordinates.abstract_spatial_co
class CircleSpatialCoordinates(AbstractSpatialCoordinates): class CircleSpatialCoordinates(AbstractSpatialCoordinates):
@classmethod @classmethod
def df_spatial(cls, nb_points, max_radius=1.0): def df_spatial(cls, nb_points, max_radius=1.0, random=True):
# Sample uniformly inside the circle # Sample uniformly inside the circle
angles = np.array(r.runif(nb_points, max=2 * math.pi)) angles = np.array(r.runif(nb_points, max=2 * math.pi)) if random else np.linspace(0.0, 2 * math.pi, nb_points+1)[:-1]
radius = np.sqrt(np.array(r.runif(nb_points, max=max_radius))) radius = np.sqrt(np.array(r.runif(nb_points, max=max_radius))) if random else np.ones(nb_points) * max_radius
df = pd.DataFrame.from_dict({cls.COORDINATE_X: radius * np.cos(angles), df = pd.DataFrame.from_dict({cls.COORDINATE_X: radius * np.cos(angles),
cls.COORDINATE_Y: radius * np.sin(angles)}) cls.COORDINATE_Y: radius * np.sin(angles)})
return df return df
@classmethod @classmethod
def from_nb_points(cls, nb_points, train_split_ratio: float = None, max_radius=1.0): def from_nb_points(cls, nb_points, train_split_ratio: float = None, max_radius=1.0, random=True):
return cls.from_df(cls.df_spatial(nb_points, max_radius), train_split_ratio) return cls.from_df(cls.df_spatial(nb_points, max_radius, random), train_split_ratio)
def visualization_2D(self): def visualization_2D(self):
radius = 1.0 radius = 1.0
...@@ -36,5 +36,5 @@ class CircleSpatialCoordinates(AbstractSpatialCoordinates): ...@@ -36,5 +36,5 @@ class CircleSpatialCoordinates(AbstractSpatialCoordinates):
class CircleSpatialCoordinatesRadius2(CircleSpatialCoordinates): class CircleSpatialCoordinatesRadius2(CircleSpatialCoordinates):
@classmethod @classmethod
def from_nb_points(cls, nb_points, train_split_ratio: float = None, max_radius=1.0): def from_nb_points(cls, nb_points, train_split_ratio: float = None, max_radius=1.0, random=True):
return 2 * super().from_nb_points(nb_points, train_split_ratio, max_radius) return 2 * super().from_nb_points(nb_points, train_split_ratio, max_radius, random)
import numpy as np
import unittest import unittest
from extreme_estimator.extreme_models.margin_model.margin_function.linear_margin_function import LinearMarginFunction from extreme_estimator.extreme_models.margin_model.margin_function.linear_margin_function import LinearMarginFunction
from extreme_estimator.margin_fits.gev.gev_params import GevParams from extreme_estimator.margin_fits.gev.gev_params import GevParams
from extreme_estimator.extreme_models.margin_model.smooth_margin_model import LinearShapeDim1MarginModel, \ from extreme_estimator.extreme_models.margin_model.smooth_margin_model import LinearShapeDim1MarginModel, \
LinearAllParametersAllDimsMarginModel LinearAllParametersAllDimsMarginModel
from spatio_temporal_dataset.coordinates.spatial_coordinates.coordinates_2D import LinSpaceSpatial2DCoordinates
from spatio_temporal_dataset.coordinates.spatial_coordinates.generated_spatial_coordinates import \ from spatio_temporal_dataset.coordinates.spatial_coordinates.generated_spatial_coordinates import \
CircleSpatialCoordinates CircleSpatialCoordinates
from spatio_temporal_dataset.coordinates.spatial_coordinates.coordinates_1D import LinSpaceSpatialCoordinates from spatio_temporal_dataset.coordinates.spatial_coordinates.coordinates_1D import LinSpaceSpatialCoordinates
from test.test_utils import load_test_spatiotemporal_coordinates from test.test_utils import load_test_spatiotemporal_coordinates
class VisualizationMarginModel(unittest.TestCase): class TestVisualizationMarginModel(unittest.TestCase):
DISPLAY = False DISPLAY = False
nb_points = 2 nb_points = 2
margin_model_class = [LinearShapeDim1MarginModel, LinearAllParametersAllDimsMarginModel][-1] margin_model_class = [LinearShapeDim1MarginModel, LinearAllParametersAllDimsMarginModel][-1]
...@@ -24,8 +26,13 @@ class VisualizationMarginModel(unittest.TestCase): ...@@ -24,8 +26,13 @@ class VisualizationMarginModel(unittest.TestCase):
self.margin_model = self.margin_model_class(coordinates=coordinates, params_sample={(GevParams.SHAPE, 1): 0.02}) self.margin_model = self.margin_model_class(coordinates=coordinates, params_sample={(GevParams.SHAPE, 1): 0.02})
def test_example_visualization_2D_spatial(self): def test_example_visualization_2D_spatial(self):
spatial_coordinates = CircleSpatialCoordinates.from_nb_points(nb_points=self.nb_points) spatial_coordinates = LinSpaceSpatial2DCoordinates.from_nb_points(nb_points=self.nb_points)
self.margin_model = self.margin_model_class(coordinates=spatial_coordinates) self.margin_model = self.margin_model_class(coordinates=spatial_coordinates)
# Assert that the grid correspond to what we expect in a simple case
self.margin_model.margin_function_sample.resolution = 2
grid = self.margin_model.margin_function_sample.grid_2D['loc']
true_grid = np.array([[0.98, 1.0], [1.0, 1.02]])
self.assertTrue((grid == true_grid).all(), msg="\nexpected:\n{}, \nfound:\n{}".format(true_grid, grid))
# def test_example_visualization_2D_spatio_temporal(self): # def test_example_visualization_2D_spatio_temporal(self):
# self.nb_steps = 2 # self.nb_steps = 2
...@@ -49,5 +56,6 @@ class VisualizationMarginModel(unittest.TestCase): ...@@ -49,5 +56,6 @@ class VisualizationMarginModel(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
# v = VisualizationMarginModel() # v = TestVisualizationMarginModel()
# v.test_example_visualization_2D_spatio_temporal() # v.test_example_visualization_2D_spatial()
# v.tearDown()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment