Commit 2872efd9 authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[COORDINATES] some renaming in abstract coordinates and in linear margin model

parent 49c96658
No related merge requests found
Showing with 43 additions and 44 deletions
+43 -44
...@@ -11,7 +11,6 @@ from extreme_estimator.estimator.full_estimator.abstract_full_estimator import \ ...@@ -11,7 +11,6 @@ from extreme_estimator.estimator.full_estimator.abstract_full_estimator import \
from extreme_estimator.estimator.margin_estimator.abstract_margin_estimator import LinearMarginEstimator, \ from extreme_estimator.estimator.margin_estimator.abstract_margin_estimator import LinearMarginEstimator, \
AbstractMarginEstimator AbstractMarginEstimator
from extreme_estimator.extreme_models.margin_model.linear_margin_model import \ from extreme_estimator.extreme_models.margin_model.linear_margin_model import \
LinearAllParametersTwoFirstCoordinatesMarginModel, LinearAllTwoStatialCoordinatesLocationLinearMarginModel, \
LinearStationaryMarginModel, LinearNonStationaryLocationMarginModel LinearStationaryMarginModel, LinearNonStationaryLocationMarginModel
from extreme_estimator.extreme_models.margin_model.margin_function.linear_margin_function import LinearMarginFunction from extreme_estimator.extreme_models.margin_model.margin_function.linear_margin_function import LinearMarginFunction
from extreme_estimator.extreme_models.margin_model.temporal_linear_margin_model import StationaryStationModel, \ from extreme_estimator.extreme_models.margin_model.temporal_linear_margin_model import StationaryStationModel, \
...@@ -46,8 +45,7 @@ class AbstractNonStationaryTrendTest(object): ...@@ -46,8 +45,7 @@ class AbstractNonStationaryTrendTest(object):
self._starting_point_to_estimator[starting_point] = estimator self._starting_point_to_estimator[starting_point] = estimator
return self._starting_point_to_estimator[starting_point] return self._starting_point_to_estimator[starting_point]
def load_estimator(self, starting_point) -> Union[ def load_estimator(self, starting_point) -> Union[AbstractFullEstimator, AbstractMarginEstimator]:
AbstractFullEstimator, AbstractMarginEstimator]:
margin_model_class = self.stationary_margin_model_class if starting_point is None else self.non_stationary_margin_model_class margin_model_class = self.stationary_margin_model_class if starting_point is None else self.non_stationary_margin_model_class
assert starting_point not in self._starting_point_to_estimator assert starting_point not in self._starting_point_to_estimator
margin_model = margin_model_class(coordinates=self.dataset.coordinates, starting_point=starting_point) margin_model = margin_model_class(coordinates=self.dataset.coordinates, starting_point=starting_point)
......
...@@ -118,26 +118,17 @@ class LinearAllParametersAllDimsMarginModel(LinearMarginModel): ...@@ -118,26 +118,17 @@ class LinearAllParametersAllDimsMarginModel(LinearMarginModel):
GevParams.SCALE: self.coordinates.coordinates_dims}) GevParams.SCALE: self.coordinates.coordinates_dims})
class LinearAllParametersTwoFirstCoordinatesMarginModel(LinearMarginModel): class LinearStationaryMarginModel(LinearMarginModel):
def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_dims=None): def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_dims=None):
super().load_margin_functions({GevParams.SHAPE: [0, 1], super().load_margin_functions({GevParams.SHAPE: self.coordinates.spatial_coordinates_dims,
GevParams.LOC: [0, 1], GevParams.LOC: self.coordinates.spatial_coordinates_dims,
GevParams.SCALE: [0, 1]}) GevParams.SCALE: self.coordinates.spatial_coordinates_dims})
class LinearAllTwoStatialCoordinatesLocationLinearMarginModel(LinearMarginModel): class LinearNonStationaryLocationMarginModel(LinearMarginModel):
def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_dims=None): def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_dims=None):
super().load_margin_functions({GevParams.SHAPE: [0, 1], super().load_margin_functions({GevParams.SHAPE: self.coordinates.spatial_coordinates_dims,
GevParams.LOC: [0, 1, 2], GevParams.LOC: self.coordinates.coordinates_dims,
GevParams.SCALE: [0, 1]}) GevParams.SCALE: self.coordinates.spatial_coordinates_dims})
# Some renaming that defines the stationary and non-stationary models of reference
class LinearStationaryMarginModel(LinearAllParametersTwoFirstCoordinatesMarginModel):
pass
class LinearNonStationaryLocationMarginModel(LinearAllTwoStatialCoordinatesLocationLinearMarginModel):
pass
...@@ -103,7 +103,7 @@ class AbstractMarginFunction(object): ...@@ -103,7 +103,7 @@ class AbstractMarginFunction(object):
def visualize_single_param(self, gev_value_name=GevParams.LOC, ax=None, show=True): def visualize_single_param(self, gev_value_name=GevParams.LOC, ax=None, show=True):
assert gev_value_name in GevParams.SUMMARY_NAMES assert gev_value_name in GevParams.SUMMARY_NAMES
nb_coordinates_spatial = self.coordinates.nb_coordinates_spatial nb_coordinates_spatial = self.coordinates.nb_spatial_coordinates
has_temporal_coordinates = self.coordinates.has_temporal_coordinates has_temporal_coordinates = self.coordinates.has_temporal_coordinates
if nb_coordinates_spatial == 1 and not has_temporal_coordinates: if nb_coordinates_spatial == 1 and not has_temporal_coordinates:
self.visualize_1D(gev_value_name, ax, show) self.visualize_1D(gev_value_name, ax, show)
......
...@@ -65,14 +65,14 @@ class LinearMarginFunction(ParametricMarginFunction): ...@@ -65,14 +65,14 @@ class LinearMarginFunction(ParametricMarginFunction):
for gev_param_name in GevParams.PARAM_NAMES: for gev_param_name in GevParams.PARAM_NAMES:
linear_dims = self.gev_param_name_to_dims.get(gev_param_name, []) linear_dims = self.gev_param_name_to_dims.get(gev_param_name, [])
# Load spatial form_dict (only if we have some spatial coordinates) # Load spatial form_dict (only if we have some spatial coordinates)
if self.coordinates.coordinates_spatial_names: if self.coordinates.has_spatial_coordinates:
spatial_names = [name for name in self.coordinates.coordinates_spatial_names spatial_names = [name for name in self.coordinates.spatial_coordinates_names
if self.coefficient_name_to_dim(self.coordinates)[name] in linear_dims] if self.coefficient_name_to_dim(self.coordinates)[name] in linear_dims]
spatial_form = self.gev_param_name_to_coef[gev_param_name].spatial_form_dict(spatial_names) spatial_form = self.gev_param_name_to_coef[gev_param_name].spatial_form_dict(spatial_names)
form_dict.update(spatial_form) form_dict.update(spatial_form)
# Load temporal form dict (only if we have some temporal coordinates) # Load temporal form dict (only if we have some temporal coordinates)
if self.coordinates.coordinates_temporal_names: if self.coordinates.has_temporal_coordinates:
temporal_names = [name for name in self.coordinates.coordinates_temporal_names temporal_names = [name for name in self.coordinates.temporal_coordinates_names
if self.coefficient_name_to_dim(self.coordinates)[name] in linear_dims] if self.coefficient_name_to_dim(self.coordinates)[name] in linear_dims]
temporal_form = self.gev_param_name_to_coef[gev_param_name].temporal_form_dict(temporal_names) temporal_form = self.gev_param_name_to_coef[gev_param_name].temporal_form_dict(temporal_names)
# Specifying a formula '~ 1' creates a bug in fitspatgev of SpatialExtreme R package # Specifying a formula '~ 1' creates a bug in fitspatgev of SpatialExtreme R package
......
...@@ -167,7 +167,7 @@ class AbstractCoordinates(object): ...@@ -167,7 +167,7 @@ class AbstractCoordinates(object):
@property @property
def coordinates_names(self) -> List[str]: def coordinates_names(self) -> List[str]:
return self.coordinates_spatial_names + self.coordinates_temporal_names return self.spatial_coordinates_names + self.temporal_coordinates_names
@property @property
def nb_coordinates(self) -> int: def nb_coordinates(self) -> int:
...@@ -180,22 +180,26 @@ class AbstractCoordinates(object): ...@@ -180,22 +180,26 @@ class AbstractCoordinates(object):
# Spatial attributes # Spatial attributes
@property @property
def coordinates_spatial_names(self) -> List[str]: def spatial_coordinates_dims(self):
return list(range(self.nb_spatial_coordinates))
@property
def spatial_coordinates_names(self) -> List[str]:
return [name for name in self.COORDINATE_SPATIAL_NAMES if name in self.df_all_coordinates.columns] return [name for name in self.COORDINATE_SPATIAL_NAMES if name in self.df_all_coordinates.columns]
@property @property
def nb_coordinates_spatial(self) -> int: def nb_spatial_coordinates(self) -> int:
return len(self.coordinates_spatial_names) return len(self.spatial_coordinates_names)
@property @property
def has_spatial_coordinates(self) -> bool: def has_spatial_coordinates(self) -> bool:
return self.nb_coordinates_spatial > 0 return self.nb_spatial_coordinates > 0
def df_spatial_coordinates(self, split: Split = Split.all, transformed=True) -> pd.DataFrame: def df_spatial_coordinates(self, split: Split = Split.all, transformed=True) -> pd.DataFrame:
if self.nb_coordinates_spatial == 0: if self.nb_spatial_coordinates == 0:
return pd.DataFrame() return pd.DataFrame()
else: else:
return self.df_coordinates(split, transformed).loc[:, self.coordinates_spatial_names].drop_duplicates() return self.df_coordinates(split, transformed).loc[:, self.spatial_coordinates_names].drop_duplicates()
@property @property
def nb_stations(self, split: Split = Split.all) -> int: def nb_stations(self, split: Split = Split.all) -> int:
...@@ -212,22 +216,28 @@ class AbstractCoordinates(object): ...@@ -212,22 +216,28 @@ class AbstractCoordinates(object):
# Temporal attributes # Temporal attributes
@property @property
def coordinates_temporal_names(self) -> List[str]: def temporal_dims(self):
start = self.nb_spatial_coordinates
end = start + self.nb_temporal_coordinates
return list(range(start, end))
@property
def temporal_coordinates_names(self) -> List[str]:
return [self.COORDINATE_T] if self.COORDINATE_T in self.df_all_coordinates else [] return [self.COORDINATE_T] if self.COORDINATE_T in self.df_all_coordinates else []
@property @property
def nb_coordinates_temporal(self) -> int: def nb_temporal_coordinates(self) -> int:
return len(self.coordinates_temporal_names) return len(self.temporal_coordinates_names)
@property @property
def has_temporal_coordinates(self) -> bool: def has_temporal_coordinates(self) -> bool:
return self.nb_coordinates_temporal > 0 return self.nb_temporal_coordinates > 0
def df_temporal_coordinates(self, split: Split = Split.all, transformed=True) -> pd.DataFrame: def df_temporal_coordinates(self, split: Split = Split.all, transformed=True) -> pd.DataFrame:
if self.nb_coordinates_temporal == 0: if self.nb_temporal_coordinates == 0:
return pd.DataFrame() return pd.DataFrame()
else: else:
return self.df_coordinates(split, transformed=transformed).loc[:, self.coordinates_temporal_names] \ return self.df_coordinates(split, transformed=transformed).loc[:, self.temporal_coordinates_names] \
.drop_duplicates() .drop_duplicates()
def df_temporal_coordinates_for_fit(self, split=Split.all, starting_point=None) -> pd.DataFrame: def df_temporal_coordinates_for_fit(self, split=Split.all, starting_point=None) -> pd.DataFrame:
...@@ -297,27 +307,27 @@ class AbstractCoordinates(object): ...@@ -297,27 +307,27 @@ class AbstractCoordinates(object):
return self.df_all_coordinates[self.COORDINATE_T].values.copy() return self.df_all_coordinates[self.COORDINATE_T].values.copy()
def visualize(self): def visualize(self):
if self.nb_coordinates_spatial == 1: if self.nb_spatial_coordinates == 1:
self.visualization_1D() self.visualization_1D()
elif self.nb_coordinates_spatial == 2: elif self.nb_spatial_coordinates == 2:
self.visualization_2D() self.visualization_2D()
else: else:
self.visualization_3D() self.visualization_3D()
def visualization_1D(self): def visualization_1D(self):
assert self.nb_coordinates_spatial >= 1 assert self.nb_spatial_coordinates >= 1
x = self.x_coordinates x = self.x_coordinates
y = np.zeros(len(x)) y = np.zeros(len(x))
plt.scatter(x, y) plt.scatter(x, y)
plt.show() plt.show()
def visualization_2D(self): def visualization_2D(self):
assert self.nb_coordinates_spatial >= 2 assert self.nb_spatial_coordinates >= 2
plt.scatter(self.x_coordinates, self.y_coordinates) plt.scatter(self.x_coordinates, self.y_coordinates)
plt.show() plt.show()
def visualization_3D(self): def visualization_3D(self):
assert self.nb_coordinates_spatial == 3 assert self.nb_spatial_coordinates == 3
fig = plt.figure() fig = plt.figure()
ax = fig.add_subplot(111, projection='3d') # type: Axes3D ax = fig.add_subplot(111, projection='3d') # type: Axes3D
ax.scatter(self.x_coordinates, self.y_coordinates, self.z_coordinates, marker='^') ax.scatter(self.x_coordinates, self.y_coordinates, self.z_coordinates, marker='^')
......
...@@ -26,7 +26,7 @@ class TestMaxStableEstimators(unittest.TestCase): ...@@ -26,7 +26,7 @@ class TestMaxStableEstimators(unittest.TestCase):
def fit_max_stable_estimator_for_all_coordinates(self): def fit_max_stable_estimator_for_all_coordinates(self):
for coordinates in self.coordinates: for coordinates in self.coordinates:
for max_stable_model in self.max_stable_models: for max_stable_model in self.max_stable_models:
use_rmaxstab_with_2_coordinates = coordinates.nb_coordinates_spatial > 2 use_rmaxstab_with_2_coordinates = coordinates.nb_spatial_coordinates > 2
dataset = MaxStableDataset.from_sampling(nb_obs=self.nb_obs, dataset = MaxStableDataset.from_sampling(nb_obs=self.nb_obs,
max_stable_model=max_stable_model, max_stable_model=max_stable_model,
coordinates=coordinates, coordinates=coordinates,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment