From 54462cff0e20727c449b27cf550df6540f67d392 Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Fri, 22 Mar 2019 11:40:13 +0100
Subject: [PATCH] [SCM] add smooth spatio_temporal margin fit

---
 .../meteo_france_SCM_study/abstract_study.py  |  2 +-
 .../meteo_france_SCM_study/safran/safran.py   |  2 +-
 .../study_visualization/study_visualizer.py   | 29 ++++++++-----------
 .../abstract_margin_estimator.py              |  3 +-
 .../abstract_margin_function.py               | 24 ++++++++-------
 .../coordinates/abstract_coordinates.py       |  4 +--
 6 files changed, 32 insertions(+), 32 deletions(-)

diff --git a/experiment/meteo_france_SCM_study/abstract_study.py b/experiment/meteo_france_SCM_study/abstract_study.py
index 89a3b59e..b0beb53d 100644
--- a/experiment/meteo_france_SCM_study/abstract_study.py
+++ b/experiment/meteo_france_SCM_study/abstract_study.py
@@ -72,7 +72,7 @@ class AbstractStudy(object):
         return year_to_dataset
 
     @property
-    def start_year_and_end_year(self) -> Tuple[int, int]:
+    def start_year_and_stop_year(self) -> Tuple[int, int]:
         ordered_years = list(self.year_to_dataset_ordered_dict.keys())
         return ordered_years[0], ordered_years[-1]
 
diff --git a/experiment/meteo_france_SCM_study/safran/safran.py b/experiment/meteo_france_SCM_study/safran/safran.py
index 5b8e19bc..98147a32 100644
--- a/experiment/meteo_france_SCM_study/safran/safran.py
+++ b/experiment/meteo_france_SCM_study/safran/safran.py
@@ -28,7 +28,7 @@ class SafranFrequency(Safran):
 
     @property
     def variable_name(self):
-        return super().variable_name + ' cumulated over {} days'.format(self.nb_consecutive_days)
+        return super().variable_name + ' cumulated over {} day(s)'.format(self.nb_consecutive_days)
 
     def annual_aggregation_function(self, *args, **kwargs):
         return np.sum(*args, **kwargs)
diff --git a/experiment/meteo_france_SCM_study/visualization/study_visualization/study_visualizer.py b/experiment/meteo_france_SCM_study/visualization/study_visualization/study_visualizer.py
index d767bc56..67f11692 100644
--- a/experiment/meteo_france_SCM_study/visualization/study_visualization/study_visualizer.py
+++ b/experiment/meteo_france_SCM_study/visualization/study_visualization/study_visualizer.py
@@ -14,12 +14,11 @@ from experiment.utils import average_smoothing_with_sliding_window
 from extreme_estimator.estimator.full_estimator.abstract_full_estimator import \
     FullEstimatorInASingleStepWithSmoothMargin
 from extreme_estimator.estimator.margin_estimator.abstract_margin_estimator import SmoothMarginEstimator
+from extreme_estimator.extreme_models.margin_model.linear_margin_model import LinearAllParametersAllDimsMarginModel
 from extreme_estimator.extreme_models.margin_model.margin_function.abstract_margin_function import \
     AbstractMarginFunction
 from extreme_estimator.extreme_models.margin_model.param_function.param_function import AbstractParamFunction
-from extreme_estimator.extreme_models.margin_model.linear_margin_model import LinearAllParametersAllDimsMarginModel
 from extreme_estimator.extreme_models.max_stable_model.abstract_max_stable_model import CovarianceFunction
-from extreme_estimator.extreme_models.max_stable_model.max_stable_models import BrownResnick
 from extreme_estimator.margin_fits.abstract_params import AbstractParams
 from extreme_estimator.margin_fits.gev.gev_params import GevParams
 from extreme_estimator.margin_fits.gev.gevmle_fit import GevMleFit
@@ -69,8 +68,11 @@ class StudyVisualizer(object):
         self.subplot_space = 0.5
         self.coef_zoom_map = 1
 
+        # Modify some class attributes
         # Remove some assert
         AbstractParamFunction.OUT_OF_BOUNDS_ASSERT = False
+        # INCREASE THE TEMPORAL STEPS FOR VISUALIZATION
+        AbstractMarginFunction.VISUALIZATION_TEMPORAL_STEPS = 5
 
     @property
     def dataset(self):
@@ -85,8 +87,8 @@ class StudyVisualizer(object):
             if self.temporal_non_stationarity:
                 # Build spatio temporal dataset from a temporal dataset
                 df_spatial = coordinates.df_spatial_coordinates()
-                start, end = self.study.start_year_and_end_year
-                nb_steps = end - start + 1
+                start, stop = self.study.start_year_and_stop_year
+                nb_steps = stop - start + 1
                 coordinates = AbstractSpatioTemporalCoordinates.from_df_spatial_and_nb_steps(df_spatial=df_spatial,
                                                                                              nb_steps=nb_steps,
                                                                                              start=start)
@@ -246,6 +248,7 @@ class StudyVisualizer(object):
 
     def visualize_linear_margin_fit(self, only_first_max_stable=False):
         default_covariance_function = CovarianceFunction.powexp
+        margin_class = LinearAllParametersAllDimsMarginModel
         plot_name = 'Full Likelihood with Linear marginals and max stable dependency structure'
         plot_name += '\n(with {} covariance structure when a covariance is needed)'.format(
             str(default_covariance_function).split('.')[-1])
@@ -277,7 +280,6 @@ class StudyVisualizer(object):
             # Plot the margin fit independently on the additional row
             self.visualize_independent_margin_fits(threshold=None, axes=axes[-1], show=False)
 
-        margin_class = LinearAllParametersAllDimsMarginModel
         # Plot the smooth margin only
         margin_model = margin_class(coordinates=self.coordinates)
         estimator = SmoothMarginEstimator(dataset=self.dataset, margin_model=margin_model)
@@ -300,30 +302,23 @@ class StudyVisualizer(object):
         margin_fct._visualization_x_limits = self.study.visualization_x_limits
         margin_fct._visualization_y_limits = self.study.visualization_y_limits
         margin_fct.mask_2D = self.study.mask_french_alps
+        if self.temporal_non_stationarity:
+            margin_fct.add_future_temporal_steps = True
 
-        axes = margin_fct.visualize_function(show=False, axes=axes, title='') # type: np.ndarray
+        axes = margin_fct.visualize_function(show=False, axes=axes, title='')  # type: np.ndarray
 
         if axes.ndim == 1:
             self.visualize_contour_and_move_axes_limits(axes)
             self.clean_axes_write_title_on_the_left(axes, title)
         else:
             axes = np.transpose(axes)
-            for axes_line in axes:
+            for temporal_step, axes_line in zip(margin_fct.temporal_steps, axes):
                 self.visualize_contour_and_move_axes_limits(axes_line)
-                self.clean_axes_write_title_on_the_left(axes_line, title, left_border=False)
+                self.clean_axes_write_title_on_the_left(axes_line, str(temporal_step) + title, left_border=False)
 
     def visualize_contour_and_move_axes_limits(self, axes):
-        def get_lim_array(ax_with_lim_to_measure):
-            return np.array([np.array(ax_with_lim_to_measure.get_xlim()), np.array(ax_with_lim_to_measure.get_ylim())])
-
         for ax in axes:
-            # old_lim = get_lim_array(ax)
             self.study.visualize_study(ax, fill=False, show=False)
-            # new_lim = get_lim_array(ax)
-            # assert 0 <= self.coef_zoom_map <= 1
-            # updated_lim = new_lim * self.coef_zoom_map + (1 - self.coef_zoom_map) * old_lim
-            # for i, method in enumerate([ax.set_xlim, ax.set_ylim]):
-            #     method(updated_lim[i, 0], updated_lim[i, 1])
 
     @staticmethod
     def clean_axes_write_title_on_the_left(axes, title, left_border=True):
diff --git a/extreme_estimator/estimator/margin_estimator/abstract_margin_estimator.py b/extreme_estimator/estimator/margin_estimator/abstract_margin_estimator.py
index 8ffb8bf2..ab686e12 100644
--- a/extreme_estimator/estimator/margin_estimator/abstract_margin_estimator.py
+++ b/extreme_estimator/estimator/margin_estimator/abstract_margin_estimator.py
@@ -18,6 +18,7 @@ class AbstractMarginEstimator(AbstractEstimator, ABC):
     @property
     def margin_function_fitted(self) -> AbstractMarginFunction:
         assert self._margin_function_fitted is not None, 'Error: estimator has not been fitted'
