From da18a7f8910b3b0df2e0898c3bb450fe1a822b1a Mon Sep 17 00:00:00 2001 From: Le Roux Erwan <erwan.le-roux@irstea.fr> Date: Thu, 6 Jun 2019 18:09:54 +0200 Subject: [PATCH] [COMPARISON] add some labels to the graph. add trend type agreement/classification graph. refactor column name --- .../stations_data/main_station_comparison.py | 9 +-- .../comparisons_visualization.py | 64 +++++++++++++------ .../abstract_univariate_test.py | 4 ++ 3 files changed, 55 insertions(+), 22 deletions(-) diff --git a/experiment/meteo_france_data/stations_data/main_station_comparison.py b/experiment/meteo_france_data/stations_data/main_station_comparison.py index e8dd8bb4..0405f624 100644 --- a/experiment/meteo_france_data/stations_data/main_station_comparison.py +++ b/experiment/meteo_france_data/stations_data/main_station_comparison.py @@ -2,7 +2,7 @@ from experiment.meteo_france_data.scm_models_data.visualization.study_visualizat ALL_ALTITUDES_WITH_20_STATIONS_AT_LEAST from experiment.meteo_france_data.stations_data.comparison_analysis import ComparisonAnalysis from experiment.meteo_france_data.stations_data.visualization.comparisons_visualization.comparisons_visualization import \ - ComparisonsVisualization + ComparisonsVisualization, path_backup_csv_file def visualize_all_stations(): @@ -48,13 +48,14 @@ def wrong_example3(): def quick_metric_analysis(): - ComparisonsVisualization.visualize_metric() + ComparisonsVisualization.visualize_metric(csv_filepath=path_backup_csv_file) + # ComparisonsVisualization.visualize_metric() if __name__ == '__main__': # wrong_example3() # visualize_fast_comparison() - visualize_all_stations() - # quick_metric_analysis() + # visualize_all_stations() + quick_metric_analysis() # wrong_example2() # visualize_non_nan_station() # example() diff --git a/experiment/meteo_france_data/stations_data/visualization/comparisons_visualization/comparisons_visualization.py b/experiment/meteo_france_data/stations_data/visualization/comparisons_visualization/comparisons_visualization.py index 0e5a4914..5b82cec7 100644 --- a/experiment/meteo_france_data/stations_data/visualization/comparisons_visualization/comparisons_visualization.py +++ b/experiment/meteo_france_data/stations_data/visualization/comparisons_visualization/comparisons_visualization.py @@ -1,3 +1,5 @@ +import os.path as op +from sklearn.metrics import confusion_matrix from collections import OrderedDict from itertools import chain from typing import Dict, List @@ -20,12 +22,23 @@ from extreme_estimator.extreme_models.utils import r, safe_run_r_estimator, ro from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates from utils import classproperty -MAE_COLUMN_NAME = 'mean absolute difference' -DISTANCE_COLUMN_NAME = 'distance' -path_df_location_to_value_csv_example = r'/home/erwan/Documents/projects/spatiotemporalextremes/experiment/meteo_france_data/stations_data/csv/example.csv' +path = r'/home/erwan/Documents/projects/spatiotemporalextremes/experiment/meteo_france_data/stations_data/csv' +path_last_csv_file = op.join(path, 'example.csv') +path_backup_csv_file = op.join(path, 'example_backup.csv') + + +def get_reanalysis_column_name(station_column_name): + return station_column_name + ' ' + REANALYSE_STR + +WELL_CLASSIFIED_CNAME = 'well classified' +DISTANCE_COLUMN_NAME = 'distance' +TREND_TYPE_CNAME = 'display trend type' +SAFRAN_TREND_TYPE_CNAME = get_reanalysis_column_name('display trend type') +MAE_COLUMN_NAME = 'mean absolute difference' + class ComparisonsVisualization(VisualizationParameters): def __init__(self, altitudes=None, keep_only_station_without_nan_values=False, margin=150, @@ -99,24 +112,39 @@ class ComparisonsVisualization(VisualizationParameters): return df @classmethod - def visualize_metric(cls, df=None): + def visualize_metric(cls, df=None, csv_filepath=path_last_csv_file): # Load or update df value from example file if df is None: - df = pd.read_csv(path_df_location_to_value_csv_example, index_col=[0, 1, 2]) + df = pd.read_csv(csv_filepath, index_col=[0, 1, 2]) else: - df.to_csv(path_df_location_to_value_csv_example) - - # Compute some column like a classication boolean - - # Display some score spatially - df_score = df.groupby([MASSIF_COLUMN_NAME]).mean() - s_mae = df_score[MAE_COLUMN_NAME] - massif_name_to_value = s_mae.to_dict() - AbstractStudy.visualize_study(massif_name_to_value=massif_name_to_value, + df.to_csv(csv_filepath) + + if TREND_TYPE_CNAME in df.columns: + # Display the confusion matrix + print(df[TREND_TYPE_CNAME].values) + print(AbstractUnivariateTest.three_main_trend_types()) + m = confusion_matrix(y_true=df[TREND_TYPE_CNAME].values, + y_pred=df[SAFRAN_TREND_TYPE_CNAME].values, + labels=AbstractUnivariateTest.three_main_trend_types()) + print(m) + + # Display the classification score per massif + df[WELL_CLASSIFIED_CNAME] = df[TREND_TYPE_CNAME] == df[SAFRAN_TREND_TYPE_CNAME] + serie_classificaiton = df.groupby([MASSIF_COLUMN_NAME]).mean()[WELL_CLASSIFIED_CNAME] * 100 + AbstractStudy.visualize_study(massif_name_to_value=serie_classificaiton.to_dict(), + default_color_for_missing_massif='b', + cmap=plt.cm.Greens, + vmin=0, + vmax=100, + label='agreement on trend type classification (%)' ) + # Display the mae score + serie_mae = df.groupby([MASSIF_COLUMN_NAME]).mean()[MAE_COLUMN_NAME] + AbstractStudy.visualize_study(massif_name_to_value=serie_mae.to_dict(), default_color_for_missing_massif='b', cmap=plt.cm.Reds, - vmin=s_mae.min(), - vmax=s_mae.max()) + vmin=0, + vmax=65, + label='average absolute difference between annual maxima snowfall (mm)') def _visualize_ax_main(self, plot_function, comparison: ComparisonAnalysis, massif, ax=None, show=False): if ax is None: @@ -164,7 +192,7 @@ class ComparisonsVisualization(VisualizationParameters): if isinstance(plot_ordered_value_dict, dict): if REANALYSE_STR in i: - plot_station_ordered_value_dict = OrderedDict([(k + ' ' + REANALYSE_STR, v) + plot_station_ordered_value_dict = OrderedDict([(get_reanalysis_column_name(k), v) for k, v in plot_ordered_value_dict.items()]) else: ordered_value_dict.update(plot_ordered_value_dict) @@ -214,7 +242,7 @@ class ComparisonsVisualization(VisualizationParameters): most_likely_trend_type = trend_test_res[best_idx][0] display_trend_type = AbstractUnivariateTest.get_display_trend_type(real_trend_type=most_likely_trend_type) label += "\n {} starting in {}".format(display_trend_type, most_likely_year) - ordered_dict['display trend type'] = display_trend_type + ordered_dict[TREND_TYPE_CNAME] = display_trend_type ordered_dict['most likely year'] = most_likely_year # Display the nllh against the starting year step = 1 diff --git a/experiment/trend_analysis/univariate_test/abstract_univariate_test.py b/experiment/trend_analysis/univariate_test/abstract_univariate_test.py index bfbe4733..d412e7c1 100644 --- a/experiment/trend_analysis/univariate_test/abstract_univariate_test.py +++ b/experiment/trend_analysis/univariate_test/abstract_univariate_test.py @@ -49,6 +49,10 @@ class AbstractUnivariateTest(object): return [cls.POSITIVE_TREND, cls.NEGATIVE_TREND, cls.SIGNIFICATIVE_POSITIVE_TREND, cls.SIGNIFICATIVE_NEGATIVE_TREND, cls.NO_TREND] + @classmethod + def three_main_trend_types(cls): + return [cls.SIGNIFICATIVE_NEGATIVE_TREND, cls.NON_SIGNIFICATIVE_TREND, cls.SIGNIFICATIVE_POSITIVE_TREND] + @classmethod def display_trend_type_to_style(cls): d = OrderedDict() -- GitLab