Spaces:
Sleeping
Sleeping
File size: 3,810 Bytes
d68c650 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import unittest
import numpy as np
from stnn.data.preprocessing import train_test_split
class TestTrainTestSplit(unittest.TestCase):
def setUp(self):
self.X_array = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
self.Y_array = np.array([1, 2, 3, 4])
self.X_list = [self.X_array, self.X_array]
self.Y_list = [self.Y_array, self.Y_array]
self.X_list_bad = [self.Y_array, self.X_array]
self.Y_list_bad = [self.Y_array, self.X_array]
def test_basic_functionality_array(self):
X_train, X_test, Y_train, Y_test = train_test_split(self.X_array, self.Y_array, test_size = 0.25)
self.assertEqual(len(X_train), 3)
self.assertEqual(len(X_test), 1)
self.assertEqual(len(Y_train), 3)
self.assertEqual(len(Y_test), 1)
def test_basic_functionality_list(self):
X_train, X_test, Y_train, Y_test = train_test_split(self.X_list, self.Y_list, test_size = 0.25)
self.assertEqual(len(X_train[0]), 3)
self.assertEqual(len(X_test[0]), 1)
self.assertEqual(len(Y_train[0]), 3)
self.assertEqual(len(Y_test[0]), 1)
def test_return_type_consistency_array(self):
X_train, X_test, Y_train, Y_test = train_test_split(self.X_array, self.Y_array, test_size = 0.25)
self.assertIsInstance(X_train, np.ndarray)
self.assertIsInstance(X_test, np.ndarray)
self.assertIsInstance(Y_train, np.ndarray)
self.assertIsInstance(Y_test, np.ndarray)
X_train, X_test, Y_train, Y_test = train_test_split([self.X_array], [self.Y_array], test_size = 0.25)
self.assertIsInstance(X_train, list)
self.assertIsInstance(X_test, list)
self.assertIsInstance(Y_train, list)
self.assertIsInstance(Y_test, list)
def test_return_type_consistency_list(self):
X_train, X_test, Y_train, Y_test = train_test_split(self.X_list, self.Y_list, test_size = 0.25)
self.assertIsInstance(X_train, list)
self.assertIsInstance(X_test, list)
self.assertIsInstance(Y_train, list)
self.assertIsInstance(Y_test, list)
# noinspection PyTypeChecker
X_train, X_test, Y_train, Y_test = train_test_split(tuple(self.X_list), tuple(self.Y_list), test_size = 0.25)
self.assertIsInstance(X_train, list)
self.assertIsInstance(X_test, list)
self.assertIsInstance(Y_train, list)
self.assertIsInstance(Y_test, list)
def test_random_state(self):
X_train1, X_test1, Y_train1, Y_test1 = train_test_split(self.X_array, self.Y_array, test_size = 0.25,
random_state = 42)
X_train2, X_test2, Y_train2, Y_test2 = train_test_split(self.X_array, self.Y_array, test_size = 0.25,
random_state = 42)
np.testing.assert_array_equal(X_train1, X_train2)
np.testing.assert_array_equal(X_test1, X_test2)
np.testing.assert_array_equal(Y_train1, Y_train2)
np.testing.assert_array_equal(Y_test1, Y_test2)
def test_invalid_test_size(self):
with self.assertRaises(ValueError):
train_test_split(self.X_array, self.Y_array, test_size = -0.1)
with self.assertRaises(ValueError):
train_test_split(self.X_array, self.Y_array, test_size = 1.5)
def test_inconsistent_length(self):
X = np.array([[1, 2], [3, 4]])
Y = np.array([1, 2, 3])
with self.assertRaises(ValueError):
train_test_split(X, Y)
with self.assertRaises(ValueError):
train_test_split(self.X_list_bad, self.Y_list_bad)
with self.assertRaises(ValueError):
train_test_split(self.X_list_bad, self.Y_list)
with self.assertRaises(ValueError):
train_test_split(self.X_list, self.Y_list_bad)
def test_empty(self):
X_empty = np.zeros(0)
Y_empty = np.zeros(0)
with self.assertRaises(ValueError):
train_test_split(X_empty, Y_empty)
with self.assertRaises(ValueError):
train_test_split([X_empty], [])
with self.assertRaises(ValueError):
train_test_split([], [Y_empty])
with self.assertRaises(ValueError):
train_test_split([X_empty], [Y_empty])
if __name__ == '__main__':
unittest.main()
|