File size: 3,810 Bytes
d68c650
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import unittest
import numpy as np
from stnn.data.preprocessing import train_test_split


class TestTrainTestSplit(unittest.TestCase):

	def setUp(self):
		self.X_array = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
		self.Y_array = np.array([1, 2, 3, 4])
		self.X_list = [self.X_array, self.X_array]
		self.Y_list = [self.Y_array, self.Y_array]
		self.X_list_bad = [self.Y_array, self.X_array]
		self.Y_list_bad = [self.Y_array, self.X_array]

	def test_basic_functionality_array(self):
		X_train, X_test, Y_train, Y_test = train_test_split(self.X_array, self.Y_array, test_size = 0.25)
		self.assertEqual(len(X_train), 3)
		self.assertEqual(len(X_test), 1)
		self.assertEqual(len(Y_train), 3)
		self.assertEqual(len(Y_test), 1)

	def test_basic_functionality_list(self):
		X_train, X_test, Y_train, Y_test = train_test_split(self.X_list, self.Y_list, test_size = 0.25)
		self.assertEqual(len(X_train[0]), 3)
		self.assertEqual(len(X_test[0]), 1)
		self.assertEqual(len(Y_train[0]), 3)
		self.assertEqual(len(Y_test[0]), 1)

	def test_return_type_consistency_array(self):
		X_train, X_test, Y_train, Y_test = train_test_split(self.X_array, self.Y_array, test_size = 0.25)
		self.assertIsInstance(X_train, np.ndarray)
		self.assertIsInstance(X_test, np.ndarray)
		self.assertIsInstance(Y_train, np.ndarray)
		self.assertIsInstance(Y_test, np.ndarray)

		X_train, X_test, Y_train, Y_test = train_test_split([self.X_array], [self.Y_array], test_size = 0.25)
		self.assertIsInstance(X_train, list)
		self.assertIsInstance(X_test, list)
		self.assertIsInstance(Y_train, list)
		self.assertIsInstance(Y_test, list)

	def test_return_type_consistency_list(self):
		X_train, X_test, Y_train, Y_test = train_test_split(self.X_list, self.Y_list, test_size = 0.25)
		self.assertIsInstance(X_train, list)
		self.assertIsInstance(X_test, list)
		self.assertIsInstance(Y_train, list)
		self.assertIsInstance(Y_test, list)

		# noinspection PyTypeChecker
		X_train, X_test, Y_train, Y_test = train_test_split(tuple(self.X_list), tuple(self.Y_list), test_size = 0.25)
		self.assertIsInstance(X_train, list)
		self.assertIsInstance(X_test, list)
		self.assertIsInstance(Y_train, list)
		self.assertIsInstance(Y_test, list)

	def test_random_state(self):
		X_train1, X_test1, Y_train1, Y_test1 = train_test_split(self.X_array, self.Y_array, test_size = 0.25,
																random_state = 42)
		X_train2, X_test2, Y_train2, Y_test2 = train_test_split(self.X_array, self.Y_array, test_size = 0.25,
																random_state = 42)
		np.testing.assert_array_equal(X_train1, X_train2)
		np.testing.assert_array_equal(X_test1, X_test2)
		np.testing.assert_array_equal(Y_train1, Y_train2)
		np.testing.assert_array_equal(Y_test1, Y_test2)

	def test_invalid_test_size(self):
		with self.assertRaises(ValueError):
			train_test_split(self.X_array, self.Y_array, test_size = -0.1)
		with self.assertRaises(ValueError):
			train_test_split(self.X_array, self.Y_array, test_size = 1.5)

	def test_inconsistent_length(self):
		X = np.array([[1, 2], [3, 4]])
		Y = np.array([1, 2, 3])
		with self.assertRaises(ValueError):
			train_test_split(X, Y)
		with self.assertRaises(ValueError):
			train_test_split(self.X_list_bad, self.Y_list_bad)
		with self.assertRaises(ValueError):
			train_test_split(self.X_list_bad, self.Y_list)
		with self.assertRaises(ValueError):
			train_test_split(self.X_list, self.Y_list_bad)

	def test_empty(self):
		X_empty = np.zeros(0)
		Y_empty = np.zeros(0)
		with self.assertRaises(ValueError):
			train_test_split(X_empty, Y_empty)
		with self.assertRaises(ValueError):
			train_test_split([X_empty], [])
		with self.assertRaises(ValueError):
			train_test_split([], [Y_empty])
		with self.assertRaises(ValueError):
			train_test_split([X_empty], [Y_empty])


if __name__ == '__main__':
	unittest.main()