1 /*M///////////////////////////////////////////////////////////////////////////////////////
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
10 // Intel License Agreement
11 // For Open Source Computer Vision Library
13 // Copyright (C) 2008, Xavier Delacour, all rights reserved.
14 // Third party copyrights are property of their respective owners.
16 // Redistribution and use in source and binary forms, with or without modification,
17 // are permitted provided that the following conditions are met:
19 // * Redistribution's of source code must retain the above copyright notice,
20 // this list of conditions and the following disclaimer.
22 // * Redistribution's in binary form must reproduce the above copyright notice,
23 // this list of conditions and the following disclaimer in the documentation
24 // and/or other materials provided with the distribution.
26 // * The name of Intel Corporation may not be used to endorse or promote products
27 // derived from this software without specific prior written permission.
29 // This software is provided by the copyright holders and contributors "as is" and
30 // any express or implied warranties, including, but not limited to, the implied
31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
32 // In no event shall the Intel Corporation or contributors be liable for any direct,
33 // indirect, incidental, special, exemplary, or consequential damages
34 // (including, but not limited to, procurement of substitute goods or services;
35 // loss of use, data, or profits; or business interruption) however caused
36 // and on any theory of liability, whether in contract, strict liability,
37 // or tort (including negligence or otherwise) arising in any way out of
38 // the use of this software, even if advised of the possibility of such damage.
42 // 2008-05-13, Xavier Delacour <xavier.delacour@gmail.com>
44 #ifndef __cv_kdtree_h__
45 #define __cv_kdtree_h__
57 #pragma warning(disable: 4512) // suppress "assignment operator could not be generated"
60 // J.S. Beis and D.G. Lowe. Shape indexing using approximate nearest-neighbor search
61 // in highdimensional spaces. In Proc. IEEE Conf. Comp. Vision Patt. Recog.,
62 // pages 1000--1006, 1997. http://citeseer.ist.psu.edu/beis97shape.html
66 template < class __valuetype, class __deref >
69 typedef __deref deref_type;
70 typedef typename __deref::scalar_type scalar_type;
71 typedef typename __deref::accum_type accum_type;
75 int dim; // split dimension; >=0 for nodes, -1 for leaves
76 __valuetype value; // if leaf, value of leaf
77 int left, right; // node indices of left and right branches
78 scalar_type boundary; // left if deref(value,dim)<=boundary, otherwise right
80 typedef std::vector < node > node_array;
82 __deref deref; // requires operator() (__valuetype lhs,int dim)
84 node_array nodes; // node storage
85 int point_dim; // dimension of points (the k in kd-tree)
86 int root_node; // index of root node, -1 if empty tree
88 // for given set of point indices, compute dimension of highest variance
89 template < class __instype, class __valuector >
90 int dimension_of_highest_variance(__instype * first, __instype * last,
92 assert(last - first > 0);
94 accum_type maxvar = -std::numeric_limits < accum_type >::max();
96 for (int j = 0; j < point_dim; ++j) {
98 for (__instype * k = first; k < last; ++k)
99 mean += deref(ctor(*k), j);
100 mean /= last - first;
102 for (__instype * k = first; k < last; ++k) {
103 accum_type diff = accum_type(deref(ctor(*k), j)) - mean;
108 assert(maxj != -1 || var >= maxvar);
119 // given point indices and dimension, find index of median; (almost) modifies [first,last)
120 // such that points_in[first,median]<=point[median], points_in(median,last)>point[median].
121 // implemented as partial quicksort; expected linear perf.
122 template < class __instype, class __valuector >
123 __instype * median_partition(__instype * first, __instype * last,
124 int dim, __valuector ctor) {
125 assert(last - first > 0);
126 __instype *k = first + (last - first) / 2;
127 median_partition(first, last, k, dim, ctor);
131 template < class __instype, class __valuector >
133 const __instype & pivot;
137 median_pr(const __instype & _pivot, int _dim, __deref _deref, __valuector _ctor)
138 : pivot(_pivot), dim(_dim), deref(_deref), ctor(_ctor) {
140 bool operator() (const __instype & lhs) const {
141 return deref(ctor(lhs), dim) <= deref(ctor(pivot), dim);
145 template < class __instype, class __valuector >
146 void median_partition(__instype * first, __instype * last,
147 __instype * k, int dim, __valuector ctor) {
148 int pivot = (int)((last - first) / 2);
150 std::swap(first[pivot], last[-1]);
151 __instype *middle = std::partition(first, last - 1,
152 median_pr < __instype, __valuector >
153 (last[-1], dim, deref, ctor));
154 std::swap(*middle, last[-1]);
157 median_partition(middle + 1, last, k, dim, ctor);
159 median_partition(first, middle, k, dim, ctor);
162 // insert given points into the tree; return created node
163 template < class __instype, class __valuector >
164 int insert(__instype * first, __instype * last, __valuector ctor) {
169 int dim = dimension_of_highest_variance(first, last, ctor);
170 __instype *median = median_partition(first, last, dim, ctor);
172 __instype *split = median;
173 for (; split != last && deref(ctor(*split), dim) ==
174 deref(ctor(*median), dim); ++split);
176 if (split == last) { // leaf
178 for (--split; split >= first; --split) {
179 int i = (int)nodes.size();
180 node & n = *nodes.insert(nodes.end(), node());
182 n.value = ctor(*split);
190 int i = (int)nodes.size();
191 // note that recursive insert may invalidate this ref
192 node & n = *nodes.insert(nodes.end(), node());
195 n.boundary = deref(ctor(*median), dim);
197 int left = insert(first, split, ctor);
198 nodes[i].left = left;
199 int right = insert(split, last, ctor);
200 nodes[i].right = right;
207 // run to leaf; linear search for p;
208 // if found, remove paths to empty leaves on unwind
209 bool remove(int *i, const __valuetype & p) {
212 node & n = nodes[*i];
215 if (n.dim >= 0) { // node
216 if (deref(p, n.dim) <= n.boundary) // left
217 r = remove(&n.left, p);
219 r = remove(&n.right, p);
221 // if terminal, remove this node
222 if (n.left == -1 && n.right == -1)
231 return remove(&n.right, p);
236 struct identity_ctor {
237 const __valuetype & operator() (const __valuetype & rhs) const {
242 // initialize an empty tree
243 CvKDTree(__deref _deref = __deref())
244 : deref(_deref), root_node(-1) {
246 // given points, initialize a balanced tree
247 CvKDTree(__valuetype * first, __valuetype * last, int _point_dim,
248 __deref _deref = __deref())
250 set_data(first, last, _point_dim, identity_ctor());
252 // given points, initialize a balanced tree
253 template < class __instype, class __valuector >
254 CvKDTree(__instype * first, __instype * last, int _point_dim,
255 __valuector ctor, __deref _deref = __deref())
257 set_data(first, last, _point_dim, ctor);
260 void set_deref(__deref _deref) {
264 void set_data(__valuetype * first, __valuetype * last, int _point_dim) {
265 set_data(first, last, _point_dim, identity_ctor());
267 template < class __instype, class __valuector >
268 void set_data(__instype * first, __instype * last, int _point_dim,
270 point_dim = _point_dim;
272 nodes.reserve(last - first);
273 root_node = insert(first, last, ctor);
280 // remove the given point
281 bool remove(const __valuetype & p) {
282 return remove(&root_node, p);
288 void print(int i, int indent = 0) const {
291 for (int j = 0; j < indent; ++j)
293 const node & n = nodes[i];
295 std::cout << "node " << i << ", left " << nodes[i].left << ", right " <<
296 nodes[i].right << ", dim " << nodes[i].dim << ", boundary " <<
297 nodes[i].boundary << std::endl;
298 print(n.left, indent + 3);
299 print(n.right, indent + 3);
301 std::cout << "leaf " << i << ", value = " << nodes[i].value << std::endl;
304 ////////////////////////////////////////////////////////////////////////////////////////
307 struct bbf_nn { // info on found neighbors (approx k nearest)
308 const __valuetype *p; // nearest neighbor
309 accum_type dist; // distance from d to query point
310 bbf_nn(const __valuetype & _p, accum_type _dist)
311 : p(&_p), dist(_dist) {
313 bool operator<(const bbf_nn & rhs) const {
314 return dist < rhs.dist;
317 typedef std::vector < bbf_nn > bbf_nn_pqueue;
319 struct bbf_node { // info on branches not taken
320 int node; // corresponding node
321 accum_type dist; // minimum distance from bounds to query point
322 bbf_node(int _node, accum_type _dist)
323 : node(_node), dist(_dist) {
325 bool operator<(const bbf_node & rhs) const {
326 return dist > rhs.dist;
329 typedef std::vector < bbf_node > bbf_pqueue;
330 mutable bbf_pqueue tmp_pq;
332 // called for branches not taken, as bbf walks to leaf;
333 // construct bbf_node given minimum distance to bounds of alternate branch
334 void pq_alternate(int alt_n, bbf_pqueue & pq, scalar_type dist) const {
338 // add bbf_node for alternate branch in priority queue
339 pq.push_back(bbf_node(alt_n, dist));
340 push_heap(pq.begin(), pq.end());
343 // called by bbf to walk to leaf;
344 // takes one step down the tree towards query point d
345 template < class __desctype >
346 int bbf_branch(int i, const __desctype * d, bbf_pqueue & pq) const {
347 const node & n = nodes[i];
348 // push bbf_node with bounds of alternate branch, then branch
349 if (d[n.dim] <= n.boundary) { // left
350 pq_alternate(n.right, pq, n.boundary - d[n.dim]);
353 pq_alternate(n.left, pq, d[n.dim] - n.boundary);
358 // compute euclidean distance between two points
359 template < class __desctype >
360 accum_type distance(const __desctype * d, const __valuetype & p) const {
362 for (int j = 0; j < point_dim; ++j) {
363 accum_type diff = accum_type(d[j]) - accum_type(deref(p, j));
365 } return (accum_type) sqrt(dist);
368 // called per candidate nearest neighbor; constructs new bbf_nn for
369 // candidate and adds it to priority queue of all candidates; if
370 // queue len exceeds k, drops the point furthest from query point d.
371 template < class __desctype >
372 void bbf_new_nn(bbf_nn_pqueue & nn_pq, int k,
373 const __desctype * d, const __valuetype & p) const {
374 bbf_nn nn(p, distance(d, p));
375 if ((int) nn_pq.size() < k) {
377 push_heap(nn_pq.begin(), nn_pq.end());
378 } else if (nn_pq[0].dist > nn.dist) {
379 pop_heap(nn_pq.begin(), nn_pq.end());
380 nn_pq.end()[-1] = nn;
381 push_heap(nn_pq.begin(), nn_pq.end());
383 assert(nn_pq.size() < 2 || nn_pq[0].dist >= nn_pq[1].dist);
387 // finds (with high probability) the k nearest neighbors of d,
388 // searching at most emax leaves/bins.
389 // ret_nn_pq is an array containing the (at most) k nearest neighbors
390 // (see bbf_nn structure def above).
391 template < class __desctype >
392 int find_nn_bbf(const __desctype * d,
394 bbf_nn_pqueue & ret_nn_pq) const {
401 // add root_node to bbf_node priority queue;
402 // iterate while queue non-empty and emax>0
404 tmp_pq.push_back(bbf_node(root_node, 0));
405 while (tmp_pq.size() && emax > 0) {
407 // from node nearest query point d, run to leaf
408 pop_heap(tmp_pq.begin(), tmp_pq.end());
409 bbf_node bbf(tmp_pq.end()[-1]);
410 tmp_pq.erase(tmp_pq.end() - 1);
414 i != -1 && nodes[i].dim >= 0;
415 i = bbf_branch(i, d, tmp_pq));
419 // add points in leaf/bin to ret_nn_pq
421 bbf_new_nn(ret_nn_pq, k, d, nodes[i].value);
422 } while (-1 != (i = nodes[i].right));
429 return (int)ret_nn_pq.size();
432 ////////////////////////////////////////////////////////////////////////////////////////
433 // orthogonal range search
435 void find_ortho_range(int i, scalar_type * bounds_min,
436 scalar_type * bounds_max,
437 std::vector < __valuetype > &inbounds) const {
440 const node & n = nodes[i];
441 if (n.dim >= 0) { // node
442 if (bounds_min[n.dim] <= n.boundary)
443 find_ortho_range(n.left, bounds_min, bounds_max, inbounds);
444 if (bounds_max[n.dim] > n.boundary)
445 find_ortho_range(n.right, bounds_min, bounds_max, inbounds);
448 inbounds.push_back(nodes[i].value);
449 } while (-1 != (i = nodes[i].right));
453 // return all points that lie within the given bounds; inbounds is cleared
454 int find_ortho_range(scalar_type * bounds_min,
455 scalar_type * bounds_max,
456 std::vector < __valuetype > &inbounds) const {
458 find_ortho_range(root_node, bounds_min, bounds_max, inbounds);
459 return (int)inbounds.size();
463 #endif // __cv_kdtree_h__