Commit 2e28c325 authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[refactor] add test_param_function.py

parent c0d698e3
No related merge requests found
Showing with 33 additions and 1 deletion
+33 -1
......@@ -37,7 +37,7 @@ class LinearOneAxisParamFunction(AbstractParamFunction):
class LinearParamFunction(AbstractParamFunction):
def __init__(self, dims: List[int], coordinates: np.ndarray, linear_coef: LinearCoef = None):
def __init__(self, dims: List[int], coordinates: np.ndarray, linear_coef: LinearCoef):
self.linear_coef = linear_coef
# Load each one axis linear function
self.linear_one_axis_param_functions = [] # type: List[LinearOneAxisParamFunction]
......
import unittest
import numpy as np
from extreme_fit.function.param_function.linear_coef import LinearCoef
from extreme_fit.function.param_function.param_function import LinearParamFunction
class ParamFunction(unittest.TestCase):
def test_out_of_bounds(self):
param_function = LinearParamFunction(dims=[0], coordinates=np.array([[0]]), linear_coef=LinearCoef())
with self.assertRaises(AssertionError):
param_function.get_param_value(np.array([1.0]))
def test_linear_param_function(self):
linear_coef = LinearCoef(idx_to_coef={0: 1})
param_function = LinearParamFunction(dims=[0], coordinates=np.array([[-1, 0, 1]]).transpose(),
linear_coef=linear_coef)
self.assertEqual(0.0, param_function.get_param_value(np.array([0.0])))
self.assertEqual(1.0, param_function.get_param_value(np.array([1.0])))
def test_affine_param_function(self):
linear_coef = LinearCoef(idx_to_coef={-1: 1, 0: 1})
param_function = LinearParamFunction(dims=[0], coordinates=np.array([[-1, 0, 1]]).transpose(),
linear_coef=linear_coef)
self.assertEqual(1.0, param_function.get_param_value(np.array([0.0])))
self.assertEqual(2.0, param_function.get_param_value(np.array([1.0])))
if __name__ == '__main__':
unittest.main()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment