From 36ea1de255cd60c89136d102775f59411cdcfe9e Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Thu, 19 Mar 2020 19:46:40 +0100
Subject: [PATCH] [refactor] refactor param_function.py. organize the code for
 the abstact_quantile_function.

---
 .../abstract_margin_estimator.py              |  8 +--
 .../abstract_quantile_estimator.py            | 34 +++++++++----
 extreme_fit/function/abstract_function.py     |  6 ++-
 .../function/abstract_quantile_function.py    | 51 ++++++++++++++++---
 .../abstract_margin_function.py               |  2 +-
 .../independent_margin_function.py            |  2 +-
 .../function/param_function/param_function.py | 12 ++---
 .../quantile_regression_model.py              |  5 +-
 .../result_from_quantilreg.py                 |  5 +-
 extreme_fit/model/utils.py                    |  1 -
 .../test_estimator/test_quantile_estimator.py |  5 +-
 11 files changed, 90 insertions(+), 41 deletions(-)

diff --git a/extreme_fit/estimator/margin_estimator/abstract_margin_estimator.py b/extreme_fit/estimator/margin_estimator/abstract_margin_estimator.py
index 117594e3..00285f89 100644
--- a/extreme_fit/estimator/margin_estimator/abstract_margin_estimator.py
+++ b/extreme_fit/estimator/margin_estimator/abstract_margin_estimator.py
@@ -13,16 +13,16 @@ from spatio_temporal_dataset.slicer.split import Split
 
 class AbstractMarginEstimator(AbstractEstimator, ABC):
 
-    def __init__(self, dataset: AbstractDataset):
-        super().__init__(dataset)
+    def __init__(self, dataset: AbstractDataset, **kwargs):
+        super().__init__(dataset, **kwargs)
         assert self.dataset.maxima_gev() is not None
 
 
 class LinearMarginEstimator(AbstractMarginEstimator):
     """# with different type of marginals: cosntant, linear...."""
 
-    def __init__(self, dataset: AbstractDataset, margin_model: LinearMarginModel):
-        super().__init__(dataset)
+    def __init__(self, dataset: AbstractDataset, margin_model: LinearMarginModel, **kwargs):
+        super().__init__(dataset, **kwargs)
         assert isinstance(margin_model, LinearMarginModel)
         self.margin_model = margin_model
 
diff --git a/extreme_fit/estimator/quantile_estimator/abstract_quantile_estimator.py b/extreme_fit/estimator/quantile_estimator/abstract_quantile_estimator.py
index 5f0966b5..f308053b 100644
--- a/extreme_fit/estimator/quantile_estimator/abstract_quantile_estimator.py
+++ b/extreme_fit/estimator/quantile_estimator/abstract_quantile_estimator.py
@@ -1,16 +1,23 @@
+from abc import ABC
+
+import numpy as np
 from cached_property import cached_property
 
 from extreme_fit.estimator.abstract_estimator import AbstractEstimator
 from extreme_fit.estimator.margin_estimator.abstract_margin_estimator import LinearMarginEstimator
-from extreme_fit.function.abstract_quantile_function import AbstractQuantileFunction
+from extreme_fit.function.abstract_quantile_function import AbstractQuantileFunction, \
+    QuantileFunctionFromMarginFunction, QuantileFunctionFromParamFunction
 from extreme_fit.function.margin_function.abstract_margin_function import AbstractMarginFunction
+from extreme_fit.function.param_function.linear_coef import LinearCoef
+from extreme_fit.function.param_function.param_function import LinearParamFunction
 from extreme_fit.model.margin_model.linear_margin_model.linear_margin_model import LinearMarginModel
 from extreme_fit.model.quantile_model.quantile_regression_model import AbstractQuantileRegressionModel
 from extreme_fit.model.result_from_model_fit.abstract_result_from_model_fit import AbstractResultFromModelFit
+from extreme_fit.model.result_from_model_fit.result_from_quantilreg import ResultFromQuantreg
 from spatio_temporal_dataset.dataset.abstract_dataset import AbstractDataset
 
 
-class AbstractQuantileEstimator(AbstractEstimator):
+class AbstractQuantileEstimator(AbstractEstimator, ABC):
 
     def __init__(self, dataset: AbstractDataset, quantile: float, **kwargs):
         super().__init__(dataset, **kwargs)
@@ -18,30 +25,37 @@ class AbstractQuantileEstimator(AbstractEstimator):
         self.quantile = quantile
 
     @cached_property
