From 9b5b77e77d52aa808e4d12c0ffb9852e2b8d44ba Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Mon, 20 May 2019 16:55:31 +0200
Subject: [PATCH] [SCM][TREND TEST] add try catch for gev trend test.

---
 .../main_studies_visualizer.py                |  4 +-
 .../studies_visualizer.py                     |  9 +++--
 .../abstract_gev_trend_test.py                | 38 ++++++++++++++-----
 .../abstract_trend_test.py                    | 22 +++++------
 extreme_estimator/extreme_models/utils.py     |  5 ++-
 5 files changed, 51 insertions(+), 27 deletions(-)

diff --git a/experiment/meteo_france_SCM_study/visualization/studies_visualization/main_studies_visualizer.py b/experiment/meteo_france_SCM_study/visualization/studies_visualization/main_studies_visualizer.py
index ea254cd6..966b87a5 100644
--- a/experiment/meteo_france_SCM_study/visualization/studies_visualization/main_studies_visualizer.py
+++ b/experiment/meteo_france_SCM_study/visualization/studies_visualization/main_studies_visualizer.py
@@ -44,9 +44,9 @@ def altitude_trends_significant():
     # altitudes that have 20 massifs at least
     altitudes = ALL_ALTITUDES[3:-6]
     # altitudes = ALL_ALTITUDES[3:5]
-    # altitudes = ALL_ALTITUDES[2:4]
+    altitudes = ALL_ALTITUDES[2:4]
     for study_class in SCM_STUDIES[:1]:
-        trend_test_classes = [MannKendallTrendTest, GevLocationTrendTest, GevScaleTrendTest, GevShapeTrendTest][2:]
+        trend_test_classes = [MannKendallTrendTest, GevLocationTrendTest, GevScaleTrendTest, GevShapeTrendTest][3:]
         visualizers = [StudyVisualizer(study, temporal_non_stationarity=True, verbose=False)
                        for study in study_iterator_global(study_classes=[study_class], only_first_one=only_first_one,
                                                           altitudes=altitudes)]
diff --git a/experiment/meteo_france_SCM_study/visualization/studies_visualization/studies_visualizer.py b/experiment/meteo_france_SCM_study/visualization/studies_visualization/studies_visualizer.py
index 8ca0156c..810f64d6 100644
--- a/experiment/meteo_france_SCM_study/visualization/studies_visualization/studies_visualizer.py
+++ b/experiment/meteo_france_SCM_study/visualization/studies_visualization/studies_visualizer.py
@@ -176,7 +176,7 @@ class AltitudeVisualizer(object):
         fig, ax = plt.subplots(1, 1, figsize=self.any_study_visualizer.figsize)
 
         # Create one display for each trend test class
-        markers = ['o', '+']
+        markers = ['o', '*', 's', 'D']
         assert len(markers) >= len(trend_test_classes)
         # Add a second legend for the color and to explain the line
 
@@ -185,7 +185,8 @@ class AltitudeVisualizer(object):
 
         # Add the color legend
         handles, labels = ax.get_legend_handles_labels()
-        handles_ax, labels_ax = handles[:5], labels[:5]
+        nb_trend_types = len(AbstractTrendTest.trend_type_to_style())
+        handles_ax, labels_ax = handles[:nb_trend_types], labels[:nb_trend_types]
         ax.legend(handles_ax, labels_ax, markerscale=0.0, loc=1)
         ax.set_xticks(self.altitudes)
         ax.set_yticks(list(range(0, 101, 10)))
@@ -193,7 +194,7 @@ class AltitudeVisualizer(object):
 
         # Add the marker legend
         names = [get_display_name_from_object_type(c) for c in trend_test_classes]
-        handles_ax2, labels_ax2 = handles[::5], names
+        handles_ax2, labels_ax2 = handles[::nb_trend_types], names
         ax2 = ax.twinx()
         ax2.legend(handles_ax2, labels_ax2, loc=2)
         ax2.set_yticks([])
