From b79c5b30c0e54ae2120aaeb8ad451d1124011dd4 Mon Sep 17 00:00:00 2001
From: Pierre-Antoine Rouby <pierre-antoine.rouby@inrae.fr>
Date: Fri, 27 Oct 2023 11:50:36 +0200
Subject: [PATCH] tests: Add tests for flatten function.

---
 src/test_pamhyr.py | 64 +++++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 63 insertions(+), 1 deletion(-)

diff --git a/src/test_pamhyr.py b/src/test_pamhyr.py
index 32ec8cd6..7066a653 100644
--- a/src/test_pamhyr.py
+++ b/src/test_pamhyr.py
@@ -20,7 +20,69 @@ import os
 import unittest
 import tempfile
 
-from tools import parse_command_line
+from tools import flatten, parse_command_line
+
+
+class FlattenTestCase(unittest.TestCase):
+    def test_flatten_0(self):
+        input = []
+        output = []
+
+        res = flatten(input)
+
+        self.assertEqual(len(res), len(output))
+        for i, o in enumerate(output):
+            self.assertEqual(res[i], o)
+
+    def test_flatten_1(self):
+        input = [['foo']]
+        output = ['foo']
+
+        res = flatten(input)
+
+        self.assertEqual(len(res), len(output))
+        for i, o in enumerate(output):
+            self.assertEqual(res[i], o)
+
+    def test_flatten_2(self):
+        input = [['foo', 'bar']]
+        output = ['foo', 'bar']
+
+        res = flatten(input)
+
+        self.assertEqual(len(res), len(output))
+        for i, o in enumerate(output):
+            self.assertEqual(res[i], o)
+
+    def test_flatten_3(self):
+        input = [['foo'], ['bar']]
+        output = ['foo', 'bar']
+
+        res = flatten(input)
+
+        self.assertEqual(len(res), len(output))
+        for i, o in enumerate(output):
+            self.assertEqual(res[i], o)
+
+    def test_flatten_4(self):
+        input = [['foo'], ['bar', 'baz'], ['bazz']]
+        output = ['foo', 'bar', 'baz', 'bazz']
+
+        res = flatten(input)
+
+        self.assertEqual(len(res), len(output))
+        for i, o in enumerate(output):
+            self.assertEqual(res[i], o)
+
+    def test_flatten_5(self):
+        input = [['foo'], ['bar', ['baz']], ['bazz']]
+        output = ['foo', 'bar', ['baz'], 'bazz']
+
+        res = flatten(input)
+
+        self.assertEqual(len(res), len(output))
+        for i, o in enumerate(output):
+            self.assertEqual(res[i], o)
 
 
 class ToolsCMDParserTestCase(unittest.TestCase):
-- 
GitLab