diff --git a/experiment/simulation/abstract_simulation.py b/experiment/simulation/abstract_simulation.py
index 92d16edcd2649d8631a0a1e25463474558903307..4319f7842c346a6db8145aa9e4cabd449e6e9637 100644
--- a/experiment/simulation/abstract_simulation.py
+++ b/experiment/simulation/abstract_simulation.py
@@ -120,9 +120,10 @@ class AbstractSimulation(object):
 
         data = self.mean_error_dict[gev_value_name].values
         data_min, data_max = data.min(), data.max()
-        nb_bins = 10
+        nb_bins = 1
         limits = np.linspace(data_min, data_max, num=nb_bins + 1)
         limits[-1] += 0.01
+        limits[0] -= 0.01
         # Binary color should
         colors = cm.binary((limits - data_min / (data_max - data_min)))
 
@@ -134,8 +135,9 @@ class AbstractSimulation(object):
                     self.full_dataset.coordinates.coordinates_index(split)].values
                 data_filter = np.logical_and(left_limit <= data_ind, data_ind < right_limit)
 
+                # todo: fix binary color problem
                 self.margin_function_sample.set_datapoint_display_parameters(split, datapoint_marker=marker,
-                                                                             filter=data_filter, color=color)
+                                                                             filter=data_filter)
                 self.margin_function_sample.visualize_single_param(gev_value_name, ax, show=False)
 
         # Display the individual fitted curve
diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
index bc593b7ef976f81cf1548693620fdbaec29ff0a5..a2475ef6780d449b147c53d5091fde7f1a91b453 100644
--- a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
+++ b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
@@ -95,15 +95,17 @@ class AbstractMarginFunction(object):
             plt.show()
 
     def grid_1D(self, x):
-        if self._grid_1D is None:
-            self._grid_1D = self.get_grid_values_1D(x)
-        return self._grid_1D
+        # if self._grid_1D is None:
+        #     self._grid_1D = self.get_grid_values_1D(x)
+        # return self._grid_1D
+        return self.get_grid_values_1D(x, self.spatio_temporal_split)
 
-    def get_grid_values_1D(self, x):
+    def get_grid_values_1D(self, x, spatio_temporal_split):
         # TODO: to avoid getting the value several times, I could cache the results
         if self.datapoint_display:
             # todo: keep only the index of interest here
-            linspace = self.coordinates.coordinates_values(self.spatio_temporal_split)[:, 0]
+            linspace = self.coordinates.coordinates_values(spatio_temporal_split)[:, 0]
+            print(self.spatio_temporal_split, linspace)
             if self.filter is not None:
                 linspace = linspace[self.filter]
             resolution = len(linspace)
@@ -142,9 +144,9 @@ class AbstractMarginFunction(object):
             plt.show()
 
     def grid_2D(self, x, y):
-        if self._grid_2D is None:
-            self._grid_2D = self.get_grid_2D(x, y)
-        return self._grid_2D
+        # if self._grid_2D is None:
+        #     self._grid_2D = self.get_grid_2D(x, y)
+        return self.get_grid_2D(x, y)
 
     def get_grid_2D(self, x, y):
         resolution = 100