1 | import unittest2 as unittest |
---|
2 | import numpy as np |
---|
3 | |
---|
4 | from vsm.model import tf |
---|
5 | from multiprocessing import Process |
---|
6 | import platform |
---|
7 | |
---|
8 | class MPTester: |
---|
9 | def setUp(self): |
---|
10 | self.corpus = np.array([0, 1, 3, 1, 1, 0, 3, 0, 3, |
---|
11 | 3, 0, 1, 0, |
---|
12 | 1, 3]) |
---|
13 | self.docs = [slice(0,9), slice(9,13), |
---|
14 | slice(13,13), slice(13,15)] |
---|
15 | self.V = 4 |
---|
16 | self.cnt_mat = np.array([[3, 2, 0, 0], |
---|
17 | [3, 1, 0, 1], |
---|
18 | [0, 0, 0, 0], |
---|
19 | [3, 1, 0, 1]]) |
---|
20 | |
---|
21 | def test_TfMulti_train(self): |
---|
22 | self.setUp() |
---|
23 | m = tf.TfMulti() |
---|
24 | m.corpus = self.corpus |
---|
25 | m.docs = self.docs |
---|
26 | m.V = self.V |
---|
27 | m.train(2) |
---|
28 | |
---|
29 | assert (self.cnt_mat == m.matrix.toarray()).all() |
---|
30 | |
---|
31 | class TestTf(unittest.TestCase): |
---|
32 | |
---|
33 | def setUp(self): |
---|
34 | self.corpus = np.array([0, 1, 3, 1, 1, 0, 3, 0, 3, |
---|
35 | 3, 0, 1, 0, |
---|
36 | 1, 3]) |
---|
37 | self.docs = [slice(0,9), slice(9,13), |
---|
38 | slice(13,13), slice(13,15)] |
---|
39 | self.V = 4 |
---|
40 | self.cnt_mat = np.array([[3, 2, 0, 0], |
---|
41 | [3, 1, 0, 1], |
---|
42 | [0, 0, 0, 0], |
---|
43 | [3, 1, 0, 1]]) |
---|
44 | |
---|
45 | def test_TF_proper_class(self): |
---|
46 | m = tf.TF(multiprocessing=True) |
---|
47 | if platform.system() == 'Windows': |
---|
48 | self.assertTrue(isinstance(m,tf.TfSeq)) |
---|
49 | else: |
---|
50 | self.assertTrue(isinstance(m,tf.TfMulti)) |
---|
51 | |
---|
52 | def test_TfSeq_train(self): |
---|
53 | m = tf.TfSeq() |
---|
54 | m.corpus = self.corpus |
---|
55 | m.docs = self.docs |
---|
56 | m.V = self.V |
---|
57 | m.train() |
---|
58 | self.assertTrue((self.cnt_mat == m.matrix.toarray()).all()) |
---|
59 | |
---|
60 | def test_demo_TfMulti_train(self): |
---|
61 | t = MPTester() |
---|
62 | p = Process(target=t.test_TfMulti_train, args=()) |
---|
63 | p.start() |
---|
64 | p.join() |
---|
65 | |
---|
66 | |
---|
67 | |
---|
68 | #Define and run test suite |
---|
69 | if __name__ == '__main__': |
---|
70 | suite = unittest.TestLoader().loadTestsFromTestCase(TestTf) |
---|
71 | unittest.TextTestRunner(verbosity=2).run(suite) |
---|