@@ -224,7 +225,7 @@ class AltitudeVisualizer(object):
             s = study_visualizer.serie_mean_trend_test_count(trend_test_class, starting_year_to_weights)
             altitude_to_serie_with_mean_percentages[altitude] = s
         # Plot lines
-        for trend_type, style in AbstractTrendTest.TREND_TYPE_TO_STYLE.items():
+        for trend_type, style in AbstractTrendTest.trend_type_to_style().items():
             percentages = [v.loc[trend_type] if trend_type in v.index else 0.0
                            for v in altitude_to_serie_with_mean_percentages.values()]
             if set(percentages) == {0.0}:
diff --git a/experiment/trend_analysis/univariate_trend_test/abstract_gev_trend_test.py b/experiment/trend_analysis/univariate_trend_test/abstract_gev_trend_test.py
index b5b3fd3e..97caaa58 100644
--- a/experiment/trend_analysis/univariate_trend_test/abstract_gev_trend_test.py
+++ b/experiment/trend_analysis/univariate_trend_test/abstract_gev_trend_test.py
@@ -6,6 +6,7 @@ from experiment.trend_analysis.univariate_trend_test.abstract_trend_test import
 from extreme_estimator.estimator.margin_estimator.abstract_margin_estimator import LinearMarginEstimator
 from extreme_estimator.extreme_models.margin_model.temporal_linear_margin_model import StationaryStationModel, \
     NonStationaryLocationStationModel, NonStationaryScaleStationModel, NonStationaryShapeStationModel
+from extreme_estimator.extreme_models.utils import SafeRunException
 from extreme_estimator.margin_fits.gev.gev_params import GevParams
 from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
 from spatio_temporal_dataset.coordinates.temporal_coordinates.abstract_temporal_coordinates import \
@@ -20,24 +21,28 @@ from spatio_temporal_dataset.spatio_temporal_observations.abstract_spatio_tempor
 
 
 class AbstractGevTrendTest(AbstractTrendTest):
+    RRunTimeError_TREND = 'R RunTimeError trend'
 
     def __init__(self, years_after_change_point, maxima_after_change_point, non_stationary_model_class):
         super().__init__(years_after_change_point, maxima_after_change_point)
         df = pd.DataFrame({AbstractCoordinates.COORDINATE_T: years_after_change_point})
         df_maxima_gev = pd.DataFrame(maxima_after_change_point, index=df.index)
         observations = AbstractSpatioTemporalObservations(df_maxima_gev=df_maxima_gev)
-        self.coordinates = AbstractTemporalCoordinates.from_df(df, transformation_class=BetweenZeroAndOneNormalization)
-        # self.coordinates = AbstractTemporalCoordinates.from_df(df, transformation_class=CenteredScaledNormalization)
+        self.coordinates = AbstractTemporalCoordinates.from_df(df, transformation_class=CenteredScaledNormalization)
         self.dataset = AbstractDataset(observations=observations, coordinates=self.coordinates)
 
-        # Fit stationary model
-        self.stationary_estimator = LinearMarginEstimator(self.dataset, StationaryStationModel(self.coordinates))
-        self.stationary_estimator.fit()
+        try:
+            # Fit stationary model
+            self.stationary_estimator = LinearMarginEstimator(self.dataset, StationaryStationModel(self.coordinates))
+            self.stationary_estimator.fit()
 
-        # Fit non stationary model
-        self.non_stationary_estimator = LinearMarginEstimator(self.dataset,
-                                                              non_stationary_model_class(self.coordinates))
-        self.non_stationary_estimator.fit()
+            # Fit non stationary model
+            self.non_stationary_estimator = LinearMarginEstimator(self.dataset,
+                                                                  non_stationary_model_class(self.coordinates))
+            self.non_stationary_estimator.fit()
+            self.crashed = False
+        except SafeRunException:
+            self.crashed = True
 
     @property
     def likelihood_ratio(self):
@@ -48,6 +53,21 @@ class AbstractGevTrendTest(AbstractTrendTest):
     def is_significant(self) -> bool:
         return self.likelihood_ratio > chi2.ppf(q=1 - self.SIGNIFICANCE_LEVEL, df=1)
 
