[0ff122b] | 1 | import unittest2 as unittest |
---|
| 2 | import numpy as np |
---|
| 3 | |
---|
| 4 | from vsm.viewer.labeleddata import * |
---|
| 5 | |
---|
| 6 | |
---|
| 7 | class TestLabeleddata(unittest.TestCase): |
---|
| 8 | |
---|
| 9 | def setUp(self): |
---|
| 10 | |
---|
| 11 | words = ['row', 'row', 'row', 'your', 'boat', 'gently', 'down', 'the', |
---|
| 12 | 'stream', 'merrily', 'merrily', 'merrily', 'merrily', 'life', |
---|
| 13 | 'is', 'but', 'a', 'dream'] |
---|
| 14 | values = [np.random.random() for t in words] |
---|
| 15 | d = [('i', np.array(words).dtype), |
---|
| 16 | ('value', np.array(values).dtype)] |
---|
| 17 | self.v = np.array(zip(words, values), dtype=d) |
---|
| 18 | |
---|
| 19 | |
---|
| 20 | |
---|
| 21 | def test_LabeledColumn(self): |
---|
| 22 | |
---|
| 23 | arr = self.v.view(LabeledColumn) |
---|
| 24 | arr.subcol_headers = ['Word', 'Value'] |
---|
| 25 | arr.col_header = 'Song lets make this longer than subcol headers' |
---|
| 26 | arr.col_len = 10 |
---|
| 27 | arr1 = self.v.view(LabeledColumn) |
---|
| 28 | |
---|
| 29 | self.assertTrue(type(arr.__str__()) == unicode) |
---|
| 30 | self.assertTrue(sum(arr.subcol_widths) <= arr.col_width) |
---|
| 31 | self.assertEqual(arr.shape[0], arr1.col_len) |
---|
| 32 | self.assertFalse(arr1.col_header) |
---|
| 33 | self.assertFalse(arr1.subcol_headers) |
---|
| 34 | |
---|
| 35 | |
---|
| 36 | def test_DataTable(self): |
---|
| 37 | |
---|
| 38 | v = LabeledColumn(self.v) |
---|
| 39 | v.subcol_widths = [30, 20] |
---|
| 40 | v.col_len = 10 |
---|
| 41 | t = [] |
---|
| 42 | for i in xrange(5): |
---|
| 43 | t.append(v.copy()) |
---|
| 44 | t[i].col_header = 'Iteration ' + str(i) |
---|
| 45 | |
---|
| 46 | schc = ['Topic', 'Word'] |
---|
| 47 | schf = ['Word', 'Value'] |
---|
| 48 | t = DataTable(t, 'Song', subcolhdr_compact=schc, subcolhdr_full=schf) |
---|
| 49 | |
---|
| 50 | self.assertTrue(type(t.__str__()) == unicode) |
---|
| 51 | self.assertTrue('Song', t.table_header) |
---|
| 52 | |
---|
| 53 | t.compact_view = False |
---|
| 54 | self.assertTrue(type(t.__str__()) == unicode) |
---|
| 55 | self.assertTrue('Song', t.table_header) |
---|
| 56 | |
---|
| 57 | |
---|
| 58 | |
---|
| 59 | def test_IndexedSymmArray(self): |
---|
| 60 | |
---|
| 61 | from vsm.corpus.util.corpusbuilders import random_corpus |
---|
| 62 | from vsm.model.ldacgsseq import LdaCgsSeq |
---|
| 63 | from vsm.viewer.ldacgsviewer import LdaCgsViewer |
---|
| 64 | |
---|
| 65 | c = random_corpus(50000, 1000, 0, 50) |
---|
| 66 | m = LdaCgsSeq(c, 'document', K=20) |
---|
| 67 | viewer = LdaCgsViewer(c, m) |
---|
| 68 | |
---|
| 69 | li = ['0', '1', '10'] |
---|
| 70 | isa = viewer.dismat_top(li) |
---|
| 71 | |
---|
| 72 | self.assertEqual(isa.shape[0], len(li)) |
---|
| 73 | |
---|
| 74 | |
---|
| 75 | |
---|
| 76 | |
---|
| 77 | #Define and run test suite |
---|
| 78 | suite = unittest.TestLoader().loadTestsFromTestCase(TestLabeleddata) |
---|
| 79 | unittest.TextTestRunner(verbosity=2).run(suite) |
---|