1 | import unittest2 as unittest |
---|
2 | import numpy as np |
---|
3 | |
---|
4 | from vsm.corpus import Corpus |
---|
5 | from vsm.corpus.util.corpusbuilders import random_corpus |
---|
6 | from vsm.model.ldacgsseq import * |
---|
7 | |
---|
8 | |
---|
9 | class TestLdaCgsSeq(unittest.TestCase): |
---|
10 | |
---|
11 | def setUp(self): |
---|
12 | pass |
---|
13 | |
---|
14 | ##TODO: write actual test cases. |
---|
15 | |
---|
16 | def test_LdaCgsSeq_IO(self): |
---|
17 | |
---|
18 | from tempfile import NamedTemporaryFile |
---|
19 | import os |
---|
20 | |
---|
21 | c = random_corpus(1000, 50, 6, 100) |
---|
22 | tmp = NamedTemporaryFile(delete=False, suffix='.npz') |
---|
23 | try: |
---|
24 | m0 = LdaCgsSeq(c, 'document', K=10) |
---|
25 | m0.train(n_iterations=20) |
---|
26 | m0.save(tmp.name) |
---|
27 | m1 = LdaCgsSeq.load(tmp.name) |
---|
28 | self.assertTrue(m0.context_type == m1.context_type) |
---|
29 | self.assertTrue(m0.K == m1.K) |
---|
30 | self.assertTrue((m0.alpha == m1.alpha).all()) |
---|
31 | self.assertTrue((m0.beta == m1.beta).all()) |
---|
32 | self.assertTrue(m0.log_probs == m1.log_probs) |
---|
33 | for i in xrange(max(len(m0.corpus), len(m1.corpus))): |
---|
34 | self.assertTrue(m0.corpus[i].all() == m1.corpus[i].all()) |
---|
35 | self.assertTrue(m0.V == m1.V) |
---|
36 | self.assertTrue(m0.iteration == m1.iteration) |
---|
37 | for i in xrange(max(len(m0.Z), len(m1.Z))): |
---|
38 | self.assertTrue(m0.Z[i].all() == m1.Z[i].all()) |
---|
39 | self.assertTrue(m0.top_doc.all() == m1.top_doc.all()) |
---|
40 | self.assertTrue(m0.word_top.all() == m1.word_top.all()) |
---|
41 | self.assertTrue(m0.inv_top_sums.all() == m1.inv_top_sums.all()) |
---|
42 | |
---|
43 | self.assertTrue(m0.seed == m1.seed) |
---|
44 | self.assertTrue(m0._mtrand_state[0] == m1._mtrand_state[0]) |
---|
45 | self.assertTrue((m0._mtrand_state[1] == m1._mtrand_state[1]).all()) |
---|
46 | self.assertTrue(m0._mtrand_state[2:] == m1._mtrand_state[2:]) |
---|
47 | |
---|
48 | |
---|
49 | m0 = LdaCgsSeq(c, 'document', K=10) |
---|
50 | m0.train(n_iterations=20) |
---|
51 | m0.save(tmp.name) |
---|
52 | m1 = LdaCgsSeq.load(tmp.name) |
---|
53 | self.assertTrue(not hasattr(m1, 'log_prob')) |
---|
54 | finally: |
---|
55 | os.remove(tmp.name) |
---|
56 | |
---|
57 | def test_LdaCgsSeq_SeedTypes(self): |
---|
58 | """ Test for issue #74 issues. """ |
---|
59 | |
---|
60 | from tempfile import NamedTemporaryFile |
---|
61 | import os |
---|
62 | |
---|
63 | c = random_corpus(1000, 50, 6, 100) |
---|
64 | tmp = NamedTemporaryFile(delete=False, suffix='.npz') |
---|
65 | try: |
---|
66 | m0 = LdaCgsSeq(c, 'document', K=10) |
---|
67 | m0.train(n_iterations=20) |
---|
68 | m0.save(tmp.name) |
---|
69 | m1 = LdaCgsSeq.load(tmp.name) |
---|
70 | |
---|
71 | self.assertTrue(type(m0.seed) == type(m1.seed)) |
---|
72 | self.assertTrue(type(m0._mtrand_state[0]) == type(m1._mtrand_state[0])) |
---|
73 | self.assertTrue(type(m0._mtrand_state[1]) == type(m1._mtrand_state[1])) |
---|
74 | self.assertTrue(type(m0._mtrand_state[2]) == type(m1._mtrand_state[2])) |
---|
75 | self.assertTrue(type(m0._mtrand_state[3]) == type(m1._mtrand_state[3])) |
---|
76 | self.assertTrue(type(m0._mtrand_state[4]) == type(m1._mtrand_state[4])) |
---|
77 | finally: |
---|
78 | os.remove(tmp.name) |
---|
79 | |
---|
80 | |
---|
81 | def test_LdaCgsQuerySampler_init(self): |
---|
82 | |
---|
83 | old_corp = Corpus([], remove_empty=False) |
---|
84 | old_corp.corpus = np.array([ 0, 1, 1, 0, 0, 1 ], dtype='i') |
---|
85 | old_corp.context_data = [ np.array([(3, ), (3, )], dtype=[('idx', 'i')]) ] |
---|
86 | old_corp.context_types = [ 'document' ] |
---|
87 | old_corp.words = np.array([ '0', '1' ], dtype='i') |
---|
88 | old_corp.words_int = { '0': 0, '1': 1 } |
---|
89 | |
---|
90 | new_corp = Corpus([], remove_empty=False) |
---|
91 | new_corp.corpus = np.array([ 0, 0 ], dtype='i') |
---|
92 | new_corp.context_data = [ np.array([(2, )], dtype=[('idx', 'i')]) ] |
---|
93 | new_corp.context_types = [ 'document' ] |
---|
94 | new_corp.words = np.array([ '0', '1' ], dtype='i') |
---|
95 | new_corp.words_int = { '0': 0, '1': 1 } |
---|
96 | |
---|
97 | m = LdaCgsSeq(corpus=old_corp, context_type='document', K=2, V=2) |
---|
98 | m.Z[:] = np.array([0, 0, 0, 1, 1, 1], dtype='i') |
---|
99 | m.word_top[:] = np.array([[ 1.01, 2.01 ], |
---|
100 | [ 2.01, 1.01 ]], dtype='d') |
---|
101 | m.top_doc[:] = np.array([[ 3.01, 0.01 ], |
---|
102 | [ 0.01, 3.01 ]], dtype='d') |
---|
103 | m.inv_top_sums[:] = 1. / m.word_top.sum(0) |
---|
104 | |
---|
105 | q = LdaCgsQuerySampler(m, new_corpus=new_corp, old_corpus=old_corp) |
---|
106 | self.assertTrue(q.V==2) |
---|
107 | self.assertTrue(q.K==2) |
---|
108 | self.assertTrue(len(q.corpus)==2) |
---|
109 | self.assertTrue((q.corpus==new_corp.corpus).all()) |
---|
110 | self.assertTrue(len(q.indices)==1) |
---|
111 | self.assertTrue((q.indices== |
---|
112 | new_corp.view_metadata('document')['idx']).all()) |
---|
113 | self.assertTrue(q.word_top.shape==(2, 2)) |
---|
114 | self.assertTrue((q.word_top==m.word_top).all()) |
---|
115 | self.assertTrue(q.top_doc.shape==(2, 1)) |
---|
116 | self.assertTrue((q.top_doc==[[ 0.01 ], |
---|
117 | [ 0.01 ]]).all()) |
---|
118 | self.assertTrue(q.inv_top_sums.shape==(2, )) |
---|
119 | self.assertTrue((q.inv_top_sums==m.inv_top_sums).all()) |
---|
120 | self.assertTrue(q.alpha.shape==(2, 1)) |
---|
121 | self.assertTrue((q.alpha==m.alpha).all()) |
---|
122 | self.assertTrue(q.beta.shape==(2, 1)) |
---|
123 | self.assertTrue((q.beta==m.beta).all()) |
---|
124 | |
---|
125 | def test_randomSeed(self): |
---|
126 | from vsm.corpus.util.corpusbuilders import random_corpus |
---|
127 | from vsm.model.ldacgsseq import LdaCgsSeq |
---|
128 | |
---|
129 | c = random_corpus(1000, 50, 0, 20, context_type='document', |
---|
130 | metadata=True) |
---|
131 | |
---|
132 | m0 = LdaCgsSeq(c, 'document', K=10) |
---|
133 | assert m0.seed is not None |
---|
134 | orig_seed = m0.seed |
---|
135 | |
---|
136 | m1 = LdaCgsSeq(c, 'document', K=10, seed=orig_seed) |
---|
137 | assert m0.seed == m1.seed |
---|
138 | |
---|
139 | m0.train(n_iterations=50, verbose=0) |
---|
140 | m1.train(n_iterations=50, verbose=0) |
---|
141 | assert m0.seed == orig_seed |
---|
142 | assert m1.seed == orig_seed |
---|
143 | |
---|
144 | # ref:http://docs.scipy.org/doc/numpy/reference/generated/numpy.random.RandomState.get_state.html |
---|
145 | assert m0._mtrand_state[0] == 'MT19937' |
---|
146 | assert m1._mtrand_state[0] == 'MT19937' |
---|
147 | assert (m0._mtrand_state[1] == m1._mtrand_state[1]).all() |
---|
148 | assert m0._mtrand_state[2:] == m1._mtrand_state[2:] |
---|
149 | |
---|
150 | self.assertTrue(m0.context_type == m1.context_type) |
---|
151 | self.assertTrue(m0.K == m1.K) |
---|
152 | self.assertTrue((m0.alpha == m1.alpha).all()) |
---|
153 | self.assertTrue((m0.beta == m1.beta).all()) |
---|
154 | self.assertTrue(m0.log_probs == m1.log_probs) |
---|
155 | for i in xrange(max(len(m0.corpus), len(m1.corpus))): |
---|
156 | self.assertTrue(m0.corpus[i].all() == m1.corpus[i].all()) |
---|
157 | self.assertTrue(m0.V == m1.V) |
---|
158 | self.assertTrue(m0.iteration == m1.iteration) |
---|
159 | for i in xrange(max(len(m0.Z), len(m1.Z))): |
---|
160 | self.assertTrue(m0.Z[i].all() == m1.Z[i].all()) |
---|
161 | self.assertTrue(m0.top_doc.all() == m1.top_doc.all()) |
---|
162 | self.assertTrue(m0.word_top.all() == m1.word_top.all()) |
---|
163 | self.assertTrue(m0.inv_top_sums.all() == m1.inv_top_sums.all()) |
---|
164 | |
---|
165 | def test_continueTraining(self): |
---|
166 | from vsm.corpus.util.corpusbuilders import random_corpus |
---|
167 | from vsm.model.ldacgsseq import LdaCgsSeq |
---|
168 | |
---|
169 | c = random_corpus(1000, 50, 0, 20, context_type='document', |
---|
170 | metadata=True) |
---|
171 | |
---|
172 | m0 = LdaCgsSeq(c, 'document', K=10) |
---|
173 | assert m0.seed is not None |
---|
174 | orig_seed = m0.seed |
---|
175 | |
---|
176 | m1 = LdaCgsSeq(c, 'document', K=10, seed=orig_seed) |
---|
177 | assert m0.seed == m1.seed |
---|
178 | |
---|
179 | m0.train(n_iterations=2, verbose=0) |
---|
180 | m1.train(n_iterations=5, verbose=0) |
---|
181 | assert m0.seed == orig_seed |
---|
182 | assert m1.seed == orig_seed |
---|
183 | assert (m0._mtrand_state[1] != m1._mtrand_state[1]).any() |
---|
184 | assert m0._mtrand_state[2:] != m1._mtrand_state[2:] |
---|
185 | |
---|
186 | m0.train(n_iterations=3, verbose=0) |
---|
187 | |
---|
188 | # ref:http://docs.scipy.org/doc/numpy/reference/generated/numpy.random.RandomState.get_state.html |
---|
189 | assert m0._mtrand_state[0] == 'MT19937' |
---|
190 | assert m1._mtrand_state[0] == 'MT19937' |
---|
191 | assert (m0._mtrand_state[1] == m1._mtrand_state[1]).all() |
---|
192 | assert m0._mtrand_state[2:] == m1._mtrand_state[2:] |
---|
193 | |
---|
194 | self.assertTrue(m0.context_type == m1.context_type) |
---|
195 | self.assertTrue(m0.K == m1.K) |
---|
196 | self.assertTrue((m0.alpha == m1.alpha).all()) |
---|
197 | self.assertTrue((m0.beta == m1.beta).all()) |
---|
198 | self.assertTrue(m0.log_probs == m1.log_probs) |
---|
199 | for i in xrange(max(len(m0.corpus), len(m1.corpus))): |
---|
200 | self.assertTrue(m0.corpus[i].all() == m1.corpus[i].all()) |
---|
201 | self.assertTrue(m0.V == m1.V) |
---|
202 | self.assertTrue(m0.iteration == m1.iteration) |
---|
203 | for i in xrange(max(len(m0.Z), len(m1.Z))): |
---|
204 | self.assertTrue(m0.Z[i].all() == m1.Z[i].all()) |
---|
205 | self.assertTrue(m0.top_doc.all() == m1.top_doc.all()) |
---|
206 | self.assertTrue(m0.word_top.all() == m1.word_top.all()) |
---|
207 | self.assertTrue(m0.inv_top_sums.all() == m1.inv_top_sums.all()) |
---|
208 | |
---|
209 | if __name__ == '__main__': |
---|
210 | suite = unittest.TestLoader().loadTestsFromTestCase(TestLdaCgsSeq) |
---|
211 | unittest.TextTestRunner(verbosity=2).run(suite) |
---|