diff --git a/experiment/simulation/abstract_simulation.py b/experiment/simulation/abstract_simulation.py index 92d16edcd2649d8631a0a1e25463474558903307..4319f7842c346a6db8145aa9e4cabd449e6e9637 100644 --- a/experiment/simulation/abstract_simulation.py +++ b/experiment/simulation/abstract_simulation.py @@ -120,9 +120,10 @@ class AbstractSimulation(object): data = self.mean_error_dict[gev_value_name].values data_min, data_max = data.min(), data.max() - nb_bins = 10 + nb_bins = 1 limits = np.linspace(data_min, data_max, num=nb_bins + 1) limits[-1] += 0.01 + limits[0] -= 0.01 # Binary color should colors = cm.binary((limits - data_min / (data_max - data_min))) @@ -134,8 +135,9 @@ class AbstractSimulation(object): self.full_dataset.coordinates.coordinates_index(split)].values data_filter = np.logical_and(left_limit <= data_ind, data_ind < right_limit) + # todo: fix binary color problem self.margin_function_sample.set_datapoint_display_parameters(split, datapoint_marker=marker, - filter=data_filter, color=color) + filter=data_filter) self.margin_function_sample.visualize_single_param(gev_value_name, ax, show=False) # Display the individual fitted curve diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py index bc593b7ef976f81cf1548693620fdbaec29ff0a5..a2475ef6780d449b147c53d5091fde7f1a91b453 100644 --- a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py +++ b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py @@ -95,15 +95,17 @@ class AbstractMarginFunction(object): plt.show() def grid_1D(self, x): - if self._grid_1D is None: - self._grid_1D = self.get_grid_values_1D(x) - return self._grid_1D + # if self._grid_1D is None: + # self._grid_1D = self.get_grid_values_1D(x) + # return self._grid_1D + return self.get_grid_values_1D(x, self.spatio_temporal_split) - def get_grid_values_1D(self, x): + def get_grid_values_1D(self, x, spatio_temporal_split): # TODO: to avoid getting the value several times, I could cache the results if self.datapoint_display: # todo: keep only the index of interest here - linspace = self.coordinates.coordinates_values(self.spatio_temporal_split)[:, 0] + linspace = self.coordinates.coordinates_values(spatio_temporal_split)[:, 0] + print(self.spatio_temporal_split, linspace) if self.filter is not None: linspace = linspace[self.filter] resolution = len(linspace) @@ -142,9 +144,9 @@ class AbstractMarginFunction(object): plt.show() def grid_2D(self, x, y): - if self._grid_2D is None: - self._grid_2D = self.get_grid_2D(x, y) - return self._grid_2D + # if self._grid_2D is None: + # self._grid_2D = self.get_grid_2D(x, y) + return self.get_grid_2D(x, y) def get_grid_2D(self, x, y): resolution = 100