+    # Add a trend type that correspond to run that crashed
+
+    @classmethod
+    def trend_type_to_style(cls):
+        trend_type_to_style = super().trend_type_to_style()
+        trend_type_to_style[cls.RRunTimeError_TREND] = 'b:'
+        return trend_type_to_style
+
+    @property
+    def test_trend_type(self) -> str:
+        if self.crashed:
+            return self.RRunTimeError_TREND
+        else:
+            return super().test_trend_type
+
 
 class GevLocationTrendTest(AbstractGevTrendTest):
 
diff --git a/experiment/trend_analysis/univariate_trend_test/abstract_trend_test.py b/experiment/trend_analysis/univariate_trend_test/abstract_trend_test.py
index a87598b3..c2afeb6e 100644
--- a/experiment/trend_analysis/univariate_trend_test/abstract_trend_test.py
+++ b/experiment/trend_analysis/univariate_trend_test/abstract_trend_test.py
@@ -1,4 +1,5 @@
 import random
+from collections import OrderedDict
 
 import numpy as np
 
@@ -17,16 +18,15 @@ class AbstractTrendTest(object):
 
     SIGNIFICANCE_LEVEL = 0.05
 
-    # todo: maybe create ordered dict
-    TREND_TYPE_TO_STYLE = {
-        NO_TREND: 'k--',
-        POSITIVE_TREND: 'g--',
-        SIGNIFICATIVE_POSITIVE_TREND: 'g-',
-        SIGNIFICATIVE_NEGATIVE_TREND: 'r-',
-        NEGATIVE_TREND: 'r--',
-    }
-
-    TREND_TYPES = list(TREND_TYPE_TO_STYLE.keys())
+    @classmethod
+    def trend_type_to_style(cls):
+        d = OrderedDict()
+        d[cls.POSITIVE_TREND] = 'g--'
+        d[cls.NEGATIVE_TREND] = 'r--'
+        d[cls.SIGNIFICATIVE_POSITIVE_TREND] = 'g-'
+        d[cls.SIGNIFICATIVE_NEGATIVE_TREND] = 'r-'
+        d[cls.NO_TREND] = 'k--'
+        return d
 
     def __init__(self, years_after_change_point, maxima_after_change_point):
         self.years_after_change_point = years_after_change_point
@@ -47,7 +47,7 @@ class AbstractTrendTest(object):
             trend_type = self.POSITIVE_TREND if test_sign > 0 else self.NEGATIVE_TREND
             if self.is_significant:
                 trend_type = self.SIGNIFICATIVE + ' ' + trend_type
-        assert trend_type in self.TREND_TYPE_TO_STYLE
+        assert trend_type in self.trend_type_to_style()
         return trend_type
 
     @property
diff --git a/extreme_estimator/extreme_models/utils.py b/extreme_estimator/extreme_models/utils.py
index 6b4fe7ba..7928ca73 100644
--- a/extreme_estimator/extreme_models/utils.py
+++ b/extreme_estimator/extreme_models/utils.py
@@ -59,6 +59,9 @@ class OptimizationConstants(object):
 
     USE_MAXIT = False
 
+class SafeRunException(Exception):
+    pass
+
 
 def safe_run_r_estimator(function, data=None, use_start=False, threshold_max_abs_value=100, maxit=1000000,
                          **parameters) -> robjects.ListVector:
@@ -96,7 +99,7 @@ def safe_run_r_estimator(function, data=None, use_start=False, threshold_max_abs
                     use_start = True
                     continue
                 elif isinstance(e, RRuntimeError):
-                    raise Exception('Some R exception have been launched at RunTime: \n {}'.format(e.__repr__()))
+                    raise SafeRunException('Some R exception have been launched at RunTime: \n {}'.format(e.__repr__()))
                 if isinstance(e, RRuntimeWarning):
                     warnings.warn(e.__repr__(), WarningWhileRunningR)
     return res
-- 
GitLab