From 2e28c32544c7408f7ff7fb766e307fbccf1eaab8 Mon Sep 17 00:00:00 2001 From: Le Roux Erwan <erwan.le-roux@irstea.fr> Date: Thu, 19 Mar 2020 20:59:22 +0100 Subject: [PATCH] [refactor] add test_param_function.py --- .../function/param_function/param_function.py | 2 +- .../test_function/test_param_function.py | 32 +++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) create mode 100644 test/test_extreme_fit/test_function/test_param_function.py diff --git a/extreme_fit/function/param_function/param_function.py b/extreme_fit/function/param_function/param_function.py index ce18997f..470b1c69 100644 --- a/extreme_fit/function/param_function/param_function.py +++ b/extreme_fit/function/param_function/param_function.py @@ -37,7 +37,7 @@ class LinearOneAxisParamFunction(AbstractParamFunction): class LinearParamFunction(AbstractParamFunction): - def __init__(self, dims: List[int], coordinates: np.ndarray, linear_coef: LinearCoef = None): + def __init__(self, dims: List[int], coordinates: np.ndarray, linear_coef: LinearCoef): self.linear_coef = linear_coef # Load each one axis linear function self.linear_one_axis_param_functions = [] # type: List[LinearOneAxisParamFunction] diff --git a/test/test_extreme_fit/test_function/test_param_function.py b/test/test_extreme_fit/test_function/test_param_function.py new file mode 100644 index 00000000..fba861ce --- /dev/null +++ b/test/test_extreme_fit/test_function/test_param_function.py @@ -0,0 +1,32 @@ +import unittest + +import numpy as np + +from extreme_fit.function.param_function.linear_coef import LinearCoef +from extreme_fit.function.param_function.param_function import LinearParamFunction + + +class ParamFunction(unittest.TestCase): + + def test_out_of_bounds(self): + param_function = LinearParamFunction(dims=[0], coordinates=np.array([[0]]), linear_coef=LinearCoef()) + with self.assertRaises(AssertionError): + param_function.get_param_value(np.array([1.0])) + + def test_linear_param_function(self): + linear_coef = LinearCoef(idx_to_coef={0: 1}) + param_function = LinearParamFunction(dims=[0], coordinates=np.array([[-1, 0, 1]]).transpose(), + linear_coef=linear_coef) + self.assertEqual(0.0, param_function.get_param_value(np.array([0.0]))) + self.assertEqual(1.0, param_function.get_param_value(np.array([1.0]))) + + def test_affine_param_function(self): + linear_coef = LinearCoef(idx_to_coef={-1: 1, 0: 1}) + param_function = LinearParamFunction(dims=[0], coordinates=np.array([[-1, 0, 1]]).transpose(), + linear_coef=linear_coef) + self.assertEqual(1.0, param_function.get_param_value(np.array([0.0]))) + self.assertEqual(2.0, param_function.get_param_value(np.array([1.0]))) + + +if __name__ == '__main__': + unittest.main() -- GitLab