[0ff122b] | 1 | import unittest2 as unittest |
---|
| 2 | import numpy as np |
---|
| 3 | |
---|
| 4 | from vsm.spatial import * |
---|
| 5 | |
---|
| 6 | #TODO: add tests for recently added methods. |
---|
| 7 | def KL(p,q): |
---|
| 8 | return sum(p*np.log2(p/q)) |
---|
| 9 | def partial_KL(p,q): |
---|
| 10 | return p * np.log2((2*p) / (p+q)) |
---|
| 11 | def JS(p,q): |
---|
| 12 | return 0.5*(KL(p,((p+q)*0.5)) + KL(q,((p+q)*0.5))) |
---|
| 13 | def JSD(p,q): |
---|
| 14 | return (0.5*(KL(p,((p+q)*0.5)) + KL(q,((p+q)*0.5))))**0.5 |
---|
| 15 | |
---|
| 16 | |
---|
| 17 | class TestSpatial(unittest.TestCase): |
---|
| 18 | |
---|
| 19 | def setUp(self): |
---|
| 20 | # 2 random distributions |
---|
| 21 | self.p=np.random.random_sample((5,)) |
---|
| 22 | self.q=np.random.random_sample((5,)) |
---|
| 23 | |
---|
| 24 | # normalize |
---|
| 25 | self.p /= self.p.sum() |
---|
| 26 | self.q /= self.q.sum() |
---|
| 27 | |
---|
| 28 | def test_KL_div(self): |
---|
| 29 | self.assertTrue(np.allclose(KL_div(self.p,self.q), KL(self.p,self.q))) |
---|
| 30 | |
---|
| 31 | def test_JS_div(self): |
---|
| 32 | self.assertTrue(np.allclose(JS_div(self.p,self.q), JS(self.p,self.q))) |
---|
| 33 | |
---|
| 34 | def test_JS_dist(self): |
---|
| 35 | self.assertTrue(np.allclose(JS_dist(self.p,self.q), JSD(self.p,self.q))) |
---|
| 36 | |
---|
| 37 | |
---|
| 38 | def test_KL_div_old(self): |
---|
| 39 | p = np.array([0,1]) |
---|
| 40 | Q = np.array([[0,1], |
---|
| 41 | [.5,.5], |
---|
| 42 | [1,0]]) |
---|
| 43 | out = np.array([0., 1., np.inf]) |
---|
| 44 | |
---|
| 45 | self.assertTrue(np.allclose(out, KL_div(p,Q.T))) |
---|
| 46 | |
---|
| 47 | |
---|
| 48 | def test_count_matrix(self): |
---|
| 49 | |
---|
| 50 | arr = [1, 2, 4, 2, 1] |
---|
| 51 | slices = [slice(0,1), slice(1, 3), slice(3,3), slice(3, 5)] |
---|
| 52 | m = 6 |
---|
| 53 | result = coo_matrix([[0, 0, 0, 0], |
---|
| 54 | [1, 0, 0, 1], |
---|
| 55 | [0, 1, 0, 1], |
---|
| 56 | [0, 0, 0, 0], |
---|
| 57 | [0, 1, 0, 0], |
---|
| 58 | [0, 0, 0, 0]]) |
---|
| 59 | |
---|
| 60 | self.assertTrue((result.toarray() == |
---|
| 61 | count_matrix(arr, slices, m).toarray()).all()) |
---|
| 62 | |
---|
| 63 | |
---|
| 64 | |
---|
| 65 | |
---|
| 66 | suite = unittest.TestLoader().loadTestsFromTestCase(TestSpatial) |
---|
| 67 | unittest.TextTestRunner(verbosity=2).run(suite) |
---|