From a1a41ac16606136f297a7964e03137ad8e790829 Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Thu, 9 May 2019 18:00:42 +0200
Subject: [PATCH] [TEST][TEMPORAL MARGIN] add stationary test. add special
 method to print dataset

---
 extreme_estimator/extreme_models/utils.py     |  8 ++++-
 .../coordinates/abstract_coordinates.py       |  3 ++
 .../dataset/abstract_dataset.py               |  5 +++
 .../abstract_spatio_temporal_observations.py  |  5 +++
 .../test_margin_temporal.py                   | 33 ++++++++++---------
 5 files changed, 38 insertions(+), 16 deletions(-)

diff --git a/extreme_estimator/extreme_models/utils.py b/extreme_estimator/extreme_models/utils.py
index 581723ad..8ca53f60 100644
--- a/extreme_estimator/extreme_models/utils.py
+++ b/extreme_estimator/extreme_models/utils.py
@@ -27,6 +27,11 @@ r.library('ismev')
 # the best solution for debugging is to copy/paste the code module into a file that belongs to me, and then
 # I can put print & stop in the code, and I can understand where are the problems
 
+def set_seed_for_test(seed=42):
+    set_seed_r(seed=seed)
+    random.seed(seed)
+
+
 def set_seed_r(seed=42):
     r("set.seed({})".format(seed))
 
@@ -42,7 +47,8 @@ class WarningMaximumAbsoluteValueTooHigh(Warning):
     pass
 
 
-def safe_run_r_estimator(function, data=None, use_start=False, threshold_max_abs_value=100, **parameters) -> robjects.ListVector:
+def safe_run_r_estimator(function, data=None, use_start=False, threshold_max_abs_value=100,
+                         **parameters) -> robjects.ListVector:
     # Some checks for Spatial Extremes
     if data is not None:
         # Raise warning if the maximum absolute value is above a threshold
diff --git a/spatio_temporal_dataset/coordinates/abstract_coordinates.py b/spatio_temporal_dataset/coordinates/abstract_coordinates.py
index cff19947..5ced1941 100644
--- a/spatio_temporal_dataset/coordinates/abstract_coordinates.py
+++ b/spatio_temporal_dataset/coordinates/abstract_coordinates.py
@@ -274,3 +274,6 @@ class AbstractCoordinates(object):
 
     def __eq__(self, other):
         return self.df_merged.equals(other.df_merged)
+
+    def __str__(self):
+        return self.df_all_coordinates.__str__()
diff --git a/spatio_temporal_dataset/dataset/abstract_dataset.py b/spatio_temporal_dataset/dataset/abstract_dataset.py
index d18ed288..7370746e 100644
--- a/spatio_temporal_dataset/dataset/abstract_dataset.py
+++ b/spatio_temporal_dataset/dataset/abstract_dataset.py
@@ -102,6 +102,11 @@ class AbstractDataset(object):
             column_idxs = [idx for idx in range(self.observations.nb_obs) if idx % nb_subsets == subset_id]
             self.subset_id_to_column_idxs[subset_id] = column_idxs
 
+    # Special methods
+
+    def __str__(self) -> str:
+        return 'coordinates: {}\nobservations: {}'.format(self.coordinates.__str__(), self.observations.__str__())
+
 
 def get_subset_dataset(dataset: AbstractDataset, subset_id) -> AbstractDataset:
     columns_idxs = dataset.subset_id_to_column_idxs[subset_id]
diff --git a/spatio_temporal_dataset/spatio_temporal_observations/abstract_spatio_temporal_observations.py b/spatio_temporal_dataset/spatio_temporal_observations/abstract_spatio_temporal_observations.py
index 483db946..9ac704c8 100644
--- a/spatio_temporal_dataset/spatio_temporal_observations/abstract_spatio_temporal_observations.py
+++ b/spatio_temporal_dataset/spatio_temporal_observations/abstract_spatio_temporal_observations.py
@@ -95,3 +95,8 @@ class AbstractSpatioTemporalObservations(object):
                          slicer: AbstractSlicer = None):
         df = df_sliced(self.df_maxima_frech, split, slicer)
         df.loc[:] = maxima_frech_values
