Move the sources to trunk
[opencv] / apps / haartraining / src / cvboost.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
14 // Third party copyrights are property of their respective owners.
15 //
16 // Redistribution and use in source and binary forms, with or without modification,
17 // are permitted provided that the following conditions are met:
18 //
19 //   * Redistribution's of source code must retain the above copyright notice,
20 //     this list of conditions and the following disclaimer.
21 //
22 //   * Redistribution's in binary form must reproduce the above copyright notice,
23 //     this list of conditions and the following disclaimer in the documentation
24 //     and/or other materials provided with the distribution.
25 //
26 //   * The name of Intel Corporation may not be used to endorse or promote products
27 //     derived from this software without specific prior written permission.
28 //
29 // This software is provided by the copyright holders and contributors "as is" and
30 // any express or implied warranties, including, but not limited to, the implied
31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
32 // In no event shall the Intel Corporation or contributors be liable for any direct,
33 // indirect, incidental, special, exemplary, or consequential damages
34 // (including, but not limited to, procurement of substitute goods or services;
35 // loss of use, data, or profits; or business interruption) however caused
36 // and on any theory of liability, whether in contract, strict liability,
37 // or tort (including negligence or otherwise) arising in any way out of
38 // the use of this software, even if advised of the possibility of such damage.
39 //
40 //M*/
41
42 #ifdef HAVE_CONFIG_H
43     #include <cvconfig.h>
44 #endif
45
46 #ifdef HAVE_MALLOC_H
47     #include <malloc.h>
48 #endif
49
50 #include <stdio.h>
51 #include <memory.h>
52 #include <float.h>
53 #include <math.h>
54
55 #include <time.h>
56 #include <limits.h>
57
58 #include <_cvcommon.h>
59 #include <cvclassifier.h>
60
61 #ifdef _OPENMP
62 #include <omp.h>
63 #endif /* _OPENMP */
64
65 #define CV_BOOST_IMPL
66
67 typedef struct CvValArray
68 {
69     uchar* data;
70     size_t step;
71 } CvValArray;
72
73 #define CMP_VALUES( idx1, idx2 )                                 \
74     ( *( (float*) (aux->data + ((int) (idx1)) * aux->step ) ) <  \
75       *( (float*) (aux->data + ((int) (idx2)) * aux->step ) ) )
76
77 CV_IMPLEMENT_QSORT_EX( icvSortIndexedValArray_16s, short, CMP_VALUES, CvValArray* )
78
79 CV_IMPLEMENT_QSORT_EX( icvSortIndexedValArray_32s, int,   CMP_VALUES, CvValArray* )
80
81 CV_IMPLEMENT_QSORT_EX( icvSortIndexedValArray_32f, float, CMP_VALUES, CvValArray* )
82
83 CV_BOOST_IMPL
84 void cvGetSortedIndices( CvMat* val, CvMat* idx, int sortcols )
85 {
86     int idxtype = 0;
87     uchar* data = NULL;
88     size_t istep = 0;
89     size_t jstep = 0;
90
91     int i = 0;
92     int j = 0;
93
94     CvValArray va;
95
96     assert( idx != NULL );
97     assert( val != NULL );
98
99     idxtype = CV_MAT_TYPE( idx->type );
100     assert( idxtype == CV_16SC1 || idxtype == CV_32SC1 || idxtype == CV_32FC1 );
101     assert( CV_MAT_TYPE( val->type ) == CV_32FC1 );
102     if( sortcols )
103     {
104         assert( idx->rows == val->cols );
105         assert( idx->cols == val->rows );
106         istep = CV_ELEM_SIZE( val->type );
107         jstep = val->step;
108     }
109     else
110     {
111         assert( idx->rows == val->rows );
112         assert( idx->cols == val->cols );
113         istep = val->step;
114         jstep = CV_ELEM_SIZE( val->type );
115     }
116
117     va.data = val->data.ptr;
118     va.step = jstep;
119     switch( idxtype )
120     {
121         case CV_16SC1:
122             for( i = 0; i < idx->rows; i++ )
123             {
124                 for( j = 0; j < idx->cols; j++ )
125                 {
126                     CV_MAT_ELEM( *idx, short, i, j ) = (short) j;
127                 }
128                 icvSortIndexedValArray_16s( (short*) (idx->data.ptr + i * idx->step),
129                                             idx->cols, &va );
130                 va.data += istep;
131             }
132             break;
133
134         case CV_32SC1:
135             for( i = 0; i < idx->rows; i++ )
136             {
137                 for( j = 0; j < idx->cols; j++ )
138                 {
139                     CV_MAT_ELEM( *idx, int, i, j ) = j;
140                 }
141                 icvSortIndexedValArray_32s( (int*) (idx->data.ptr + i * idx->step),
142                                             idx->cols, &va );
143                 va.data += istep;
144             }
145             break;
146
147         case CV_32FC1:
148             for( i = 0; i < idx->rows; i++ )
149             {
150                 for( j = 0; j < idx->cols; j++ )
151                 {
152                     CV_MAT_ELEM( *idx, float, i, j ) = (float) j;
153                 }
154                 icvSortIndexedValArray_32f( (float*) (idx->data.ptr + i * idx->step),
155                                             idx->cols, &va );
156                 va.data += istep;
157             }
158             break;
159
160         default:
161             assert( 0 );
162             break;
163     }
164 }
165
166 CV_BOOST_IMPL
167 void cvReleaseStumpClassifier( CvClassifier** classifier )
168 {
169     cvFree( classifier );
170     *classifier = 0;
171 }
172
173 CV_BOOST_IMPL
174 float cvEvalStumpClassifier( CvClassifier* classifier, CvMat* sample )
175 {
176     assert( classifier != NULL );
177     assert( sample != NULL );
178     assert( CV_MAT_TYPE( sample->type ) == CV_32FC1 );
179     
180     if( (CV_MAT_ELEM( (*sample), float, 0,
181             ((CvStumpClassifier*) classifier)->compidx )) <
182         ((CvStumpClassifier*) classifier)->threshold ) 
183     {
184         return ((CvStumpClassifier*) classifier)->left;
185     }
186     else
187     {
188         return ((CvStumpClassifier*) classifier)->right;
189     }
190
191     return 0.0F;
192 }
193
194 #define ICV_DEF_FIND_STUMP_THRESHOLD( suffix, type, error )                              \
195 CV_BOOST_IMPL int icvFindStumpThreshold_##suffix(                                              \
196         uchar* data, size_t datastep,                                                    \
197         uchar* wdata, size_t wstep,                                                      \
198         uchar* ydata, size_t ystep,                                                      \
199         uchar* idxdata, size_t idxstep, int num,                                         \
200         float* lerror,                                                                   \
201         float* rerror,                                                                   \
202         float* threshold, float* left, float* right,                                     \
203         float* sumw, float* sumwy, float* sumwyy )                                       \
204 {                                                                                        \
205     int found = 0;                                                                       \
206     float wyl  = 0.0F;                                                                   \
207     float wl   = 0.0F;                                                                   \
208     float wyyl = 0.0F;                                                                   \
209     float wyr  = 0.0F;                                                                   \
210     float wr   = 0.0F;                                                                   \
211                                                                                          \
212     float curleft  = 0.0F;                                                               \
213     float curright = 0.0F;                                                               \
214     float* prevval = NULL;                                                               \
215     float* curval  = NULL;                                                               \
216     float curlerror = 0.0F;                                                              \
217     float currerror = 0.0F;                                                              \
218     float wposl;                                                                         \
219     float wposr;                                                                         \
220                                                                                          \
221     int i = 0;                                                                           \
222     int idx = 0;                                                                         \
223                                                                                          \
224     wposl = wposr = 0.0F;                                                                \
225     if( *sumw == FLT_MAX )                                                               \
226     {                                                                                    \
227         /* calculate sums */                                                             \
228         float *y = NULL;                                                                 \
229         float *w = NULL;                                                                 \
230         float wy = 0.0F;                                                                 \
231                                                                                          \
232         *sumw   = 0.0F;                                                                  \
233         *sumwy  = 0.0F;                                                                  \
234         *sumwyy = 0.0F;                                                                  \
235         for( i = 0; i < num; i++ )                                                       \
236         {                                                                                \
237             idx = (int) ( *((type*) (idxdata + i*idxstep)) );                            \
238             w = (float*) (wdata + idx * wstep);                                          \
239             *sumw += *w;                                                                 \
240             y = (float*) (ydata + idx * ystep);                                          \
241             wy = (*w) * (*y);                                                            \
242             *sumwy += wy;                                                                \
243             *sumwyy += wy * (*y);                                                        \
244         }                                                                                \
245     }                                                                                    \
246                                                                                          \
247     for( i = 0; i < num; i++ )                                                           \
248     {                                                                                    \
249         idx = (int) ( *((type*) (idxdata + i*idxstep)) );                                \
250         curval = (float*) (data + idx * datastep);                                       \
251          /* for debug purpose */                                                         \
252         if( i > 0 ) assert( (*prevval) <= (*curval) );                                   \
253                                                                                          \
254         wyr  = *sumwy - wyl;                                                             \
255         wr   = *sumw  - wl;                                                              \
256                                                                                          \
257         if( wl > 0.0 ) curleft = wyl / wl;                                               \
258         else curleft = 0.0F;                                                             \
259                                                                                          \
260         if( wr > 0.0 ) curright = wyr / wr;                                              \
261         else curright = 0.0F;                                                            \
262                                                                                          \
263         error                                                                            \
264                                                                                          \
265         if( curlerror + currerror < (*lerror) + (*rerror) )                              \
266         {                                                                                \
267             (*lerror) = curlerror;                                                       \
268             (*rerror) = currerror;                                                       \
269             *threshold = *curval;                                                        \
270             if( i > 0 ) {                                                                \
271                 *threshold = 0.5F * (*threshold + *prevval);                             \
272             }                                                                            \
273             *left  = curleft;                                                            \
274             *right = curright;                                                           \
275             found = 1;                                                                   \
276         }                                                                                \
277                                                                                          \
278         do                                                                               \
279         {                                                                                \
280             wl  += *((float*) (wdata + idx * wstep));                                    \
281             wyl += (*((float*) (wdata + idx * wstep)))                                   \
282                 * (*((float*) (ydata + idx * ystep)));                                   \
283             wyyl += *((float*) (wdata + idx * wstep))                                    \
284                 * (*((float*) (ydata + idx * ystep)))                                    \
285                 * (*((float*) (ydata + idx * ystep)));                                   \
286         }                                                                                \
287         while( (++i) < num &&                                                            \
288             ( *((float*) (data + (idx =                                                  \
289                 (int) ( *((type*) (idxdata + i*idxstep))) ) * datastep))                 \
290                 == *curval ) );                                                          \
291         --i;                                                                             \
292         prevval = curval;                                                                \
293     } /* for each value */                                                               \
294                                                                                          \
295     return found;                                                                        \
296 }
297
298 /* misclassification error
299  * err = MIN( wpos, wneg );
300  */
301 #define ICV_DEF_FIND_STUMP_THRESHOLD_MISC( suffix, type )                                \
302     ICV_DEF_FIND_STUMP_THRESHOLD( misc_##suffix, type,                                   \
303         wposl = 0.5F * ( wl + wyl );                                                     \
304         wposr = 0.5F * ( wr + wyr );                                                     \
305         curleft = 0.5F * ( 1.0F + curleft );                                             \
306         curright = 0.5F * ( 1.0F + curright );                                           \
307         curlerror = MIN( wposl, wl - wposl );                                            \
308         currerror = MIN( wposr, wr - wposr );                                            \
309     )
310
311 /* gini error
312  * err = 2 * wpos * wneg /(wpos + wneg)
313  */
314 #define ICV_DEF_FIND_STUMP_THRESHOLD_GINI( suffix, type )                                \
315     ICV_DEF_FIND_STUMP_THRESHOLD( gini_##suffix, type,                                   \
316         wposl = 0.5F * ( wl + wyl );                                                     \
317         wposr = 0.5F * ( wr + wyr );                                                     \
318         curleft = 0.5F * ( 1.0F + curleft );                                             \
319         curright = 0.5F * ( 1.0F + curright );                                           \
320         curlerror = 2.0F * wposl * ( 1.0F - curleft );                                   \
321         currerror = 2.0F * wposr * ( 1.0F - curright );                                  \
322     )
323
324 #define CV_ENTROPY_THRESHOLD FLT_MIN
325
326 /* entropy error
327  * err = - wpos * log(wpos / (wpos + wneg)) - wneg * log(wneg / (wpos + wneg))
328  */
329 #define ICV_DEF_FIND_STUMP_THRESHOLD_ENTROPY( suffix, type )                             \
330     ICV_DEF_FIND_STUMP_THRESHOLD( entropy_##suffix, type,                                \
331         wposl = 0.5F * ( wl + wyl );                                                     \
332         wposr = 0.5F * ( wr + wyr );                                                     \
333         curleft = 0.5F * ( 1.0F + curleft );                                             \
334         curright = 0.5F * ( 1.0F + curright );                                           \
335         curlerror = currerror = 0.0F;                                                    \
336         if( curleft > CV_ENTROPY_THRESHOLD )                                             \
337             curlerror -= wposl * logf( curleft );                                        \
338         if( curleft < 1.0F - CV_ENTROPY_THRESHOLD )                                      \
339             curlerror -= (wl - wposl) * logf( 1.0F - curleft );                          \
340                                                                                          \
341         if( curright > CV_ENTROPY_THRESHOLD )                                            \
342             currerror -= wposr * logf( curright );                                       \
343         if( curright < 1.0F - CV_ENTROPY_THRESHOLD )                                     \
344             currerror -= (wr - wposr) * logf( 1.0F - curright );                         \
345     )
346
347 /* least sum of squares error */
348 #define ICV_DEF_FIND_STUMP_THRESHOLD_SQ( suffix, type )                                  \
349     ICV_DEF_FIND_STUMP_THRESHOLD( sq_##suffix, type,                                     \
350         /* calculate error (sum of squares)          */                                  \
351         /* err = sum( w * (y - left(rigt)Val)^2 )    */                                  \
352         curlerror = wyyl + curleft * curleft * wl - 2.0F * curleft * wyl;                \
353         currerror = (*sumwyy) - wyyl + curright * curright * wr - 2.0F * curright * wyr; \
354     )
355
356 ICV_DEF_FIND_STUMP_THRESHOLD_MISC( 16s, short )
357
358 ICV_DEF_FIND_STUMP_THRESHOLD_MISC( 32s, int )
359
360 ICV_DEF_FIND_STUMP_THRESHOLD_MISC( 32f, float )
361
362
363 ICV_DEF_FIND_STUMP_THRESHOLD_GINI( 16s, short )
364
365 ICV_DEF_FIND_STUMP_THRESHOLD_GINI( 32s, int )
366
367 ICV_DEF_FIND_STUMP_THRESHOLD_GINI( 32f, float )
368
369
370 ICV_DEF_FIND_STUMP_THRESHOLD_ENTROPY( 16s, short )
371
372 ICV_DEF_FIND_STUMP_THRESHOLD_ENTROPY( 32s, int )
373
374 ICV_DEF_FIND_STUMP_THRESHOLD_ENTROPY( 32f, float )
375
376
377 ICV_DEF_FIND_STUMP_THRESHOLD_SQ( 16s, short )
378
379 ICV_DEF_FIND_STUMP_THRESHOLD_SQ( 32s, int )
380
381 ICV_DEF_FIND_STUMP_THRESHOLD_SQ( 32f, float )
382
383 typedef int (*CvFindThresholdFunc)( uchar* data, size_t datastep,
384                                     uchar* wdata, size_t wstep,
385                                     uchar* ydata, size_t ystep,
386                                     uchar* idxdata, size_t idxstep, int num,
387                                     float* lerror,
388                                     float* rerror,
389                                     float* threshold, float* left, float* right,
390                                     float* sumw, float* sumwy, float* sumwyy );
391
392 CvFindThresholdFunc findStumpThreshold_16s[4] = {
393         icvFindStumpThreshold_misc_16s,
394         icvFindStumpThreshold_gini_16s,
395         icvFindStumpThreshold_entropy_16s,
396         icvFindStumpThreshold_sq_16s
397     };
398
399 CvFindThresholdFunc findStumpThreshold_32s[4] = {
400         icvFindStumpThreshold_misc_32s,
401         icvFindStumpThreshold_gini_32s,
402         icvFindStumpThreshold_entropy_32s,
403         icvFindStumpThreshold_sq_32s
404     };
405
406 CvFindThresholdFunc findStumpThreshold_32f[4] = {
407         icvFindStumpThreshold_misc_32f,
408         icvFindStumpThreshold_gini_32f,
409         icvFindStumpThreshold_entropy_32f,
410         icvFindStumpThreshold_sq_32f
411     };
412
413 CV_BOOST_IMPL
414 CvClassifier* cvCreateStumpClassifier( CvMat* trainData,
415                       int flags,
416                       CvMat* trainClasses,
417                       CvMat* typeMask,
418                       CvMat* missedMeasurementsMask,
419                       CvMat* compIdx,
420                       CvMat* sampleIdx,
421                       CvMat* weights,
422                       CvClassifierTrainParams* trainParams
423                     )
424 {
425     CvStumpClassifier* stump = NULL;
426     int m = 0; /* number of samples */
427     int n = 0; /* number of components */
428     uchar* data = NULL;
429     int cstep   = 0;
430     int sstep   = 0;
431     uchar* ydata = NULL;
432     int ystep    = 0;
433     uchar* idxdata = NULL;
434     int idxstep    = 0;
435     int l = 0; /* number of indices */     
436     uchar* wdata = NULL;
437     int wstep    = 0;
438
439     int* idx = NULL;
440     int i = 0;
441     
442     float sumw   = FLT_MAX;
443     float sumw1  = FLT_MAX;
444     float sumw0  = FLT_MAX;
445     float sumwy  = FLT_MAX;
446     float sumwyy = FLT_MAX;
447
448     assert( trainData != NULL );
449     assert( CV_MAT_TYPE( trainData->type ) == CV_32FC1 );
450     assert( trainClasses != NULL );
451     assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
452     assert( missedMeasurementsMask == NULL );
453     assert( compIdx == NULL );
454     assert( weights != NULL );
455     assert( CV_MAT_TYPE( weights->type ) == CV_32FC1 );
456     assert( trainParams != NULL );
457
458     data = trainData->data.ptr;
459     if( CV_IS_ROW_SAMPLE( flags ) )
460     {
461         cstep = CV_ELEM_SIZE( trainData->type );
462         sstep = trainData->step;
463         m = trainData->rows;
464         n = trainData->cols;
465     }
466     else
467     {
468         sstep = CV_ELEM_SIZE( trainData->type );
469         cstep = trainData->step;
470         m = trainData->cols;
471         n = trainData->rows;
472     }
473
474     ydata = trainClasses->data.ptr;
475     if( trainClasses->rows == 1 )
476     {
477         assert( trainClasses->cols == m );
478         ystep = CV_ELEM_SIZE( trainClasses->type );
479     }
480     else
481     {
482         assert( trainClasses->rows == m );
483         ystep = trainClasses->step;
484     }
485
486     wdata = weights->data.ptr;
487     if( weights->rows == 1 )
488     {
489         assert( weights->cols == m );
490         wstep = CV_ELEM_SIZE( weights->type );
491     }
492     else
493     {
494         assert( weights->rows == m );
495         wstep = weights->step;
496     }
497
498     l = m;
499     if( sampleIdx != NULL )
500     {
501         assert( CV_MAT_TYPE( sampleIdx->type ) == CV_32FC1 );
502
503         idxdata = sampleIdx->data.ptr;
504         if( sampleIdx->rows == 1 )
505         {
506             l = sampleIdx->cols;
507             idxstep = CV_ELEM_SIZE( sampleIdx->type );
508         }
509         else
510         {
511             l = sampleIdx->rows;
512             idxstep = sampleIdx->step;
513         }
514         assert( l <= m );
515     }
516
517     idx = (int*) cvAlloc( l * sizeof( int ) );
518     stump = (CvStumpClassifier*) cvAlloc( sizeof( CvStumpClassifier) );
519
520     /* START */
521     memset( (void*) stump, 0, sizeof( CvStumpClassifier ) );
522
523     stump->eval = cvEvalStumpClassifier;
524     stump->tune = NULL;
525     stump->save = NULL;
526     stump->release = cvReleaseStumpClassifier;
527
528     stump->lerror = FLT_MAX;
529     stump->rerror = FLT_MAX;
530     stump->left  = 0.0F;
531     stump->right = 0.0F;
532
533     /* copy indices */
534     if( sampleIdx != NULL )
535     {
536         for( i = 0; i < l; i++ )
537         {
538             idx[i] = (int) *((float*) (idxdata + i*idxstep));
539         }
540     }
541     else
542     {
543         for( i = 0; i < l; i++ )
544         {
545             idx[i] = i;
546         }
547     }
548
549     for( i = 0; i < n; i++ )
550     {
551         CvValArray va;
552
553         va.data = data + i * ((size_t) cstep);
554         va.step = sstep;
555         icvSortIndexedValArray_32s( idx, l, &va );
556         if( findStumpThreshold_32s[(int) ((CvStumpTrainParams*) trainParams)->error]
557               ( data + i * ((size_t) cstep), sstep,
558                 wdata, wstep, ydata, ystep, (uchar*) idx, sizeof( int ), l,
559                 &(stump->lerror), &(stump->rerror),
560                 &(stump->threshold), &(stump->left), &(stump->right), 
561                 &sumw, &sumwy, &sumwyy ) )
562         {
563             stump->compidx = i;
564         }
565     } /* for each component */
566
567     /* END */
568
569     cvFree( &idx );
570
571     if( ((CvStumpTrainParams*) trainParams)->type == CV_CLASSIFICATION_CLASS )
572     {
573         stump->left = 2.0F * (stump->left >= 0.5F) - 1.0F;
574         stump->right = 2.0F * (stump->right >= 0.5F) - 1.0F;
575     }
576
577     return (CvClassifier*) stump;
578 }
579
580 /*
581  * cvCreateMTStumpClassifier
582  *
583  * Multithreaded stump classifier constructor
584  * Includes huge train data support through callback function
585  */
586 CV_BOOST_IMPL
587 CvClassifier* cvCreateMTStumpClassifier( CvMat* trainData,
588                       int flags,
589                       CvMat* trainClasses,
590                       CvMat* typeMask,
591                       CvMat* missedMeasurementsMask,
592                       CvMat* compIdx,
593                       CvMat* sampleIdx,
594                       CvMat* weights,
595                       CvClassifierTrainParams* trainParams )
596 {
597     CvStumpClassifier* stump = NULL;
598     int m = 0; /* number of samples */
599     int n = 0; /* number of components */
600     uchar* data = NULL;
601     size_t cstep   = 0;
602     size_t sstep   = 0;
603     int    datan   = 0; /* num components */
604     uchar* ydata = NULL;
605     size_t ystep = 0;
606     uchar* idxdata = NULL;
607     size_t idxstep = 0;
608     int    l = 0; /* number of indices */     
609     uchar* wdata = NULL;
610     size_t wstep = 0;
611
612     uchar* sorteddata = NULL;
613     int    sortedtype    = 0;
614     size_t sortedcstep   = 0; /* component step */
615     size_t sortedsstep   = 0; /* sample step */
616     int    sortedn       = 0; /* num components */
617     int    sortedm       = 0; /* num samples */
618
619     char* filter = NULL;
620     int i = 0;
621     
622     int compidx = 0;
623     int stumperror;
624     int portion;
625
626     /* private variables */
627     CvMat mat;
628     CvValArray va;
629     float lerror;
630     float rerror;
631     float left;
632     float right;
633     float threshold;
634     int optcompidx;
635
636     float sumw;
637     float sumwy;
638     float sumwyy;
639
640     int t_compidx;
641     int t_n;
642     
643     int ti;
644     int tj;
645     int tk;
646
647     uchar* t_data;
648     size_t t_cstep;
649     size_t t_sstep;
650
651     size_t matcstep;
652     size_t matsstep;
653
654     int* t_idx;
655     /* end private variables */
656
657     assert( trainParams != NULL );
658     assert( trainClasses != NULL );
659     assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
660     assert( missedMeasurementsMask == NULL );
661     assert( compIdx == NULL );
662
663     stumperror = (int) ((CvMTStumpTrainParams*) trainParams)->error;
664
665     ydata = trainClasses->data.ptr;
666     if( trainClasses->rows == 1 )
667     {
668         m = trainClasses->cols;
669         ystep = CV_ELEM_SIZE( trainClasses->type );
670     }
671     else
672     {
673         m = trainClasses->rows;
674         ystep = trainClasses->step;
675     }
676
677     wdata = weights->data.ptr;
678     if( weights->rows == 1 )
679     {
680         assert( weights->cols == m );
681         wstep = CV_ELEM_SIZE( weights->type );
682     }
683     else
684     {
685         assert( weights->rows == m );
686         wstep = weights->step;
687     }
688
689     if( ((CvMTStumpTrainParams*) trainParams)->sortedIdx != NULL )
690     {
691         sortedtype =
692             CV_MAT_TYPE( ((CvMTStumpTrainParams*) trainParams)->sortedIdx->type );
693         assert( sortedtype == CV_16SC1 || sortedtype == CV_32SC1
694                 || sortedtype == CV_32FC1 );
695         sorteddata = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->data.ptr;
696         sortedsstep = CV_ELEM_SIZE( sortedtype );
697         sortedcstep = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->step;
698         sortedn = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->rows;
699         sortedm = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->cols;
700     }
701
702     if( trainData == NULL )
703     {
704         assert( ((CvMTStumpTrainParams*) trainParams)->getTrainData != NULL );
705         n = ((CvMTStumpTrainParams*) trainParams)->numcomp;
706         assert( n > 0 );
707     }
708     else
709     {
710         assert( CV_MAT_TYPE( trainData->type ) == CV_32FC1 );
711         data = trainData->data.ptr;
712         if( CV_IS_ROW_SAMPLE( flags ) )
713         {
714             cstep = CV_ELEM_SIZE( trainData->type );
715             sstep = trainData->step;
716             assert( m == trainData->rows );
717             datan = n = trainData->cols;
718         }
719         else
720         {
721             sstep = CV_ELEM_SIZE( trainData->type );
722             cstep = trainData->step;
723             assert( m == trainData->cols );
724             datan = n = trainData->rows;
725         }
726         if( ((CvMTStumpTrainParams*) trainParams)->getTrainData != NULL )
727         {
728             n = ((CvMTStumpTrainParams*) trainParams)->numcomp;
729         }        
730     }
731     assert( datan <= n );
732
733     if( sampleIdx != NULL )
734     {
735         assert( CV_MAT_TYPE( sampleIdx->type ) == CV_32FC1 );
736         idxdata = sampleIdx->data.ptr;
737         idxstep = ( sampleIdx->rows == 1 )
738             ? CV_ELEM_SIZE( sampleIdx->type ) : sampleIdx->step;
739         l = ( sampleIdx->rows == 1 ) ? sampleIdx->cols : sampleIdx->rows;
740
741         if( sorteddata != NULL )
742         {
743             filter = (char*) cvAlloc( sizeof( char ) * m );
744             memset( (void*) filter, 0, sizeof( char ) * m );
745             for( i = 0; i < l; i++ )
746             {
747                 filter[(int) *((float*) (idxdata + i * idxstep))] = (char) 1;
748             }
749         }
750     }
751     else
752     {
753         l = m;
754     }
755
756     stump = (CvStumpClassifier*) cvAlloc( sizeof( CvStumpClassifier) );
757
758     /* START */
759     memset( (void*) stump, 0, sizeof( CvStumpClassifier ) );
760
761     portion = ((CvMTStumpTrainParams*)trainParams)->portion;
762     
763     if( portion < 1 )
764     {
765         /* auto portion */
766         portion = n;
767         #ifdef _OPENMP
768         portion /= omp_get_max_threads();        
769         #endif /* _OPENMP */        
770     }
771
772     stump->eval = cvEvalStumpClassifier;
773     stump->tune = NULL;
774     stump->save = NULL;
775     stump->release = cvReleaseStumpClassifier;
776
777     stump->lerror = FLT_MAX;
778     stump->rerror = FLT_MAX;
779     stump->left  = 0.0F;
780     stump->right = 0.0F;
781
782     compidx = 0;
783     #ifdef _OPENMP
784     #pragma omp parallel private(mat, va, lerror, rerror, left, right, threshold, \
785                                  optcompidx, sumw, sumwy, sumwyy, t_compidx, t_n, \
786                                  ti, tj, tk, t_data, t_cstep, t_sstep, matcstep,  \
787                                  matsstep, t_idx)
788     #endif /* _OPENMP */
789     {
790         lerror = FLT_MAX;
791         rerror = FLT_MAX;
792         left  = 0.0F;
793         right = 0.0F;
794         threshold = 0.0F;
795         optcompidx = 0;
796
797         sumw   = FLT_MAX;
798         sumwy  = FLT_MAX;
799         sumwyy = FLT_MAX;
800
801         t_compidx = 0;
802         t_n = 0;
803         
804         ti = 0;
805         tj = 0;
806         tk = 0;
807
808         t_data = NULL;
809         t_cstep = 0;
810         t_sstep = 0;
811
812         matcstep = 0;
813         matsstep = 0;
814
815         t_idx = NULL;
816
817         mat.data.ptr = NULL;
818         
819         if( datan < n )
820         {
821             /* prepare matrix for callback */
822             if( CV_IS_ROW_SAMPLE( flags ) )
823             {
824                 mat = cvMat( m, portion, CV_32FC1, 0 );
825                 matcstep = CV_ELEM_SIZE( mat.type );
826                 matsstep = mat.step;
827             }
828             else
829             {
830                 mat = cvMat( portion, m, CV_32FC1, 0 );
831                 matcstep = mat.step;
832                 matsstep = CV_ELEM_SIZE( mat.type );
833             }
834             mat.data.ptr = (uchar*) cvAlloc( sizeof( float ) * mat.rows * mat.cols );
835         }
836
837         if( filter != NULL || sortedn < n )
838         {
839             t_idx = (int*) cvAlloc( sizeof( int ) * m );
840             if( sortedn == 0 || filter == NULL )
841             {
842                 if( idxdata != NULL )
843                 {
844                     for( ti = 0; ti < l; ti++ )
845                     {
846                         t_idx[ti] = (int) *((float*) (idxdata + ti * idxstep));
847                     }
848                 }
849                 else
850                 {
851                     for( ti = 0; ti < l; ti++ )
852                     {
853                         t_idx[ti] = ti;
854                     }
855                 }                
856             }
857         }
858
859         #ifdef _OPENMP
860         #pragma omp critical(c_compidx)
861         #endif /* _OPENMP */
862         {
863             t_compidx = compidx;
864             compidx += portion;
865         }
866         while( t_compidx < n )
867         {
868             t_n = portion;
869             if( t_compidx < datan )
870             {
871                 t_n = ( t_n < (datan - t_compidx) ) ? t_n : (datan - t_compidx);
872                 t_data = data;
873                 t_cstep = cstep;
874                 t_sstep = sstep;
875             }
876             else
877             {
878                 t_n = ( t_n < (n - t_compidx) ) ? t_n : (n - t_compidx);
879                 t_cstep = matcstep;
880                 t_sstep = matsstep;
881                 t_data = mat.data.ptr - t_compidx * ((size_t) t_cstep );
882
883                 /* calculate components */
884                 ((CvMTStumpTrainParams*)trainParams)->getTrainData( &mat,
885                         sampleIdx, compIdx, t_compidx, t_n,
886                         ((CvMTStumpTrainParams*)trainParams)->userdata );
887             }
888
889             if( sorteddata != NULL )
890             {
891                 if( filter != NULL )
892                 {
893                     /* have sorted indices and filter */
894                     switch( sortedtype )
895                     {
896                         case CV_16SC1:
897                             for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
898                             {
899                                 tk = 0;
900                                 for( tj = 0; tj < sortedm; tj++ )
901                                 {
902                                     int curidx = (int) ( *((short*) (sorteddata
903                                             + ti * sortedcstep + tj * sortedsstep)) );
904                                     if( filter[curidx] != 0 )
905                                     {
906                                         t_idx[tk++] = curidx;
907                                     }
908                                 }
909                                 if( findStumpThreshold_32s[stumperror]( 
910                                         t_data + ti * t_cstep, t_sstep,
911                                         wdata, wstep, ydata, ystep,
912                                         (uchar*) t_idx, sizeof( int ), tk,
913                                         &lerror, &rerror,
914                                         &threshold, &left, &right, 
915                                         &sumw, &sumwy, &sumwyy ) )
916                                 {
917                                     optcompidx = ti;
918                                 }
919                             }
920                             break;
921                         case CV_32SC1:
922                             for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
923                             {
924                                 tk = 0;
925                                 for( tj = 0; tj < sortedm; tj++ )
926                                 {
927                                     int curidx = (int) ( *((int*) (sorteddata
928                                             + ti * sortedcstep + tj * sortedsstep)) );
929                                     if( filter[curidx] != 0 )
930                                     {
931                                         t_idx[tk++] = curidx;
932                                     }
933                                 }
934                                 if( findStumpThreshold_32s[stumperror]( 
935                                         t_data + ti * t_cstep, t_sstep,
936                                         wdata, wstep, ydata, ystep,
937                                         (uchar*) t_idx, sizeof( int ), tk,
938                                         &lerror, &rerror,
939                                         &threshold, &left, &right, 
940                                         &sumw, &sumwy, &sumwyy ) )
941                                 {
942                                     optcompidx = ti;
943                                 }
944                             }
945                             break;
946                         case CV_32FC1:
947                             for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
948                             {
949                                 tk = 0;
950                                 for( tj = 0; tj < sortedm; tj++ )
951                                 {
952                                     int curidx = (int) ( *((float*) (sorteddata
953                                             + ti * sortedcstep + tj * sortedsstep)) );
954                                     if( filter[curidx] != 0 )
955                                     {
956                                         t_idx[tk++] = curidx;
957                                     }
958                                 }
959                                 if( findStumpThreshold_32s[stumperror]( 
960                                         t_data + ti * t_cstep, t_sstep,
961                                         wdata, wstep, ydata, ystep,
962                                         (uchar*) t_idx, sizeof( int ), tk,
963                                         &lerror, &rerror,
964                                         &threshold, &left, &right, 
965                                         &sumw, &sumwy, &sumwyy ) )
966                                 {
967                                     optcompidx = ti;
968                                 }
969                             }
970                             break;
971                         default:
972                             assert( 0 );
973                             break;
974                     }
975                 }
976                 else
977                 {
978                     /* have sorted indices */
979                     switch( sortedtype )
980                     {
981                         case CV_16SC1:
982                             for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
983                             {
984                                 if( findStumpThreshold_16s[stumperror]( 
985                                         t_data + ti * t_cstep, t_sstep,
986                                         wdata, wstep, ydata, ystep,
987                                         sorteddata + ti * sortedcstep, sortedsstep, sortedm,
988                                         &lerror, &rerror,
989                                         &threshold, &left, &right, 
990                                         &sumw, &sumwy, &sumwyy ) )
991                                 {
992                                     optcompidx = ti;
993                                 }
994                             }
995                             break;
996                         case CV_32SC1:
997                             for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
998                             {
999                                 if( findStumpThreshold_32s[stumperror]( 
1000                                         t_data + ti * t_cstep, t_sstep,
1001                                         wdata, wstep, ydata, ystep,
1002                                         sorteddata + ti * sortedcstep, sortedsstep, sortedm,
1003                                         &lerror, &rerror,
1004                                         &threshold, &left, &right, 
1005                                         &sumw, &sumwy, &sumwyy ) )
1006                                 {
1007                                     optcompidx = ti;
1008                                 }
1009                             }
1010                             break;
1011                         case CV_32FC1:
1012                             for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
1013                             {
1014                                 if( findStumpThreshold_32f[stumperror]( 
1015                                         t_data + ti * t_cstep, t_sstep,
1016                                         wdata, wstep, ydata, ystep,
1017                                         sorteddata + ti * sortedcstep, sortedsstep, sortedm,
1018                                         &lerror, &rerror,
1019                                         &threshold, &left, &right, 
1020                                         &sumw, &sumwy, &sumwyy ) )
1021                                 {
1022                                     optcompidx = ti;
1023                                 }
1024                             }
1025                             break;
1026                         default:
1027                             assert( 0 );
1028                             break;
1029                     }
1030                 }
1031             }
1032
1033             ti = MAX( t_compidx, MIN( sortedn, t_compidx + t_n ) );
1034             for( ; ti < t_compidx + t_n; ti++ )
1035             {
1036                 va.data = t_data + ti * t_cstep;
1037                 va.step = t_sstep;
1038                 icvSortIndexedValArray_32s( t_idx, l, &va );
1039                 if( findStumpThreshold_32s[stumperror]( 
1040                         t_data + ti * t_cstep, t_sstep,
1041                         wdata, wstep, ydata, ystep,
1042                         (uchar*)t_idx, sizeof( int ), l,
1043                         &lerror, &rerror,
1044                         &threshold, &left, &right, 
1045                         &sumw, &sumwy, &sumwyy ) )
1046                 {
1047                     optcompidx = ti;
1048                 }
1049             }
1050             #ifdef _OPENMP
1051             #pragma omp critical(c_compidx)
1052             #endif /* _OPENMP */
1053             {
1054                 t_compidx = compidx;
1055                 compidx += portion;
1056             }
1057         } /* while have training data */
1058
1059         /* get the best classifier */
1060         #ifdef _OPENMP
1061         #pragma omp critical(c_beststump)
1062         #endif /* _OPENMP */
1063         {
1064             if( lerror + rerror < stump->lerror + stump->rerror )
1065             {
1066                 stump->lerror    = lerror;
1067                 stump->rerror    = rerror;
1068                 stump->compidx   = optcompidx;
1069                 stump->threshold = threshold;
1070                 stump->left      = left;
1071                 stump->right     = right;
1072             }
1073         }
1074
1075         /* free allocated memory */
1076         if( mat.data.ptr != NULL )
1077         {
1078             cvFree( &(mat.data.ptr) );
1079         }
1080         if( t_idx != NULL )
1081         {
1082             cvFree( &t_idx );
1083         }
1084     } /* end of parallel region */
1085
1086     /* END */
1087
1088     /* free allocated memory */
1089     if( filter != NULL )
1090     {
1091         cvFree( &filter );
1092     }
1093
1094     if( ((CvMTStumpTrainParams*) trainParams)->type == CV_CLASSIFICATION_CLASS )
1095     {
1096         stump->left = 2.0F * (stump->left >= 0.5F) - 1.0F;
1097         stump->right = 2.0F * (stump->right >= 0.5F) - 1.0F;
1098     }
1099
1100     return (CvClassifier*) stump;
1101 }
1102
1103 CV_BOOST_IMPL
1104 float cvEvalCARTClassifier( CvClassifier* classifier, CvMat* sample )
1105 {
1106     CV_FUNCNAME( "cvEvalCARTClassifier" );
1107
1108     int idx;
1109
1110     __BEGIN__;
1111
1112
1113     CV_ASSERT( classifier != NULL );
1114     CV_ASSERT( sample != NULL );
1115     CV_ASSERT( CV_MAT_TYPE( sample->type ) == CV_32FC1 );
1116     CV_ASSERT( sample->rows == 1 || sample->cols == 1 );
1117
1118     idx = 0;
1119     if( sample->rows == 1 )
1120     {
1121         do
1122         {
1123             if( (CV_MAT_ELEM( (*sample), float, 0,
1124                     ((CvCARTClassifier*) classifier)->compidx[idx] )) <
1125                 ((CvCARTClassifier*) classifier)->threshold[idx] ) 
1126             {
1127                 idx = ((CvCARTClassifier*) classifier)->left[idx];
1128             }
1129             else
1130             {
1131                 idx = ((CvCARTClassifier*) classifier)->right[idx];
1132             }
1133         } while( idx > 0 );
1134     }
1135     else
1136     {
1137         do
1138         {
1139             if( (CV_MAT_ELEM( (*sample), float,
1140                     ((CvCARTClassifier*) classifier)->compidx[idx], 0 )) <
1141                 ((CvCARTClassifier*) classifier)->threshold[idx] ) 
1142             {
1143                 idx = ((CvCARTClassifier*) classifier)->left[idx];
1144             }
1145             else
1146             {
1147                 idx = ((CvCARTClassifier*) classifier)->right[idx];
1148             }
1149         } while( idx > 0 );
1150     } 
1151
1152     __END__;
1153
1154     return ((CvCARTClassifier*) classifier)->val[-idx];
1155 }
1156
1157 CV_BOOST_IMPL
1158 float cvEvalCARTClassifierIdx( CvClassifier* classifier, CvMat* sample )
1159 {
1160     CV_FUNCNAME( "cvEvalCARTClassifierIdx" );
1161
1162     int idx;
1163
1164     __BEGIN__;
1165
1166
1167     CV_ASSERT( classifier != NULL );
1168     CV_ASSERT( sample != NULL );
1169     CV_ASSERT( CV_MAT_TYPE( sample->type ) == CV_32FC1 );
1170     CV_ASSERT( sample->rows == 1 || sample->cols == 1 );
1171
1172     idx = 0;
1173     if( sample->rows == 1 )
1174     {
1175         do
1176         {
1177             if( (CV_MAT_ELEM( (*sample), float, 0,
1178                     ((CvCARTClassifier*) classifier)->compidx[idx] )) <
1179                 ((CvCARTClassifier*) classifier)->threshold[idx] ) 
1180             {
1181                 idx = ((CvCARTClassifier*) classifier)->left[idx];
1182             }
1183             else
1184             {
1185                 idx = ((CvCARTClassifier*) classifier)->right[idx];
1186             }
1187         } while( idx > 0 );
1188     }
1189     else
1190     {
1191         do
1192         {
1193             if( (CV_MAT_ELEM( (*sample), float,
1194                     ((CvCARTClassifier*) classifier)->compidx[idx], 0 )) <
1195                 ((CvCARTClassifier*) classifier)->threshold[idx] ) 
1196             {
1197                 idx = ((CvCARTClassifier*) classifier)->left[idx];
1198             }
1199             else
1200             {
1201                 idx = ((CvCARTClassifier*) classifier)->right[idx];
1202             }
1203         } while( idx > 0 );
1204     } 
1205
1206     __END__;
1207
1208     return (float) (-idx);
1209 }
1210
1211 CV_BOOST_IMPL
1212 void cvReleaseCARTClassifier( CvClassifier** classifier )
1213 {
1214     cvFree( classifier );
1215     *classifier = NULL;
1216 }
1217
1218 void CV_CDECL icvDefaultSplitIdx_R( int compidx, float threshold,
1219                                     CvMat* idx, CvMat** left, CvMat** right,
1220                                     void* userdata )
1221 {
1222     CvMat* trainData = (CvMat*) userdata;
1223     int i = 0;
1224
1225     *left = cvCreateMat( 1, trainData->rows, CV_32FC1 );
1226     *right = cvCreateMat( 1, trainData->rows, CV_32FC1 );
1227     (*left)->cols = (*right)->cols = 0;
1228     if( idx == NULL )
1229     {
1230         for( i = 0; i < trainData->rows; i++ )
1231         {
1232             if( CV_MAT_ELEM( *trainData, float, i, compidx ) < threshold )
1233             {
1234                 (*left)->data.fl[(*left)->cols++] = (float) i;
1235             }
1236             else
1237             {
1238                 (*right)->data.fl[(*right)->cols++] = (float) i;
1239             }
1240         }
1241     }
1242     else
1243     {
1244         uchar* idxdata;
1245         int idxnum;
1246         int idxstep;
1247         int index;
1248
1249         idxdata = idx->data.ptr;
1250         idxnum = (idx->rows == 1) ? idx->cols : idx->rows;
1251         idxstep = (idx->rows == 1) ? CV_ELEM_SIZE( idx->type ) : idx->step;
1252         for( i = 0; i < idxnum; i++ )
1253         {
1254             index = (int) *((float*) (idxdata + i * idxstep));
1255             if( CV_MAT_ELEM( *trainData, float, index, compidx ) < threshold )
1256             {
1257                 (*left)->data.fl[(*left)->cols++] = (float) index;
1258             }
1259             else
1260             {
1261                 (*right)->data.fl[(*right)->cols++] = (float) index;
1262             }
1263         }
1264     }
1265 }
1266
1267 void CV_CDECL icvDefaultSplitIdx_C( int compidx, float threshold,
1268                                     CvMat* idx, CvMat** left, CvMat** right,
1269                                     void* userdata )
1270 {
1271     CvMat* trainData = (CvMat*) userdata;
1272     int i = 0;
1273
1274     *left = cvCreateMat( 1, trainData->cols, CV_32FC1 );
1275     *right = cvCreateMat( 1, trainData->cols, CV_32FC1 );
1276     (*left)->cols = (*right)->cols = 0;
1277     if( idx == NULL )
1278     {
1279         for( i = 0; i < trainData->cols; i++ )
1280         {
1281             if( CV_MAT_ELEM( *trainData, float, compidx, i ) < threshold )
1282             {
1283                 (*left)->data.fl[(*left)->cols++] = (float) i;
1284             }
1285             else
1286             {
1287                 (*right)->data.fl[(*right)->cols++] = (float) i;
1288             }
1289         }
1290     }
1291     else
1292     {
1293         uchar* idxdata;
1294         int idxnum;
1295         int idxstep;
1296         int index;
1297
1298         idxdata = idx->data.ptr;
1299         idxnum = (idx->rows == 1) ? idx->cols : idx->rows;
1300         idxstep = (idx->rows == 1) ? CV_ELEM_SIZE( idx->type ) : idx->step;
1301         for( i = 0; i < idxnum; i++ )
1302         {
1303             index = (int) *((float*) (idxdata + i * idxstep));
1304             if( CV_MAT_ELEM( *trainData, float, compidx, index ) < threshold )
1305             {
1306                 (*left)->data.fl[(*left)->cols++] = (float) index;
1307             }
1308             else
1309             {
1310                 (*right)->data.fl[(*right)->cols++] = (float) index;
1311             }
1312         }
1313     }
1314 }
1315
1316 /* internal structure used in CART creation */
1317 typedef struct CvCARTNode
1318 {
1319     CvMat* sampleIdx;
1320     CvStumpClassifier* stump;
1321     int parent;
1322     int leftflag;
1323     float errdrop;
1324 } CvCARTNode;
1325
1326 CV_BOOST_IMPL
1327 CvClassifier* cvCreateCARTClassifier( CvMat* trainData,
1328                      int flags,
1329                      CvMat* trainClasses,
1330                      CvMat* typeMask,
1331                      CvMat* missedMeasurementsMask,
1332                      CvMat* compIdx,
1333                      CvMat* sampleIdx,
1334                      CvMat* weights,
1335                      CvClassifierTrainParams* trainParams )
1336 {
1337     CvCARTClassifier* cart = NULL;
1338     size_t datasize = 0;
1339     int count = 0;
1340     int i = 0;
1341     int j = 0;
1342     
1343     CvCARTNode* intnode = NULL;
1344     CvCARTNode* list = NULL;
1345     int listcount = 0;
1346     CvMat* lidx = NULL;
1347     CvMat* ridx = NULL;
1348     
1349     float maxerrdrop = 0.0F;
1350     int idx = 0;
1351
1352     void (*splitIdxCallback)( int compidx, float threshold,
1353                               CvMat* idx, CvMat** left, CvMat** right,
1354                               void* userdata );
1355     void* userdata;
1356
1357     count = ((CvCARTTrainParams*) trainParams)->count;
1358     
1359     assert( count > 0 );
1360
1361     datasize = sizeof( *cart ) + (sizeof( float ) + 3 * sizeof( int )) * count + 
1362         sizeof( float ) * (count + 1);
1363     
1364     cart = (CvCARTClassifier*) cvAlloc( datasize );
1365     memset( cart, 0, datasize );
1366     
1367     cart->count = count;
1368     
1369     cart->eval = cvEvalCARTClassifier;
1370     cart->save = NULL;
1371     cart->release = cvReleaseCARTClassifier;
1372
1373     cart->compidx = (int*) (cart + 1);
1374     cart->threshold = (float*) (cart->compidx + count);
1375     cart->left  = (int*) (cart->threshold + count);
1376     cart->right = (int*) (cart->left + count);
1377     cart->val = (float*) (cart->right + count);
1378
1379     datasize = sizeof( CvCARTNode ) * (count + count);
1380     intnode = (CvCARTNode*) cvAlloc( datasize );
1381     memset( intnode, 0, datasize );
1382     list = (CvCARTNode*) (intnode + count);
1383
1384     splitIdxCallback = ((CvCARTTrainParams*) trainParams)->splitIdx;
1385     userdata = ((CvCARTTrainParams*) trainParams)->userdata;
1386     if( splitIdxCallback == NULL )
1387     {
1388         splitIdxCallback = ( CV_IS_ROW_SAMPLE( flags ) )
1389             ? icvDefaultSplitIdx_R : icvDefaultSplitIdx_C;
1390         userdata = trainData;
1391     }
1392
1393     /* create root of the tree */
1394     intnode[0].sampleIdx = sampleIdx;
1395     intnode[0].stump = (CvStumpClassifier*)
1396         ((CvCARTTrainParams*) trainParams)->stumpConstructor( trainData, flags,
1397             trainClasses, typeMask, missedMeasurementsMask, compIdx, sampleIdx, weights,
1398             ((CvCARTTrainParams*) trainParams)->stumpTrainParams );
1399     cart->left[0] = cart->right[0] = 0;
1400
1401     /* build tree */
1402     listcount = 0;
1403     for( i = 1; i < count; i++ )
1404     {
1405         /* split last added node */
1406         splitIdxCallback( intnode[i-1].stump->compidx, intnode[i-1].stump->threshold,
1407             intnode[i-1].sampleIdx, &lidx, &ridx, userdata );
1408         
1409         if( intnode[i-1].stump->lerror != 0.0F )
1410         {
1411             list[listcount].sampleIdx = lidx;
1412             list[listcount].stump = (CvStumpClassifier*)
1413                 ((CvCARTTrainParams*) trainParams)->stumpConstructor( trainData, flags,
1414                     trainClasses, typeMask, missedMeasurementsMask, compIdx,
1415                     list[listcount].sampleIdx,
1416                     weights, ((CvCARTTrainParams*) trainParams)->stumpTrainParams );
1417             list[listcount].errdrop = intnode[i-1].stump->lerror
1418                 - (list[listcount].stump->lerror + list[listcount].stump->rerror);
1419             list[listcount].leftflag = 1;
1420             list[listcount].parent = i-1;
1421             listcount++;
1422         }
1423         else
1424         {
1425             cvReleaseMat( &lidx );
1426         }
1427         if( intnode[i-1].stump->rerror != 0.0F )
1428         {
1429             list[listcount].sampleIdx = ridx;
1430             list[listcount].stump = (CvStumpClassifier*)
1431                 ((CvCARTTrainParams*) trainParams)->stumpConstructor( trainData, flags,
1432                     trainClasses, typeMask, missedMeasurementsMask, compIdx,
1433                     list[listcount].sampleIdx,
1434                     weights, ((CvCARTTrainParams*) trainParams)->stumpTrainParams );
1435             list[listcount].errdrop = intnode[i-1].stump->rerror
1436                 - (list[listcount].stump->lerror + list[listcount].stump->rerror);
1437             list[listcount].leftflag = 0;
1438             list[listcount].parent = i-1;
1439             listcount++;
1440         }
1441         else
1442         {
1443             cvReleaseMat( &ridx );
1444         }
1445         
1446         if( listcount == 0 ) break;
1447
1448         /* find the best node to be added to the tree */
1449         idx = 0;
1450         maxerrdrop = list[idx].errdrop;
1451         for( j = 1; j < listcount; j++ )
1452         {
1453             if( list[j].errdrop > maxerrdrop )
1454             {
1455                 idx = j;
1456                 maxerrdrop = list[j].errdrop;
1457             }
1458         }
1459         intnode[i] = list[idx];
1460         if( list[idx].leftflag )
1461         {
1462             cart->left[list[idx].parent] = i;
1463         }
1464         else
1465         {
1466             cart->right[list[idx].parent] = i;
1467         }
1468         if( idx != (listcount - 1) )
1469         {
1470             list[idx] = list[listcount - 1];
1471         }
1472         listcount--;
1473     }
1474
1475     /* fill <cart> fields */
1476     j = 0;
1477     cart->count = 0;
1478     for( i = 0; i < count && (intnode[i].stump != NULL); i++ )
1479     {
1480         cart->count++;
1481         cart->compidx[i] = intnode[i].stump->compidx;
1482         cart->threshold[i] = intnode[i].stump->threshold;
1483         
1484         /* leaves */
1485         if( cart->left[i] <= 0 )
1486         {
1487             cart->left[i] = -j;
1488             cart->val[j] = intnode[i].stump->left;
1489             j++;
1490         }
1491         if( cart->right[i] <= 0 )
1492         {
1493             cart->right[i] = -j;
1494             cart->val[j] = intnode[i].stump->right;
1495             j++;
1496         }
1497     }
1498     
1499     /* CLEAN UP */
1500     for( i = 0; i < count && (intnode[i].stump != NULL); i++ )
1501     {
1502         intnode[i].stump->release( (CvClassifier**) &(intnode[i].stump) );
1503         if( i != 0 )
1504         {
1505             cvReleaseMat( &(intnode[i].sampleIdx) );
1506         }
1507     }
1508     for( i = 0; i < listcount; i++ )
1509     {
1510         list[i].stump->release( (CvClassifier**) &(list[i].stump) );
1511         cvReleaseMat( &(list[i].sampleIdx) );
1512     }
1513     
1514     cvFree( &intnode );
1515
1516     return (CvClassifier*) cart;
1517 }
1518
1519 /****************************************************************************************\
1520 *                                        Boosting                                        *
1521 \****************************************************************************************/
1522
1523 typedef struct CvBoostTrainer
1524 {
1525     CvBoostType type;
1526     int count;             /* (idx) ? number_of_indices : number_of_samples */
1527     int* idx;
1528     float* F;
1529 } CvBoostTrainer;
1530
1531 /*
1532  * cvBoostStartTraining, cvBoostNextWeakClassifier, cvBoostEndTraining
1533  *
1534  * These functions perform training of 2-class boosting classifier
1535  * using ANY appropriate weak classifier
1536  */
1537
1538 CV_BOOST_IMPL
1539 CvBoostTrainer* icvBoostStartTraining( CvMat* trainClasses,
1540                                        CvMat* weakTrainVals,
1541                                        CvMat* weights,
1542                                        CvMat* sampleIdx,
1543                                        CvBoostType type )
1544 {
1545     uchar* ydata;
1546     int ystep;
1547     int m;
1548     uchar* traindata;
1549     int trainstep;
1550     int trainnum;
1551     int i;
1552     int idx;
1553
1554     size_t datasize;
1555     CvBoostTrainer* ptr;
1556
1557     int idxnum;
1558     int idxstep;
1559     uchar* idxdata;
1560
1561     assert( trainClasses != NULL );
1562     assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
1563     assert( weakTrainVals != NULL );
1564     assert( CV_MAT_TYPE( weakTrainVals->type ) == CV_32FC1 );
1565
1566     CV_MAT2VEC( *trainClasses, ydata, ystep, m );
1567     CV_MAT2VEC( *weakTrainVals, traindata, trainstep, trainnum );
1568
1569     assert( m == trainnum );
1570
1571     idxnum = 0;
1572     idxstep = 0;
1573     idxdata = NULL;
1574     if( sampleIdx )
1575     {
1576         CV_MAT2VEC( *sampleIdx, idxdata, idxstep, idxnum );
1577     }
1578         
1579     datasize = sizeof( *ptr ) + sizeof( *ptr->idx ) * idxnum;
1580     ptr = (CvBoostTrainer*) cvAlloc( datasize );
1581     memset( ptr, 0, datasize );
1582     ptr->F = NULL;
1583     ptr->idx = NULL;
1584
1585     ptr->count = m;
1586     ptr->type = type;
1587     
1588     if( idxnum > 0 )
1589     {
1590         CvScalar s;
1591
1592         ptr->idx = (int*) (ptr + 1);
1593         ptr->count = idxnum;
1594         for( i = 0; i < ptr->count; i++ )
1595         {
1596             cvRawDataToScalar( idxdata + i*idxstep, CV_MAT_TYPE( sampleIdx->type ), &s );
1597             ptr->idx[i] = (int) s.val[0];
1598         }
1599     }
1600     for( i = 0; i < ptr->count; i++ )
1601     {
1602         idx = (ptr->idx) ? ptr->idx[i] : i;
1603
1604         *((float*) (traindata + idx * trainstep)) = 
1605             2.0F * (*((float*) (ydata + idx * ystep))) - 1.0F;
1606     }
1607
1608     return ptr;
1609 }
1610
1611 /*
1612  *
1613  * Discrete AdaBoost functions
1614  *
1615  */
1616 CV_BOOST_IMPL
1617 float icvBoostNextWeakClassifierDAB( CvMat* weakEvalVals,
1618                                      CvMat* trainClasses,
1619                                      CvMat* weakTrainVals,
1620                                      CvMat* weights,
1621                                      CvBoostTrainer* trainer )
1622 {
1623     uchar* evaldata;
1624     int evalstep;
1625     int m;
1626     uchar* ydata;
1627     int ystep;
1628     int ynum;
1629     uchar* wdata;
1630     int wstep;
1631     int wnum;
1632
1633     float sumw;
1634     float err;
1635     int i;
1636     int idx;
1637
1638     assert( weakEvalVals != NULL );
1639     assert( CV_MAT_TYPE( weakEvalVals->type ) == CV_32FC1 );
1640     assert( trainClasses != NULL );
1641     assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
1642     assert( weights != NULL );
1643     assert( CV_MAT_TYPE( weights ->type ) == CV_32FC1 );
1644
1645     CV_MAT2VEC( *weakEvalVals, evaldata, evalstep, m );
1646     CV_MAT2VEC( *trainClasses, ydata, ystep, ynum );
1647     CV_MAT2VEC( *weights, wdata, wstep, wnum );
1648
1649     assert( m == ynum );
1650     assert( m == wnum );
1651
1652     sumw = 0.0F;
1653     err = 0.0F;
1654     for( i = 0; i < trainer->count; i++ )
1655     {
1656         idx = (trainer->idx) ? trainer->idx[i] : i;
1657
1658         sumw += *((float*) (wdata + idx*wstep));
1659         err += (*((float*) (wdata + idx*wstep))) *
1660             ( (*((float*) (evaldata + idx*evalstep))) != 
1661                 2.0F * (*((float*) (ydata + idx*ystep))) - 1.0F );
1662     }
1663     err /= sumw;
1664     err = -cvLogRatio( err );
1665     
1666     for( i = 0; i < trainer->count; i++ )
1667     {
1668         idx = (trainer->idx) ? trainer->idx[i] : i;
1669
1670         *((float*) (wdata + idx*wstep)) *= expf( err * 
1671             ((*((float*) (evaldata + idx*evalstep))) != 
1672                 2.0F * (*((float*) (ydata + idx*ystep))) - 1.0F) );
1673         sumw += *((float*) (wdata + idx*wstep));
1674     }
1675     for( i = 0; i < trainer->count; i++ )
1676     {
1677         idx = (trainer->idx) ? trainer->idx[i] : i;
1678
1679         *((float*) (wdata + idx * wstep)) /= sumw;
1680     }
1681     
1682     return err;
1683 }
1684
1685 /*
1686  *
1687  * Real AdaBoost functions
1688  *
1689  */
1690 CV_BOOST_IMPL
1691 float icvBoostNextWeakClassifierRAB( CvMat* weakEvalVals,
1692                                      CvMat* trainClasses,
1693                                      CvMat* weakTrainVals,
1694                                      CvMat* weights,
1695                                      CvBoostTrainer* trainer )
1696 {
1697     uchar* evaldata;
1698     int evalstep;
1699     int m;
1700     uchar* ydata;
1701     int ystep;
1702     int ynum;
1703     uchar* wdata;
1704     int wstep;
1705     int wnum;
1706
1707     float sumw;
1708     int i, idx;
1709
1710     assert( weakEvalVals != NULL );
1711     assert( CV_MAT_TYPE( weakEvalVals->type ) == CV_32FC1 );
1712     assert( trainClasses != NULL );
1713     assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
1714     assert( weights != NULL );
1715     assert( CV_MAT_TYPE( weights ->type ) == CV_32FC1 );
1716
1717     CV_MAT2VEC( *weakEvalVals, evaldata, evalstep, m );
1718     CV_MAT2VEC( *trainClasses, ydata, ystep, ynum );
1719     CV_MAT2VEC( *weights, wdata, wstep, wnum );
1720
1721     assert( m == ynum );
1722     assert( m == wnum );
1723
1724
1725     sumw = 0.0F;
1726     for( i = 0; i < trainer->count; i++ )
1727     {
1728         idx = (trainer->idx) ? trainer->idx[i] : i;
1729
1730         *((float*) (wdata + idx*wstep)) *= expf( (-(*((float*) (ydata + idx*ystep))) + 0.5F)
1731             * cvLogRatio( *((float*) (evaldata + idx*evalstep)) ) );
1732         sumw += *((float*) (wdata + idx*wstep));
1733     }
1734     for( i = 0; i < trainer->count; i++ )
1735     {
1736         idx = (trainer->idx) ? trainer->idx[i] : i;
1737
1738         *((float*) (wdata + idx*wstep)) /= sumw;
1739     }
1740     
1741     return 1.0F;
1742 }
1743
1744 /*
1745  *
1746  * LogitBoost functions
1747  *
1748  */
1749 #define CV_LB_PROB_THRESH      0.01F
1750 #define CV_LB_WEIGHT_THRESHOLD 0.0001F
1751
1752 CV_BOOST_IMPL
1753 void icvResponsesAndWeightsLB( int num, uchar* wdata, int wstep,
1754                                uchar* ydata, int ystep,
1755                                uchar* fdata, int fstep,
1756                                uchar* traindata, int trainstep,
1757                                int* indices )
1758 {
1759     int i, idx;
1760     float p;
1761
1762     for( i = 0; i < num; i++ )
1763     {
1764         idx = (indices) ? indices[i] : i;
1765
1766         p = 1.0F / (1.0F + expf( -(*((float*) (fdata + idx*fstep)))) );
1767         *((float*) (wdata + idx*wstep)) = MAX( p * (1.0F - p), CV_LB_WEIGHT_THRESHOLD );
1768         if( *((float*) (ydata + idx*ystep)) == 1.0F )
1769         {
1770             *((float*) (traindata + idx*trainstep)) = 
1771                 1.0F / (MAX( p, CV_LB_PROB_THRESH ));
1772         }
1773         else
1774         {
1775             *((float*) (traindata + idx*trainstep)) = 
1776                 -1.0F / (MAX( 1.0F - p, CV_LB_PROB_THRESH ));
1777         }
1778     }
1779 }
1780
1781 CV_BOOST_IMPL
1782 CvBoostTrainer* icvBoostStartTrainingLB( CvMat* trainClasses,
1783                                          CvMat* weakTrainVals,
1784                                          CvMat* weights,
1785                                          CvMat* sampleIdx,
1786                                          CvBoostType type )
1787 {
1788     size_t datasize;
1789     CvBoostTrainer* ptr;
1790
1791     uchar* ydata;
1792     int ystep;
1793     int m;
1794     uchar* traindata;
1795     int trainstep;
1796     int trainnum;
1797     uchar* wdata;
1798     int wstep;
1799     int wnum;
1800     int i;
1801
1802     int idxnum;
1803     int idxstep;
1804     uchar* idxdata;
1805
1806     assert( trainClasses != NULL );
1807     assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
1808     assert( weakTrainVals != NULL );
1809     assert( CV_MAT_TYPE( weakTrainVals->type ) == CV_32FC1 );
1810     assert( weights != NULL );
1811     assert( CV_MAT_TYPE( weights->type ) == CV_32FC1 );
1812
1813     CV_MAT2VEC( *trainClasses, ydata, ystep, m );
1814     CV_MAT2VEC( *weakTrainVals, traindata, trainstep, trainnum );
1815     CV_MAT2VEC( *weights, wdata, wstep, wnum );
1816
1817     assert( m == trainnum );
1818     assert( m == wnum );
1819
1820
1821     idxnum = 0;
1822     idxstep = 0;
1823     idxdata = NULL;
1824     if( sampleIdx )
1825     {
1826         CV_MAT2VEC( *sampleIdx, idxdata, idxstep, idxnum );
1827     }
1828         
1829     datasize = sizeof( *ptr ) + sizeof( *ptr->F ) * m + sizeof( *ptr->idx ) * idxnum;
1830     ptr = (CvBoostTrainer*) cvAlloc( datasize );
1831     memset( ptr, 0, datasize );
1832     ptr->F = (float*) (ptr + 1);
1833     ptr->idx = NULL;
1834
1835     ptr->count = m;
1836     ptr->type = type;
1837     
1838     if( idxnum > 0 )
1839     {
1840         CvScalar s;
1841
1842         ptr->idx = (int*) (ptr->F + m);
1843         ptr->count = idxnum;
1844         for( i = 0; i < ptr->count; i++ )
1845         {
1846             cvRawDataToScalar( idxdata + i*idxstep, CV_MAT_TYPE( sampleIdx->type ), &s );
1847             ptr->idx[i] = (int) s.val[0];
1848         }
1849     }
1850
1851     for( i = 0; i < m; i++ )
1852     {
1853         ptr->F[i] = 0.0F;
1854     }
1855
1856     icvResponsesAndWeightsLB( ptr->count, wdata, wstep, ydata, ystep,
1857                               (uchar*) ptr->F, sizeof( *ptr->F ),
1858                               traindata, trainstep, ptr->idx );
1859
1860     return ptr;
1861 }
1862
1863 CV_BOOST_IMPL
1864 float icvBoostNextWeakClassifierLB( CvMat* weakEvalVals,
1865                                     CvMat* trainClasses,
1866                                     CvMat* weakTrainVals,
1867                                     CvMat* weights,
1868                                     CvBoostTrainer* trainer )
1869 {
1870     uchar* evaldata;
1871     int evalstep;
1872     int m;
1873     uchar* ydata;
1874     int ystep;
1875     int ynum;
1876     uchar* traindata;
1877     int trainstep;
1878     int trainnum;
1879     uchar* wdata;
1880     int wstep;
1881     int wnum;
1882     int i, idx;
1883
1884     assert( weakEvalVals != NULL );
1885     assert( CV_MAT_TYPE( weakEvalVals->type ) == CV_32FC1 );
1886     assert( trainClasses != NULL );
1887     assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
1888     assert( weakTrainVals != NULL );
1889     assert( CV_MAT_TYPE( weakTrainVals->type ) == CV_32FC1 );
1890     assert( weights != NULL );
1891     assert( CV_MAT_TYPE( weights ->type ) == CV_32FC1 );
1892
1893     CV_MAT2VEC( *weakEvalVals, evaldata, evalstep, m );
1894     CV_MAT2VEC( *trainClasses, ydata, ystep, ynum );
1895     CV_MAT2VEC( *weakTrainVals, traindata, trainstep, trainnum );
1896     CV_MAT2VEC( *weights, wdata, wstep, wnum );
1897
1898     assert( m == ynum );
1899     assert( m == wnum );
1900     assert( m == trainnum );
1901     //assert( m == trainer->count );
1902
1903     for( i = 0; i < trainer->count; i++ )
1904     {
1905         idx = (trainer->idx) ? trainer->idx[i] : i;
1906
1907         trainer->F[idx] += *((float*) (evaldata + idx * evalstep));
1908     }
1909     
1910     icvResponsesAndWeightsLB( trainer->count, wdata, wstep, ydata, ystep,
1911                               (uchar*) trainer->F, sizeof( *trainer->F ),
1912                               traindata, trainstep, trainer->idx );
1913
1914     return 1.0F;
1915 }
1916
1917 /*
1918  *
1919  * Gentle AdaBoost
1920  *
1921  */
1922 CV_BOOST_IMPL
1923 float icvBoostNextWeakClassifierGAB( CvMat* weakEvalVals,
1924                                      CvMat* trainClasses,
1925                                      CvMat* weakTrainVals,
1926                                      CvMat* weights,
1927                                      CvBoostTrainer* trainer )
1928 {
1929     uchar* evaldata;
1930     int evalstep;
1931     int m;
1932     uchar* ydata;
1933     int ystep;
1934     int ynum;
1935     uchar* wdata;
1936     int wstep;
1937     int wnum;
1938
1939     int i, idx;
1940     float sumw;
1941
1942     assert( weakEvalVals != NULL );
1943     assert( CV_MAT_TYPE( weakEvalVals->type ) == CV_32FC1 );
1944     assert( trainClasses != NULL );
1945     assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
1946     assert( weights != NULL );
1947     assert( CV_MAT_TYPE( weights->type ) == CV_32FC1 );
1948
1949     CV_MAT2VEC( *weakEvalVals, evaldata, evalstep, m );
1950     CV_MAT2VEC( *trainClasses, ydata, ystep, ynum );
1951     CV_MAT2VEC( *weights, wdata, wstep, wnum );
1952
1953     assert( m == ynum );
1954     assert( m == wnum );
1955
1956     sumw = 0.0F;
1957     for( i = 0; i < trainer->count; i++ )
1958     {
1959         idx = (trainer->idx) ? trainer->idx[i] : i;
1960
1961         *((float*) (wdata + idx*wstep)) *= 
1962             expf( -(*((float*) (evaldata + idx*evalstep)))
1963                   * ( 2.0F * (*((float*) (ydata + idx*ystep))) - 1.0F ) );
1964         sumw += *((float*) (wdata + idx*wstep));
1965     }
1966     
1967     for( i = 0; i < trainer->count; i++ )
1968     {
1969         idx = (trainer->idx) ? trainer->idx[i] : i;
1970
1971         *((float*) (wdata + idx*wstep)) /= sumw;
1972     }
1973
1974     return 1.0F;
1975 }
1976
1977 typedef CvBoostTrainer* (*CvBoostStartTraining)( CvMat* trainClasses,
1978                                                  CvMat* weakTrainVals,
1979                                                  CvMat* weights,
1980                                                  CvMat* sampleIdx,
1981                                                  CvBoostType type );
1982
1983 typedef float (*CvBoostNextWeakClassifier)( CvMat* weakEvalVals,
1984                                             CvMat* trainClasses,
1985                                             CvMat* weakTrainVals,
1986                                             CvMat* weights,
1987                                             CvBoostTrainer* data );
1988
1989 CvBoostStartTraining startTraining[4] = {
1990         icvBoostStartTraining,
1991         icvBoostStartTraining,
1992         icvBoostStartTrainingLB,
1993         icvBoostStartTraining
1994     };
1995
1996 CvBoostNextWeakClassifier nextWeakClassifier[4] = {
1997         icvBoostNextWeakClassifierDAB,
1998         icvBoostNextWeakClassifierRAB,
1999         icvBoostNextWeakClassifierLB,
2000         icvBoostNextWeakClassifierGAB
2001     };
2002
2003 /*
2004  *
2005  * Dispatchers
2006  *
2007  */
2008 CV_BOOST_IMPL
2009 CvBoostTrainer* cvBoostStartTraining( CvMat* trainClasses,
2010                                       CvMat* weakTrainVals,
2011                                       CvMat* weights,
2012                                       CvMat* sampleIdx,
2013                                       CvBoostType type )
2014 {
2015     return startTraining[type]( trainClasses, weakTrainVals, weights, sampleIdx, type );
2016 }
2017
2018 CV_BOOST_IMPL
2019 void cvBoostEndTraining( CvBoostTrainer** trainer )
2020 {
2021     cvFree( trainer );
2022     *trainer = NULL;
2023 }
2024
2025 CV_BOOST_IMPL
2026 float cvBoostNextWeakClassifier( CvMat* weakEvalVals,
2027                                  CvMat* trainClasses,
2028                                  CvMat* weakTrainVals,
2029                                  CvMat* weights,
2030                                  CvBoostTrainer* trainer )
2031 {
2032     return nextWeakClassifier[trainer->type]( weakEvalVals, trainClasses,
2033         weakTrainVals, weights, trainer    );
2034 }
2035
2036 /****************************************************************************************\
2037 *                                    Boosted tree models                                 *
2038 \****************************************************************************************/
2039
2040 typedef struct CvBtTrainer
2041 {
2042     /* {{ external */    
2043     CvMat* trainData;
2044     int flags;
2045     
2046     CvMat* trainClasses;
2047     int m;
2048     uchar* ydata;
2049     int ystep;
2050
2051     CvMat* sampleIdx;
2052     int numsamples;
2053     
2054     float param[2];
2055     CvBoostType type;
2056     int numclasses;
2057     /* }} external */
2058
2059     CvMTStumpTrainParams stumpParams;
2060     CvCARTTrainParams  cartParams;
2061
2062     float* f;          /* F_(m-1) */
2063     CvMat* y;          /* yhat    */
2064     CvMat* weights;
2065     CvBoostTrainer* boosttrainer;
2066 } CvBtTrainer;
2067
2068 /*
2069  * cvBtStart, cvBtNext, cvBtEnd
2070  *
2071  * These functions perform iterative training of
2072  * 2-class (CV_DABCLASS - CV_GABCLASS, CV_L2CLASS), K-class (CV_LKCLASS) classifier
2073  * or fit regression model (CV_LSREG, CV_LADREG, CV_MREG)
2074  * using decision tree as a weak classifier.
2075  */
2076
2077 typedef void (*CvZeroApproxFunc)( float* approx, CvBtTrainer* trainer );
2078
2079 /* Mean zero approximation */
2080 void icvZeroApproxMean( float* approx, CvBtTrainer* trainer )
2081 {
2082     int i;
2083     int idx;
2084
2085     approx[0] = 0.0F;
2086     for( i = 0; i < trainer->numsamples; i++ )
2087     {
2088         idx = icvGetIdxAt( trainer->sampleIdx, i );
2089         approx[0] += *((float*) (trainer->ydata + idx * trainer->ystep));
2090     }
2091     approx[0] /= (float) trainer->numsamples;
2092 }
2093
2094 /*
2095  * Median zero approximation
2096  */
2097 void icvZeroApproxMed( float* approx, CvBtTrainer* trainer )
2098 {
2099     int i;
2100     int idx;
2101
2102     for( i = 0; i < trainer->numsamples; i++ )
2103     {
2104         idx = icvGetIdxAt( trainer->sampleIdx, i );
2105         trainer->f[i] = *((float*) (trainer->ydata + idx * trainer->ystep));
2106     }
2107     
2108     icvSort_32f( trainer->f, trainer->numsamples, 0 );
2109     approx[0] = trainer->f[trainer->numsamples / 2];
2110 }
2111
2112 /*
2113  * 0.5 * log( mean(y) / (1 - mean(y)) ) where y in {0, 1}
2114  */
2115 void icvZeroApproxLog( float* approx, CvBtTrainer* trainer )
2116 {
2117     float y_mean;
2118
2119     icvZeroApproxMean( &y_mean, trainer );
2120     approx[0] = 0.5F * cvLogRatio( y_mean );
2121 }
2122
2123 /*
2124  * 0 zero approximation
2125  */
2126 void icvZeroApprox0( float* approx, CvBtTrainer* trainer )
2127 {
2128     int i;
2129
2130     for( i = 0; i < trainer->numclasses; i++ )
2131     {
2132         approx[i] = 0.0F;
2133     }
2134 }
2135
2136 static CvZeroApproxFunc icvZeroApproxFunc[] =
2137 {
2138     icvZeroApprox0,    /* CV_DABCLASS */
2139     icvZeroApprox0,    /* CV_RABCLASS */
2140     icvZeroApprox0,    /* CV_LBCLASS  */
2141     icvZeroApprox0,    /* CV_GABCLASS */
2142     icvZeroApproxLog,  /* CV_L2CLASS  */
2143     icvZeroApprox0,    /* CV_LKCLASS  */
2144     icvZeroApproxMean, /* CV_LSREG    */
2145     icvZeroApproxMed,  /* CV_LADREG   */
2146     icvZeroApproxMed,  /* CV_MREG     */
2147 };
2148
2149 CV_BOOST_IMPL
2150 void cvBtNext( CvCARTClassifier** trees, CvBtTrainer* trainer );
2151
2152 CV_BOOST_IMPL
2153 CvBtTrainer* cvBtStart( CvCARTClassifier** trees,
2154                         CvMat* trainData,
2155                         int flags,
2156                         CvMat* trainClasses,
2157                         CvMat* sampleIdx,
2158                         int numsplits,
2159                         CvBoostType type,
2160                         int numclasses,
2161                         float* param )
2162 {
2163     CvBtTrainer* ptr;
2164
2165     CV_FUNCNAME( "cvBtStart" );
2166
2167     __BEGIN__;
2168
2169     size_t data_size;
2170     float* zero_approx;
2171     int m;
2172     int i, j;
2173     
2174     if( trees == NULL )
2175     {
2176         CV_ERROR( CV_StsNullPtr, "Invalid trees parameter" );
2177     }
2178     
2179     if( type < CV_DABCLASS || type > CV_MREG ) 
2180     {
2181         CV_ERROR( CV_StsUnsupportedFormat, "Unsupported type parameter" );
2182     }
2183     if( type == CV_LKCLASS )
2184     {
2185         CV_ASSERT( numclasses >= 2 );
2186     }
2187     else
2188     {
2189         numclasses = 1;
2190     }
2191
2192     m = MAX( trainClasses->rows, trainClasses->cols );
2193     ptr = NULL;
2194     data_size = sizeof( *ptr );
2195     if( type > CV_GABCLASS )
2196     {
2197         data_size += m * numclasses * sizeof( *(ptr->f) );
2198     }
2199     CV_CALL( ptr = (CvBtTrainer*) cvAlloc( data_size ) );
2200     memset( ptr, 0, data_size );
2201     ptr->f = (float*) (ptr + 1);
2202
2203     ptr->trainData = trainData;
2204     ptr->flags = flags;
2205     ptr->trainClasses = trainClasses;
2206     CV_MAT2VEC( *trainClasses, ptr->ydata, ptr->ystep, ptr->m );
2207     
2208     memset( &(ptr->cartParams), 0, sizeof( ptr->cartParams ) );
2209     memset( &(ptr->stumpParams), 0, sizeof( ptr->stumpParams ) );
2210
2211     switch( type )
2212     {
2213         case CV_DABCLASS:
2214             ptr->stumpParams.error = CV_MISCLASSIFICATION;
2215             ptr->stumpParams.type  = CV_CLASSIFICATION_CLASS;
2216             break;
2217         case CV_RABCLASS:
2218             ptr->stumpParams.error = CV_GINI;
2219             ptr->stumpParams.type  = CV_CLASSIFICATION;
2220             break;
2221         default:
2222             ptr->stumpParams.error = CV_SQUARE;
2223             ptr->stumpParams.type  = CV_REGRESSION;
2224     }
2225     ptr->cartParams.count = numsplits;
2226     ptr->cartParams.stumpTrainParams = (CvClassifierTrainParams*) &(ptr->stumpParams);
2227     ptr->cartParams.stumpConstructor = cvCreateMTStumpClassifier;
2228
2229     ptr->param[0] = param[0];
2230     ptr->param[1] = param[1];
2231     ptr->type = type;
2232     ptr->numclasses = numclasses;
2233
2234     CV_CALL( ptr->y = cvCreateMat( 1, m, CV_32FC1 ) );
2235     ptr->sampleIdx = sampleIdx;
2236     ptr->numsamples = ( sampleIdx == NULL ) ? ptr->m
2237                              : MAX( sampleIdx->rows, sampleIdx->cols );
2238     
2239     ptr->weights = cvCreateMat( 1, m, CV_32FC1 );
2240     cvSet( ptr->weights, cvScalar( 1.0 ) );    
2241     
2242     if( type <= CV_GABCLASS )
2243     {
2244         ptr->boosttrainer = cvBoostStartTraining( ptr->trainClasses, ptr->y,
2245             ptr->weights, NULL, type );
2246
2247         CV_CALL( cvBtNext( trees, ptr ) );
2248     }
2249     else
2250     {
2251         data_size = sizeof( *zero_approx ) * numclasses;
2252         CV_CALL( zero_approx = (float*) cvAlloc( data_size ) );
2253         icvZeroApproxFunc[type]( zero_approx, ptr );
2254         for( i = 0; i < m; i++ )
2255         {
2256             for( j = 0; j < numclasses; j++ )
2257             {
2258                 ptr->f[i * numclasses + j] = zero_approx[j];
2259             }
2260         }
2261
2262         CV_CALL( cvBtNext( trees, ptr ) );
2263
2264         for( i = 0; i < numclasses; i++ )
2265         {
2266             for( j = 0; j <= trees[i]->count; j++ )
2267             {
2268                 trees[i]->val[j] += zero_approx[i];
2269             }
2270         }    
2271         CV_CALL( cvFree( &zero_approx ) );
2272     }
2273
2274     __END__;
2275
2276     return ptr;
2277 }
2278
2279 void icvBtNext_LSREG( CvCARTClassifier** trees, CvBtTrainer* trainer )
2280 {
2281     int i;
2282
2283     /* yhat_i = y_i - F_(m-1)(x_i) */
2284     for( i = 0; i < trainer->m; i++ )
2285     {
2286         trainer->y->data.fl[i] = 
2287             *((float*) (trainer->ydata + i * trainer->ystep)) - trainer->f[i];
2288     }
2289
2290     trees[0] = (CvCARTClassifier*) cvCreateCARTClassifier( trainer->trainData,
2291         trainer->flags,
2292         trainer->y, NULL, NULL, NULL, trainer->sampleIdx, trainer->weights,
2293         (CvClassifierTrainParams*) &trainer->cartParams );
2294 }
2295
2296
2297 void icvBtNext_LADREG( CvCARTClassifier** trees, CvBtTrainer* trainer )
2298 {
2299     CvCARTClassifier* ptr;
2300     int i, j;
2301     CvMat sample;
2302     int sample_step;
2303     uchar* sample_data;
2304     int index;
2305     
2306     int data_size;
2307     int* idx;
2308     float* resp;
2309     int respnum;
2310     float val;
2311
2312     data_size = trainer->m * sizeof( *idx );
2313     idx = (int*) cvAlloc( data_size );
2314     data_size = trainer->m * sizeof( *resp );
2315     resp = (float*) cvAlloc( data_size );
2316
2317     /* yhat_i = sign(y_i - F_(m-1)(x_i)) */
2318     for( i = 0; i < trainer->numsamples; i++ )
2319     {
2320         index = icvGetIdxAt( trainer->sampleIdx, i );
2321         trainer->y->data.fl[index] = (float)
2322              CV_SIGN( *((float*) (trainer->ydata + index * trainer->ystep))
2323                      - trainer->f[index] );
2324     }
2325
2326     ptr = (CvCARTClassifier*) cvCreateCARTClassifier( trainer->trainData, trainer->flags,
2327         trainer->y, NULL, NULL, NULL, trainer->sampleIdx, trainer->weights,
2328         (CvClassifierTrainParams*) &trainer->cartParams );
2329
2330     CV_GET_SAMPLE( *trainer->trainData, trainer->flags, 0, sample );
2331     CV_GET_SAMPLE_STEP( *trainer->trainData, trainer->flags, sample_step );
2332     sample_data = sample.data.ptr;
2333     for( i = 0; i < trainer->numsamples; i++ )
2334     {
2335         index = icvGetIdxAt( trainer->sampleIdx, i );
2336         sample.data.ptr = sample_data + index * sample_step;
2337         idx[index] = (int) cvEvalCARTClassifierIdx( (CvClassifier*) ptr, &sample );
2338     }
2339     for( j = 0; j <= ptr->count; j++ )
2340     {
2341         respnum = 0;
2342         for( i = 0; i < trainer->numsamples; i++ )
2343         {
2344             index = icvGetIdxAt( trainer->sampleIdx, i );
2345             if( idx[index] == j )
2346             {
2347                 resp[respnum++] = *((float*) (trainer->ydata + index * trainer->ystep))
2348                                   - trainer->f[index];
2349             }
2350         }
2351         if( respnum > 0 )
2352         {
2353             icvSort_32f( resp, respnum, 0 );
2354             val = resp[respnum / 2];
2355         }
2356         else
2357         {
2358             val = 0.0F;
2359         }
2360         ptr->val[j] = val;
2361     }
2362
2363     cvFree( &idx );
2364     cvFree( &resp );
2365     
2366     trees[0] = ptr;
2367 }
2368
2369
2370 void icvBtNext_MREG( CvCARTClassifier** trees, CvBtTrainer* trainer )
2371 {
2372     CvCARTClassifier* ptr;
2373     int i, j;
2374     CvMat sample;
2375     int sample_step;
2376     uchar* sample_data;
2377     
2378     int data_size;
2379     int* idx;
2380     float* resid;
2381     float* resp;
2382     int respnum;
2383     float rhat;
2384     float val;
2385     float delta;
2386     int index;
2387
2388     data_size = trainer->m * sizeof( *idx );
2389     idx = (int*) cvAlloc( data_size );
2390     data_size = trainer->m * sizeof( *resp );
2391     resp = (float*) cvAlloc( data_size );
2392     data_size = trainer->m * sizeof( *resid );
2393     resid = (float*) cvAlloc( data_size );
2394
2395     /* resid_i = (y_i - F_(m-1)(x_i)) */
2396     for( i = 0; i < trainer->numsamples; i++ )
2397     {
2398         index = icvGetIdxAt( trainer->sampleIdx, i );
2399         resid[index] = *((float*) (trainer->ydata + index * trainer->ystep))
2400                        - trainer->f[index];
2401         /* for delta */
2402         resp[i] = (float) fabs( resid[index] );
2403     }
2404     
2405     /* delta = quantile_alpha{abs(resid_i)} */
2406     icvSort_32f( resp, trainer->numsamples, 0 );
2407     delta = resp[(int)(trainer->param[1] * (trainer->numsamples - 1))];
2408
2409     /* yhat_i */
2410     for( i = 0; i < trainer->numsamples; i++ )
2411     {
2412         index = icvGetIdxAt( trainer->sampleIdx, i );
2413         trainer->y->data.fl[index] = MIN( delta, ((float) fabs( resid[index] )) ) *
2414                                  CV_SIGN( resid[index] );
2415     }
2416     
2417     ptr = (CvCARTClassifier*) cvCreateCARTClassifier( trainer->trainData, trainer->flags,
2418         trainer->y, NULL, NULL, NULL, trainer->sampleIdx, trainer->weights,
2419         (CvClassifierTrainParams*) &trainer->cartParams );
2420
2421     CV_GET_SAMPLE( *trainer->trainData, trainer->flags, 0, sample );
2422     CV_GET_SAMPLE_STEP( *trainer->trainData, trainer->flags, sample_step );
2423     sample_data = sample.data.ptr;
2424     for( i = 0; i < trainer->numsamples; i++ )
2425     {
2426         index = icvGetIdxAt( trainer->sampleIdx, i );
2427         sample.data.ptr = sample_data + index * sample_step;
2428         idx[index] = (int) cvEvalCARTClassifierIdx( (CvClassifier*) ptr, &sample );
2429     }
2430     for( j = 0; j <= ptr->count; j++ )
2431     {
2432         respnum = 0;
2433
2434         for( i = 0; i < trainer->numsamples; i++ )
2435         {
2436             index = icvGetIdxAt( trainer->sampleIdx, i );
2437             if( idx[index] == j )
2438             {
2439                 resp[respnum++] = *((float*) (trainer->ydata + index * trainer->ystep))
2440                                   - trainer->f[index];
2441             }
2442         }
2443         if( respnum > 0 )
2444         {
2445             /* rhat = median(y_i - F_(m-1)(x_i)) */
2446             icvSort_32f( resp, respnum, 0 );
2447             rhat = resp[respnum / 2];
2448             
2449             /* val = sum{sign(r_i - rhat_i) * min(delta, abs(r_i - rhat_i)}
2450              * r_i = y_i - F_(m-1)(x_i)
2451              */
2452             val = 0.0F;
2453             for( i = 0; i < respnum; i++ )
2454             {
2455                 val += CV_SIGN( resp[i] - rhat )
2456                        * MIN( delta, (float) fabs( resp[i] - rhat ) );
2457             }
2458
2459             val = rhat + val / (float) respnum;
2460         }
2461         else
2462         {
2463             val = 0.0F;
2464         }
2465
2466         ptr->val[j] = val;
2467
2468     }
2469
2470     cvFree( &resid );
2471     cvFree( &resp );
2472     cvFree( &idx );
2473     
2474     trees[0] = ptr;
2475 }
2476
2477 //#define CV_VAL_MAX 1e304
2478
2479 //#define CV_LOG_VAL_MAX 700.0
2480
2481 #define CV_VAL_MAX 1e+8
2482
2483 #define CV_LOG_VAL_MAX 18.0
2484
2485 void icvBtNext_L2CLASS( CvCARTClassifier** trees, CvBtTrainer* trainer )
2486 {
2487     CvCARTClassifier* ptr;
2488     int i, j;
2489     CvMat sample;
2490     int sample_step;
2491     uchar* sample_data;
2492     
2493     int data_size;
2494     int* idx;
2495     int respnum;
2496     float val;
2497     double val_f;
2498
2499     float sum_weights;
2500     float* weights;
2501     float* sorted_weights;
2502     CvMat* trimmed_idx;
2503     CvMat* sample_idx;
2504     int index;
2505     int trimmed_num;
2506
2507     data_size = trainer->m * sizeof( *idx );
2508     idx = (int*) cvAlloc( data_size );
2509
2510     data_size = trainer->m * sizeof( *weights );
2511     weights = (float*) cvAlloc( data_size );
2512     data_size = trainer->m * sizeof( *sorted_weights );
2513     sorted_weights = (float*) cvAlloc( data_size );
2514     
2515     /* yhat_i = (4 * y_i - 2) / ( 1 + exp( (4 * y_i - 2) * F_(m-1)(x_i) ) ).
2516      *   y_i in {0, 1}
2517      */
2518     sum_weights = 0.0F;
2519     for( i = 0; i < trainer->numsamples; i++ )
2520     {
2521         index = icvGetIdxAt( trainer->sampleIdx, i );
2522         val = 4.0F * (*((float*) (trainer->ydata + index * trainer->ystep))) - 2.0F;
2523         val_f = val * trainer->f[index];
2524         val_f = ( val_f < CV_LOG_VAL_MAX ) ? exp( val_f ) : CV_LOG_VAL_MAX;
2525         val = (float) ( (double) val / ( 1.0 + val_f ) );
2526         trainer->y->data.fl[index] = val;
2527         val = (float) fabs( val );
2528         weights[index] = val * (2.0F - val);
2529         sorted_weights[i] = weights[index];
2530         sum_weights += sorted_weights[i];
2531     }
2532     
2533     trimmed_idx = NULL;
2534     sample_idx = trainer->sampleIdx;
2535     trimmed_num = trainer->numsamples;
2536     if( trainer->param[1] < 1.0F )
2537     {
2538         /* perform weight trimming */
2539         
2540         float threshold;
2541         int count;
2542         
2543         icvSort_32f( sorted_weights, trainer->numsamples, 0 );
2544
2545         sum_weights *= (1.0F - trainer->param[1]);
2546         
2547         i = -1;
2548         do { sum_weights -= sorted_weights[++i]; }
2549         while( sum_weights > 0.0F && i < (trainer->numsamples - 1) );
2550         
2551         threshold = sorted_weights[i];
2552
2553         while( i > 0 && sorted_weights[i-1] == threshold ) i--;
2554
2555         if( i > 0 )
2556         {
2557             trimmed_num = trainer->numsamples - i;            
2558             trimmed_idx = cvCreateMat( 1, trimmed_num, CV_32FC1 );
2559             count = 0;
2560             for( i = 0; i < trainer->numsamples; i++ )
2561             {
2562                 index = icvGetIdxAt( trainer->sampleIdx, i );
2563                 if( weights[index] >= threshold )
2564                 {
2565                     CV_MAT_ELEM( *trimmed_idx, float, 0, count ) = (float) index;
2566                     count++;
2567                 }
2568             }
2569             
2570             assert( count == trimmed_num );
2571
2572             sample_idx = trimmed_idx;
2573
2574             printf( "Used samples %%: %g\n", 
2575                 (float) trimmed_num / (float) trainer->numsamples * 100.0F );
2576         }
2577     }
2578
2579     ptr = (CvCARTClassifier*) cvCreateCARTClassifier( trainer->trainData, trainer->flags,
2580         trainer->y, NULL, NULL, NULL, sample_idx, trainer->weights,
2581         (CvClassifierTrainParams*) &trainer->cartParams );
2582
2583     CV_GET_SAMPLE( *trainer->trainData, trainer->flags, 0, sample );
2584     CV_GET_SAMPLE_STEP( *trainer->trainData, trainer->flags, sample_step );
2585     sample_data = sample.data.ptr;
2586     for( i = 0; i < trimmed_num; i++ )
2587     {
2588         index = icvGetIdxAt( sample_idx, i );
2589         sample.data.ptr = sample_data + index * sample_step;
2590         idx[index] = (int) cvEvalCARTClassifierIdx( (CvClassifier*) ptr, &sample );
2591     }
2592     for( j = 0; j <= ptr->count; j++ )
2593     {
2594         respnum = 0;
2595         val = 0.0F;
2596         sum_weights = 0.0F;
2597         for( i = 0; i < trimmed_num; i++ )
2598         {
2599             index = icvGetIdxAt( sample_idx, i );
2600             if( idx[index] == j )
2601             {
2602                 val += trainer->y->data.fl[index];
2603                 sum_weights += weights[index];
2604                 respnum++;
2605             }
2606         }
2607         if( sum_weights > 0.0F )
2608         {
2609             val /= sum_weights;
2610         }
2611         else
2612         {
2613             val = 0.0F;
2614         }
2615         ptr->val[j] = val;
2616     }
2617     
2618     if( trimmed_idx != NULL ) cvReleaseMat( &trimmed_idx );
2619     cvFree( &sorted_weights );
2620     cvFree( &weights );
2621     cvFree( &idx );
2622     
2623     trees[0] = ptr;
2624 }
2625
2626 void icvBtNext_LKCLASS( CvCARTClassifier** trees, CvBtTrainer* trainer )
2627 {
2628     int i, j, k, kk, num;
2629     CvMat sample;
2630     int sample_step;
2631     uchar* sample_data;
2632     
2633     int data_size;
2634     int* idx;
2635     int respnum;
2636     float val;
2637
2638     float sum_weights;
2639     float* weights;
2640     float* sorted_weights;
2641     CvMat* trimmed_idx;
2642     CvMat* sample_idx;
2643     int index;
2644     int trimmed_num;
2645     double sum_exp_f;
2646     double exp_f;
2647     double f_k;
2648
2649     data_size = trainer->m * sizeof( *idx );
2650     idx = (int*) cvAlloc( data_size );
2651     data_size = trainer->m * sizeof( *weights );
2652     weights = (float*) cvAlloc( data_size );
2653     data_size = trainer->m * sizeof( *sorted_weights );
2654     sorted_weights = (float*) cvAlloc( data_size );
2655     trimmed_idx = cvCreateMat( 1, trainer->numsamples, CV_32FC1 );
2656
2657     for( k = 0; k < trainer->numclasses; k++ )
2658     {
2659         /* yhat_i = y_i - p_k(x_i), y_i in {0, 1}      */
2660         /* p_k(x_i) = exp(f_k(x_i)) / (sum_exp_f(x_i)) */
2661         sum_weights = 0.0F;
2662         for( i = 0; i < trainer->numsamples; i++ )
2663         {
2664             index = icvGetIdxAt( trainer->sampleIdx, i );
2665             /* p_k(x_i) = 1 / (1 + sum(exp(f_kk(x_i) - f_k(x_i)))), kk != k */
2666             num = index * trainer->numclasses;
2667             f_k = (double) trainer->f[num + k];
2668             sum_exp_f = 1.0;
2669             for( kk = 0; kk < trainer->numclasses; kk++ )
2670             {
2671                 if( kk == k ) continue;
2672                 exp_f = (double) trainer->f[num + kk] - f_k;
2673                 exp_f = (exp_f < CV_LOG_VAL_MAX) ? exp( exp_f ) : CV_VAL_MAX;
2674                 if( exp_f == CV_VAL_MAX || exp_f >= (CV_VAL_MAX - sum_exp_f) )
2675                 {
2676                     sum_exp_f = CV_VAL_MAX;
2677                     break;
2678                 }
2679                 sum_exp_f += exp_f;
2680             }
2681
2682             val = (float) ( (*((float*) (trainer->ydata + index * trainer->ystep))) 
2683                             == (float) k );
2684             val -= (float) ( (sum_exp_f == CV_VAL_MAX) ? 0.0 : ( 1.0 / sum_exp_f ) );
2685
2686             assert( val >= -1.0F );
2687             assert( val <= 1.0F );
2688
2689             trainer->y->data.fl[index] = val;
2690             val = (float) fabs( val );
2691             weights[index] = val * (1.0F - val);
2692             sorted_weights[i] = weights[index];
2693             sum_weights += sorted_weights[i];
2694         }
2695
2696         sample_idx = trainer->sampleIdx;
2697         trimmed_num = trainer->numsamples;
2698         if( trainer->param[1] < 1.0F )
2699         {
2700             /* perform weight trimming */
2701         
2702             float threshold;
2703             int count;
2704         
2705             icvSort_32f( sorted_weights, trainer->numsamples, 0 );
2706
2707             sum_weights *= (1.0F - trainer->param[1]);
2708         
2709             i = -1;
2710             do { sum_weights -= sorted_weights[++i]; }
2711             while( sum_weights > 0.0F && i < (trainer->numsamples - 1) );
2712         
2713             threshold = sorted_weights[i];
2714
2715             while( i > 0 && sorted_weights[i-1] == threshold ) i--;
2716
2717             if( i > 0 )
2718             {
2719                 trimmed_num = trainer->numsamples - i;            
2720                 trimmed_idx->cols = trimmed_num;
2721                 count = 0;
2722                 for( i = 0; i < trainer->numsamples; i++ )
2723                 {
2724                     index = icvGetIdxAt( trainer->sampleIdx, i );
2725                     if( weights[index] >= threshold )
2726                     {
2727                         CV_MAT_ELEM( *trimmed_idx, float, 0, count ) = (float) index;
2728                         count++;
2729                     }
2730                 }
2731             
2732                 assert( count == trimmed_num );
2733
2734                 sample_idx = trimmed_idx;
2735
2736                 printf( "k: %d Used samples %%: %g\n", k, 
2737                     (float) trimmed_num / (float) trainer->numsamples * 100.0F );
2738             }
2739         } /* weight trimming */
2740
2741         trees[k] = (CvCARTClassifier*) cvCreateCARTClassifier( trainer->trainData,
2742             trainer->flags, trainer->y, NULL, NULL, NULL, sample_idx, trainer->weights,
2743             (CvClassifierTrainParams*) &trainer->cartParams );
2744
2745         CV_GET_SAMPLE( *trainer->trainData, trainer->flags, 0, sample );
2746         CV_GET_SAMPLE_STEP( *trainer->trainData, trainer->flags, sample_step );
2747         sample_data = sample.data.ptr;
2748         for( i = 0; i < trimmed_num; i++ )
2749         {
2750             index = icvGetIdxAt( sample_idx, i );
2751             sample.data.ptr = sample_data + index * sample_step;
2752             idx[index] = (int) cvEvalCARTClassifierIdx( (CvClassifier*) trees[k],
2753                                                         &sample );
2754         }
2755         for( j = 0; j <= trees[k]->count; j++ )
2756         {
2757             respnum = 0;
2758             val = 0.0F;
2759             sum_weights = 0.0F;
2760             for( i = 0; i < trimmed_num; i++ )
2761             {
2762                 index = icvGetIdxAt( sample_idx, i );
2763                 if( idx[index] == j )
2764                 {
2765                     val += trainer->y->data.fl[index];
2766                     sum_weights += weights[index];
2767                     respnum++;
2768                 }
2769             }
2770             if( sum_weights > 0.0F )
2771             {
2772                 val = ((float) (trainer->numclasses - 1)) * val /
2773                       ((float) (trainer->numclasses)) / sum_weights;
2774             }
2775             else
2776             {
2777                 val = 0.0F;
2778             }
2779             trees[k]->val[j] = val;
2780         }
2781     } /* for each class */
2782     
2783     cvReleaseMat( &trimmed_idx );
2784     cvFree( &sorted_weights );
2785     cvFree( &weights );
2786     cvFree( &idx );
2787 }
2788
2789
2790 void icvBtNext_XXBCLASS( CvCARTClassifier** trees, CvBtTrainer* trainer )
2791 {
2792     float alpha;
2793     int i;
2794     CvMat* weak_eval_vals;
2795     CvMat* sample_idx;
2796     int num_samples;
2797     CvMat sample;
2798     uchar* sample_data;
2799     int sample_step;
2800
2801     weak_eval_vals = cvCreateMat( 1, trainer->m, CV_32FC1 );
2802
2803     sample_idx = cvTrimWeights( trainer->weights, trainer->sampleIdx,
2804                                 trainer->param[1] );
2805     num_samples = ( sample_idx == NULL )
2806         ? trainer->m : MAX( sample_idx->rows, sample_idx->cols );
2807
2808     printf( "Used samples %%: %g\n", 
2809         (float) num_samples / (float) trainer->numsamples * 100.0F );
2810
2811     trees[0] = (CvCARTClassifier*) cvCreateCARTClassifier( trainer->trainData,
2812         trainer->flags, trainer->y, NULL, NULL, NULL,
2813         sample_idx, trainer->weights,
2814         (CvClassifierTrainParams*) &trainer->cartParams );
2815     
2816     /* evaluate samples */
2817     CV_GET_SAMPLE( *trainer->trainData, trainer->flags, 0, sample );
2818     CV_GET_SAMPLE_STEP( *trainer->trainData, trainer->flags, sample_step );
2819     sample_data = sample.data.ptr;
2820     
2821     for( i = 0; i < trainer->m; i++ )
2822     {
2823         sample.data.ptr = sample_data + i * sample_step;
2824         weak_eval_vals->data.fl[i] = trees[0]->eval( (CvClassifier*) trees[0], &sample );
2825     }
2826
2827     alpha = cvBoostNextWeakClassifier( weak_eval_vals, trainer->trainClasses,
2828         trainer->y, trainer->weights, trainer->boosttrainer );
2829     
2830     /* multiply tree by alpha */
2831     for( i = 0; i <= trees[0]->count; i++ )
2832     {
2833         trees[0]->val[i] *= alpha;
2834     }
2835     if( trainer->type == CV_RABCLASS )
2836     {
2837         for( i = 0; i <= trees[0]->count; i++ )
2838         {
2839             trees[0]->val[i] = cvLogRatio( trees[0]->val[i] );
2840         }
2841     }
2842     
2843     if( sample_idx != NULL && sample_idx != trainer->sampleIdx )
2844     {
2845         cvReleaseMat( &sample_idx );
2846     }
2847     cvReleaseMat( &weak_eval_vals );
2848 }
2849
2850 typedef void (*CvBtNextFunc)( CvCARTClassifier** trees, CvBtTrainer* trainer );
2851
2852 static CvBtNextFunc icvBtNextFunc[] =
2853 {
2854     icvBtNext_XXBCLASS,
2855     icvBtNext_XXBCLASS,
2856     icvBtNext_XXBCLASS,
2857     icvBtNext_XXBCLASS,
2858     icvBtNext_L2CLASS,
2859     icvBtNext_LKCLASS,
2860     icvBtNext_LSREG,
2861     icvBtNext_LADREG,
2862     icvBtNext_MREG
2863 };
2864
2865 CV_BOOST_IMPL
2866 void cvBtNext( CvCARTClassifier** trees, CvBtTrainer* trainer )
2867 {
2868
2869     CV_FUNCNAME( "cvBtNext" );
2870
2871     __BEGIN__;
2872
2873     int i, j;
2874     int index;
2875     CvMat sample;
2876     int sample_step;
2877     uchar* sample_data;
2878
2879     icvBtNextFunc[trainer->type]( trees, trainer );        
2880
2881     /* shrinkage */
2882     if( trainer->param[0] != 1.0F )
2883     {
2884         for( j = 0; j < trainer->numclasses; j++ )
2885         {
2886             for( i = 0; i <= trees[j]->count; i++ )
2887             {
2888                 trees[j]->val[i] *= trainer->param[0];
2889             }
2890         }
2891     }
2892
2893     if( trainer->type > CV_GABCLASS )
2894     {
2895         /* update F_(m-1) */
2896         CV_GET_SAMPLE( *(trainer->trainData), trainer->flags, 0, sample );
2897         CV_GET_SAMPLE_STEP( *(trainer->trainData), trainer->flags, sample_step );
2898         sample_data = sample.data.ptr;
2899         for( i = 0; i < trainer->numsamples; i++ )
2900         {
2901             index = icvGetIdxAt( trainer->sampleIdx, i );
2902             sample.data.ptr = sample_data + index * sample_step;
2903             for( j = 0; j < trainer->numclasses; j++ )
2904             {            
2905                 trainer->f[index * trainer->numclasses + j] += 
2906                     trees[j]->eval( (CvClassifier*) (trees[j]), &sample );
2907             }
2908         }
2909     }
2910     
2911     __END__;
2912 }
2913
2914 CV_BOOST_IMPL
2915 void cvBtEnd( CvBtTrainer** trainer )
2916 {
2917     CV_FUNCNAME( "cvBtEnd" );
2918     
2919     __BEGIN__;
2920     
2921     if( trainer == NULL || (*trainer) == NULL )
2922     {
2923         CV_ERROR( CV_StsNullPtr, "Invalid trainer parameter" );
2924     }
2925     
2926     if( (*trainer)->y != NULL )
2927     {
2928         CV_CALL( cvReleaseMat( &((*trainer)->y) ) );
2929     }
2930     if( (*trainer)->weights != NULL )
2931     {
2932         CV_CALL( cvReleaseMat( &((*trainer)->weights) ) );
2933     }
2934     if( (*trainer)->boosttrainer != NULL )
2935     {
2936         CV_CALL( cvBoostEndTraining( &((*trainer)->boosttrainer) ) );
2937     }
2938     CV_CALL( cvFree( trainer ) );
2939
2940     __END__;
2941 }
2942
2943 /****************************************************************************************\
2944 *                         Boosted tree model as a classifier                             *
2945 \****************************************************************************************/
2946
2947 CV_BOOST_IMPL
2948 float cvEvalBtClassifier( CvClassifier* classifier, CvMat* sample )
2949 {
2950     float val;
2951
2952     CV_FUNCNAME( "cvEvalBtClassifier" );
2953
2954     __BEGIN__;
2955     
2956     int i;
2957
2958     val = 0.0F;
2959     if( CV_IS_TUNABLE( classifier->flags ) )
2960     {
2961         CvSeqReader reader;
2962         CvCARTClassifier* tree;
2963
2964         CV_CALL( cvStartReadSeq( ((CvBtClassifier*) classifier)->seq, &reader ) );
2965         for( i = 0; i < ((CvBtClassifier*) classifier)->numiter; i++ )
2966         {
2967             CV_READ_SEQ_ELEM( tree, reader );
2968             val += tree->eval( (CvClassifier*) tree, sample );
2969         }
2970     }
2971     else
2972     {
2973         CvCARTClassifier** ptree;
2974
2975         ptree = ((CvBtClassifier*) classifier)->trees;
2976         for( i = 0; i < ((CvBtClassifier*) classifier)->numiter; i++ )
2977         {
2978             val += (*ptree)->eval( (CvClassifier*) (*ptree), sample );
2979             ptree++;
2980         }
2981     }
2982
2983     __END__;
2984
2985     return val;
2986 }
2987
2988 CV_BOOST_IMPL
2989 float cvEvalBtClassifier2( CvClassifier* classifier, CvMat* sample )
2990 {
2991     float val;
2992
2993     CV_FUNCNAME( "cvEvalBtClassifier2" );
2994
2995     __BEGIN__;
2996     
2997     CV_CALL( val = cvEvalBtClassifier( classifier, sample ) );
2998
2999     __END__;
3000
3001     return (float) (val >= 0.0F);
3002 }
3003
3004 CV_BOOST_IMPL
3005 float cvEvalBtClassifierK( CvClassifier* classifier, CvMat* sample )
3006 {
3007     int cls;
3008
3009     CV_FUNCNAME( "cvEvalBtClassifierK" );
3010
3011     __BEGIN__;
3012     
3013     int i, k;
3014     float max_val;
3015     int numclasses;
3016
3017     float* vals;
3018     size_t data_size;
3019
3020     numclasses = ((CvBtClassifier*) classifier)->numclasses;
3021     data_size = sizeof( *vals ) * numclasses;
3022     CV_CALL( vals = (float*) cvAlloc( data_size ) );
3023     memset( vals, 0, data_size );
3024
3025     if( CV_IS_TUNABLE( classifier->flags ) )
3026     {
3027         CvSeqReader reader;
3028         CvCARTClassifier* tree;
3029
3030         CV_CALL( cvStartReadSeq( ((CvBtClassifier*) classifier)->seq, &reader ) );
3031         for( i = 0; i < ((CvBtClassifier*) classifier)->numiter; i++ )
3032         {
3033             for( k = 0; k < numclasses; k++ )
3034             {
3035                 CV_READ_SEQ_ELEM( tree, reader );
3036                 vals[k] += tree->eval( (CvClassifier*) tree, sample );
3037             }
3038         }
3039
3040     }
3041     else
3042     {
3043         CvCARTClassifier** ptree;
3044
3045         ptree = ((CvBtClassifier*) classifier)->trees;
3046         for( i = 0; i < ((CvBtClassifier*) classifier)->numiter; i++ )
3047         {
3048             for( k = 0; k < numclasses; k++ )
3049             {
3050                 vals[k] += (*ptree)->eval( (CvClassifier*) (*ptree), sample );
3051                 ptree++;
3052             }
3053         }
3054     }
3055
3056     cls = 0;
3057     max_val = vals[cls];
3058     for( k = 1; k < numclasses; k++ )
3059     {
3060         if( vals[k] > max_val )
3061         {
3062             max_val = vals[k];
3063             cls = k;
3064         }
3065     }
3066
3067     CV_CALL( cvFree( &vals ) );
3068
3069     __END__;
3070
3071     return (float) cls;
3072 }
3073
3074 typedef float (*CvEvalBtClassifier)( CvClassifier* classifier, CvMat* sample );
3075
3076 static CvEvalBtClassifier icvEvalBtClassifier[] =
3077 {
3078     cvEvalBtClassifier2,
3079     cvEvalBtClassifier2,
3080     cvEvalBtClassifier2,
3081     cvEvalBtClassifier2,
3082     cvEvalBtClassifier2,
3083     cvEvalBtClassifierK,
3084     cvEvalBtClassifier,
3085     cvEvalBtClassifier,
3086     cvEvalBtClassifier
3087 };
3088
3089 CV_BOOST_IMPL
3090 int cvSaveBtClassifier( CvClassifier* classifier, const char* filename )
3091 {
3092     CV_FUNCNAME( "cvSaveBtClassifier" );
3093
3094     __BEGIN__;
3095
3096     FILE* file;
3097     int i, j;
3098     CvSeqReader reader;
3099     CvCARTClassifier* tree;
3100
3101     CV_ASSERT( classifier );
3102     CV_ASSERT( filename );
3103     
3104     if( !icvMkDir( filename ) || !(file = fopen( filename, "w" )) )
3105     {
3106         CV_ERROR( CV_StsError, "Unable to create file" );
3107     }
3108
3109     if( CV_IS_TUNABLE( classifier->flags ) )
3110     {
3111         CV_CALL( cvStartReadSeq( ((CvBtClassifier*) classifier)->seq, &reader ) );
3112     }
3113     fprintf( file, "%d %d\n%d\n%d\n", (int) ((CvBtClassifier*) classifier)->type,
3114                                       ((CvBtClassifier*) classifier)->numclasses,
3115                                       ((CvBtClassifier*) classifier)->numfeatures,
3116                                       ((CvBtClassifier*) classifier)->numiter );
3117     
3118     for( i = 0; i < ((CvBtClassifier*) classifier)->numclasses *
3119                     ((CvBtClassifier*) classifier)->numiter; i++ )
3120     {
3121         if( CV_IS_TUNABLE( classifier->flags ) )
3122         {
3123             CV_READ_SEQ_ELEM( tree, reader );
3124         }
3125         else
3126         {
3127             tree = ((CvBtClassifier*) classifier)->trees[i];
3128         }
3129
3130         fprintf( file, "%d\n", tree->count );
3131         for( j = 0; j < tree->count; j++ )
3132         {
3133             fprintf( file, "%d %g %d %d\n", tree->compidx[j],
3134                                             tree->threshold[j],
3135                                             tree->left[j],
3136                                             tree->right[j] );
3137         }
3138         for( j = 0; j <= tree->count; j++ )
3139         {
3140             fprintf( file, "%g ", tree->val[j] );
3141         }
3142         fprintf( file, "\n" );
3143     }
3144
3145     fclose( file );
3146
3147     __END__;
3148
3149     return 1;
3150 }
3151
3152
3153 CV_BOOST_IMPL
3154 void cvReleaseBtClassifier( CvClassifier** ptr )
3155 {
3156     CV_FUNCNAME( "cvReleaseBtClassifier" );
3157
3158     __BEGIN__;
3159
3160     int i;
3161
3162     if( ptr == NULL || *ptr == NULL )
3163     {
3164         CV_ERROR( CV_StsNullPtr, "" );
3165     }
3166     if( CV_IS_TUNABLE( (*ptr)->flags ) )
3167     {
3168         CvSeqReader reader;
3169         CvCARTClassifier* tree;
3170
3171         CV_CALL( cvStartReadSeq( ((CvBtClassifier*) *ptr)->seq, &reader ) );
3172         for( i = 0; i < ((CvBtClassifier*) *ptr)->numclasses *
3173                         ((CvBtClassifier*) *ptr)->numiter; i++ )
3174         {
3175             CV_READ_SEQ_ELEM( tree, reader );
3176             tree->release( (CvClassifier**) (&tree) );
3177         }
3178         CV_CALL( cvReleaseMemStorage( &(((CvBtClassifier*) *ptr)->seq->storage) ) );
3179     }
3180     else
3181     {
3182         CvCARTClassifier** ptree;
3183
3184         ptree = ((CvBtClassifier*) *ptr)->trees;
3185         for( i = 0; i < ((CvBtClassifier*) *ptr)->numclasses *
3186                         ((CvBtClassifier*) *ptr)->numiter; i++ )
3187         {
3188             (*ptree)->release( (CvClassifier**) ptree );
3189             ptree++;
3190         }
3191     }
3192
3193     CV_CALL( cvFree( ptr ) );
3194     *ptr = NULL;
3195
3196     __END__;
3197 }
3198
3199 void cvTuneBtClassifier( CvClassifier* classifier, CvMat*, int flags,
3200                          CvMat*, CvMat* , CvMat*, CvMat*, CvMat* )
3201 {
3202     CV_FUNCNAME( "cvTuneBtClassifier" );
3203
3204     __BEGIN__;
3205
3206     size_t data_size;
3207
3208     if( CV_IS_TUNABLE( flags ) )
3209     {
3210         if( !CV_IS_TUNABLE( classifier->flags ) )
3211         {
3212             CV_ERROR( CV_StsUnsupportedFormat,
3213                       "Classifier does not support tune function" );
3214         }
3215         else
3216         {
3217             /* tune classifier */
3218             CvCARTClassifier** trees;
3219
3220             printf( "Iteration %d\n", ((CvBtClassifier*) classifier)->numiter + 1 );
3221
3222             data_size = sizeof( *trees ) * ((CvBtClassifier*) classifier)->numclasses;
3223             CV_CALL( trees = (CvCARTClassifier**) cvAlloc( data_size ) );
3224             CV_CALL( cvBtNext( trees,
3225                 (CvBtTrainer*) ((CvBtClassifier*) classifier)->trainer ) );
3226             CV_CALL( cvSeqPushMulti( ((CvBtClassifier*) classifier)->seq,
3227                 trees, ((CvBtClassifier*) classifier)->numclasses ) );
3228             CV_CALL( cvFree( &trees ) );
3229             ((CvBtClassifier*) classifier)->numiter++;
3230         }
3231     }
3232     else
3233     {
3234         if( CV_IS_TUNABLE( classifier->flags ) )
3235         {
3236             /* convert */
3237             void* ptr;
3238
3239             assert( ((CvBtClassifier*) classifier)->seq->total ==
3240                         ((CvBtClassifier*) classifier)->numiter *
3241                         ((CvBtClassifier*) classifier)->numclasses );
3242
3243             data_size = sizeof( ((CvBtClassifier*) classifier)->trees[0] ) *
3244                 ((CvBtClassifier*) classifier)->seq->total;
3245             CV_CALL( ptr = cvAlloc( data_size ) );
3246             CV_CALL( cvCvtSeqToArray( ((CvBtClassifier*) classifier)->seq, ptr ) );
3247             CV_CALL( cvReleaseMemStorage( 
3248                     &(((CvBtClassifier*) classifier)->seq->storage) ) );
3249             ((CvBtClassifier*) classifier)->trees = (CvCARTClassifier**) ptr;
3250             classifier->flags &= ~CV_TUNABLE;
3251             CV_CALL( cvBtEnd( (CvBtTrainer**)
3252                 &(((CvBtClassifier*) classifier)->trainer )) );
3253             ((CvBtClassifier*) classifier)->trainer = NULL;
3254         }
3255     }
3256
3257     __END__;
3258 }
3259
3260 CvBtClassifier* icvAllocBtClassifier( CvBoostType type, int flags, int numclasses,
3261                                       int numiter )
3262 {
3263     CvBtClassifier* ptr;
3264     size_t data_size;
3265
3266     assert( numclasses >= 1 );
3267     assert( numiter >= 0 );
3268     assert( ( numclasses == 1 ) || (type == CV_LKCLASS) );
3269
3270     data_size = sizeof( *ptr );
3271     ptr = (CvBtClassifier*) cvAlloc( data_size );
3272     memset( ptr, 0, data_size );
3273
3274     if( CV_IS_TUNABLE( flags ) )
3275     {
3276         ptr->seq = cvCreateSeq( 0, sizeof( *(ptr->seq) ), sizeof( *(ptr->trees) ),
3277                                 cvCreateMemStorage() );
3278         ptr->numiter = 0;
3279     }
3280     else
3281     {
3282         data_size = numclasses * numiter * sizeof( *(ptr->trees) );
3283         ptr->trees = (CvCARTClassifier**) cvAlloc( data_size );
3284         memset( ptr->trees, 0, data_size );
3285
3286         ptr->numiter = numiter;
3287     }
3288
3289     ptr->flags = flags;
3290     ptr->numclasses = numclasses;
3291     ptr->type = type;
3292
3293     ptr->eval = icvEvalBtClassifier[(int) type];
3294     ptr->tune = cvTuneBtClassifier;
3295     ptr->save = cvSaveBtClassifier;
3296     ptr->release = cvReleaseBtClassifier;
3297
3298     return ptr;
3299 }
3300
3301 CV_BOOST_IMPL
3302 CvClassifier* cvCreateBtClassifier( CvMat* trainData,
3303                                     int flags,
3304                                     CvMat* trainClasses,
3305                                     CvMat* typeMask,
3306                                     CvMat* missedMeasurementsMask,
3307                                     CvMat* compIdx,
3308                                     CvMat* sampleIdx,
3309                                     CvMat* weights,
3310                                     CvClassifierTrainParams* trainParams )
3311 {
3312     CvBtClassifier* ptr;
3313
3314     CV_FUNCNAME( "cvCreateBtClassifier" );
3315
3316     __BEGIN__;
3317     CvBoostType type;
3318     int num_classes;
3319     int num_iter;
3320     int i;
3321     CvCARTClassifier** trees;
3322     size_t data_size;
3323
3324     CV_ASSERT( trainData != NULL );
3325     CV_ASSERT( trainClasses != NULL );
3326     CV_ASSERT( typeMask == NULL );
3327     CV_ASSERT( missedMeasurementsMask == NULL );
3328     CV_ASSERT( compIdx == NULL );
3329     CV_ASSERT( weights == NULL );
3330     CV_ASSERT( trainParams != NULL );
3331
3332     type = ((CvBtClassifierTrainParams*) trainParams)->type;
3333     
3334     if( type >= CV_DABCLASS && type <= CV_GABCLASS && sampleIdx )
3335     {
3336         CV_ERROR( CV_StsBadArg, "Sample indices are not supported for this type" );
3337     }
3338
3339     if( type == CV_LKCLASS )
3340     {
3341         double min_val;
3342         double max_val;
3343
3344         cvMinMaxLoc( trainClasses, &min_val, &max_val );
3345         num_classes = (int) (max_val + 1.0);
3346         
3347         CV_ASSERT( num_classes >= 2 );
3348     }
3349     else
3350     {
3351         num_classes = 1;
3352     }
3353     num_iter = ((CvBtClassifierTrainParams*) trainParams)->numiter;
3354     
3355     CV_ASSERT( num_iter > 0 );
3356
3357     ptr = icvAllocBtClassifier( type, CV_TUNABLE | flags, num_classes, num_iter );
3358     ptr->numfeatures = (CV_IS_ROW_SAMPLE( flags )) ? trainData->cols : trainData->rows;
3359     
3360     i = 0;
3361
3362     printf( "Iteration %d\n", 1 );
3363
3364     data_size = sizeof( *trees ) * ptr->numclasses;
3365     CV_CALL( trees = (CvCARTClassifier**) cvAlloc( data_size ) );
3366
3367     CV_CALL( ptr->trainer = cvBtStart( trees, trainData, flags, trainClasses, sampleIdx,
3368         ((CvBtClassifierTrainParams*) trainParams)->numsplits, type, num_classes,
3369         &(((CvBtClassifierTrainParams*) trainParams)->param[0]) ) );
3370
3371     CV_CALL( cvSeqPushMulti( ptr->seq, trees, ptr->numclasses ) );
3372     CV_CALL( cvFree( &trees ) );
3373     ptr->numiter++;
3374     
3375     for( i = 1; i < num_iter; i++ )
3376     {
3377         ptr->tune( (CvClassifier*) ptr, NULL, CV_TUNABLE, NULL, NULL, NULL, NULL, NULL );
3378     }
3379     if( !CV_IS_TUNABLE( flags ) )
3380     {
3381         /* convert */
3382         ptr->tune( (CvClassifier*) ptr, NULL, 0, NULL, NULL, NULL, NULL, NULL );
3383     }
3384
3385     __END__;
3386
3387     return (CvClassifier*) ptr;
3388 }
3389
3390 CV_BOOST_IMPL
3391 CvClassifier* cvCreateBtClassifierFromFile( const char* filename )
3392 {
3393     CvBtClassifier* ptr;
3394
3395     CV_FUNCNAME( "cvCreateBtClassifierFromFile" );
3396     
3397     __BEGIN__;
3398
3399     FILE* file;
3400     int i, j;
3401     int data_size;
3402     int num_classifiers;
3403     int num_features;
3404     int num_classes;
3405     int type;
3406
3407     CV_ASSERT( filename != NULL );
3408
3409     ptr = NULL;
3410     file = fopen( filename, "r" );
3411     if( !file )
3412     {
3413         CV_ERROR( CV_StsError, "Unable to open file" );
3414     }
3415     
3416     fscanf( file, "%d %d %d %d", &type, &num_classes, &num_features, &num_classifiers );
3417
3418     CV_ASSERT( type >= (int) CV_DABCLASS && type <= (int) CV_MREG );
3419     CV_ASSERT( num_features > 0 );
3420     CV_ASSERT( num_classifiers > 0 );
3421
3422     if( (CvBoostType) type != CV_LKCLASS )
3423     {
3424         num_classes = 1;
3425     }
3426     ptr = icvAllocBtClassifier( (CvBoostType) type, 0, num_classes, num_classifiers );
3427     ptr->numfeatures = num_features;
3428     
3429     for( i = 0; i < num_classes * num_classifiers; i++ )
3430     {
3431         int count;
3432         CvCARTClassifier* tree;
3433
3434         fscanf( file, "%d", &count );
3435
3436         data_size = sizeof( *tree )
3437             + count * ( sizeof( *(tree->compidx) ) + sizeof( *(tree->threshold) ) +
3438                         sizeof( *(tree->right) ) + sizeof( *(tree->left) ) )
3439             + (count + 1) * ( sizeof( *(tree->val) ) );
3440         CV_CALL( tree = (CvCARTClassifier*) cvAlloc( data_size ) );
3441         memset( tree, 0, data_size );
3442         tree->eval = cvEvalCARTClassifier;
3443         tree->tune = NULL;
3444         tree->save = NULL;
3445         tree->release = cvReleaseCARTClassifier;
3446         tree->compidx = (int*) ( tree + 1 );
3447         tree->threshold = (float*) ( tree->compidx + count );
3448         tree->left = (int*) ( tree->threshold + count );
3449         tree->right = (int*) ( tree->left + count );
3450         tree->val = (float*) ( tree->right + count );
3451
3452         tree->count = count;
3453         for( j = 0; j < tree->count; j++ )
3454         {
3455             fscanf( file, "%d %g %d %d", &(tree->compidx[j]),
3456                                          &(tree->threshold[j]),
3457                                          &(tree->left[j]),
3458                                          &(tree->right[j]) );
3459         }
3460         for( j = 0; j <= tree->count; j++ )
3461         {
3462             fscanf( file, "%g", &(tree->val[j]) );
3463         }
3464         ptr->trees[i] = tree;
3465     }
3466
3467     fclose( file );
3468
3469     __END__;
3470
3471     return (CvClassifier*) ptr;
3472 }
3473
3474 /****************************************************************************************\
3475 *                                    Utility functions                                   *
3476 \****************************************************************************************/
3477
3478 CV_BOOST_IMPL
3479 CvMat* cvTrimWeights( CvMat* weights, CvMat* idx, float factor )
3480 {
3481     CvMat* ptr;
3482
3483     CV_FUNCNAME( "cvTrimWeights" );
3484     __BEGIN__;
3485     int i, index, num;
3486     float sum_weights;
3487     uchar* wdata;
3488     size_t wstep;
3489     int wnum;
3490     float threshold;
3491     int count;
3492     float* sorted_weights;
3493
3494     CV_ASSERT( CV_MAT_TYPE( weights->type ) == CV_32FC1 );
3495
3496     ptr = idx;
3497     sorted_weights = NULL;
3498
3499     if( factor > 0.0F && factor < 1.0F )
3500     {
3501         size_t data_size;
3502
3503         CV_MAT2VEC( *weights, wdata, wstep, wnum );
3504         num = ( idx == NULL ) ? wnum : MAX( idx->rows, idx->cols );
3505
3506         data_size = num * sizeof( *sorted_weights );
3507         sorted_weights = (float*) cvAlloc( data_size );
3508         memset( sorted_weights, 0, data_size );
3509
3510         sum_weights = 0.0F;
3511         for( i = 0; i < num; i++ )
3512         {
3513             index = icvGetIdxAt( idx, i );
3514             sorted_weights[i] = *((float*) (wdata + index * wstep));
3515             sum_weights += sorted_weights[i];
3516         }
3517
3518         icvSort_32f( sorted_weights, num, 0 );
3519
3520         sum_weights *= (1.0F - factor);
3521
3522         i = -1;
3523         do { sum_weights -= sorted_weights[++i]; }
3524         while( sum_weights > 0.0F && i < (num - 1) );
3525
3526         threshold = sorted_weights[i];
3527
3528         while( i > 0 && sorted_weights[i-1] == threshold ) i--;
3529
3530         if( i > 0 || ( idx != NULL && CV_MAT_TYPE( idx->type ) != CV_32FC1 ) )
3531         {
3532             CV_CALL( ptr = cvCreateMat( 1, num - i, CV_32FC1 ) );
3533             count = 0;
3534             for( i = 0; i < num; i++ )
3535             {
3536                 index = icvGetIdxAt( idx, i );
3537                 if( *((float*) (wdata + index * wstep)) >= threshold )
3538                 {
3539                     CV_MAT_ELEM( *ptr, float, 0, count ) = (float) index;
3540                     count++;
3541                 }
3542             }
3543         
3544             assert( count == ptr->cols );
3545         }
3546         cvFree( &sorted_weights );
3547     }
3548
3549     __END__;
3550
3551     return ptr;
3552 }
3553
3554
3555 CV_BOOST_IMPL
3556 void cvReadTrainData( const char* filename, int flags,
3557                       CvMat** trainData,
3558                       CvMat** trainClasses )
3559 {
3560
3561     CV_FUNCNAME( "cvReadTrainData" );
3562
3563     __BEGIN__;
3564
3565     FILE* file;
3566     int m, n;
3567     int i, j;
3568     float val;
3569
3570     if( filename == NULL )
3571     {
3572         CV_ERROR( CV_StsNullPtr, "filename must be specified" );
3573     }
3574     if( trainData == NULL )
3575     {
3576         CV_ERROR( CV_StsNullPtr, "trainData must be not NULL" );
3577     }
3578     if( trainClasses == NULL )
3579     {
3580         CV_ERROR( CV_StsNullPtr, "trainClasses must be not NULL" );
3581     }
3582     
3583     *trainData = NULL;
3584     *trainClasses = NULL;
3585     file = fopen( filename, "r" );
3586     if( !file )
3587     {
3588         CV_ERROR( CV_StsError, "Unable to open file" );
3589     }
3590
3591     fscanf( file, "%d %d", &m, &n );
3592
3593     if( CV_IS_ROW_SAMPLE( flags ) )
3594     {
3595         CV_CALL( *trainData = cvCreateMat( m, n, CV_32FC1 ) );
3596     }
3597     else
3598     {
3599         CV_CALL( *trainData = cvCreateMat( n, m, CV_32FC1 ) );
3600     }
3601     
3602     CV_CALL( *trainClasses = cvCreateMat( 1, m, CV_32FC1 ) );
3603
3604     for( i = 0; i < m; i++ )
3605     {
3606         for( j = 0; j < n; j++ )
3607         {
3608             fscanf( file, "%f", &val );
3609             if( CV_IS_ROW_SAMPLE( flags ) )
3610             {
3611                 CV_MAT_ELEM( **trainData, float, i, j ) = val;
3612             }
3613             else
3614             {
3615                 CV_MAT_ELEM( **trainData, float, j, i ) = val;
3616             }
3617         }
3618         fscanf( file, "%f", &val );
3619         CV_MAT_ELEM( **trainClasses, float, 0, i ) = val;
3620     }
3621
3622     fclose( file );
3623
3624     __END__;
3625     
3626 }
3627
3628 CV_BOOST_IMPL
3629 void cvWriteTrainData( const char* filename, int flags,
3630                        CvMat* trainData, CvMat* trainClasses, CvMat* sampleIdx )
3631 {
3632     CV_FUNCNAME( "cvWriteTrainData" );
3633
3634     __BEGIN__;
3635
3636     FILE* file;
3637     int m, n;
3638     int i, j;
3639     int clsrow;
3640     int count;
3641     int idx;
3642     CvScalar sc;
3643
3644     if( filename == NULL )
3645     {
3646         CV_ERROR( CV_StsNullPtr, "filename must be specified" );
3647     }
3648     if( trainData == NULL || CV_MAT_TYPE( trainData->type ) != CV_32FC1 )
3649     {
3650         CV_ERROR( CV_StsUnsupportedFormat, "Invalid trainData" );
3651     }
3652     if( CV_IS_ROW_SAMPLE( flags ) )
3653     {
3654         m = trainData->rows;
3655         n = trainData->cols;
3656     }
3657     else
3658     {
3659         n = trainData->rows;
3660         m = trainData->cols;
3661     }
3662     if( trainClasses == NULL || CV_MAT_TYPE( trainClasses->type ) != CV_32FC1 ||
3663         MIN( trainClasses->rows, trainClasses->cols ) != 1 )
3664     {
3665         CV_ERROR( CV_StsUnsupportedFormat, "Invalid trainClasses" );
3666     }
3667     clsrow = (trainClasses->rows == 1);
3668     if( m != ( (clsrow) ? trainClasses->cols : trainClasses->rows ) )
3669     {
3670         CV_ERROR( CV_StsUnmatchedSizes, "Incorrect trainData and trainClasses sizes" );
3671     }
3672     
3673     if( sampleIdx != NULL )
3674     {
3675         count = (sampleIdx->rows == 1) ? sampleIdx->cols : sampleIdx->rows;
3676     }
3677     else
3678     {
3679         count = m;
3680     }
3681     
3682
3683     file = fopen( filename, "w" );
3684     if( !file )
3685     {
3686         CV_ERROR( CV_StsError, "Unable to create file" );
3687     }
3688
3689     fprintf( file, "%d %d\n", count, n );
3690
3691     for( i = 0; i < count; i++ )
3692     {
3693         if( sampleIdx )
3694         {
3695             if( sampleIdx->rows == 1 )
3696             {
3697                 sc = cvGet2D( sampleIdx, 0, i );
3698             }
3699             else
3700             {
3701                 sc = cvGet2D( sampleIdx, i, 0 );
3702             }
3703             idx = (int) sc.val[0];
3704         }
3705         else
3706         {
3707             idx = i;
3708         }
3709         for( j = 0; j < n; j++ )
3710         {
3711             fprintf( file, "%g ", ( (CV_IS_ROW_SAMPLE( flags ))
3712                                     ? CV_MAT_ELEM( *trainData, float, idx, j ) 
3713                                     : CV_MAT_ELEM( *trainData, float, j, idx ) ) );
3714         }
3715         fprintf( file, "%g\n", ( (clsrow)
3716                                 ? CV_MAT_ELEM( *trainClasses, float, 0, idx )
3717                                 : CV_MAT_ELEM( *trainClasses, float, idx, 0 ) ) );
3718     }
3719
3720     fclose( file );
3721     
3722     __END__;
3723 }
3724
3725
3726 #define ICV_RAND_SHUFFLE( suffix, type )                                                 \
3727 void icvRandShuffle_##suffix( uchar* data, size_t step, int num )                        \
3728 {                                                                                        \
3729     CvRandState state;                                                                   \
3730     time_t seed;                                                                         \
3731     type tmp;                                                                            \
3732     int i;                                                                               \
3733     float rn;                                                                            \
3734                                                                                          \
3735     time( &seed );                                                                       \
3736                                                                                          \
3737     cvRandInit( &state, (double) 0, (double) 0, (int)seed );                             \
3738     for( i = 0; i < (num-1); i++ )                                                       \
3739     {                                                                                    \
3740         rn = ((float) cvRandNext( &state )) / (1.0F + UINT_MAX);                         \
3741         CV_SWAP( *((type*)(data + i * step)),                                            \
3742                  *((type*)(data + ( i + (int)( rn * (num - i ) ) )* step)),              \
3743                  tmp );                                                                  \
3744     }                                                                                    \
3745 }
3746
3747 ICV_RAND_SHUFFLE( 8U, uchar )
3748
3749 ICV_RAND_SHUFFLE( 16S, short )
3750
3751 ICV_RAND_SHUFFLE( 32S, int )
3752
3753 ICV_RAND_SHUFFLE( 32F, float )
3754
3755 CV_BOOST_IMPL
3756 void cvRandShuffleVec( CvMat* mat )
3757 {
3758     CV_FUNCNAME( "cvRandShuffle" );
3759
3760     __BEGIN__;
3761
3762     uchar* data;
3763     size_t step;
3764     int num;
3765
3766     if( (mat == NULL) || !CV_IS_MAT( mat ) || MIN( mat->rows, mat->cols ) != 1 )
3767     {
3768         CV_ERROR( CV_StsUnsupportedFormat, "" );
3769     }
3770
3771     CV_MAT2VEC( *mat, data, step, num );
3772     switch( CV_MAT_TYPE( mat->type ) )
3773     {
3774         case CV_8UC1:
3775             icvRandShuffle_8U( data, step, num);
3776             break;
3777         case CV_16SC1:
3778             icvRandShuffle_16S( data, step, num);
3779             break;
3780         case CV_32SC1:
3781             icvRandShuffle_32S( data, step, num);
3782             break;
3783         case CV_32FC1:
3784             icvRandShuffle_32F( data, step, num);
3785             break;
3786         default:
3787             CV_ERROR( CV_StsUnsupportedFormat, "" );
3788     }
3789
3790     __END__;
3791 }
3792
3793 /* End of file. */