1 | import unittest2 as unittest |
---|
2 | import numpy as np |
---|
3 | |
---|
4 | from vsm.spatial import * |
---|
5 | |
---|
6 | #TODO: add tests for recently added methods. |
---|
7 | def KL(p,q): |
---|
8 | return sum(p*np.log2(p/q)) |
---|
9 | def partial_KL(p,q): |
---|
10 | return p * np.log2((2*p) / (p+q)) |
---|
11 | def JS(p,q): |
---|
12 | return 0.5*(KL(p,((p+q)*0.5)) + KL(q,((p+q)*0.5))) |
---|
13 | def JSD(p,q): |
---|
14 | return (0.5*(KL(p,((p+q)*0.5)) + KL(q,((p+q)*0.5))))**0.5 |
---|
15 | |
---|
16 | |
---|
17 | class TestSpatial(unittest.TestCase): |
---|
18 | |
---|
19 | def setUp(self): |
---|
20 | # 2 random distributions |
---|
21 | self.p=np.random.random_sample((5,)) |
---|
22 | self.q=np.random.random_sample((5,)) |
---|
23 | |
---|
24 | # normalize |
---|
25 | self.p /= self.p.sum() |
---|
26 | self.q /= self.q.sum() |
---|
27 | |
---|
28 | def test_KL_div(self): |
---|
29 | self.assertTrue(np.allclose(KL_div(self.p,self.q), KL(self.p,self.q))) |
---|
30 | |
---|
31 | def test_JS_div(self): |
---|
32 | self.assertTrue(np.allclose(JS_div(self.p,self.q), JS(self.p,self.q))) |
---|
33 | |
---|
34 | def test_JS_dist(self): |
---|
35 | self.assertTrue(np.allclose(JS_dist(self.p,self.q), JSD(self.p,self.q))) |
---|
36 | |
---|
37 | |
---|
38 | def test_KL_div_old(self): |
---|
39 | p = np.array([0,1]) |
---|
40 | Q = np.array([[0,1], |
---|
41 | [.5,.5], |
---|
42 | [1,0]]) |
---|
43 | out = np.array([0., 1., np.inf]) |
---|
44 | |
---|
45 | self.assertTrue(np.allclose(out, KL_div(p,Q.T))) |
---|
46 | |
---|
47 | |
---|
48 | def test_count_matrix(self): |
---|
49 | |
---|
50 | arr = [1, 2, 4, 2, 1] |
---|
51 | slices = [slice(0,1), slice(1, 3), slice(3,3), slice(3, 5)] |
---|
52 | m = 6 |
---|
53 | result = coo_matrix([[0, 0, 0, 0], |
---|
54 | [1, 0, 0, 1], |
---|
55 | [0, 1, 0, 1], |
---|
56 | [0, 0, 0, 0], |
---|
57 | [0, 1, 0, 0], |
---|
58 | [0, 0, 0, 0]]) |
---|
59 | |
---|
60 | self.assertTrue((result.toarray() == |
---|
61 | count_matrix(arr, slices, m).toarray()).all()) |
---|
62 | |
---|
63 | |
---|
64 | |
---|
65 | |
---|
66 | suite = unittest.TestLoader().loadTestsFromTestCase(TestSpatial) |
---|
67 | unittest.TextTestRunner(verbosity=2).run(suite) |
---|