source: consulta_publica/vsm/unit_tests/tests_ldacgsseq.py @ 7095598

baseconstituyenteestudiantesgeneralplan_patriasala
Last change on this file since 7095598 was 0ff122b, checked in by rudmanmrrod <rudman22@…>, 7 años ago

Agregado módulo de gestión de perfiles de procesamiento, incorporado el módulo de visualización de modelado de tópicos

  • Propiedad mode establecida a 100644
File size: 8.8 KB
Línea 
1import unittest2 as unittest
2import numpy as np
3
4from vsm.corpus import Corpus
5from vsm.corpus.util.corpusbuilders import random_corpus
6from vsm.model.ldacgsseq import *
7
8
9class TestLdaCgsSeq(unittest.TestCase):
10
11    def setUp(self):
12        pass
13
14    ##TODO: write actual test cases.
15
16    def test_LdaCgsSeq_IO(self):
17
18        from tempfile import NamedTemporaryFile
19        import os
20   
21        c = random_corpus(1000, 50, 6, 100)
22        tmp = NamedTemporaryFile(delete=False, suffix='.npz')
23        try:
24            m0 = LdaCgsSeq(c, 'document', K=10)
25            m0.train(n_iterations=20)
26            m0.save(tmp.name)
27            m1 = LdaCgsSeq.load(tmp.name)
28            self.assertTrue(m0.context_type == m1.context_type)
29            self.assertTrue(m0.K == m1.K)
30            self.assertTrue((m0.alpha == m1.alpha).all())
31            self.assertTrue((m0.beta == m1.beta).all())
32            self.assertTrue(m0.log_probs == m1.log_probs)
33            for i in xrange(max(len(m0.corpus), len(m1.corpus))):
34                self.assertTrue(m0.corpus[i].all() == m1.corpus[i].all())
35            self.assertTrue(m0.V == m1.V)
36            self.assertTrue(m0.iteration == m1.iteration)
37            for i in xrange(max(len(m0.Z), len(m1.Z))):
38                self.assertTrue(m0.Z[i].all() == m1.Z[i].all())
39            self.assertTrue(m0.top_doc.all() == m1.top_doc.all())
40            self.assertTrue(m0.word_top.all() == m1.word_top.all())
41            self.assertTrue(m0.inv_top_sums.all() == m1.inv_top_sums.all())
42
43            self.assertTrue(m0.seed == m1.seed)
44            self.assertTrue(m0._mtrand_state[0] == m1._mtrand_state[0])
45            self.assertTrue((m0._mtrand_state[1] == m1._mtrand_state[1]).all())
46            self.assertTrue(m0._mtrand_state[2:] == m1._mtrand_state[2:])
47           
48
49            m0 = LdaCgsSeq(c, 'document', K=10)
50            m0.train(n_iterations=20)
51            m0.save(tmp.name)
52            m1 = LdaCgsSeq.load(tmp.name)
53            self.assertTrue(not hasattr(m1, 'log_prob'))
54        finally:
55            os.remove(tmp.name)
56   
57    def test_LdaCgsSeq_SeedTypes(self):
58        """ Test for issue #74 issues. """
59
60        from tempfile import NamedTemporaryFile
61        import os
62   
63        c = random_corpus(1000, 50, 6, 100)
64        tmp = NamedTemporaryFile(delete=False, suffix='.npz')
65        try:
66            m0 = LdaCgsSeq(c, 'document', K=10)
67            m0.train(n_iterations=20)
68            m0.save(tmp.name)
69            m1 = LdaCgsSeq.load(tmp.name)
70
71            self.assertTrue(type(m0.seed) == type(m1.seed))
72            self.assertTrue(type(m0._mtrand_state[0]) == type(m1._mtrand_state[0]))
73            self.assertTrue(type(m0._mtrand_state[1]) == type(m1._mtrand_state[1]))
74            self.assertTrue(type(m0._mtrand_state[2]) == type(m1._mtrand_state[2]))
75            self.assertTrue(type(m0._mtrand_state[3]) == type(m1._mtrand_state[3]))
76            self.assertTrue(type(m0._mtrand_state[4]) == type(m1._mtrand_state[4]))
77        finally:
78            os.remove(tmp.name)
79
80
81    def test_LdaCgsQuerySampler_init(self):
82
83        old_corp = Corpus([], remove_empty=False)
84        old_corp.corpus = np.array([ 0, 1, 1, 0, 0, 1 ], dtype='i')
85        old_corp.context_data = [ np.array([(3, ), (3, )], dtype=[('idx', 'i')]) ]
86        old_corp.context_types = [ 'document' ]
87        old_corp.words = np.array([ '0', '1' ], dtype='i')
88        old_corp.words_int = { '0': 0, '1': 1 }
89
90        new_corp = Corpus([], remove_empty=False)
91        new_corp.corpus = np.array([ 0, 0 ], dtype='i')
92        new_corp.context_data = [ np.array([(2, )], dtype=[('idx', 'i')]) ]
93        new_corp.context_types = [ 'document' ]
94        new_corp.words = np.array([ '0', '1' ], dtype='i')
95        new_corp.words_int = { '0': 0, '1': 1 }
96
97        m = LdaCgsSeq(corpus=old_corp, context_type='document', K=2, V=2)
98        m.Z[:] = np.array([0, 0, 0, 1, 1, 1], dtype='i')
99        m.word_top[:] = np.array([[ 1.01, 2.01 ],
100                                  [ 2.01, 1.01 ]], dtype='d')
101        m.top_doc[:] = np.array([[ 3.01, 0.01 ], 
102                                 [ 0.01, 3.01 ]], dtype='d')
103        m.inv_top_sums[:] = 1. / m.word_top.sum(0)
104
105        q = LdaCgsQuerySampler(m, new_corpus=new_corp, old_corpus=old_corp)
106        self.assertTrue(q.V==2)
107        self.assertTrue(q.K==2)
108        self.assertTrue(len(q.corpus)==2)
109        self.assertTrue((q.corpus==new_corp.corpus).all())
110        self.assertTrue(len(q.indices)==1)
111        self.assertTrue((q.indices==
112                         new_corp.view_metadata('document')['idx']).all())
113        self.assertTrue(q.word_top.shape==(2, 2))
114        self.assertTrue((q.word_top==m.word_top).all())
115        self.assertTrue(q.top_doc.shape==(2, 1))
116        self.assertTrue((q.top_doc==[[ 0.01 ],
117                                     [ 0.01 ]]).all())
118        self.assertTrue(q.inv_top_sums.shape==(2, ))
119        self.assertTrue((q.inv_top_sums==m.inv_top_sums).all())
120        self.assertTrue(q.alpha.shape==(2, 1))
121        self.assertTrue((q.alpha==m.alpha).all())
122        self.assertTrue(q.beta.shape==(2, 1))
123        self.assertTrue((q.beta==m.beta).all())
124
125    def test_randomSeed(self):
126        from vsm.corpus.util.corpusbuilders import random_corpus
127        from vsm.model.ldacgsseq import LdaCgsSeq
128
129        c = random_corpus(1000, 50, 0, 20, context_type='document',
130                            metadata=True)
131
132        m0 = LdaCgsSeq(c, 'document', K=10)
133        assert m0.seed is not None
134        orig_seed = m0.seed
135
136        m1 = LdaCgsSeq(c, 'document', K=10, seed=orig_seed)
137        assert m0.seed == m1.seed
138
139        m0.train(n_iterations=50, verbose=0)
140        m1.train(n_iterations=50, verbose=0)
141        assert m0.seed == orig_seed
142        assert m1.seed == orig_seed
143       
144        # ref:http://docs.scipy.org/doc/numpy/reference/generated/numpy.random.RandomState.get_state.html
145        assert m0._mtrand_state[0] == 'MT19937'
146        assert m1._mtrand_state[0] == 'MT19937'
147        assert (m0._mtrand_state[1] == m1._mtrand_state[1]).all()
148        assert m0._mtrand_state[2:] == m1._mtrand_state[2:]
149
150        self.assertTrue(m0.context_type == m1.context_type)
151        self.assertTrue(m0.K == m1.K)
152        self.assertTrue((m0.alpha == m1.alpha).all())
153        self.assertTrue((m0.beta == m1.beta).all())
154        self.assertTrue(m0.log_probs == m1.log_probs)
155        for i in xrange(max(len(m0.corpus), len(m1.corpus))):
156            self.assertTrue(m0.corpus[i].all() == m1.corpus[i].all())
157        self.assertTrue(m0.V == m1.V)
158        self.assertTrue(m0.iteration == m1.iteration)
159        for i in xrange(max(len(m0.Z), len(m1.Z))):
160            self.assertTrue(m0.Z[i].all() == m1.Z[i].all())
161        self.assertTrue(m0.top_doc.all() == m1.top_doc.all())
162        self.assertTrue(m0.word_top.all() == m1.word_top.all())
163        self.assertTrue(m0.inv_top_sums.all() == m1.inv_top_sums.all())
164   
165    def test_continueTraining(self):
166        from vsm.corpus.util.corpusbuilders import random_corpus
167        from vsm.model.ldacgsseq import LdaCgsSeq
168
169        c = random_corpus(1000, 50, 0, 20, context_type='document',
170                            metadata=True)
171
172        m0 = LdaCgsSeq(c, 'document', K=10)
173        assert m0.seed is not None
174        orig_seed = m0.seed
175
176        m1 = LdaCgsSeq(c, 'document', K=10, seed=orig_seed)
177        assert m0.seed == m1.seed
178
179        m0.train(n_iterations=2, verbose=0)
180        m1.train(n_iterations=5, verbose=0)
181        assert m0.seed == orig_seed
182        assert m1.seed == orig_seed
183        assert (m0._mtrand_state[1] != m1._mtrand_state[1]).any()
184        assert m0._mtrand_state[2:] != m1._mtrand_state[2:]
185
186        m0.train(n_iterations=3, verbose=0)
187       
188        # ref:http://docs.scipy.org/doc/numpy/reference/generated/numpy.random.RandomState.get_state.html
189        assert m0._mtrand_state[0] == 'MT19937'
190        assert m1._mtrand_state[0] == 'MT19937'
191        assert (m0._mtrand_state[1] == m1._mtrand_state[1]).all()
192        assert m0._mtrand_state[2:] == m1._mtrand_state[2:]
193
194        self.assertTrue(m0.context_type == m1.context_type)
195        self.assertTrue(m0.K == m1.K)
196        self.assertTrue((m0.alpha == m1.alpha).all())
197        self.assertTrue((m0.beta == m1.beta).all())
198        self.assertTrue(m0.log_probs == m1.log_probs)
199        for i in xrange(max(len(m0.corpus), len(m1.corpus))):
200            self.assertTrue(m0.corpus[i].all() == m1.corpus[i].all())
201        self.assertTrue(m0.V == m1.V)
202        self.assertTrue(m0.iteration == m1.iteration)
203        for i in xrange(max(len(m0.Z), len(m1.Z))):
204            self.assertTrue(m0.Z[i].all() == m1.Z[i].all())
205        self.assertTrue(m0.top_doc.all() == m1.top_doc.all())
206        self.assertTrue(m0.word_top.all() == m1.word_top.all())
207        self.assertTrue(m0.inv_top_sums.all() == m1.inv_top_sums.all())
208
209if __name__ == '__main__':
210    suite = unittest.TestLoader().loadTestsFromTestCase(TestLdaCgsSeq)
211    unittest.TextTestRunner(verbosity=2).run(suite)
Nota: Vea TracBrowser para ayuda de uso del navegador del repositorio.