+
+    def __str__(self) -> str:
+        return self._df_maxima.__str__()
+
+
diff --git a/test/test_extreme_estimator/test_extreme_models/test_margin_temporal.py b/test/test_extreme_estimator/test_extreme_models/test_margin_temporal.py
index 606d52b9..23677042 100644
--- a/test/test_extreme_estimator/test_extreme_models/test_margin_temporal.py
+++ b/test/test_extreme_estimator/test_extreme_models/test_margin_temporal.py
@@ -1,13 +1,15 @@
+import random
 import unittest
 
 import numpy as np
 import pandas as pd
 
 from extreme_estimator.estimator.margin_estimator.abstract_margin_estimator import LinearMarginEstimator
-from extreme_estimator.extreme_models.margin_model.linear_margin_model import LinearNonStationaryLocationMarginModel
+from extreme_estimator.extreme_models.margin_model.linear_margin_model import LinearNonStationaryLocationMarginModel, \
+    LinearStationaryMarginModel
 from extreme_estimator.extreme_models.margin_model.temporal_linear_margin_model import StationaryStationModel, \
     NonStationaryStationModel
-from extreme_estimator.extreme_models.utils import r, set_seed_r
+from extreme_estimator.extreme_models.utils import r, set_seed_r, set_seed_for_test
 from extreme_estimator.margin_fits.gev.gevmle_fit import GevMleFit
 from extreme_estimator.margin_fits.gev.ismev_gev_fit import IsmevGevFit
 from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
@@ -23,6 +25,7 @@ from test.test_utils import load_test_spatiotemporal_coordinates, load_smooth_ma
 class TestMarginTemporal(unittest.TestCase):
 
     def setUp(self) -> None:
+        set_seed_for_test(seed=42)
         self.nb_points = 2
         self.nb_steps = 5
         self.nb_obs = 1
@@ -31,23 +34,23 @@ class TestMarginTemporal(unittest.TestCase):
         self.start_year = 2
         smooth_margin_models = LinearNonStationaryLocationMarginModel(coordinates=self.coordinates,
                                                                       starting_point=self.start_year)
+
         self.dataset = MarginDataset.from_sampling(nb_obs=self.nb_obs,
                                                    margin_model=smooth_margin_models,
                                                    coordinates=self.coordinates)
 
-    def test_loading_dataset(self):
-        self.assertTrue(True)
-
-    # def test_gev_temporal_margin_fit_stationary(self):
-    #     # Create estimator
-    #     margin_model = StationaryStationModel(self.coordinates)
-    #     estimator = LinearMarginEstimator(self.dataset, margin_model)
-    #     estimator.fit()
-    #     ref = {'loc': 0.0219, 'scale': 1.0347, 'shape': 0.8295}
-    #     for year in range(1, 3):
-    #         mle_params_estimated = estimator.margin_function_fitted.get_gev_params(np.array([year])).to_dict()
-    #         for key in ref.keys():
-    #             self.assertAlmostEqual(ref[key], mle_params_estimated[key], places=3)
+    def test_margin_fit_stationary(self):
+        # Create estimator
+        margin_model = LinearStationaryMarginModel(self.coordinates)
+        estimator = LinearMarginEstimator(self.dataset, margin_model)
+        estimator.fit()
+        ref = {'loc': 2.2985600257321295, 'scale': 8.937484202730161, 'shape': 5.744352285758161}
+        for year in range(1, 3):
+            coordinate = np.array([0.0, 0.0, year])
+            mle_params_estimated = estimator.margin_function_fitted.get_gev_params(coordinate).to_dict()
+            print(mle_params_estimated)
+            for key in ref.keys():
+                self.assertAlmostEqual(ref[key], mle_params_estimated[key], places=3)
     #
     # def test_gev_temporal_margin_fit_nonstationary(self):
     #     # Create estimator
-- 
GitLab