From 958c21b3d1263a573328b15f640f317e7f53c97c Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Fri, 7 Jun 2019 15:47:48 +0200
Subject: [PATCH] [COMPARISON] add possibility to change trend_test_class

---
 .../stations_data/main_station_comparison.py  | 19 +++++++++++--------
 .../comparisons_visualization.py              |  5 +++--
 2 files changed, 14 insertions(+), 10 deletions(-)

diff --git a/experiment/meteo_france_data/stations_data/main_station_comparison.py b/experiment/meteo_france_data/stations_data/main_station_comparison.py
index 0405f624..04922666 100644
--- a/experiment/meteo_france_data/stations_data/main_station_comparison.py
+++ b/experiment/meteo_france_data/stations_data/main_station_comparison.py
@@ -1,8 +1,9 @@
 from experiment.meteo_france_data.scm_models_data.visualization.study_visualization.main_study_visualizer import \
     ALL_ALTITUDES_WITH_20_STATIONS_AT_LEAST
-from experiment.meteo_france_data.stations_data.comparison_analysis import ComparisonAnalysis
 from experiment.meteo_france_data.stations_data.visualization.comparisons_visualization.comparisons_visualization import \
     ComparisonsVisualization, path_backup_csv_file
+from experiment.trend_analysis.univariate_test.abstract_gev_change_point_test import GevLocationChangePointTest, \
+    GevScaleChangePointTest, GevShapeChangePointTest
 
 
 def visualize_all_stations():
@@ -11,11 +12,13 @@ def visualize_all_stations():
 
 
 def visualize_non_nan_station():
-    vizu = ComparisonsVisualization(altitudes=ALL_ALTITUDES_WITH_20_STATIONS_AT_LEAST,
-                                    keep_only_station_without_nan_values=True,
-                                    normalize_observations=False)
-    vizu.visualize_maximum(visualize_metric_only=True)
-    # vizu.visualize_gev()
+    for trend_test_class in [GevLocationChangePointTest, GevScaleChangePointTest, GevShapeChangePointTest][1:2]:
+        vizu = ComparisonsVisualization(altitudes=ALL_ALTITUDES_WITH_20_STATIONS_AT_LEAST,
+                                        keep_only_station_without_nan_values=True,
+                                        normalize_observations=False,
+                                        trend_test_class=trend_test_class)
+        vizu.visualize_maximum(visualize_metric_only=True)
+        # vizu.visualize_gev()
 
 
 def example():
@@ -55,8 +58,8 @@ if __name__ == '__main__':
     # wrong_example3()
     # visualize_fast_comparison()
     # visualize_all_stations()
-    quick_metric_analysis()
+    # quick_metric_analysis()
     # wrong_example2()
-    # visualize_non_nan_station()
+    visualize_non_nan_station()
     # example()
 
diff --git a/experiment/meteo_france_data/stations_data/visualization/comparisons_visualization/comparisons_visualization.py b/experiment/meteo_france_data/stations_data/visualization/comparisons_visualization/comparisons_visualization.py
index 5b82cec7..d300ddf1 100644
--- a/experiment/meteo_france_data/stations_data/visualization/comparisons_visualization/comparisons_visualization.py
+++ b/experiment/meteo_france_data/stations_data/visualization/comparisons_visualization/comparisons_visualization.py
@@ -42,7 +42,8 @@ MAE_COLUMN_NAME = 'mean absolute difference'
 class ComparisonsVisualization(VisualizationParameters):
 
     def __init__(self, altitudes=None, keep_only_station_without_nan_values=False, margin=150,
-                 normalize_observations=False):
+                 normalize_observations=False, trend_test_class=GevLocationChangePointTest):
+        self.trend_test_class = trend_test_class
         self.keep_only_station_without_nan_values = keep_only_station_without_nan_values
         if self.keep_only_station_without_nan_values:
             self.nb_columns = 5
@@ -235,7 +236,7 @@ class ComparisonsVisualization(VisualizationParameters):
             trend_test_res, best_idxs = compute_gev_change_point_test_results(multiprocessing=True,
                                                                               maxima=maxima,
                                                                               starting_years=starting_years,
-                                                                              trend_test_class=GevLocationChangePointTest,
+                                                                              trend_test_class=self.trend_test_class,
                                                                               years=years)
             best_idx = best_idxs[0]
             most_likely_year = years[best_idx]
-- 
GitLab