-    def quantile_function_from_fit(self) -> AbstractQuantileFunction:
-        pass
+    def function_from_fit(self) -> AbstractQuantileFunction:
+        raise NotImplementedError
 
 
-class QuantileEstimatorFromMargin(AbstractQuantileEstimator, LinearMarginEstimator):
+class QuantileEstimatorFromMargin(LinearMarginEstimator, AbstractQuantileEstimator):
 
     def __init__(self, dataset: AbstractDataset, margin_model: LinearMarginModel, quantile):
         super().__init__(dataset=dataset, quantile=quantile, margin_model=margin_model)
 
     @cached_property
-    def quantile_function_from_fit(self) -> AbstractQuantileFunction:
+    def function_from_fit(self) -> AbstractQuantileFunction:
         linear_margin_function = super().function_from_fit  # type: AbstractMarginFunction
-        return AbstractQuantileFunction(linear_margin_function, self.quantile)
+        return QuantileFunctionFromMarginFunction(self.dataset.coordinates, linear_margin_function, self.quantile)
 
 
 class QuantileRegressionEstimator(AbstractQuantileEstimator):
 
     def __init__(self, dataset: AbstractDataset, quantile: float, quantile_regression_model_class: type, **kwargs):
         super().__init__(dataset, quantile, **kwargs)
-        self.quantile_regression_model = quantile_regression_model_class(dataset, quantile) # type: AbstractQuantileRegressionModel
+        self.quantile_regression_model = quantile_regression_model_class(dataset, quantile)  # type: AbstractQuantileRegressionModel
 
     def _fit(self) -> AbstractResultFromModelFit:
         return self.quantile_regression_model.fit()
 
     @cached_property
-    def quantile_function_from_fit(self) -> AbstractQuantileFunction:
-        return self.result_from_model_fit.quantile_function
+    def function_from_fit(self) -> AbstractQuantileFunction:
+        result_from_model_fit = self.result_from_model_fit  # type: ResultFromQuantreg
+        coefs = result_from_model_fit.coefficients
+        dims = list(np.arange(len(coefs)) - 1)
+        linear_coef = LinearCoef('quantile', idx_to_coef=dict(zip(dims, coefs)))
+        param_function = LinearParamFunction(dims=dims, coordinates=self.dataset.coordinates.coordinates_values(),
+                                             linear_coef=linear_coef)
+        return QuantileFunctionFromParamFunction(coordinates=self.dataset.coordinates,
+                                                 param_function=param_function)
diff --git a/extreme_fit/function/abstract_function.py b/extreme_fit/function/abstract_function.py
index 0cf3bcfb..e563c3d0 100644
--- a/extreme_fit/function/abstract_function.py
+++ b/extreme_fit/function/abstract_function.py
@@ -1,4 +1,8 @@
+from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
 
 
 class AbstractFunction(object):
-    pass
\ No newline at end of file
+
+    def __init__(self, coordinates: AbstractCoordinates):
+        self.coordinates = coordinates
+
diff --git a/extreme_fit/function/abstract_quantile_function.py b/extreme_fit/function/abstract_quantile_function.py
index 62596039..d86b8555 100644
--- a/extreme_fit/function/abstract_quantile_function.py
+++ b/extreme_fit/function/abstract_quantile_function.py
@@ -2,19 +2,58 @@ import numpy as np
 
 from extreme_fit.function.abstract_function import AbstractFunction
 from extreme_fit.function.margin_function.abstract_margin_function import AbstractMarginFunction
+import matplotlib.pyplot as plt
+
+from extreme_fit.function.param_function.param_function import AbstractParamFunction
+from spatio_temporal_dataset.coordinates.abstract_coordinates import AbstractCoordinates
 
 
 class AbstractQuantileFunction(AbstractFunction):
 
-    def __init__(self, margin_function: AbstractMarginFunction, quantile: float):
+    def get_quantile(self, coordinate: np.ndarray) -> float:
+        raise NotImplementedError
+
+    def visualize(self, show=True):
+        if self.coordinates.nb_coordinates == 1:
+            self.visualize_1D(show=show)
+        elif self.coordinates.nb_coordinates == 2:
+            self.visualize_2D()
+        else:
+            return
+            # raise NotImplementedError
+
+    def visualize_1D(self, ax=None, show=True):
+        if ax is None:
+            ax = plt.gca()
+        x = self.coordinates.coordinates_values()
+        resolution = 100
+        x = np.linspace(x.min(), x.max(), resolution)
+        y = [self.get_quantile(np.array([e])) for e in x]
+        ax.plot(x, y)
+        if show:
+            plt.show()
+
+    def visualize_2D(self):
+        return
+
+
+class QuantileFunctionFromParamFunction(AbstractQuantileFunction):
+
+    def __init__(self, coordinates: AbstractCoordinates, param_function: AbstractParamFunction):
+        super().__init__(coordinates)
+        self.param_function = param_function
+
+    def get_quantile(self, coordinate: np.ndarray) -> float:
+        return self.param_function.get_param_value(coordinate)
+
+
+class QuantileFunctionFromMarginFunction(AbstractQuantileFunction):
+
+    def __init__(self, coordinates: AbstractCoordinates, margin_function: AbstractMarginFunction, quantile: float):
+        super().__init__(coordinates)
         self.margin_function = margin_function
         self.quantile = quantile
 
     def get_quantile(self, coordinate: np.ndarray) -> float:
         gev_params = self.margin_function.get_gev_params(coordinate)
         return gev_params.quantile(self.quantile)
