source: consulta_publica/vsm/unit_tests/tests_ldacgsmulti.py @ 7095598

baseconstituyenteestudiantesgeneralplan_patriasala
Last change on this file since 7095598 was 0ff122b, checked in by rudmanmrrod <rudman22@…>, 7 años ago

Agregado módulo de gestión de perfiles de procesamiento, incorporado el módulo de visualización de modelado de tópicos

  • Propiedad mode establecida a 100644
File size: 11.3 KB
Línea 
1import unittest2 as unittest
2import numpy as np
3
4from vsm.corpus import Corpus
5from vsm.corpus.util.corpusbuilders import random_corpus
6from vsm.model.ldacgsseq import *
7from vsm.model.ldacgsmulti import *
8from multiprocessing import Process
9
10class MPTester:
11    def test_demo_LdaCgsMulti(self):
12        from vsm.model.ldacgsmulti import demo_LdaCgsMulti
13        demo_LdaCgsMulti()
14   
15    def test_LdaCgsMulti_IO(self):
16        from tempfile import NamedTemporaryFile
17        import os
18   
19        c = random_corpus(1000, 50, 6, 100)
20        tmp = NamedTemporaryFile(delete=False, suffix='.npz')
21        try:
22            m0 = LdaCgsMulti(c, 'document', K=10)
23            m0.train(n_iterations=20)
24            m0.save(tmp.name)
25            m1 = LdaCgsMulti.load(tmp.name)
26            assert m0.context_type == m1.context_type
27            assert m0.K == m1.K
28            assert (m0.alpha == m1.alpha).all()
29            assert (m0.beta == m1.beta).all()
30            assert m0.log_probs == m1.log_probs
31            for i in xrange(max(len(m0.corpus), len(m1.corpus))):
32                assert m0.corpus[i].all() == m1.corpus[i].all()
33            assert m0.V == m1.V
34            assert m0.iteration == m1.iteration
35            for i in xrange(max(len(m0.Z), len(m1.Z))):
36                assert m0.Z[i].all() == m1.Z[i].all()
37            assert m0.top_doc.all() == m1.top_doc.all()
38            assert m0.word_top.all() == m1.word_top.all()
39            assert m0.inv_top_sums.all() == m1.inv_top_sums.all()
40
41            assert m0.seeds == m1.seeds
42            for s0, s1 in zip(m0._mtrand_states,m1._mtrand_states):
43                assert s0[0] == s1[0]
44                assert (s0[1] == s1[1]).all()
45                assert s0[2:] == s1[2:]
46
47            m0 = LdaCgsMulti(c, 'document', K=10)
48            m0.train(n_iterations=20)
49            m0.save(tmp.name)
50            m1 = LdaCgsMulti.load(tmp.name)
51            assert not hasattr(m1, 'log_prob')
52        finally:
53            os.remove(tmp.name)
54
55    def test_LdaCgsMulti_SeedTypes(self):
56        """ Test for issue #74 issues. """
57
58        from tempfile import NamedTemporaryFile
59        import os
60   
61        c = random_corpus(1000, 50, 6, 100)
62        tmp = NamedTemporaryFile(delete=False, suffix='.npz')
63        try:
64            m0 = LdaCgsMulti(c, 'document', K=10)
65            m0.train(n_iterations=20)
66            m0.save(tmp.name)
67            m1 = LdaCgsMulti.load(tmp.name)
68
69            for s0, s1 in zip(m0.seeds, m1.seeds):
70                assert type(s0) == type(s1)
71            for s0, s1 in zip(m0._mtrand_states,m1._mtrand_states):
72                for i in range(5):
73                    assert type(s0[i]) == type(s1[i])
74        finally:
75            os.remove(tmp.name)
76
77    def test_LdaCgsMulti_random_seeds(self):
78        from vsm.corpus.util.corpusbuilders import random_corpus
79
80        c = random_corpus(1000, 50, 0, 20, context_type='document',
81                            metadata=True)
82
83        m0 = LdaCgsMulti(c, 'document', K=10)
84        assert m0.seeds is not None
85        orig_seeds = m0.seeds
86
87        m1 = LdaCgsMulti(c, 'document', K=10, seeds=orig_seeds)
88        assert m0.seeds == m1.seeds
89
90        m0.train(n_iterations=5, verbose=0)
91        m1.train(n_iterations=5, verbose=0)
92        assert m0.seeds == orig_seeds
93        assert m1.seeds == orig_seeds
94
95        # ref:http://docs.scipy.org/doc/numpy/reference/generated/numpy.random.RandomState.get_state.html
96        for s0, s1 in zip(m0._mtrand_states,m1._mtrand_states):
97            assert s0[0] == 'MT19937'
98            assert s1[0] == 'MT19937'
99            assert (s0[1] == s1[1]).all()
100            assert s0[2:] == s1[2:]
101
102        assert m0.context_type == m1.context_type
103        assert m0.K == m1.K
104        assert (m0.alpha == m1.alpha).all()
105        assert (m0.beta == m1.beta).all()
106        assert m0.log_probs == m1.log_probs
107        for i in xrange(max(len(m0.corpus), len(m1.corpus))):
108            assert m0.corpus[i].all() == m1.corpus[i].all()
109        assert m0.V == m1.V
110        assert m0.iteration == m1.iteration
111        for i in xrange(max(len(m0.Z), len(m1.Z))):
112            assert m0.Z[i].all() == m1.Z[i].all()
113        assert m0.top_doc.all() == m1.top_doc.all()
114        assert m0.word_top.all() == m1.word_top.all()
115        assert m0.inv_top_sums.all() == m1.inv_top_sums.all()
116
117    def test_LdaCgsMulti_continue_training(self):
118        from vsm.corpus.util.corpusbuilders import random_corpus
119
120        c = random_corpus(1000, 50, 0, 20, context_type='document',
121                            metadata=True)
122
123        m0 = LdaCgsMulti(c, 'document', K=10)
124        assert m0.seeds is not None
125        orig_seeds = m0.seeds
126
127        m1 = LdaCgsMulti(c, 'document', K=10, seeds=orig_seeds)
128        assert m0.seeds == m1.seeds
129
130        m0.train(n_iterations=2, verbose=0)
131        m1.train(n_iterations=5, verbose=0)
132        assert m0.seeds == orig_seeds
133        assert m1.seeds == orig_seeds
134        for s0, s1 in zip(m0._mtrand_states,m1._mtrand_states):
135            assert (s0[1] != s1[1]).any()
136            assert s0[2:] != s1[2:]
137
138        m0.train(n_iterations=3, verbose=0)
139        # ref:http://docs.scipy.org/doc/numpy/reference/generated/numpy.random.RandomState.get_state.html
140        for s0, s1 in zip(m0._mtrand_states,m1._mtrand_states):
141            assert s0[0] == 'MT19937'
142            assert s1[0] == 'MT19937'
143            assert (s0[1] == s1[1]).all()
144            assert s0[2:] == s1[2:]
145
146        assert m0.context_type == m1.context_type
147        assert m0.K == m1.K
148        assert (m0.alpha == m1.alpha).all()
149        assert (m0.beta == m1.beta).all()
150        assert m0.log_probs == m1.log_probs
151        for i in xrange(max(len(m0.corpus), len(m1.corpus))):
152            assert m0.corpus[i].all() == m1.corpus[i].all()
153        assert m0.V == m1.V
154        assert m0.iteration == m1.iteration
155        for i in xrange(max(len(m0.Z), len(m1.Z))):
156            assert m0.Z[i].all() == m1.Z[i].all()
157        assert m0.top_doc.all() == m1.top_doc.all()
158        assert m0.word_top.all() == m1.word_top.all()
159        assert m0.inv_top_sums.all() == m1.inv_top_sums.all()
160       
161
162
163    def test_LdaCgsMulti_remove_Seq_props(self):
164        from vsm.corpus.util.corpusbuilders import random_corpus
165
166        c = random_corpus(1000, 50, 0, 20, context_type='document',
167                            metadata=True)
168
169        m0 = LdaCgsMulti(c, 'document', K=10)
170
171        assert getattr(m0, 'seed', None) is None
172        assert getattr(m0, '_mtrand_state', None) is None
173
174    def test_LdaCgsMulti_eq_LdaCgsSeq(self):
175        from tempfile import NamedTemporaryFile
176        import os
177   
178        c = random_corpus(1000, 50, 6, 100, seed=2)
179        tmp = NamedTemporaryFile(delete=False, suffix='.npz')
180        m0 = LdaCgsMulti(c, 'document', K=10, n_proc=1, seeds=[2])
181        m1 = LdaCgsSeq(c, 'document', K=10, seed=2)
182        for iteration in range(20):
183            m0.train(n_iterations=1, verbose=0)
184            m1.train(n_iterations=1, verbose=0)
185           
186            assert m0.context_type == m1.context_type
187            assert m0.K == m1.K
188            assert (m0.alpha == m1.alpha).all()
189            assert (m0.beta == m1.beta).all()
190            for i in xrange(max(len(m0.corpus), len(m1.corpus))):
191                assert m0.corpus[i].all() == m1.corpus[i].all()
192            assert m0.V == m1.V
193            assert m0.iteration == m1.iteration
194            assert (m0.Z[i] == m1.Z[i]).all()
195            assert (m0.top_doc == m1.top_doc).all()
196            assert (m0.word_top == m1.word_top).all()
197            assert (np.isclose(m0.inv_top_sums, m1.inv_top_sums)).all()
198   
199            assert m0.seeds[0] == m1.seed
200            assert m0._mtrand_states[0][0] == m1._mtrand_state[0]
201            for s0,s1 in zip(m0._mtrand_states[0][1], m1._mtrand_state[1]):
202                assert s0 == s1
203            assert m0._mtrand_states[0][2] == m1._mtrand_state[2]
204            assert m0._mtrand_states[0][3] == m1._mtrand_state[3]
205            assert m0._mtrand_states[0][4] == m1._mtrand_state[4]
206            print iteration, m0.log_probs[-1], m1.log_probs[-1] 
207            for i in range(iteration):
208                assert np.isclose(m0.log_probs[i][1], m1.log_probs[i][1])
209   
210    def test_LdaCgsMulti_eq_LdaCgsSeq_multi(self):
211        from tempfile import NamedTemporaryFile
212        import os
213   
214        c = random_corpus(1000, 50, 6, 100, seed=2)
215        tmp = NamedTemporaryFile(delete=False, suffix='.npz')
216        m0 = LdaCgsMulti(c, 'document', K=10, n_proc=1, seeds=[2])
217        m1 = LdaCgsSeq(c, 'document', K=10, seed=2)
218        for iteration in range(20):
219            m0.train(n_iterations=2, verbose=0)
220            m1.train(n_iterations=2, verbose=0)
221           
222            assert m0.context_type == m1.context_type
223            assert m0.K == m1.K
224            assert (m0.alpha == m1.alpha).all()
225            assert (m0.beta == m1.beta).all()
226            for i in xrange(max(len(m0.corpus), len(m1.corpus))):
227                assert m0.corpus[i].all() == m1.corpus[i].all()
228            assert m0.V == m1.V
229            assert m0.iteration == m1.iteration
230            assert (m0.Z[i] == m1.Z[i]).all()
231            assert (m0.top_doc == m1.top_doc).all()
232            assert (m0.word_top == m1.word_top).all()
233            assert (np.isclose(m0.inv_top_sums, m1.inv_top_sums)).all()
234   
235            assert m0.seeds[0] == m1.seed
236            assert m0._mtrand_states[0][0] == m1._mtrand_state[0]
237            for s0,s1 in zip(m0._mtrand_states[0][1], m1._mtrand_state[1]):
238                assert s0 == s1
239            assert m0._mtrand_states[0][2] == m1._mtrand_state[2]
240            assert m0._mtrand_states[0][3] == m1._mtrand_state[3]
241            assert m0._mtrand_states[0][4] == m1._mtrand_state[4]
242            print iteration, m0.log_probs[-1], m1.log_probs[-1] 
243            for i in range(iteration):
244                assert np.isclose(m0.log_probs[i][1], m1.log_probs[i][1])
245
246
247class TestLdaCgsMulti(unittest.TestCase):
248    def setUp(self):
249        pass
250   
251    def test_demo_LdaCgsMulti(self):
252        t = MPTester()
253        p = Process(target=t.test_demo_LdaCgsMulti, args=())
254        p.start()
255        p.join()
256   
257    def test_LdaCgsMulti_IO(self):
258        t = MPTester()
259        p = Process(target=t.test_LdaCgsMulti_IO, args=())
260        p.start()
261        p.join()
262   
263    def test_LdaCgsMulti_SeedTypes(self):
264        t = MPTester()
265        p = Process(target=t.test_LdaCgsMulti_SeedTypes, args=())
266        p.start()
267        p.join()
268   
269    def test_LdaCgsMulti_random_seeds(self):
270        t = MPTester()
271        p = Process(target=t.test_LdaCgsMulti_random_seeds, args=())
272        p.start()
273        p.join()
274   
275    def test_LdaCgsMulti_remove_Seq_props(self):
276        t = MPTester()
277        p = Process(target=t.test_LdaCgsMulti_remove_Seq_props, args=())
278        p.start()
279        p.join()
280   
281    def test_LdaCgsMulti_continue_training(self):
282        t = MPTester()
283        p = Process(target=t.test_LdaCgsMulti_continue_training, args=())
284        p.start()
285        p.join()
286   
287    def test_LdaCgsMulti_eq_LdaCgsSeq(self):
288        t = MPTester()
289        p = Process(target=t.test_LdaCgsMulti_eq_LdaCgsSeq, args=())
290        p.start()
291        p.join()
292   
293    def test_LdaCgsMulti_eq_LdaCgsSeq_multi(self):
294        t = MPTester()
295        p = Process(target=t.test_LdaCgsMulti_eq_LdaCgsSeq_multi, args=())
296        p.start()
297        p.join()
298   
299
300if __name__ == '__main__':
301    suite = unittest.TestLoader().loadTestsFromTestCase(TestLdaCgsMulti)
302    unittest.TextTestRunner(verbosity=2).run(suite)
Nota: Vea TracBrowser para ayuda de uso del navegador del repositorio.