Commit 2cd265ea authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[ROBUSTNESS PLOT] refactor plot and coordinates, add magic methods to coordinates

parent 8650cc55
No related merge requests found
Showing with 141 additions and 62 deletions
+141 -62
...@@ -4,7 +4,7 @@ from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset ...@@ -4,7 +4,7 @@ from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
class AbstractEstimator(object): class AbstractEstimator(object):
DURATION = 'Duration' DURATION = 'Average duration'
MAE_ERROR = 'Mean Average Error' MAE_ERROR = 'Mean Average Error'
def __init__(self, dataset: AbstractDataset): def __init__(self, dataset: AbstractDataset):
......
from extreme_estimator.R_fit.max_stable_fit.abstract_max_stable_model import AbstractMaxStableModel from extreme_estimator.R_fit.max_stable_fit.abstract_max_stable_model import AbstractMaxStableModel, CovarianceFunction
from extreme_estimator.R_fit.max_stable_fit.max_stable_models import Smith, BrownResnick from extreme_estimator.R_fit.max_stable_fit.max_stable_models import Smith, BrownResnick, Schlather, ExtremalT
from extreme_estimator.estimator.abstract_estimator import AbstractEstimator from extreme_estimator.estimator.abstract_estimator import AbstractEstimator
from extreme_estimator.estimator.unitary_msp_estimator import MaxStableEstimator from extreme_estimator.estimator.unitary_msp_estimator import MaxStableEstimator
from extreme_estimator.robustness_plot.multiple_plot import MultiplePlot from extreme_estimator.robustness_plot.multiple_plot import MultiplePlot
from extreme_estimator.robustness_plot.single_plot import SinglePlot from extreme_estimator.robustness_plot.single_plot import SinglePlot
from spatio_temporal_dataset.dataset.simulation_dataset import SimulatedDataset from spatio_temporal_dataset.dataset.simulation_dataset import SimulatedDataset
from spatio_temporal_dataset.spatial_coordinates.alps_station_coordinates import AlpsStationCoordinatesBetweenZeroAndOne from spatio_temporal_dataset.spatial_coordinates.abstract_coordinates import AbstractSpatialCoordinates
from spatio_temporal_dataset.spatial_coordinates.generated_coordinates import CircleCoordinates from spatio_temporal_dataset.spatial_coordinates.alps_station_coordinates import \
AlpsStationCoordinatesBetweenZeroAndOne, AlpsStationCoordinatesBetweenZeroAndTwo
from spatio_temporal_dataset.spatial_coordinates.generated_coordinates import CircleCoordinatesRadius1, \
CircleCoordinatesRadius2
from extreme_estimator.robustness_plot.display_item import DisplayItem from extreme_estimator.robustness_plot.display_item import DisplayItem
...@@ -16,10 +19,15 @@ class MaxStableDisplayItem(DisplayItem): ...@@ -16,10 +19,15 @@ class MaxStableDisplayItem(DisplayItem):
return value.cov_mod return value.cov_mod
class SpatialCoordinateDisplayItem(DisplayItem):
def display_name_from_value(self, value: AbstractSpatialCoordinates):
return str(value).split('.')[-1].split("'")[0]
class MspSpatial(object): class MspSpatial(object):
MaxStableModelItem = MaxStableDisplayItem('max_stable_model', Smith) MaxStableModelItem = MaxStableDisplayItem('max_stable_model', Smith)
SpatialCoordinateClassItem = DisplayItem('spatial_coordinate_class', CircleCoordinates) SpatialCoordinateClassItem = SpatialCoordinateDisplayItem('spatial_coordinate_class', CircleCoordinatesRadius1)
SpatialParamsItem = DisplayItem('spatial_params', {"r": 1})
NbStationItem = DisplayItem('Number of stations', 50) NbStationItem = DisplayItem('Number of stations', 50)
NbObservationItem = DisplayItem('nb_obs', 60) NbObservationItem = DisplayItem('nb_obs', 60)
...@@ -29,10 +37,9 @@ class MspSpatial(object): ...@@ -29,10 +37,9 @@ class MspSpatial(object):
**kwargs_single_point) # type: AbstractMaxStableModel **kwargs_single_point) # type: AbstractMaxStableModel
spatial_coordinate_class = self.SpatialCoordinateClassItem.value_from_kwargs(**kwargs_single_point) spatial_coordinate_class = self.SpatialCoordinateClassItem.value_from_kwargs(**kwargs_single_point)
nb_station = self.NbStationItem.value_from_kwargs(**kwargs_single_point) nb_station = self.NbStationItem.value_from_kwargs(**kwargs_single_point)
spatial_params = self.SpatialParamsItem.value_from_kwargs(**kwargs_single_point)
nb_obs = self.NbObservationItem.value_from_kwargs(**kwargs_single_point) nb_obs = self.NbObservationItem.value_from_kwargs(**kwargs_single_point)
# Run the estimation # Run the estimation
spatial_coordinate = spatial_coordinate_class.from_nb_points(nb_points=nb_station, **spatial_params) spatial_coordinate = spatial_coordinate_class.from_nb_points(nb_points=nb_station)
dataset = SimulatedDataset.from_max_stable_sampling(nb_obs=nb_obs, max_stable_model=max_stable_model, dataset = SimulatedDataset.from_max_stable_sampling(nb_obs=nb_obs, max_stable_model=max_stable_model,
spatial_coordinates=spatial_coordinate) spatial_coordinates=spatial_coordinate)
estimator = MaxStableEstimator(dataset, max_stable_model) estimator = MaxStableEstimator(dataset, max_stable_model)
...@@ -49,39 +56,54 @@ class SingleMspSpatial(SinglePlot, MspSpatial): ...@@ -49,39 +56,54 @@ class SingleMspSpatial(SinglePlot, MspSpatial):
class MultipleMspSpatial(MultiplePlot, MspSpatial): class MultipleMspSpatial(MultiplePlot, MspSpatial):
def compute_value_from_kwargs_single_point(self, **kwargs_single_point): def compute_value_from_kwargs_single_point(self, **kwargs_single_point):
print('here')
return self.msp_spatial_ordinates(**kwargs_single_point) return self.msp_spatial_ordinates(**kwargs_single_point)
def single_spatial_robustness_alps(): # def single_spatial_robustness_alps():
spatial_robustness = SingleMspSpatial(grid_row_item=SingleMspSpatial.NbObservationItem, # spatial_robustness = SingleMspSpatial(grid_row_item=SingleMspSpatial.NbObservationItem,
grid_column_item=SingleMspSpatial.SpatialCoordinateClassItem, # grid_column_item=SingleMspSpatial.SpatialCoordinateClassItem,
plot_row_item=SingleMspSpatial.NbStationItem, # plot_row_item=SingleMspSpatial.NbStationItem,
plot_label_item=SingleMspSpatial.MaxStableModelItem) # plot_label_item=SingleMspSpatial.MaxStableModelItem)
# Put only the parameter that will vary # # Put only the parameter that will vary
spatial_robustness.robustness_grid_plot(**{ # spatial_robustness.robustness_grid_plot(**{
SingleMspSpatial.NbStationItem.name: [10, 30, 50, 70, 86][:], # SingleMspSpatial.NbStationItem.name: list(range(43, 87, 15)),
SingleMspSpatial.NbObservationItem.name: [10], # SingleMspSpatial.NbObservationItem.name: [10],
SingleMspSpatial.MaxStableModelItem.name: [Smith(), BrownResnick()][:], # SingleMspSpatial.MaxStableModelItem.name: [Smith(), BrownResnick()][:],
SingleMspSpatial.SpatialCoordinateClassItem.name: [CircleCoordinates, # SingleMspSpatial.SpatialCoordinateClassItem.name: [CircleCoordinatesRadius1,
AlpsStationCoordinatesBetweenZeroAndOne][:], # AlpsStationCoordinatesBetweenZeroAndOne][:],
}) # })
def multiple_spatial_robustness_alps(): def multiple_spatial_robustness_alps():
nb_observation = 60
nb_sample = 10
plot_name = 'fast_result'
nb_stations = list(range(43, 87, 15))
# nb_stations = [10, 20, 30]
spatial_robustness = MultipleMspSpatial( spatial_robustness = MultipleMspSpatial(
grid_column_item=MspSpatial.MaxStableModelItem, grid_column_item=MspSpatial.SpatialCoordinateClassItem,
plot_row_item=MspSpatial.NbStationItem, plot_row_item=MspSpatial.NbStationItem,
plot_label_item=MspSpatial.SpatialCoordinateClassItem, plot_label_item=MspSpatial.MaxStableModelItem,
nb_samples=10) nb_samples=nb_sample,
main_title="Max stable analysis with {} years of observations".format(nb_observation),
plot_png_filename=plot_name
)
# Load all the models
msp_models = [Smith(), BrownResnick()]
# for covariance_function in CovarianceFunction:
# msp_models.extend([ExtremalT(covariance_function=covariance_function)])
# Put only the parameter that will vary # Put only the parameter that will vary
spatial_robustness.robustness_grid_plot(**{ spatial_robustness.robustness_grid_plot(**{
SinglePlot.OrdinateItem.name: [AbstractEstimator.MAE_ERROR, AbstractEstimator.DURATION], SinglePlot.OrdinateItem.name: [AbstractEstimator.MAE_ERROR, AbstractEstimator.DURATION],
MspSpatial.NbStationItem.name: [10, 20, 30, 50, 70, 86][:3], MspSpatial.NbStationItem.name: nb_stations,
MspSpatial.NbObservationItem.name: 10, MspSpatial.NbObservationItem.name: nb_observation,
MspSpatial.MaxStableModelItem.name: [Smith(), BrownResnick()][:], MspSpatial.MaxStableModelItem.name: msp_models,
MspSpatial.SpatialCoordinateClassItem.name: [CircleCoordinates, MspSpatial.SpatialCoordinateClassItem.name: [CircleCoordinatesRadius1,
AlpsStationCoordinatesBetweenZeroAndOne][:], CircleCoordinatesRadius2,
AlpsStationCoordinatesBetweenZeroAndOne,
AlpsStationCoordinatesBetweenZeroAndTwo][:],
}) })
......
...@@ -8,19 +8,18 @@ class MultiplePlot(SinglePlot): ...@@ -8,19 +8,18 @@ class MultiplePlot(SinglePlot):
Each scalar, will be display on a grid row (to ease visual comparison) Each scalar, will be display on a grid row (to ease visual comparison)
""" """
def __init__(self, grid_column_item, plot_row_item, plot_label_item, nb_samples=1): def __init__(self, grid_column_item, plot_row_item, plot_label_item, nb_samples=1, main_title='',
plot_png_filename=None):
super().__init__(grid_row_item=self.OrdinateItem, grid_column_item=grid_column_item, super().__init__(grid_row_item=self.OrdinateItem, grid_column_item=grid_column_item,
plot_row_item=plot_row_item, plot_label_item=plot_label_item, plot_row_item=plot_row_item, plot_label_item=plot_label_item,
nb_samples=nb_samples) nb_samples=nb_samples, main_title=main_title, plot_png_filename=plot_png_filename)
self.kwargs_single_point_to_errors = {} self.kwargs_single_point_to_errors = {}
def cached_compute_value_from_kwargs_single_point(self, **kwargs_single_point): def cached_compute_value_from_kwargs_single_point(self, **kwargs_single_point):
print('here1')
# Compute hash # Compute hash
hash_from_kwargs_single_point = self.hash_from_kwargs_single_point(kwargs_single_point) hash_from_kwargs_single_point = self.hash_from_kwargs_single_point(kwargs_single_point)
# Either compute the errors or Reload them from cached results # Either compute the errors or Reload them from cached results
if hash_from_kwargs_single_point in self.kwargs_single_point_to_errors: if hash_from_kwargs_single_point in self.kwargs_single_point_to_errors:
print('Load')
errors = self.kwargs_single_point_to_errors[hash_from_kwargs_single_point] errors = self.kwargs_single_point_to_errors[hash_from_kwargs_single_point]
else: else:
errors = self.compute_value_from_kwargs_single_point(**kwargs_single_point) errors = self.compute_value_from_kwargs_single_point(**kwargs_single_point)
...@@ -30,6 +29,5 @@ class MultiplePlot(SinglePlot): ...@@ -30,6 +29,5 @@ class MultiplePlot(SinglePlot):
def hash_from_kwargs_single_point(self, kwargs_single_point): def hash_from_kwargs_single_point(self, kwargs_single_point):
items_except_error = [(k, v) for k, v in kwargs_single_point.items() if k != self.OrdinateItem.name] items_except_error = [(k, v) for k, v in kwargs_single_point.items() if k != self.OrdinateItem.name]
ordered_dict_items_str = str(sorted(items_except_error, key=lambda x: x[0])) ordered_dict_items_str = str(sorted(items_except_error, key=lambda x: x[0]))
print(ordered_dict_items_str)
hash_from_kwargs_single_point = hash(ordered_dict_items_str) hash_from_kwargs_single_point = hash(ordered_dict_items_str)
return hash_from_kwargs_single_point return hash_from_kwargs_single_point
import os
import os.path as op
import random
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
from itertools import product from itertools import product
from extreme_estimator.estimator.abstract_estimator import AbstractEstimator from extreme_estimator.estimator.abstract_estimator import AbstractEstimator
from extreme_estimator.robustness_plot.display_item import DisplayItem from extreme_estimator.robustness_plot.display_item import DisplayItem
from utils import get_full_path
plt.style.use('seaborn-white') plt.style.use('seaborn-white')
...@@ -12,12 +17,15 @@ class SinglePlot(object): ...@@ -12,12 +17,15 @@ class SinglePlot(object):
COLORS = ['blue', 'red', 'green', 'black', 'magenta', 'cyan'] COLORS = ['blue', 'red', 'green', 'black', 'magenta', 'cyan']
OrdinateItem = DisplayItem('ordinate', AbstractEstimator.MAE_ERROR) OrdinateItem = DisplayItem('ordinate', AbstractEstimator.MAE_ERROR)
def __init__(self, grid_row_item, grid_column_item, plot_row_item, plot_label_item, nb_samples=1): def __init__(self, grid_row_item, grid_column_item, plot_row_item, plot_label_item, nb_samples=1, main_title='',
plot_png_filename=None):
self.grid_row_item = grid_row_item # type: DisplayItem self.grid_row_item = grid_row_item # type: DisplayItem
self.grid_column_item = grid_column_item # type: DisplayItem self.grid_column_item = grid_column_item # type: DisplayItem
self.plot_row_item = plot_row_item # type: DisplayItem self.plot_row_item = plot_row_item # type: DisplayItem
self.plot_label_item = plot_label_item # type: DisplayItem self.plot_label_item = plot_label_item # type: DisplayItem
self.nb_samples = nb_samples self.nb_samples = nb_samples
self.main_title = main_title
self.plot_png_filename = plot_png_filename
def robustness_grid_plot(self, **kwargs): def robustness_grid_plot(self, **kwargs):
# Extract Grid row and columns values # Extract Grid row and columns values
...@@ -25,20 +33,41 @@ class SinglePlot(object): ...@@ -25,20 +33,41 @@ class SinglePlot(object):
grid_column_values = self.grid_column_item.values_from_kwargs(**kwargs) grid_column_values = self.grid_column_item.values_from_kwargs(**kwargs)
nb_grid_rows, nb_grid_columns = len(grid_row_values), len(grid_column_values) nb_grid_rows, nb_grid_columns = len(grid_row_values), len(grid_column_values)
# Start the overall plot # Start the overall plot
fig = plt.figure() # fig = plt.figure()
fig.subplots_adjust(hspace=0.4, wspace=0.4) fig, axes = plt.subplots(nb_grid_rows, nb_grid_columns, sharex='col', sharey='row')
for i, (grid_row_value, grid_column_value) in enumerate(product(grid_row_values, grid_column_values), 1): fig.subplots_adjust(hspace=0.4, wspace=0.4, )
for (i, grid_row_value), (j, grid_column_value) in product(enumerate(grid_row_values),
enumerate(grid_column_values)):
print('Grid plot: {}={} {}={}'.format(self.grid_row_item.name, grid_row_value, print('Grid plot: {}={} {}={}'.format(self.grid_row_item.name, grid_row_value,
self.grid_column_item.name, grid_column_value)) self.grid_column_item.name, grid_column_value))
ax = fig.add_subplot(nb_grid_rows, nb_grid_columns, i) ax = axes[i, j]
# ax = fig.add_subplot(nb_grid_rows, nb_grid_columns, i)
# Adapt the kwargs for the single plot # Adapt the kwargs for the single plot
kwargs_single_plot = kwargs.copy() kwargs_single_plot = kwargs.copy()
kwargs_single_plot[self.grid_row_item.name] = grid_row_value kwargs_single_plot[self.grid_row_item.name] = grid_row_value
kwargs_single_plot[self.grid_column_item.name] = grid_column_value kwargs_single_plot[self.grid_column_item.name] = grid_column_value
self.robustness_single_plot(ax, **kwargs_single_plot) self.robustness_single_plot(ax, **kwargs_single_plot)
self.add_title(ax, grid_column_value, grid_row_value) self.add_sub_title(ax, grid_column_value, grid_row_value)
fig.suptitle(self.main_title)
self.save_plot()
plt.show() plt.show()
def save_plot(self):
if self.plot_png_filename is None:
return
assert isinstance(self.plot_png_filename, str)
relative_path = op.join('local', 'plot')
plot_pn_dirpath = get_full_path(relative_path=relative_path)
if not op.exists(plot_pn_dirpath):
os.makedirs(plot_pn_dirpath)
plot_pn_filepath = op.join(plot_pn_dirpath, self.plot_png_filename + '.png')
i = 2
while op.exists(plot_pn_filepath):
plot_pn_filepath = op.join(plot_pn_dirpath, self.plot_png_filename + str(i) + '.png')
i += 1
# plt.savefig(plot_pn_filepath, bbox_inches='tight')
plt.savefig(plot_pn_filepath)
def robustness_single_plot(self, ax, **kwargs_single_plot): def robustness_single_plot(self, ax, **kwargs_single_plot):
plot_row_values = self.plot_row_item.values_from_kwargs(**kwargs_single_plot) plot_row_values = self.plot_row_item.values_from_kwargs(**kwargs_single_plot)
plot_label_values = self.plot_label_item.values_from_kwargs(**kwargs_single_plot) plot_label_values = self.plot_label_item.values_from_kwargs(**kwargs_single_plot)
...@@ -47,20 +76,28 @@ class SinglePlot(object): ...@@ -47,20 +76,28 @@ class SinglePlot(object):
mean_values, std_values = self.compute_mean_and_std_ordinate_values(kwargs_single_plot, ordinate_name, mean_values, std_values = self.compute_mean_and_std_ordinate_values(kwargs_single_plot, ordinate_name,
plot_label_value, plot_row_values) plot_label_value, plot_row_values)
ax.errorbar(plot_row_values, mean_values, std_values, ax.errorbar(plot_row_values, mean_values, std_values,
# linestyle='None', marker='^', # linestyle='None', marker='^',
linewidth = 0.5, linewidth=0.5,
color=self.COLORS[j % len(self.COLORS)], color=self.COLORS[j % len(self.COLORS)],
label=self.plot_label_item.display_name_from_value(plot_label_value)) label=self.plot_label_item.display_name_from_value(plot_label_value))
ax.legend() ax.legend()
# X axis
ax.set_xlabel(self.plot_row_item.name) ax.set_xlabel(self.plot_row_item.name)
ax.set_ylabel(ordinate_name) plt.setp(ax.get_xticklabels(), visible=True)
ax.xaxis.set_tick_params(labelbottom=True)
# Y axis
ax.set_ylabel(ordinate_name + ' ({} samples)'.format(self.nb_samples))
plt.setp(ax.get_yticklabels(), visible=True)
ax.yaxis.set_tick_params(labelbottom=True)
def compute_mean_and_std_ordinate_values(self, kwargs_single_plot, ordinate_name, plot_label_value, plot_row_values): def compute_mean_and_std_ordinate_values(self, kwargs_single_plot, ordinate_name, plot_label_value,
plot_row_values):
all_ordinate_values = [] all_ordinate_values = []
for nb_sample in range(self.nb_samples): for nb_sample in range(self.nb_samples):
# Important to add the nb_sample argument, to differentiate the different experiments # Important to add the nb_sample argument, to differentiate the different experiments
kwargs_single_plot['nb_sample'] = nb_sample kwargs_single_plot['nb_sample'] = nb_sample
ordinate_values = self.compute_ordinate_values(kwargs_single_plot, ordinate_name, plot_label_value, plot_row_values) ordinate_values = self.compute_ordinate_values(kwargs_single_plot, ordinate_name, plot_label_value,
plot_row_values)
all_ordinate_values.append(ordinate_values) all_ordinate_values.append(ordinate_values)
all_ordinate_values = np.array(all_ordinate_values) all_ordinate_values = np.array(all_ordinate_values)
return np.mean(all_ordinate_values, axis=0), np.std(all_ordinate_values, axis=0) return np.mean(all_ordinate_values, axis=0), np.std(all_ordinate_values, axis=0)
...@@ -79,7 +116,6 @@ class SinglePlot(object): ...@@ -79,7 +116,6 @@ class SinglePlot(object):
assert not isinstance(v, list), '"{}" argument is a list'.format(k) assert not isinstance(v, list), '"{}" argument is a list'.format(k)
# Compute ordinate values # Compute ordinate values
ordinate_name_to_ordinate_value = self.cached_compute_value_from_kwargs_single_point(**kwargs_single_point) ordinate_name_to_ordinate_value = self.cached_compute_value_from_kwargs_single_point(**kwargs_single_point)
print(ordinate_name, plot_row_value)
plot_row_value_to_ordinate_value[plot_row_value] = ordinate_name_to_ordinate_value[ordinate_name] plot_row_value_to_ordinate_value[plot_row_value] = ordinate_name_to_ordinate_value[ordinate_name]
# Plot the figure # Plot the figure
plot_ordinate_values = [plot_row_value_to_ordinate_value[plot_row_value] for plot_row_value in plot_ordinate_values = [plot_row_value_to_ordinate_value[plot_row_value] for plot_row_value in
...@@ -92,7 +128,7 @@ class SinglePlot(object): ...@@ -92,7 +128,7 @@ class SinglePlot(object):
def cached_compute_value_from_kwargs_single_point(self, **kwargs_single_point) -> dict: def cached_compute_value_from_kwargs_single_point(self, **kwargs_single_point) -> dict:
return self.compute_value_from_kwargs_single_point(**kwargs_single_point) return self.compute_value_from_kwargs_single_point(**kwargs_single_point)
def add_title(self, ax, grid_column_value, grid_row_value): def add_sub_title(self, ax, grid_column_value, grid_row_value):
title_str = self.grid_row_item.display_name_from_value(grid_row_value) title_str = self.grid_row_item.display_name_from_value(grid_row_value)
title_str += ' ' if len(title_str) > 0 else '' title_str += ' ' if len(title_str) > 0 else ''
title_str += self.grid_column_item.display_name_from_value(grid_column_value) title_str += self.grid_column_item.display_name_from_value(grid_column_value)
......
...@@ -3,7 +3,7 @@ import pandas as pd ...@@ -3,7 +3,7 @@ import pandas as pd
from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
from spatio_temporal_dataset.temporal_maxima.temporal_maxima import TemporalMaxima from spatio_temporal_dataset.temporal_maxima.temporal_maxima import TemporalMaxima
from spatio_temporal_dataset.spatial_coordinates.abstract_coordinates import AbstractSpatialCoordinates from spatio_temporal_dataset.spatial_coordinates.abstract_coordinates import AbstractSpatialCoordinates
from spatio_temporal_dataset.spatial_coordinates.generated_coordinates import CircleCoordinates from spatio_temporal_dataset.spatial_coordinates.generated_coordinates import CircleCoordinatesRadius1
class SimulatedDataset(AbstractDataset): class SimulatedDataset(AbstractDataset):
......
...@@ -73,10 +73,19 @@ class AbstractSpatialCoordinates(object): ...@@ -73,10 +73,19 @@ class AbstractSpatialCoordinates(object):
def index(self): def index(self):
return self.df_coord.index return self.df_coord.index
def __len__(self):
return len(self.df_coord)
def visualization(self): def visualization(self):
x, y = self.coord[:, 0], self.coord[:, 1] x, y = self.coord[:, 0], self.coord[:, 1]
plt.scatter(x, y) plt.scatter(x, y)
plt.show() plt.show()
# Magic Methods
def __len__(self):
return len(self.df_coord)
def __mul__(self, other: float):
self.df_coord *= other
return self
def __rmul__(self, other):
return self * other
\ No newline at end of file
...@@ -42,10 +42,18 @@ class AlpsStationCoordinatesBetweenZeroAndOne(AlpsStationCoordinates): ...@@ -42,10 +42,18 @@ class AlpsStationCoordinatesBetweenZeroAndOne(AlpsStationCoordinates):
normalizing_function=BetweenZeroAndOneNormalization()) normalizing_function=BetweenZeroAndOneNormalization())
class AlpsStationCoordinatesBetweenZeroAndTwo(AlpsStationCoordinatesBetweenZeroAndOne):
@classmethod
def from_csv(cls, csv_file='coord-lambert2'):
return 2 * super().from_csv(csv_file)
if __name__ == '__main__': if __name__ == '__main__':
# AlpsStationCoordinate.transform_txt_into_csv() # AlpsStationCoordinate.transform_txt_into_csv()
# coord = AlpsStationCoordinates.from_csv() # coord = AlpsStationCoordinates.from_csv()
# coord = AlpsStationCoordinates.from_nb_points(nb_points=60) # coord = AlpsStationCoordinates.from_nb_points(nb_points=60)
# coord = AlpsStationCoordinatesBetweenZeroAndOne.from_csv() # coord = AlpsStationCoordinatesBetweenZeroAndOne.from_csv()
coord = AlpsStationCoordinatesBetweenZeroAndOne.from_nb_points(nb_points=60) coord = AlpsStationCoordinatesBetweenZeroAndTwo.from_nb_points(nb_points=60)
# coord = coord * 2
coord.visualization() coord.visualization()
...@@ -7,11 +7,10 @@ from spatio_temporal_dataset.spatial_coordinates.abstract_coordinates import Abs ...@@ -7,11 +7,10 @@ from spatio_temporal_dataset.spatial_coordinates.abstract_coordinates import Abs
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
class CircleCoordinates(AbstractSpatialCoordinates): class CircleCoordinatesRadius1(AbstractSpatialCoordinates):
@classmethod @classmethod
def from_nb_points(cls, nb_points, **kwargs): def from_nb_points(cls, nb_points, max_radius=1.0):
max_radius = kwargs.get('max_radius', 1.0)
# Sample uniformly inside the circle # Sample uniformly inside the circle
r = get_loaded_r() r = get_loaded_r()
angles = np.array(r.runif(nb_points, max=2 * math.pi)) angles = np.array(r.runif(nb_points, max=2 * math.pi))
...@@ -28,6 +27,13 @@ class CircleCoordinates(AbstractSpatialCoordinates): ...@@ -28,6 +27,13 @@ class CircleCoordinates(AbstractSpatialCoordinates):
super().visualization() super().visualization()
class CircleCoordinatesRadius2(CircleCoordinatesRadius1):
@classmethod
def from_nb_points(cls, nb_points, max_radius=1.0):
return 2 * super().from_nb_points(nb_points, max_radius)
if __name__ == '__main__': if __name__ == '__main__':
coord = CircleCoordinates.from_nb_points(nb_points=500, max_radius=1) coord = CircleCoordinatesRadius1.from_nb_points(nb_points=500, max_radius=1)
coord.visualization() coord.visualization()
...@@ -5,7 +5,7 @@ from extreme_estimator.R_fit.max_stable_fit.abstract_max_stable_model import \ ...@@ -5,7 +5,7 @@ from extreme_estimator.R_fit.max_stable_fit.abstract_max_stable_model import \
from extreme_estimator.R_fit.max_stable_fit.max_stable_models import Smith, BrownResnick, Schlather, \ from extreme_estimator.R_fit.max_stable_fit.max_stable_models import Smith, BrownResnick, Schlather, \
Geometric, ExtremalT, ISchlather Geometric, ExtremalT, ISchlather
from spatio_temporal_dataset.dataset.simulation_dataset import SimulatedDataset from spatio_temporal_dataset.dataset.simulation_dataset import SimulatedDataset
from spatio_temporal_dataset.spatial_coordinates.generated_coordinates import CircleCoordinates from spatio_temporal_dataset.spatial_coordinates.generated_coordinates import CircleCoordinatesRadius1
class TestMaxStableFit(unittest.TestCase): class TestMaxStableFit(unittest.TestCase):
...@@ -13,7 +13,7 @@ class TestMaxStableFit(unittest.TestCase): ...@@ -13,7 +13,7 @@ class TestMaxStableFit(unittest.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.spatial_coord = CircleCoordinates.from_nb_points(nb_points=5, max_radius=1) self.spatial_coord = CircleCoordinatesRadius1.from_nb_points(nb_points=5, max_radius=1)
self.max_stable_models = [] self.max_stable_models = []
for max_stable_class in self.MAX_STABLE_CLASSES: for max_stable_class in self.MAX_STABLE_CLASSES:
if issubclass(max_stable_class, AbstractMaxStableModelWithCovarianceFunction): if issubclass(max_stable_class, AbstractMaxStableModelWithCovarianceFunction):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment