From 990754ac6b01519b52873203fa496e459f4c1097 Mon Sep 17 00:00:00 2001 From: Le Roux Erwan <erwan.le-roux@irstea.fr> Date: Mon, 13 May 2019 11:47:38 +0200 Subject: [PATCH] [COORDINATES] add test_coordinate_sensitivity --- .../margin_model/parametric_margin_model.py | 1 - .../coordinates/abstract_coordinates.py | 13 +++++++--- ..._bug.py => test_coordinate_sensitivity.py} | 26 +++++++++---------- 3 files changed, 21 insertions(+), 19 deletions(-) rename test/test_experiment/{test_weird_bug.py => test_coordinate_sensitivity.py} (58%) diff --git a/extreme_estimator/extreme_models/margin_model/parametric_margin_model.py b/extreme_estimator/extreme_models/margin_model/parametric_margin_model.py index 7acacf83..feb0f46b 100644 --- a/extreme_estimator/extreme_models/margin_model/parametric_margin_model.py +++ b/extreme_estimator/extreme_models/margin_model/parametric_margin_model.py @@ -39,7 +39,6 @@ class ParametricMarginModel(AbstractMarginModel, ABC): # Enforce a starting point for the temporal trend if self.transformed_starting_point is not None: # Compute the indices to modify - print('transformed starting point', self.transformed_starting_point) ind_to_modify = df_coordinates_temp.iloc[:, 0] <= self.transformed_starting_point # type: pd.Series # Assert that some coordinates are selected but not all (at least 20 data should be left for temporal trend) assert 0 < sum(ind_to_modify) < len(ind_to_modify) - 20 diff --git a/spatio_temporal_dataset/coordinates/abstract_coordinates.py b/spatio_temporal_dataset/coordinates/abstract_coordinates.py index 84afada1..a51fe218 100644 --- a/spatio_temporal_dataset/coordinates/abstract_coordinates.py +++ b/spatio_temporal_dataset/coordinates/abstract_coordinates.py @@ -36,6 +36,7 @@ class AbstractCoordinates(object): # Coordinates columns COORDINATES_NAMES = COORDINATE_SPATIAL_NAMES + [COORDINATE_T] # Coordinate type + ALL_COORDINATES_ACCEPTED_TYPES = ['int64', 'float64'] COORDINATE_TYPE = 'float64' def __init__(self, df: pd.DataFrame, slicer_class: type, s_split_spatial: pd.Series = None, @@ -46,8 +47,10 @@ class AbstractCoordinates(object): # Sort coordinates according to a specified order sorted_coordinates_columns = [c for c in self.COORDINATES_NAMES if c in coordinate_columns] self.df_all_coordinates = df.loc[:, sorted_coordinates_columns].copy() # type: pd.DataFrame - # Cast df_all_coordinates to the desired type - self.df_all_coordinates = self.df_all_coordinates.astype(self.COORDINATE_TYPE) + # Check the data type of the coordinate columns + accepted_dtypes = ['int'] + assert len(self.df_all_coordinates.select_dtypes(include=self.ALL_COORDINATES_ACCEPTED_TYPES).columns) \ + == len(coordinate_columns), 'coordinates columns dtypes should belong to {}'.format(accepted_dtypes) # Slicing attributes self.s_split_spatial = s_split_spatial # type: pd.Series @@ -128,12 +131,14 @@ class AbstractCoordinates(object): # Normalize def transform(self, coordinate: np.ndarray) -> np.ndarray: - return self.transformation.transform_array(coordinate=coordinate) + coordinate_float = coordinate.astype(self.COORDINATE_TYPE) + return self.transformation.transform_array(coordinate=coordinate_float) # Split def df_coordinates(self, split: Split = Split.all) -> pd.DataFrame: - df_transformed_coordinates = self.transformation.transform_df(df_coord=self.df_all_coordinates) + df_all_coordinate_as_float = self.df_all_coordinates.astype(self.COORDINATE_TYPE) # type: pd.DataFrame + df_transformed_coordinates = self.transformation.transform_df(df_all_coordinate_as_float) return df_sliced(df=df_transformed_coordinates, split=split, slicer=self.slicer) def coordinates_values(self, split: Split = Split.all) -> np.ndarray: diff --git a/test/test_experiment/test_weird_bug.py b/test/test_experiment/test_coordinate_sensitivity.py similarity index 58% rename from test/test_experiment/test_weird_bug.py rename to test/test_experiment/test_coordinate_sensitivity.py index 77ad1869..8fcc5363 100644 --- a/test/test_experiment/test_weird_bug.py +++ b/test/test_experiment/test_coordinate_sensitivity.py @@ -12,28 +12,26 @@ from utils import get_display_name_from_object_type class TestCoordinateSensitivity(unittest.TestCase): + DISPLAY = False - def test_weird(self): - # todo: maybe the code does not like negative coordinates - # todo: maybe not that the sign of the x coordinate are all negative and the other are all positive, it is easier to find the perfect spatial structure + def test_coordinate_normalization_sensitivity(self): altitudes = [3000] transformation_classes = [BetweenZeroAndOneNormalization, BetweenMinusOneAndOneNormalization][:] for transformation_class in transformation_classes: study_classes = [CrocusSwe] for study in study_iterator_global(study_classes, altitudes=altitudes, verbose=False): - print('\n\n') study_visualizer = StudyVisualizer(study, transformation_class=transformation_class) study_visualizer.temporal_non_stationarity = True - print(study_visualizer.coordinates) - # trend_test = ConditionalIndedendenceLocationTrendTest(study_visualizer.dataset) - # # years = [1960, 1990] - # # mu1s = [trend_test.get_mu1(year) for year in years] - # # print('Stationary') - # # print(trend_test.get_estimator(trend_test.stationary_margin_model_class, starting_point=None).margin_function_fitted.coef_dict) - # print('Non Stationary') - # print(trend_test.get_estimator(trend_test.non_stationary_margin_model_class, starting_point=1960).margin_function_fitted.coef_dict) - # # print(get_display_name_from_object_type(type(transformation_2D)), 'mu1s: ', mu1s) - # # self.assertTrue(0.0 not in mu1s) + trend_test = ConditionalIndedendenceLocationTrendTest(study_visualizer.dataset) + years = [1960, 1990] + mu1s = [trend_test.get_mu1(year) for year in years] + if self.DISPLAY: + print('Stationary') + print(trend_test.get_estimator(trend_test.stationary_margin_model_class, starting_point=None).margin_function_fitted.coef_dict) + print('Non Stationary') + print(trend_test.get_estimator(trend_test.non_stationary_margin_model_class, starting_point=1960).margin_function_fitted.coef_dict) + print(get_display_name_from_object_type(transformation_class), 'mu1s: ', mu1s) + self.assertTrue(0.0 not in mu1s) if __name__ == '__main__': -- GitLab