From 2e28c32544c7408f7ff7fb766e307fbccf1eaab8 Mon Sep 17 00:00:00 2001
From: Le Roux Erwan <erwan.le-roux@irstea.fr>
Date: Thu, 19 Mar 2020 20:59:22 +0100
Subject: [PATCH] [refactor] add test_param_function.py

---
 .../function/param_function/param_function.py |  2 +-
 .../test_function/test_param_function.py      | 32 +++++++++++++++++++
 2 files changed, 33 insertions(+), 1 deletion(-)
 create mode 100644 test/test_extreme_fit/test_function/test_param_function.py

diff --git a/extreme_fit/function/param_function/param_function.py b/extreme_fit/function/param_function/param_function.py
index ce18997f..470b1c69 100644
--- a/extreme_fit/function/param_function/param_function.py
+++ b/extreme_fit/function/param_function/param_function.py
@@ -37,7 +37,7 @@ class LinearOneAxisParamFunction(AbstractParamFunction):
 
 class LinearParamFunction(AbstractParamFunction):
 
-    def __init__(self, dims: List[int], coordinates: np.ndarray, linear_coef: LinearCoef = None):
+    def __init__(self, dims: List[int], coordinates: np.ndarray, linear_coef: LinearCoef):
         self.linear_coef = linear_coef
         # Load each one axis linear function
         self.linear_one_axis_param_functions = []  # type: List[LinearOneAxisParamFunction]
diff --git a/test/test_extreme_fit/test_function/test_param_function.py b/test/test_extreme_fit/test_function/test_param_function.py
new file mode 100644
index 00000000..fba861ce
--- /dev/null
+++ b/test/test_extreme_fit/test_function/test_param_function.py
@@ -0,0 +1,32 @@
+import unittest
+
+import numpy as np
+
+from extreme_fit.function.param_function.linear_coef import LinearCoef
+from extreme_fit.function.param_function.param_function import LinearParamFunction
+
+
+class ParamFunction(unittest.TestCase):
+
+    def test_out_of_bounds(self):
+        param_function = LinearParamFunction(dims=[0], coordinates=np.array([[0]]), linear_coef=LinearCoef())
+        with self.assertRaises(AssertionError):
+            param_function.get_param_value(np.array([1.0]))
+
+    def test_linear_param_function(self):
+        linear_coef = LinearCoef(idx_to_coef={0: 1})
+        param_function = LinearParamFunction(dims=[0], coordinates=np.array([[-1, 0, 1]]).transpose(),
+                                             linear_coef=linear_coef)
+        self.assertEqual(0.0, param_function.get_param_value(np.array([0.0])))
+        self.assertEqual(1.0, param_function.get_param_value(np.array([1.0])))
+
+    def test_affine_param_function(self):
+        linear_coef = LinearCoef(idx_to_coef={-1: 1, 0: 1})
+        param_function = LinearParamFunction(dims=[0], coordinates=np.array([[-1, 0, 1]]).transpose(),
+                                             linear_coef=linear_coef)
+        self.assertEqual(1.0, param_function.get_param_value(np.array([0.0])))
+        self.assertEqual(2.0, param_function.get_param_value(np.array([1.0])))
+
+
+if __name__ == '__main__':
+    unittest.main()
-- 
GitLab