[0ff122b] | 1 | import unittest2 as unittest |
---|
| 2 | import numpy as np |
---|
| 3 | |
---|
| 4 | |
---|
| 5 | |
---|
| 6 | class TestBeagleContext(unittest.TestCase): |
---|
| 7 | |
---|
| 8 | def setUp(self): |
---|
| 9 | from vsm.corpus.util.corpusbuilders import random_corpus |
---|
| 10 | from vsm.model.beaglecontext import BeagleContextSeq, BeagleContextMulti |
---|
| 11 | from vsm.model.beagleenvironment import BeagleEnvironment |
---|
| 12 | |
---|
| 13 | self.ec = random_corpus(1000, 50, 0, 5, context_type='sentence') |
---|
| 14 | self.cc = self.ec.apply_stoplist(stoplist=[str(i) for i in xrange(0,50,7)]) |
---|
| 15 | |
---|
| 16 | self.e = BeagleEnvironment(self.ec, n_cols=5) |
---|
| 17 | self.e.train() |
---|
| 18 | |
---|
| 19 | self.ms = BeagleContextSeq(self.cc, self.ec, self.e.matrix) |
---|
| 20 | self.ms.train() |
---|
| 21 | ''' |
---|
| 22 | self.mm = BeagleContextMulti(self.cc, self.ec, self.e.matrix) |
---|
| 23 | self.mm.train(n_procs=2) |
---|
| 24 | ''' |
---|
| 25 | |
---|
| 26 | |
---|
| 27 | def test_BeagleContextSeq(self): |
---|
| 28 | from tempfile import NamedTemporaryFile |
---|
| 29 | import os |
---|
| 30 | |
---|
| 31 | from vsm.model.beaglecontext import BeagleContextSeq |
---|
| 32 | try: |
---|
| 33 | tmp = NamedTemporaryFile(delete=False, suffix='.npz') |
---|
| 34 | self.ms.save(tmp.name) |
---|
| 35 | tmp.close() |
---|
| 36 | m1 = BeagleContextSeq.load(tmp.name) |
---|
| 37 | self.assertTrue((self.ms.matrix == m1.matrix).all()) |
---|
| 38 | |
---|
| 39 | finally: |
---|
| 40 | os.remove(tmp.name) |
---|
| 41 | |
---|
| 42 | |
---|
| 43 | ''' |
---|
| 44 | def test_BeagleContextMulti(self): |
---|
| 45 | from tempfile import NamedTemporaryFile |
---|
| 46 | import os |
---|
| 47 | |
---|
| 48 | from vsm.model.beaglecontext import BeagleContextMulti |
---|
| 49 | try: |
---|
| 50 | tmp = NamedTemporaryFile(delete=False, suffix='.npz') |
---|
| 51 | self.mm.save(tmp.name) |
---|
| 52 | tmp.close() |
---|
| 53 | m1 = BeagleContextMulti.load(tmp.name) |
---|
| 54 | self.assertTrue((self.mm.matrix == m1.matrix).all()) |
---|
| 55 | |
---|
| 56 | finally: |
---|
| 57 | os.remove(tmp.name) |
---|
| 58 | |
---|
| 59 | |
---|
| 60 | |
---|
| 61 | def test_compare(self): |
---|
| 62 | |
---|
| 63 | print 'Training single processor model' |
---|
| 64 | ms = BeagleContextSeq(self.cc, self.ec, self.e.matrix) |
---|
| 65 | ms.train() |
---|
| 66 | |
---|
| 67 | print 'Training multiprocessor model' |
---|
| 68 | mm = BeagleContextMulti(self.cc, self.ec, self.e.matrix) |
---|
| 69 | mm.train() |
---|
| 70 | |
---|
| 71 | self.assertTrue(np.allclose(ms.matrix, mm.matrix)) |
---|
| 72 | ''' |
---|
| 73 | |
---|
| 74 | #Define and run test suite |
---|
| 75 | suite = unittest.TestLoader().loadTestsFromTestCase(TestBeagleContext) |
---|
| 76 | unittest.TextTestRunner(verbosity=2).run(suite) |
---|