-
-    def visualize(self):
-        pass
-        # for coordine
-        # self.margin_function.
\ No newline at end of file
diff --git a/extreme_fit/function/margin_function/abstract_margin_function.py b/extreme_fit/function/margin_function/abstract_margin_function.py
index 98f04528..df30892b 100644
--- a/extreme_fit/function/margin_function/abstract_margin_function.py
+++ b/extreme_fit/function/margin_function/abstract_margin_function.py
@@ -21,7 +21,7 @@ class AbstractMarginFunction(AbstractFunction):
     VISUALIZATION_TEMPORAL_STEPS = 2
 
     def __init__(self, coordinates: AbstractCoordinates):
-        self.coordinates = coordinates
+        super().__init__(coordinates)
         self.mask_2D = None
 
         # Visualization parameters
diff --git a/extreme_fit/function/margin_function/independent_margin_function.py b/extreme_fit/function/margin_function/independent_margin_function.py
index 74fe3131..ab8f7917 100644
--- a/extreme_fit/function/margin_function/independent_margin_function.py
+++ b/extreme_fit/function/margin_function/independent_margin_function.py
@@ -29,7 +29,7 @@ class IndependentMarginFunction(AbstractMarginFunction):
         gev_params = {}
         for gev_param_name in GevParams.PARAM_NAMES:
             param_function = self.gev_param_name_to_param_function[gev_param_name]
-            gev_params[gev_param_name] = param_function.get_gev_param_value(transformed_coordinate)
+            gev_params[gev_param_name] = param_function.get_param_value(transformed_coordinate)
         return GevParams.from_dict(gev_params)
 
     def transform(self, coordinate: np.ndarray) -> np.ndarray:
diff --git a/extreme_fit/function/param_function/param_function.py b/extreme_fit/function/param_function/param_function.py
index 38e4f6b5..ce18997f 100644
--- a/extreme_fit/function/param_function/param_function.py
+++ b/extreme_fit/function/param_function/param_function.py
@@ -7,7 +7,7 @@ from extreme_fit.function.param_function.spline_coef import SplineCoef
 class AbstractParamFunction(object):
     OUT_OF_BOUNDS_ASSERT = True
 
-    def get_gev_param_value(self, coordinate: np.ndarray) -> float:
+    def get_param_value(self, coordinate: np.ndarray) -> float:
         pass
 
 
@@ -16,7 +16,7 @@ class ConstantParamFunction(AbstractParamFunction):
     def __init__(self, constant):
         self.constant = constant
 
-    def get_gev_param_value(self, coordinate: np.ndarray) -> float:
+    def get_param_value(self, coordinate: np.ndarray) -> float:
         return self.constant
 
 
@@ -28,7 +28,7 @@ class LinearOneAxisParamFunction(AbstractParamFunction):
         self.t_max = coordinates[:, dim].max()
         self.coef = coef
 
-    def get_gev_param_value(self, coordinate: np.ndarray) -> float:
+    def get_param_value(self, coordinate: np.ndarray) -> float:
         t = coordinate[self.dim]
         if self.OUT_OF_BOUNDS_ASSERT:
             assert self.t_min <= t <= self.t_max, '{} is out of bounds ({}, {})'.format(t, self.t_min, self.t_max)
@@ -46,11 +46,11 @@ class LinearParamFunction(AbstractParamFunction):
                                                         coef=self.linear_coef.get_coef(idx=dim))
             self.linear_one_axis_param_functions.append(param_function)
 
-    def get_gev_param_value(self, coordinate: np.ndarray) -> float:
+    def get_param_value(self, coordinate: np.ndarray) -> float:
         # Add the intercept and the value with respect to each axis
         gev_param_value = self.linear_coef.intercept
         for linear_one_axis_param_function in self.linear_one_axis_param_functions:
-            gev_param_value += linear_one_axis_param_function.get_gev_param_value(coordinate)
+            gev_param_value += linear_one_axis_param_function.get_param_value(coordinate)
         return gev_param_value
 
 
