1 /* Original code has been submitted by Liu Liu.
2 ----------------------------------------------------------------------------------
3 * Spill-Tree for Approximate KNN Search
5 * mailto: liuliu.1987+opencv@gmail.com
7 * An Investigation of Practical Approximate Nearest Neighbor Algorithms
10 * Redistribution and use in source and binary forms, with or
11 * without modification, are permitted provided that the following
13 * Redistributions of source code must retain the above
14 * copyright notice, this list of conditions and the following
16 * Redistributions in binary form must reproduce the above
17 * copyright notice, this list of conditions and the following
18 * disclaimer in the documentation and/or other materials
19 * provided with the distribution.
20 * The name of Contributor may not be used to endorse or
21 * promote products derived from this software without
22 * specific prior written permission.
24 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
25 * CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
26 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
27 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
28 * DISCLAIMED. IN NO EVENT SHALL THE CONTRIBUTORS BE LIABLE FOR ANY
29 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
31 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
32 * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
33 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
34 * TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
35 * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY
40 #include "_cvfeaturetree.h"
42 struct CvSpillTreeNode
44 bool leaf; // is leaf or not (leaf is the point that have no more child)
45 bool spill; // is not a non-overlapping point (defeatist search)
46 CvSpillTreeNode* lc; // left child (<)
47 CvSpillTreeNode* rc; // right child (>)
48 int cc; // child count
49 CvMat* u; // projection vector
50 CvMat* center; // center
51 int i; // original index
52 double r; // radius of remaining feature point
53 double ub; // upper bound
54 double lb; // lower bound
55 double mp; // mean point
56 double p; // projection value
61 CvSpillTreeNode* root;
62 CvMat** refmat; // leaf ref matrix
63 bool* cache; // visited or not
64 int total; // total leaves
65 int naive; // under this value, we perform naive search
67 double rho; // under this value, it is a spill tree
68 double tau; // the overlapping buffer ratio
71 // find the farthest node in the "list" from "node"
72 static inline CvSpillTreeNode*
73 icvFarthestNode( CvSpillTreeNode* node,
74 CvSpillTreeNode* list,
77 double farthest = -1.;
78 CvSpillTreeNode* result = NULL;
79 for ( int i = 0; i < total; i++ )
81 double norm = cvNorm( node->center, list->center );
82 if ( norm > farthest )
92 // clone a new tree node
93 static inline CvSpillTreeNode*
94 icvCloneSpillTreeNode( CvSpillTreeNode* node )
96 CvSpillTreeNode* result = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
97 memcpy( result, node, sizeof(CvSpillTreeNode) );
101 // append the link-list of a tree node
103 icvAppendSpillTreeNode( CvSpillTreeNode* node,
104 CvSpillTreeNode* append )
106 if ( node->lc == NULL )
108 node->lc = node->rc = append;
109 node->lc->lc = node->rc->rc = NULL;
111 append->lc = node->rc;
113 node->rc->rc = append;
119 #define _dispatch_mat_ptr(x, step) (CV_MAT_DEPTH((x)->type) == CV_32F ? (void*)((x)->data.fl+(step)) : (CV_MAT_DEPTH((x)->type) == CV_64F ? (void*)((x)->data.db+(step)) : (void*)(0)))
122 icvDFSInitSpillTreeNode( const CvSpillTree* tr,
124 CvSpillTreeNode* node )
126 if ( node->cc <= tr->naive )
128 // already get to a leaf, terminate the recursion.
134 // random select a node, then find a farthest node from this one, then find a farthest from that one...
135 // to approximate the farthest node-pair
136 static CvRNG rng_state = cvRNG(0xdeadbeef);
137 int rn = cvRandInt( &rng_state ) % node->cc;
138 CvSpillTreeNode* lnode = NULL;
139 CvSpillTreeNode* rnode = node->lc;
140 for ( int i = 0; i < rn; i++ )
142 lnode = icvFarthestNode( rnode, node->lc, node->cc );
143 rnode = icvFarthestNode( lnode, node->lc, node->cc );
145 // u is the projection vector
146 node->u = cvCreateMat( 1, d, tr->type );
147 cvSub( lnode->center, rnode->center, node->u );
148 cvNormalize( node->u, node->u );
150 // find the center of node in hyperspace
151 node->center = cvCreateMat( 1, d, tr->type );
152 cvZero( node->center );
153 CvSpillTreeNode* it = node->lc;
154 for ( int i = 0; i < node->cc; i++ )
156 cvAdd( it->center, node->center, node->center );
159 cvConvertScale( node->center, node->center, 1./node->cc );
161 // project every node to "u", and find the mean point "mp"
165 for ( int i = 0; i < node->cc; i++ )
167 node->mp += ( it->p = cvDotProduct( it->center, node->u ) );
168 double norm = cvNorm( node->center, it->center );
169 if ( norm > node->r )
173 node->mp = node->mp / node->cc;
175 // overlapping buffer and upper bound, lower bound
176 double ob = (lnode->p-rnode->p)*tr->tau*.5;
177 node->ub = node->mp+ob;
178 node->lb = node->mp-ob;
182 for ( int i = 0; i < node->cc; i++ )
184 if ( it->p <= node->ub )
186 if ( it->p >= node->lb )
188 if ( it->p < node->mp )
194 // precision problem, return the node as it is.
195 if (( l == 0 )||( r == 0 ))
197 cvReleaseMat( &(node->u) );
198 cvReleaseMat( &(node->center) );
203 CvSpillTreeNode* lc = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
204 memset(lc, 0, sizeof(CvSpillTreeNode));
205 CvSpillTreeNode* rc = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
206 memset(rc, 0, sizeof(CvSpillTreeNode));
207 lc->lc = lc->rc = rc->lc = rc->rc = NULL;
209 int undo = cvRound(node->cc*tr->rho);
210 if (( sl >= undo )||( sr >= undo ))
212 // it is not a spill point (defeatist search disabled)
214 for ( int i = 0; i < node->cc; i++ )
216 CvSpillTreeNode* next = it->rc;
217 if ( it->p < node->mp )
218 icvAppendSpillTreeNode( lc, it );
220 icvAppendSpillTreeNode( rc, it );
227 for ( int i = 0; i < node->cc; i++ )
229 CvSpillTreeNode* next = it->rc;
230 if ( it->p < node->lb )
231 icvAppendSpillTreeNode( lc, it );
232 else if ( it->p > node->ub )
233 icvAppendSpillTreeNode( rc, it );
235 CvSpillTreeNode* cit = icvCloneSpillTreeNode( it );
236 icvAppendSpillTreeNode( lc, it );
237 icvAppendSpillTreeNode( rc, cit );
247 icvDFSInitSpillTreeNode( tr, d, node->lc );
248 icvDFSInitSpillTreeNode( tr, d, node->rc );
252 icvCreateSpillTree( const CvMat* raw_data,
257 int n = raw_data->rows;
258 int d = raw_data->cols;
260 CvSpillTree* tr = (CvSpillTree*)cvAlloc( sizeof(CvSpillTree) );
261 tr->root = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
262 memset(tr->root, 0, sizeof(CvSpillTreeNode));
263 tr->refmat = (CvMat**)cvAlloc( sizeof(CvMat*)*n );
264 tr->cache = (bool*)cvAlloc( sizeof(bool)*n );
269 tr->type = raw_data->type;
271 // tie a link-list to the root node
272 tr->root->lc = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
273 memset(tr->root->lc, 0, sizeof(CvSpillTreeNode));
274 tr->root->lc->center = cvCreateMatHeader( 1, d, tr->type );
275 cvSetData( tr->root->lc->center, _dispatch_mat_ptr(raw_data, 0), raw_data->step );
276 tr->refmat[0] = tr->root->lc->center;
277 tr->root->lc->lc = NULL;
278 tr->root->lc->leaf = true;
280 CvSpillTreeNode* node = tr->root->lc;
281 for ( int i = 1; i < n; i++ )
283 CvSpillTreeNode* newnode = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
284 memset(newnode, 0, sizeof(CvSpillTreeNode));
285 newnode->center = cvCreateMatHeader( 1, d, tr->type );
286 cvSetData( newnode->center, _dispatch_mat_ptr(raw_data, i*d), raw_data->step );
287 tr->refmat[i] = newnode->center;
290 newnode->leaf = true;
297 icvDFSInitSpillTreeNode( tr, d, tr->root );
302 icvSpillTreeNodeHeapify( CvSpillTreeNode** heap,
306 if ( heap[i] == NULL )
308 int l, r, largest = i;
309 CvSpillTreeNode* inp;
314 if (( l < k )&&( heap[l] == NULL ))
316 else if (( r < k )&&( heap[r] == NULL ))
319 if (( l < k )&&( heap[l]->mp > heap[i]->mp ))
321 if (( r < k )&&( heap[r]->mp > heap[largest]->mp ))
325 CV_SWAP( heap[largest], heap[i], inp );
326 } while ( largest != i );
330 icvSpillTreeDFSearch( CvSpillTree* tr,
331 CvSpillTreeNode* node,
332 CvSpillTreeNode** heap,
338 if ((emax > 0)&&( *es >= emax ))
341 while ( node->spill )
345 p = cvDotProduct( node->u, desc );
346 if ( p < node->lb && node->lc->cc >= k ) // check the number of children larger than k otherwise you'll skip over better neighbor
348 else if ( p > node->ub && node->rc->cc >= k )
357 // a leaf, naive search
358 CvSpillTreeNode* it = node->lc;
359 for ( int i = 0; i < node->cc; i++ )
361 if ( !tr->cache[it->i] )
363 it->mp = cvNorm( it->center, desc );
364 tr->cache[it->i] = true;
365 if (( heap[0] == NULL)||( it->mp < heap[0]->mp ))
368 icvSpillTreeNodeHeapify( heap, 0, k );
376 dist = cvNorm( node->center, desc );
377 // impossible case, skip
378 if (( heap[0] != NULL )&&( dist-node->r > heap[0]->mp ))
380 p = cvDotProduct( node->u, desc );
384 icvSpillTreeDFSearch( tr, node->lc, heap, es, desc, k, emax );
385 icvSpillTreeDFSearch( tr, node->rc, heap, es, desc, k, emax );
387 icvSpillTreeDFSearch( tr, node->rc, heap, es, desc, k, emax );
388 icvSpillTreeDFSearch( tr, node->lc, heap, es, desc, k, emax );
393 icvFindSpillTreeFeatures( CvSpillTree* tr,
400 assert( desc->type == tr->type );
401 CvSpillTreeNode** heap = (CvSpillTreeNode**)cvAlloc( k*sizeof(heap[0]) );
402 for ( int j = 0; j < desc->rows; j++ )
404 CvMat _desc = cvMat( 1, desc->cols, desc->type, _dispatch_mat_ptr(desc, j*desc->cols) );
405 for ( int i = 0; i < k; i++ )
407 memset( tr->cache, 0, sizeof(bool)*tr->total );
409 icvSpillTreeDFSearch( tr, tr->root, heap, &es, &_desc, k, emax );
410 CvSpillTreeNode* inp;
411 for ( int i = k-1; i > 0; i-- )
413 CV_SWAP( heap[i], heap[0], inp );
414 icvSpillTreeNodeHeapify( heap, 0, i );
416 int* rs = results->data.i+j*results->cols;
417 double* dt = dist->data.db+j*dist->cols;
418 for ( int i = 0; i < k; i++, rs++, dt++ )
419 if ( heap[i] != NULL )
430 icvDFSReleaseSpillTreeNode( CvSpillTreeNode* node )
434 CvSpillTreeNode* it = node->lc;
435 for ( int i = 0; i < node->cc; i++ )
437 CvSpillTreeNode* s = it;
442 cvReleaseMat( &node->u );
443 cvReleaseMat( &node->center );
444 icvDFSReleaseSpillTreeNode( node->lc );
445 icvDFSReleaseSpillTreeNode( node->rc );
451 icvReleaseSpillTree( CvSpillTree** tr )
453 for ( int i = 0; i < (*tr)->total; i++ )
454 cvReleaseMat( &((*tr)->refmat[i]) );
455 cvFree( &((*tr)->refmat) );
456 cvFree( &((*tr)->cache) );
457 icvDFSReleaseSpillTreeNode( (*tr)->root );
461 class CvSpillTreeWrap : public CvFeatureTree {
464 CvSpillTreeWrap(const CvMat* raw_data,
468 tr = icvCreateSpillTree(raw_data, naive, rho, tau);
471 icvReleaseSpillTree(&tr);
474 void FindFeatures(const CvMat* desc, int k, int emax, CvMat* results, CvMat* dist) {
475 icvFindSpillTreeFeatures(tr, desc, results, dist, k, emax);
479 CvFeatureTree* cvCreateSpillTree( const CvMat* raw_data,
483 return new CvSpillTreeWrap(raw_data, naive, rho, tau);