diff --git a/experiment/trend_analysis/non_stationary_trends.py b/experiment/trend_analysis/non_stationary_trends.py index b3a5b920a09b7ac721497f694dbf707fa490a147..e1eecad912e836392c57448ea886427cd2819bd4 100644 --- a/experiment/trend_analysis/non_stationary_trends.py +++ b/experiment/trend_analysis/non_stationary_trends.py @@ -11,7 +11,6 @@ from extreme_estimator.estimator.full_estimator.abstract_full_estimator import \ from extreme_estimator.estimator.margin_estimator.abstract_margin_estimator import LinearMarginEstimator, \ AbstractMarginEstimator from extreme_estimator.extreme_models.margin_model.linear_margin_model import \ - LinearAllParametersTwoFirstCoordinatesMarginModel, LinearAllTwoStatialCoordinatesLocationLinearMarginModel, \ LinearStationaryMarginModel, LinearNonStationaryLocationMarginModel from extreme_estimator.extreme_models.margin_model.margin_function.linear_margin_function import LinearMarginFunction from extreme_estimator.extreme_models.margin_model.temporal_linear_margin_model import StationaryStationModel, \ @@ -46,8 +45,7 @@ class AbstractNonStationaryTrendTest(object): self._starting_point_to_estimator[starting_point] = estimator return self._starting_point_to_estimator[starting_point] - def load_estimator(self, starting_point) -> Union[ - AbstractFullEstimator, AbstractMarginEstimator]: + def load_estimator(self, starting_point) -> Union[AbstractFullEstimator, AbstractMarginEstimator]: margin_model_class = self.stationary_margin_model_class if starting_point is None else self.non_stationary_margin_model_class assert starting_point not in self._starting_point_to_estimator margin_model = margin_model_class(coordinates=self.dataset.coordinates, starting_point=starting_point) diff --git a/extreme_estimator/extreme_models/margin_model/linear_margin_model.py b/extreme_estimator/extreme_models/margin_model/linear_margin_model.py index 7956be5c2c7546ac626b3495760bcf7b0cd7e999..1cf924def4a29307e88ee41cc8bbba007c289e95 100644 --- a/extreme_estimator/extreme_models/margin_model/linear_margin_model.py +++ b/extreme_estimator/extreme_models/margin_model/linear_margin_model.py @@ -118,26 +118,17 @@ class LinearAllParametersAllDimsMarginModel(LinearMarginModel): GevParams.SCALE: self.coordinates.coordinates_dims}) -class LinearAllParametersTwoFirstCoordinatesMarginModel(LinearMarginModel): +class LinearStationaryMarginModel(LinearMarginModel): def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_dims=None): - super().load_margin_functions({GevParams.SHAPE: [0, 1], - GevParams.LOC: [0, 1], - GevParams.SCALE: [0, 1]}) + super().load_margin_functions({GevParams.SHAPE: self.coordinates.spatial_coordinates_dims, + GevParams.LOC: self.coordinates.spatial_coordinates_dims, + GevParams.SCALE: self.coordinates.spatial_coordinates_dims}) -class LinearAllTwoStatialCoordinatesLocationLinearMarginModel(LinearMarginModel): +class LinearNonStationaryLocationMarginModel(LinearMarginModel): def load_margin_functions(self, margin_function_class: type = None, gev_param_name_to_dims=None): - super().load_margin_functions({GevParams.SHAPE: [0, 1], - GevParams.LOC: [0, 1, 2], - GevParams.SCALE: [0, 1]}) - - -# Some renaming that defines the stationary and non-stationary models of reference -class LinearStationaryMarginModel(LinearAllParametersTwoFirstCoordinatesMarginModel): - pass - - -class LinearNonStationaryLocationMarginModel(LinearAllTwoStatialCoordinatesLocationLinearMarginModel): - pass + super().load_margin_functions({GevParams.SHAPE: self.coordinates.spatial_coordinates_dims, + GevParams.LOC: self.coordinates.coordinates_dims, + GevParams.SCALE: self.coordinates.spatial_coordinates_dims}) diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py index ada3da81f6f2f3112dda56942902b07c2d1d90eb..7d9c2a7af6d1cd7ed4c0c2992664cae008054bdb 100644 --- a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py +++ b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py @@ -103,7 +103,7 @@ class AbstractMarginFunction(object): def visualize_single_param(self, gev_value_name=GevParams.LOC, ax=None, show=True): assert gev_value_name in GevParams.SUMMARY_NAMES - nb_coordinates_spatial = self.coordinates.nb_coordinates_spatial + nb_coordinates_spatial = self.coordinates.nb_spatial_coordinates has_temporal_coordinates = self.coordinates.has_temporal_coordinates if nb_coordinates_spatial == 1 and not has_temporal_coordinates: self.visualize_1D(gev_value_name, ax, show) diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/linear_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/linear_margin_function.py index ad975235f95926f6b46f6221a53f4e17beded1e5..ac760eb6ec0685a3e099b03276ebe1a53e3b7ec5 100644 --- a/extreme_estimator/extreme_models/margin_model/margin_function/linear_margin_function.py +++ b/extreme_estimator/extreme_models/margin_model/margin_function/linear_margin_function.py @@ -65,14 +65,14 @@ class LinearMarginFunction(ParametricMarginFunction): for gev_param_name in GevParams.PARAM_NAMES: linear_dims = self.gev_param_name_to_dims.get(gev_param_name, []) # Load spatial form_dict (only if we have some spatial coordinates) - if self.coordinates.coordinates_spatial_names: - spatial_names = [name for name in self.coordinates.coordinates_spatial_names + if self.coordinates.has_spatial_coordinates: + spatial_names = [name for name in self.coordinates.spatial_coordinates_names if self.coefficient_name_to_dim(self.coordinates)[name] in linear_dims] spatial_form = self.gev_param_name_to_coef[gev_param_name].spatial_form_dict(spatial_names) form_dict.update(spatial_form) # Load temporal form dict (only if we have some temporal coordinates) - if self.coordinates.coordinates_temporal_names: - temporal_names = [name for name in self.coordinates.coordinates_temporal_names + if self.coordinates.has_temporal_coordinates: + temporal_names = [name for name in self.coordinates.temporal_coordinates_names if self.coefficient_name_to_dim(self.coordinates)[name] in linear_dims] temporal_form = self.gev_param_name_to_coef[gev_param_name].temporal_form_dict(temporal_names) # Specifying a formula '~ 1' creates a bug in fitspatgev of SpatialExtreme R package diff --git a/spatio_temporal_dataset/coordinates/abstract_coordinates.py b/spatio_temporal_dataset/coordinates/abstract_coordinates.py index 1f1dd4f67267d483bf2edd9cd32f29530958699d..a4135fadfec26fc6a7631bda3baf52be6bfb2703 100644 --- a/spatio_temporal_dataset/coordinates/abstract_coordinates.py +++ b/spatio_temporal_dataset/coordinates/abstract_coordinates.py @@ -167,7 +167,7 @@ class AbstractCoordinates(object): @property def coordinates_names(self) -> List[str]: - return self.coordinates_spatial_names + self.coordinates_temporal_names + return self.spatial_coordinates_names + self.temporal_coordinates_names @property def nb_coordinates(self) -> int: @@ -180,22 +180,26 @@ class AbstractCoordinates(object): # Spatial attributes @property - def coordinates_spatial_names(self) -> List[str]: + def spatial_coordinates_dims(self): + return list(range(self.nb_spatial_coordinates)) + + @property + def spatial_coordinates_names(self) -> List[str]: return [name for name in self.COORDINATE_SPATIAL_NAMES if name in self.df_all_coordinates.columns] @property - def nb_coordinates_spatial(self) -> int: - return len(self.coordinates_spatial_names) + def nb_spatial_coordinates(self) -> int: + return len(self.spatial_coordinates_names) @property def has_spatial_coordinates(self) -> bool: - return self.nb_coordinates_spatial > 0 + return self.nb_spatial_coordinates > 0 def df_spatial_coordinates(self, split: Split = Split.all, transformed=True) -> pd.DataFrame: - if self.nb_coordinates_spatial == 0: + if self.nb_spatial_coordinates == 0: return pd.DataFrame() else: - return self.df_coordinates(split, transformed).loc[:, self.coordinates_spatial_names].drop_duplicates() + return self.df_coordinates(split, transformed).loc[:, self.spatial_coordinates_names].drop_duplicates() @property def nb_stations(self, split: Split = Split.all) -> int: @@ -212,22 +216,28 @@ class AbstractCoordinates(object): # Temporal attributes @property - def coordinates_temporal_names(self) -> List[str]: + def temporal_dims(self): + start = self.nb_spatial_coordinates + end = start + self.nb_temporal_coordinates + return list(range(start, end)) + + @property + def temporal_coordinates_names(self) -> List[str]: return [self.COORDINATE_T] if self.COORDINATE_T in self.df_all_coordinates else [] @property - def nb_coordinates_temporal(self) -> int: - return len(self.coordinates_temporal_names) + def nb_temporal_coordinates(self) -> int: + return len(self.temporal_coordinates_names) @property def has_temporal_coordinates(self) -> bool: - return self.nb_coordinates_temporal > 0 + return self.nb_temporal_coordinates > 0 def df_temporal_coordinates(self, split: Split = Split.all, transformed=True) -> pd.DataFrame: - if self.nb_coordinates_temporal == 0: + if self.nb_temporal_coordinates == 0: return pd.DataFrame() else: - return self.df_coordinates(split, transformed=transformed).loc[:, self.coordinates_temporal_names] \ + return self.df_coordinates(split, transformed=transformed).loc[:, self.temporal_coordinates_names] \ .drop_duplicates() def df_temporal_coordinates_for_fit(self, split=Split.all, starting_point=None) -> pd.DataFrame: @@ -297,27 +307,27 @@ class AbstractCoordinates(object): return self.df_all_coordinates[self.COORDINATE_T].values.copy() def visualize(self): - if self.nb_coordinates_spatial == 1: + if self.nb_spatial_coordinates == 1: self.visualization_1D() - elif self.nb_coordinates_spatial == 2: + elif self.nb_spatial_coordinates == 2: self.visualization_2D() else: self.visualization_3D() def visualization_1D(self): - assert self.nb_coordinates_spatial >= 1 + assert self.nb_spatial_coordinates >= 1 x = self.x_coordinates y = np.zeros(len(x)) plt.scatter(x, y) plt.show() def visualization_2D(self): - assert self.nb_coordinates_spatial >= 2 + assert self.nb_spatial_coordinates >= 2 plt.scatter(self.x_coordinates, self.y_coordinates) plt.show() def visualization_3D(self): - assert self.nb_coordinates_spatial == 3 + assert self.nb_spatial_coordinates == 3 fig = plt.figure() ax = fig.add_subplot(111, projection='3d') # type: Axes3D ax.scatter(self.x_coordinates, self.y_coordinates, self.z_coordinates, marker='^') diff --git a/test/test_extreme_estimator/test_estimator/test_max_stable_estimators.py b/test/test_extreme_estimator/test_estimator/test_max_stable_estimators.py index 6c7fa02574bfe9ca7f1d70915b6e3106e2e12a0c..8ba17afc9435306d014bb0a006f6fb314750a6dc 100644 --- a/test/test_extreme_estimator/test_estimator/test_max_stable_estimators.py +++ b/test/test_extreme_estimator/test_estimator/test_max_stable_estimators.py @@ -26,7 +26,7 @@ class TestMaxStableEstimators(unittest.TestCase): def fit_max_stable_estimator_for_all_coordinates(self): for coordinates in self.coordinates: for max_stable_model in self.max_stable_models: - use_rmaxstab_with_2_coordinates = coordinates.nb_coordinates_spatial > 2 + use_rmaxstab_with_2_coordinates = coordinates.nb_spatial_coordinates > 2 dataset = MaxStableDataset.from_sampling(nb_obs=self.nb_obs, max_stable_model=max_stable_model, coordinates=coordinates,