Update to 2.0.0 tree from current Fremantle build
[opencv] / src / cv / cvspilltree.cpp
1 /* Original code has been submitted by Liu Liu.
2    ----------------------------------------------------------------------------------
3    * Spill-Tree for Approximate KNN Search
4    * Author: Liu Liu
5    * mailto: liuliu.1987+opencv@gmail.com
6    * Refer to Paper:
7    * An Investigation of Practical Approximate Nearest Neighbor Algorithms
8    * cvMergeSpillTree TBD
9    *
10    * Redistribution and use in source and binary forms, with or
11    * without modification, are permitted provided that the following
12    * conditions are met:
13    *    Redistributions of source code must retain the above
14    *    copyright notice, this list of conditions and the following
15    *    disclaimer.
16    *    Redistributions in binary form must reproduce the above
17    *    copyright notice, this list of conditions and the following
18    *    disclaimer in the documentation and/or other materials
19    *    provided with the distribution.
20    *    The name of Contributor may not be used to endorse or
21    *    promote products derived from this software without
22    *    specific prior written permission.
23    *
24    * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
25    * CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
26    * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
27    * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
28    * DISCLAIMED. IN NO EVENT SHALL THE CONTRIBUTORS BE LIABLE FOR ANY
29    * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30    * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
31    * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
32    * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
33    * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
34    * TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
35    * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY
36    * OF SUCH DAMAGE.
37    */
38
39 #include "_cv.h"
40 #include "_cvfeaturetree.h"
41
42 struct CvSpillTreeNode
43 {
44   bool leaf; // is leaf or not (leaf is the point that have no more child)
45   bool spill; // is not a non-overlapping point (defeatist search)
46   CvSpillTreeNode* lc; // left child (<)
47   CvSpillTreeNode* rc; // right child (>)
48   int cc; // child count
49   CvMat* u; // projection vector
50   CvMat* center; // center
51   int i; // original index
52   double r; // radius of remaining feature point
53   double ub; // upper bound
54   double lb; // lower bound
55   double mp; // mean point
56   double p; // projection value
57 };
58
59 struct CvSpillTree
60 {
61   CvSpillTreeNode* root;
62   CvMat** refmat; // leaf ref matrix
63   bool* cache; // visited or not
64   int total; // total leaves
65   int naive; // under this value, we perform naive search
66   int type; // mat type
67   double rho; // under this value, it is a spill tree
68   double tau; // the overlapping buffer ratio
69 };
70
71 // find the farthest node in the "list" from "node"
72 static inline CvSpillTreeNode*
73 icvFarthestNode( CvSpillTreeNode* node,
74                  CvSpillTreeNode* list,
75                  int total )
76 {
77   double farthest = -1.;
78   CvSpillTreeNode* result = NULL;
79   for ( int i = 0; i < total; i++ )
80     {
81       double norm = cvNorm( node->center, list->center );
82       if ( norm > farthest )
83         {
84           farthest = norm;
85           result = list;
86         }
87       list = list->rc;
88     }
89   return result;
90 }
91
92 // clone a new tree node
93 static inline CvSpillTreeNode*
94 icvCloneSpillTreeNode( CvSpillTreeNode* node )
95 {
96   CvSpillTreeNode* result = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
97   memcpy( result, node, sizeof(CvSpillTreeNode) );
98   return result;
99 }
100
101 // append the link-list of a tree node
102 static inline void
103 icvAppendSpillTreeNode( CvSpillTreeNode* node,
104                         CvSpillTreeNode* append )
105 {
106   if ( node->lc == NULL )
107     {
108       node->lc = node->rc = append;
109       node->lc->lc = node->rc->rc = NULL;
110     } else {
111       append->lc = node->rc;
112       append->rc = NULL;
113       node->rc->rc = append;
114       node->rc = append;
115     }
116   node->cc++;
117 }
118
119 #define _dispatch_mat_ptr(x, step) (CV_MAT_DEPTH((x)->type) == CV_32F ? (void*)((x)->data.fl+(step)) : (CV_MAT_DEPTH((x)->type) == CV_64F ? (void*)((x)->data.db+(step)) : (void*)(0)))
120
121 static void
122 icvDFSInitSpillTreeNode( const CvSpillTree* tr,
123                          const int d,
124                          CvSpillTreeNode* node )
125 {
126   if ( node->cc <= tr->naive )
127     {
128       // already get to a leaf, terminate the recursion.
129       node->leaf = true;
130       node->spill = false;
131       return;
132     }
133
134   // random select a node, then find a farthest node from this one, then find a farthest from that one...
135   // to approximate the farthest node-pair
136   static CvRNG rng_state = cvRNG(0xdeadbeef);
137   int rn = cvRandInt( &rng_state ) % node->cc;
138   CvSpillTreeNode* lnode = NULL;
139   CvSpillTreeNode* rnode = node->lc;
140   for ( int i = 0; i < rn; i++ )
141     rnode = rnode->rc;
142   lnode = icvFarthestNode( rnode, node->lc, node->cc );
143   rnode = icvFarthestNode( lnode, node->lc, node->cc );
144
145   // u is the projection vector
146   node->u = cvCreateMat( 1, d, tr->type );
147   cvSub( lnode->center, rnode->center, node->u );
148   cvNormalize( node->u, node->u );
149         
150   // find the center of node in hyperspace
151   node->center = cvCreateMat( 1, d, tr->type );
152   cvZero( node->center );
153   CvSpillTreeNode* it = node->lc;
154   for ( int i = 0; i < node->cc; i++ )
155     {
156       cvAdd( it->center, node->center, node->center );
157       it = it->rc;
158     }
159   cvConvertScale( node->center, node->center, 1./node->cc );
160
161   // project every node to "u", and find the mean point "mp"
162   it = node->lc;
163   node->r = -1.;
164   node->mp = 0;
165   for ( int i = 0; i < node->cc; i++ )
166     {
167       node->mp += ( it->p = cvDotProduct( it->center, node->u ) );
168       double norm = cvNorm( node->center, it->center );
169       if ( norm > node->r )
170         node->r = norm;
171       it = it->rc;
172     }
173   node->mp = node->mp / node->cc;
174
175   // overlapping buffer and upper bound, lower bound
176   double ob = (lnode->p-rnode->p)*tr->tau*.5;
177   node->ub = node->mp+ob;
178   node->lb = node->mp-ob;
179   int sl = 0, l = 0;
180   int sr = 0, r = 0;
181   it = node->lc;
182   for ( int i = 0; i < node->cc; i++ )
183     {
184       if ( it->p <= node->ub )
185         sl++;
186       if ( it->p >= node->lb )
187         sr++;
188       if ( it->p < node->mp )
189         l++;
190       else
191         r++;
192       it = it->rc;
193     }
194   // precision problem, return the node as it is.
195   if (( l == 0 )||( r == 0 ))
196     {
197       cvReleaseMat( &(node->u) );
198       cvReleaseMat( &(node->center) );
199       node->leaf = true;
200       node->spill = false;
201       return;
202     }
203   CvSpillTreeNode* lc = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
204   memset(lc, 0, sizeof(CvSpillTreeNode));
205   CvSpillTreeNode* rc = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
206   memset(rc, 0, sizeof(CvSpillTreeNode));
207   lc->lc = lc->rc = rc->lc = rc->rc = NULL;
208   lc->cc = rc->cc = 0;
209   int undo = cvRound(node->cc*tr->rho);
210   if (( sl >= undo )||( sr >= undo ))
211     {
212       // it is not a spill point (defeatist search disabled)
213       it = node->lc;
214       for ( int i = 0; i < node->cc; i++ )
215         {
216           CvSpillTreeNode* next = it->rc;
217           if ( it->p < node->mp )
218             icvAppendSpillTreeNode( lc, it );
219           else
220             icvAppendSpillTreeNode( rc, it );
221           it = next;
222         }
223       node->spill = false;
224     } else {
225       // a spill point
226       it = node->lc;
227       for ( int i = 0; i < node->cc; i++ )
228         {
229           CvSpillTreeNode* next = it->rc;
230           if ( it->p < node->lb )
231             icvAppendSpillTreeNode( lc, it );
232           else if ( it->p > node->ub )
233             icvAppendSpillTreeNode( rc, it );
234           else {
235             CvSpillTreeNode* cit = icvCloneSpillTreeNode( it );
236             icvAppendSpillTreeNode( lc, it );
237             icvAppendSpillTreeNode( rc, cit );
238           }
239           it = next;
240         }
241       node->spill = true;
242     }
243   node->lc = lc;
244   node->rc = rc;
245
246   // recursion process
247   icvDFSInitSpillTreeNode( tr, d, node->lc );
248   icvDFSInitSpillTreeNode( tr, d, node->rc );
249 }
250
251 static CvSpillTree*
252 icvCreateSpillTree( const CvMat* raw_data,
253                     const int naive,
254                     const double rho,
255                     const double tau )
256 {
257   int n = raw_data->rows;
258   int d = raw_data->cols;
259
260   CvSpillTree* tr = (CvSpillTree*)cvAlloc( sizeof(CvSpillTree) );
261   tr->root = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
262   memset(tr->root, 0, sizeof(CvSpillTreeNode));
263   tr->refmat = (CvMat**)cvAlloc( sizeof(CvMat*)*n );
264   tr->cache = (bool*)cvAlloc( sizeof(bool)*n );
265   tr->total = n;
266   tr->naive = naive;
267   tr->rho = rho;
268   tr->tau = tau;
269   tr->type = raw_data->type;
270
271   // tie a link-list to the root node
272   tr->root->lc = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
273   memset(tr->root->lc, 0, sizeof(CvSpillTreeNode));
274   tr->root->lc->center = cvCreateMatHeader( 1, d, tr->type );
275   cvSetData( tr->root->lc->center, _dispatch_mat_ptr(raw_data, 0), raw_data->step );
276   tr->refmat[0] = tr->root->lc->center;
277   tr->root->lc->lc = NULL;
278   tr->root->lc->leaf = true;
279   tr->root->lc->i = 0;
280   CvSpillTreeNode* node = tr->root->lc;
281   for ( int i = 1; i < n; i++ )
282     {
283       CvSpillTreeNode* newnode = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
284       memset(newnode, 0, sizeof(CvSpillTreeNode));
285       newnode->center = cvCreateMatHeader( 1, d, tr->type );
286       cvSetData( newnode->center, _dispatch_mat_ptr(raw_data, i*d), raw_data->step );
287       tr->refmat[i] = newnode->center;
288       newnode->lc = node;
289       newnode->i = i;
290       newnode->leaf = true;
291       newnode->rc = NULL;
292       node->rc = newnode;
293       node = newnode;
294     }
295   tr->root->rc = node;
296   tr->root->cc = n;
297   icvDFSInitSpillTreeNode( tr, d, tr->root );
298   return tr;
299 }
300
301 static void
302 icvSpillTreeNodeHeapify( CvSpillTreeNode** heap,
303                          int i,
304                          const int k )
305 {
306   if ( heap[i] == NULL )
307     return;
308   int l, r, largest = i;
309   CvSpillTreeNode* inp;
310   do {
311     i = largest;
312     r = (i+1)<<1;
313     l = r-1;
314     if (( l < k )&&( heap[l] == NULL ))
315       largest = l;
316     else if (( r < k )&&( heap[r] == NULL ))
317       largest = r;
318         else {
319       if (( l < k )&&( heap[l]->mp > heap[i]->mp ))
320         largest = l;
321       if (( r < k )&&( heap[r]->mp > heap[largest]->mp ))
322         largest = r;
323     }
324     if ( largest != i )
325       CV_SWAP( heap[largest], heap[i], inp );
326   } while ( largest != i );
327 }
328
329 static void
330 icvSpillTreeDFSearch( CvSpillTree* tr,
331                       CvSpillTreeNode* node,
332                       CvSpillTreeNode** heap,
333                       int* es,
334                       const CvMat* desc,
335                       const int k,
336                       const int emax )
337 {
338   if ((emax > 0)&&( *es >= emax ))
339     return;
340   double dist, p=0;
341   while ( node->spill )
342     {
343       // defeatist search
344       if ( !node->leaf )
345         p = cvDotProduct( node->u, desc );
346       if ( p < node->lb && node->lc->cc >= k ) // check the number of children larger than k otherwise you'll skip over better neighbor
347         node = node->lc;
348       else if ( p > node->ub && node->rc->cc >= k )
349         node = node->rc;
350       else
351         break;
352       if ( NULL == node )
353         return;
354     }
355   if ( node->leaf )
356     {
357       // a leaf, naive search
358       CvSpillTreeNode* it = node->lc;
359       for ( int i = 0; i < node->cc; i++ )
360         {
361           if ( !tr->cache[it->i] )
362           {
363             it->mp = cvNorm( it->center, desc );
364             tr->cache[it->i] = true;
365             if (( heap[0] == NULL)||( it->mp < heap[0]->mp ))
366               {
367                 heap[0] = it;
368                 icvSpillTreeNodeHeapify( heap, 0, k );
369                 (*es)++;
370               }
371           }
372           it = it->rc;
373         }
374       return;
375     }
376   dist = cvNorm( node->center, desc );
377   // impossible case, skip
378   if (( heap[0] != NULL )&&( dist-node->r > heap[0]->mp ))
379     return;
380   p = cvDotProduct( node->u, desc );
381   // guided dfs
382   if ( p < node->mp )
383     {
384       icvSpillTreeDFSearch( tr, node->lc, heap, es, desc, k, emax );
385       icvSpillTreeDFSearch( tr, node->rc, heap, es, desc, k, emax );
386     } else {
387       icvSpillTreeDFSearch( tr, node->rc, heap, es, desc, k, emax );
388       icvSpillTreeDFSearch( tr, node->lc, heap, es, desc, k, emax );
389     }
390 }
391
392 static void
393 icvFindSpillTreeFeatures( CvSpillTree* tr,
394                           const CvMat* desc,
395                           CvMat* results,
396                           CvMat* dist,
397                           const int k,
398                           const int emax )
399 {
400   assert( desc->type == tr->type );
401   CvSpillTreeNode** heap = (CvSpillTreeNode**)cvAlloc( k*sizeof(heap[0]) );
402   for ( int j = 0; j < desc->rows; j++ )
403     {
404       CvMat _desc = cvMat( 1, desc->cols, desc->type, _dispatch_mat_ptr(desc, j*desc->cols) );
405       for ( int i = 0; i < k; i++ )
406         heap[i] = NULL;
407       memset( tr->cache, 0, sizeof(bool)*tr->total );
408       int es = 0;
409       icvSpillTreeDFSearch( tr, tr->root, heap, &es, &_desc, k, emax );
410       CvSpillTreeNode* inp;
411       for ( int i = k-1; i > 0; i-- )
412         {
413           CV_SWAP( heap[i], heap[0], inp );
414           icvSpillTreeNodeHeapify( heap, 0, i );
415         }
416       int* rs = results->data.i+j*results->cols;
417       double* dt = dist->data.db+j*dist->cols;
418       for ( int i = 0; i < k; i++, rs++, dt++ )
419         if ( heap[i] != NULL )
420           {
421             *rs = heap[i]->i;
422             *dt = heap[i]->mp;
423           } else
424             *rs = -1;
425     }
426   cvFree( &heap );
427 }
428
429 static void
430 icvDFSReleaseSpillTreeNode( CvSpillTreeNode* node )
431 {
432   if ( node->leaf )
433     {
434       CvSpillTreeNode* it = node->lc;
435       for ( int i = 0; i < node->cc; i++ )
436         {
437           CvSpillTreeNode* s = it;
438           it = it->rc;
439           cvFree( &s );
440         }
441     } else {
442       cvReleaseMat( &node->u );
443       cvReleaseMat( &node->center );
444       icvDFSReleaseSpillTreeNode( node->lc );
445       icvDFSReleaseSpillTreeNode( node->rc );
446     }
447   cvFree( &node );
448 }
449
450 static void
451 icvReleaseSpillTree( CvSpillTree** tr )
452 {
453   for ( int i = 0; i < (*tr)->total; i++ )
454     cvReleaseMat( &((*tr)->refmat[i]) );
455   cvFree( &((*tr)->refmat) );
456   cvFree( &((*tr)->cache) );
457   icvDFSReleaseSpillTreeNode( (*tr)->root );
458   cvFree( tr );
459 }
460
461 class CvSpillTreeWrap : public CvFeatureTree {
462   CvSpillTree* tr;
463 public:
464   CvSpillTreeWrap(const CvMat* raw_data,
465                   const int naive,
466                   const double rho,
467                   const double tau) {
468     tr = icvCreateSpillTree(raw_data, naive, rho, tau);
469   }
470   ~CvSpillTreeWrap() {
471     icvReleaseSpillTree(&tr);
472   }
473
474   void FindFeatures(const CvMat* desc, int k, int emax, CvMat* results, CvMat* dist) {
475     icvFindSpillTreeFeatures(tr, desc, results, dist, k, emax);
476   }
477 };
478
479 CvFeatureTree* cvCreateSpillTree( const CvMat* raw_data,
480                                   const int naive,
481                                   const double rho,
482                                   const double tau ) {
483   return new CvSpillTreeWrap(raw_data, naive, rho, tau);
484 }