diff --git a/extreme_fit/function/param_function/param_function.py b/extreme_fit/function/param_function/param_function.py index ce18997f923bca39b35230fd27f7c586eb867ce1..470b1c693644ecab0997d296571cf683b78f4abb 100644 --- a/extreme_fit/function/param_function/param_function.py +++ b/extreme_fit/function/param_function/param_function.py @@ -37,7 +37,7 @@ class LinearOneAxisParamFunction(AbstractParamFunction): class LinearParamFunction(AbstractParamFunction): - def __init__(self, dims: List[int], coordinates: np.ndarray, linear_coef: LinearCoef = None): + def __init__(self, dims: List[int], coordinates: np.ndarray, linear_coef: LinearCoef): self.linear_coef = linear_coef # Load each one axis linear function self.linear_one_axis_param_functions = [] # type: List[LinearOneAxisParamFunction] diff --git a/test/test_extreme_fit/test_function/test_param_function.py b/test/test_extreme_fit/test_function/test_param_function.py new file mode 100644 index 0000000000000000000000000000000000000000..fba861ce6490d0bd8e91f6d297bab3f271fc2f7a --- /dev/null +++ b/test/test_extreme_fit/test_function/test_param_function.py @@ -0,0 +1,32 @@ +import unittest + +import numpy as np + +from extreme_fit.function.param_function.linear_coef import LinearCoef +from extreme_fit.function.param_function.param_function import LinearParamFunction + + +class ParamFunction(unittest.TestCase): + + def test_out_of_bounds(self): + param_function = LinearParamFunction(dims=[0], coordinates=np.array([[0]]), linear_coef=LinearCoef()) + with self.assertRaises(AssertionError): + param_function.get_param_value(np.array([1.0])) + + def test_linear_param_function(self): + linear_coef = LinearCoef(idx_to_coef={0: 1}) + param_function = LinearParamFunction(dims=[0], coordinates=np.array([[-1, 0, 1]]).transpose(), + linear_coef=linear_coef) + self.assertEqual(0.0, param_function.get_param_value(np.array([0.0]))) + self.assertEqual(1.0, param_function.get_param_value(np.array([1.0]))) + + def test_affine_param_function(self): + linear_coef = LinearCoef(idx_to_coef={-1: 1, 0: 1}) + param_function = LinearParamFunction(dims=[0], coordinates=np.array([[-1, 0, 1]]).transpose(), + linear_coef=linear_coef) + self.assertEqual(1.0, param_function.get_param_value(np.array([0.0]))) + self.assertEqual(2.0, param_function.get_param_value(np.array([1.0]))) + + +if __name__ == '__main__': + unittest.main()