+        assert isinstance(self._margin_function_fitted, AbstractMarginFunction)
         return self._margin_function_fitted
 
 
@@ -44,4 +45,4 @@ class SmoothMarginEstimator(AbstractMarginEstimator):
                                                                             df_coordinates_spat=df_coordinates_spat,
                                                                             df_coordinates_temp=df_coordinates_temp)
         self.extract_fitted_models_from_fitted_params(self.margin_model.margin_function_start_fit, self.fitted_values)
-        assert isinstance(self.margin_function_fitted, AbstractMarginFunction)
+
diff --git a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
index 2428f5fd..a2b5b3d0 100644
--- a/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
+++ b/extreme_estimator/extreme_models/margin_model/margin_function/abstract_margin_function.py
@@ -1,4 +1,4 @@
-from typing import Dict
+from typing import Dict, List
 
 import matplotlib.cm as cm
 import matplotlib.pyplot as plt
@@ -37,6 +37,7 @@ class AbstractMarginFunction(object):
         self.temporal_step_to_grid_2D = {}
         self._grid_1D = None
         self.title = None
+        self.add_future_temporal_steps = False
 
         # Visualization limits
         self._visualization_x_limits = None
@@ -220,22 +221,25 @@ class AbstractMarginFunction(object):
 
     # Visualization 3D
 
-    def visualize_2D_spatial_1D_temporal(self, gev_param_name=GevParams.LOC, axes=None, show=True,
-                                         add_future_temporal_steps=False):
+    def visualize_2D_spatial_1D_temporal(self, gev_param_name=GevParams.LOC, axes=None, show=True):
         if axes is None:
             axes = create_adjusted_axes(self.VISUALIZATION_TEMPORAL_STEPS, 1)
         assert len(axes) == self.VISUALIZATION_TEMPORAL_STEPS
 
         # Build temporal_steps a list of time steps
-        future_temporal_steps = [10, 100] if add_future_temporal_steps else []
-        nb_past_temporal_step = self.VISUALIZATION_TEMPORAL_STEPS - len(future_temporal_steps)
-        start, stop = self.coordinates.df_temporal_range()
-        temporal_steps = list(np.linspace(start, stop, num=nb_past_temporal_step)) + future_temporal_steps
-        assert len(temporal_steps) == self.VISUALIZATION_TEMPORAL_STEPS
-
-        for ax, temporal_step in zip(axes, temporal_steps):
+        assert len(self.temporal_steps) == self.VISUALIZATION_TEMPORAL_STEPS
+        for ax, temporal_step in zip(axes, self.temporal_steps):
             self.visualize_2D(gev_param_name, ax, show=False, temporal_step=temporal_step)
             self.set_title(ax, gev_param_name)
 
         if show:
             plt.show()
+
+    @cached_property
+    def temporal_steps(self) -> List[int]:
+        future_temporal_steps = [10, 100] if self.add_future_temporal_steps else []
+        nb_past_temporal_step = self.VISUALIZATION_TEMPORAL_STEPS - len(future_temporal_steps)
+        start, stop = self.coordinates.df_temporal_range()
+        temporal_steps = [int(step) for step in np.linspace(start, stop, num=nb_past_temporal_step)]
+        temporal_steps += [stop + step for step in future_temporal_steps]
+        return temporal_steps
diff --git a/spatio_temporal_dataset/coordinates/abstract_coordinates.py b/spatio_temporal_dataset/coordinates/abstract_coordinates.py
index 3d259a95..5b1a0ea1 100644
--- a/spatio_temporal_dataset/coordinates/abstract_coordinates.py
+++ b/spatio_temporal_dataset/coordinates/abstract_coordinates.py
@@ -202,9 +202,9 @@ class AbstractCoordinates(object):
         else:
             return self.df_coordinates(split).loc[:, self.coordinates_temporal_names].drop_duplicates()
 
-    def df_temporal_range(self, split: Split = Split.all) -> Tuple[float, float]:
+    def df_temporal_range(self, split: Split = Split.all) -> Tuple[int, int]:
         df_temporal_coordinates = self.df_temporal_coordinates(split)
-        return float(df_temporal_coordinates.min()), float(df_temporal_coordinates.max()),
+        return int(df_temporal_coordinates.min()), int(df_temporal_coordinates.max()),
 
     #  Visualization
 
-- 
GitLab