83c4fe42aed3350c7b3d1b00962d6ac9caa7a35d
[opencv] / tests / swig_python / lsh_tests.py
1 #!/usr/bin/env python
2
3 # 2009-01-12, Xavier Delacour <xavier.delacour@gmail.com>
4
5 # gdb --cd ~/opencv-lsh/tests/python --args /usr/bin/python lsh_tests.py
6 # set env PYTHONPATH /home/x/opencv-lsh/debug/interfaces/swig/python:/home/x/opencv-lsh/debug/lib
7 # export PYTHONPATH=/home/x/opencv-lsh/debug/interfaces/swig/python:/home/x/opencv-lsh/debug/lib
8
9 import unittest
10 from numpy import *;
11 from numpy.linalg import *;
12 import sys;
13
14 import cvtestutils
15 from cv import *;
16 from adaptors import *;
17
18 def planted_neighbors(query_points, R = .4):
19     n,d = query_points.shape
20     data = zeros(query_points.shape)
21     for i in range(0,n):
22         a = random.rand(d)
23         a = random.rand()*R*a/sqrt(sum(a**2))
24         data[i] = query_points[i] + a
25     return data
26
27 class lsh_test(unittest.TestCase):
28
29     def test_basic(self):
30         n = 10000;
31         d = 64;
32         query_points = random.rand(n,d)*2-1;
33         data = planted_neighbors(query_points)
34
35         lsh = cvCreateMemoryLSH(d, n);
36         cvLSHAdd(lsh, data);
37         indices,dist = cvLSHQuery(lsh, query_points, 1, 100);
38         correct = sum([i == j for j,i in enumerate(indices)])
39         assert(correct >= n * .75);
40
41     def test_sensitivity(self):
42         n = 10000;
43         d = 64;
44         query_points = random.rand(n,d);
45         data = random.rand(n,d);
46
47         lsh = cvCreateMemoryLSH(d, 1000, 10, 10);
48         cvLSHAdd(lsh, data);
49
50         good = 0
51         trials = 20
52         print 
53         for x in query_points[0:trials]:
54             x1 = asmatrix(x) # PyArray_to_CvArr doesn't like 1-dim arrays
55             indices,dist = cvLSHQuery(lsh, x1, n, n);
56             indices = Ipl2NumPy(indices)
57             indices = unique(indices[where(indices>=0)])
58
59             brute = vstack([(sqrt(sum((a-x)**2)),i,0) for i,a in enumerate(data)])
60             lshp = vstack([(sqrt(sum((x-data[i])**2)),i,1) for i in indices])
61             combined = vstack((brute,lshp))
62             combined = combined[argsort(combined[:,0])]
63
64             spread = [i for i,a in enumerate(combined[:,2]) if a==1]
65             spread = histogram(spread,bins=4,new=True)[0]
66             print spread, sum(diff(spread)<0)
67             if sum(diff(spread)<0) == 3: good = good + 1
68         print good,"pass"
69         assert(good > trials * .75);
70
71     def test_remove(self):
72         n = 10000;
73         d = 64;
74         query_points = random.rand(n,d)*2-1;
75         data = planted_neighbors(query_points)
76         lsh = cvCreateMemoryLSH(d, n);
77         indices = cvLSHAdd(lsh, data);
78         assert(LSHSize(lsh)==n);
79         cvLSHRemove(lsh,indices[0:n/2])
80         assert(LSHSize(lsh)==n/2);
81
82     def test_destroy(self):
83         n = 10000;
84         d = 64;
85         lsh = cvCreateMemoryLSH(d, n);
86
87     def test_destroy2(self):
88         n = 10000;
89         d = 64;
90         query_points = random.rand(n,d)*2-1;
91         data = planted_neighbors(query_points)
92         lsh = cvCreateMemoryLSH(d, n);
93         indices = cvLSHAdd(lsh, data);
94
95
96 # move this to another file
97
98 # img1 = cvLoadImage(img1_fn);
99 # img2 = cvLoadImage(img2_fn);
100 # pts1,desc1 = cvExtractSURF(img1); # * make util routine to extract points and descriptors
101 # pts2,desc2 = cvExtractSURF(img2);
102 # lsh = cvCreateMemoryLSH(d, n);
103 # cvLSHAdd(lsh, desc1);
104 # indices,dist = cvLSHQuery(lsh, desc2, 2, 100);
105 # matches = [((pts1[x[0]].pt.x,pts1[x[0]].pt.y),(pts2[j].pt.x,pts2[j].pt.y)) \
106 #            for j,x in enumerate(hstack((indices,dist))) \
107 #            if x[2] and (not x[3] or x[2]/x[3]>.6)]
108 # out = cvCloneImage(img1);
109 # for p1,p2 in matches:
110 #     cvCircle(out,p1,3,CV_RGB(255,0,0));
111 #     cvLine(out,p1,p2,CV_RGB(100,100,100));
112 # cvNamedWindow("matches");
113 # cvShowImage("matches",out);
114 # cvWaitKey(0);
115
116         
117 def suite():
118     return unittest.TestLoader().loadTestsFromTestCase(lsh_test)
119
120 if __name__ == '__main__':
121     unittest.TextTestRunner(verbosity=2).run(suite())
122