@@ -66,7 +66,7 @@ class SplineParamFunction(AbstractParamFunction):
     def m(self) -> int:
         return int((self.degree + 1) / 2)
 
-    def get_gev_param_value(self, coordinate: np.ndarray) -> float:
+    def get_param_value(self, coordinate: np.ndarray) -> float:
         gev_param_value = self.spline_coef.intercept
         # Polynomial part
         for dim in self.dims:
diff --git a/extreme_fit/model/quantile_model/quantile_regression_model.py b/extreme_fit/model/quantile_model/quantile_regression_model.py
index 2652d5bb..66c99e29 100644
--- a/extreme_fit/model/quantile_model/quantile_regression_model.py
+++ b/extreme_fit/model/quantile_model/quantile_regression_model.py
@@ -19,12 +19,9 @@ class AbstractQuantileRegressionModel(AbstractModel):
 
     @property
     def first_column_of_observation(self):
-        return self.data.colnames[1]
-        # print(self.dataset.df_dataset.columns)
-        # return str(self.dataset.df_dataset.columns[0])
+        return self.data.colnames[0]
 
     def fit(self):
-        print(self.data)
         parameters = {
             'tau': self.quantile,
             'data': self.data,
diff --git a/extreme_fit/model/result_from_model_fit/result_from_quantilreg.py b/extreme_fit/model/result_from_model_fit/result_from_quantilreg.py
index 2be5c564..d04af928 100644
--- a/extreme_fit/model/result_from_model_fit/result_from_quantilreg.py
+++ b/extreme_fit/model/result_from_model_fit/result_from_quantilreg.py
@@ -1,5 +1,6 @@
 from cached_property import cached_property
 
+from extreme_fit.function.param_function.param_function import LinearParamFunction
 from extreme_fit.model.result_from_model_fit.abstract_result_from_model_fit import AbstractResultFromModelFit
 
 
@@ -8,7 +9,3 @@ class ResultFromQuantreg(AbstractResultFromModelFit):
     @property
     def coefficients(self):
         return self.name_to_value['coefficients']
-
-    @cached_property
-    def quantile_function(self):
-        print(self.coefficients)
\ No newline at end of file
diff --git a/extreme_fit/model/utils.py b/extreme_fit/model/utils.py
index 582544e7..227f328f 100644
--- a/extreme_fit/model/utils.py
+++ b/extreme_fit/model/utils.py
@@ -90,7 +90,6 @@ def safe_run_r_estimator(function, data=None, use_start=False, max_ratio_between
         if isinstance(data, np.ndarray):
             # Raise warning if the gap is too important between the two biggest values of data
             sorted_data = sorted(data.flatten())
-            print(data)
             if sorted_data[-2] * max_ratio_between_two_extremes_values < sorted_data[-1]:
                 msg = "maxmimum absolute value in data {} is too high, i.e. above the defined threshold {}" \
                     .format(sorted_data[-1], max_ratio_between_two_extremes_values)
diff --git a/test/test_extreme_fit/test_estimator/test_quantile_estimator.py b/test/test_extreme_fit/test_estimator/test_quantile_estimator.py
index 0a304b27..eea9482c 100644
--- a/test/test_extreme_fit/test_estimator/test_quantile_estimator.py
+++ b/test/test_extreme_fit/test_estimator/test_quantile_estimator.py
@@ -13,7 +13,7 @@ class TestQuantileEstimator(unittest.TestCase):
     def test_smooth_margin_estimator_spatial(self):
         self.nb_points = 20
         self.nb_obs = 1
-        self.coordinates = load_test_1D_and_2D_spatial_coordinates(nb_points=self.nb_points)
+        self.coordinates = load_test_1D_and_2D_spatial_coordinates(nb_points=self.nb_points)[:1]
 
     def test_smooth_margin_estimator_spatio_temporal(self):
         self.nb_points = 2
@@ -28,7 +28,6 @@ class TestQuantileEstimator(unittest.TestCase):
             dataset = MarginDataset.from_sampling(nb_obs=self.nb_obs,
                                                   margin_model=constant_margin_model,
                                                   coordinates=coordinates)
-            print(dataset)
             # Load quantile estimators
             quantile_estimators = [
                 QuantileEstimatorFromMargin(dataset, constant_margin_model, quantile),
@@ -40,7 +39,7 @@ class TestQuantileEstimator(unittest.TestCase):
             # Fit quantile estimators
             for quantile_estimator in quantile_estimators:
                 quantile_estimator.fit()
-                print(quantile_estimator.quantile_function_from_fit)
+                quantile_estimator.function_from_fit.visualize(show=self.DISPLAY)
 
         self.assertTrue(True)
 
-- 
GitLab