1 /*M///////////////////////////////////////////////////////////////////////////////////////
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
10 // Intel License Agreement
11 // For Open Source Computer Vision Library
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
14 // Third party copyrights are property of their respective owners.
16 // Redistribution and use in source and binary forms, with or without modification,
17 // are permitted provided that the following conditions are met:
19 // * Redistribution's of source code must retain the above copyright notice,
20 // this list of conditions and the following disclaimer.
22 // * Redistribution's in binary form must reproduce the above copyright notice,
23 // this list of conditions and the following disclaimer in the documentation
24 // and/or other materials provided with the distribution.
26 // * The name of Intel Corporation may not be used to endorse or promote products
27 // derived from this software without specific prior written permission.
29 // This software is provided by the copyright holders and contributors "as is" and
30 // any express or implied warranties, including, but not limited to, the implied
31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
32 // In no event shall the Intel Corporation or contributors be liable for any direct,
33 // indirect, incidental, special, exemplary, or consequential damages
34 // (including, but not limited to, procurement of substitute goods or services;
35 // loss of use, data, or profits; or business interruption) however caused
36 // and on any theory of liability, whether in contract, strict liability,
37 // or tort (including negligence or otherwise) arising in any way out of
38 // the use of this software, even if advised of the possibility of such damage.
58 #include <_cvcommon.h>
59 #include <cvclassifier.h>
67 typedef struct CvValArray
73 #define CMP_VALUES( idx1, idx2 ) \
74 ( *( (float*) (aux->data + ((int) (idx1)) * aux->step ) ) < \
75 *( (float*) (aux->data + ((int) (idx2)) * aux->step ) ) )
77 CV_IMPLEMENT_QSORT_EX( icvSortIndexedValArray_16s, short, CMP_VALUES, CvValArray* )
79 CV_IMPLEMENT_QSORT_EX( icvSortIndexedValArray_32s, int, CMP_VALUES, CvValArray* )
81 CV_IMPLEMENT_QSORT_EX( icvSortIndexedValArray_32f, float, CMP_VALUES, CvValArray* )
84 void cvGetSortedIndices( CvMat* val, CvMat* idx, int sortcols )
96 assert( idx != NULL );
97 assert( val != NULL );
99 idxtype = CV_MAT_TYPE( idx->type );
100 assert( idxtype == CV_16SC1 || idxtype == CV_32SC1 || idxtype == CV_32FC1 );
101 assert( CV_MAT_TYPE( val->type ) == CV_32FC1 );
104 assert( idx->rows == val->cols );
105 assert( idx->cols == val->rows );
106 istep = CV_ELEM_SIZE( val->type );
111 assert( idx->rows == val->rows );
112 assert( idx->cols == val->cols );
114 jstep = CV_ELEM_SIZE( val->type );
117 va.data = val->data.ptr;
122 for( i = 0; i < idx->rows; i++ )
124 for( j = 0; j < idx->cols; j++ )
126 CV_MAT_ELEM( *idx, short, i, j ) = (short) j;
128 icvSortIndexedValArray_16s( (short*) (idx->data.ptr + i * idx->step),
135 for( i = 0; i < idx->rows; i++ )
137 for( j = 0; j < idx->cols; j++ )
139 CV_MAT_ELEM( *idx, int, i, j ) = j;
141 icvSortIndexedValArray_32s( (int*) (idx->data.ptr + i * idx->step),
148 for( i = 0; i < idx->rows; i++ )
150 for( j = 0; j < idx->cols; j++ )
152 CV_MAT_ELEM( *idx, float, i, j ) = (float) j;
154 icvSortIndexedValArray_32f( (float*) (idx->data.ptr + i * idx->step),
167 void cvReleaseStumpClassifier( CvClassifier** classifier )
169 cvFree( classifier );
174 float cvEvalStumpClassifier( CvClassifier* classifier, CvMat* sample )
176 assert( classifier != NULL );
177 assert( sample != NULL );
178 assert( CV_MAT_TYPE( sample->type ) == CV_32FC1 );
180 if( (CV_MAT_ELEM( (*sample), float, 0,
181 ((CvStumpClassifier*) classifier)->compidx )) <
182 ((CvStumpClassifier*) classifier)->threshold )
184 return ((CvStumpClassifier*) classifier)->left;
188 return ((CvStumpClassifier*) classifier)->right;
194 #define ICV_DEF_FIND_STUMP_THRESHOLD( suffix, type, error ) \
195 CV_BOOST_IMPL int icvFindStumpThreshold_##suffix( \
196 uchar* data, size_t datastep, \
197 uchar* wdata, size_t wstep, \
198 uchar* ydata, size_t ystep, \
199 uchar* idxdata, size_t idxstep, int num, \
202 float* threshold, float* left, float* right, \
203 float* sumw, float* sumwy, float* sumwyy ) \
212 float curleft = 0.0F; \
213 float curright = 0.0F; \
214 float* prevval = NULL; \
215 float* curval = NULL; \
216 float curlerror = 0.0F; \
217 float currerror = 0.0F; \
224 wposl = wposr = 0.0F; \
225 if( *sumw == FLT_MAX ) \
227 /* calculate sums */ \
235 for( i = 0; i < num; i++ ) \
237 idx = (int) ( *((type*) (idxdata + i*idxstep)) ); \
238 w = (float*) (wdata + idx * wstep); \
240 y = (float*) (ydata + idx * ystep); \
243 *sumwyy += wy * (*y); \
247 for( i = 0; i < num; i++ ) \
249 idx = (int) ( *((type*) (idxdata + i*idxstep)) ); \
250 curval = (float*) (data + idx * datastep); \
251 /* for debug purpose */ \
252 if( i > 0 ) assert( (*prevval) <= (*curval) ); \
254 wyr = *sumwy - wyl; \
257 if( wl > 0.0 ) curleft = wyl / wl; \
258 else curleft = 0.0F; \
260 if( wr > 0.0 ) curright = wyr / wr; \
261 else curright = 0.0F; \
265 if( curlerror + currerror < (*lerror) + (*rerror) ) \
267 (*lerror) = curlerror; \
268 (*rerror) = currerror; \
269 *threshold = *curval; \
271 *threshold = 0.5F * (*threshold + *prevval); \
280 wl += *((float*) (wdata + idx * wstep)); \
281 wyl += (*((float*) (wdata + idx * wstep))) \
282 * (*((float*) (ydata + idx * ystep))); \
283 wyyl += *((float*) (wdata + idx * wstep)) \
284 * (*((float*) (ydata + idx * ystep))) \
285 * (*((float*) (ydata + idx * ystep))); \
287 while( (++i) < num && \
288 ( *((float*) (data + (idx = \
289 (int) ( *((type*) (idxdata + i*idxstep))) ) * datastep)) \
293 } /* for each value */ \
298 /* misclassification error
299 * err = MIN( wpos, wneg );
301 #define ICV_DEF_FIND_STUMP_THRESHOLD_MISC( suffix, type ) \
302 ICV_DEF_FIND_STUMP_THRESHOLD( misc_##suffix, type, \
303 wposl = 0.5F * ( wl + wyl ); \
304 wposr = 0.5F * ( wr + wyr ); \
305 curleft = 0.5F * ( 1.0F + curleft ); \
306 curright = 0.5F * ( 1.0F + curright ); \
307 curlerror = MIN( wposl, wl - wposl ); \
308 currerror = MIN( wposr, wr - wposr ); \
312 * err = 2 * wpos * wneg /(wpos + wneg)
314 #define ICV_DEF_FIND_STUMP_THRESHOLD_GINI( suffix, type ) \
315 ICV_DEF_FIND_STUMP_THRESHOLD( gini_##suffix, type, \
316 wposl = 0.5F * ( wl + wyl ); \
317 wposr = 0.5F * ( wr + wyr ); \
318 curleft = 0.5F * ( 1.0F + curleft ); \
319 curright = 0.5F * ( 1.0F + curright ); \
320 curlerror = 2.0F * wposl * ( 1.0F - curleft ); \
321 currerror = 2.0F * wposr * ( 1.0F - curright ); \
324 #define CV_ENTROPY_THRESHOLD FLT_MIN
327 * err = - wpos * log(wpos / (wpos + wneg)) - wneg * log(wneg / (wpos + wneg))
329 #define ICV_DEF_FIND_STUMP_THRESHOLD_ENTROPY( suffix, type ) \
330 ICV_DEF_FIND_STUMP_THRESHOLD( entropy_##suffix, type, \
331 wposl = 0.5F * ( wl + wyl ); \
332 wposr = 0.5F * ( wr + wyr ); \
333 curleft = 0.5F * ( 1.0F + curleft ); \
334 curright = 0.5F * ( 1.0F + curright ); \
335 curlerror = currerror = 0.0F; \
336 if( curleft > CV_ENTROPY_THRESHOLD ) \
337 curlerror -= wposl * logf( curleft ); \
338 if( curleft < 1.0F - CV_ENTROPY_THRESHOLD ) \
339 curlerror -= (wl - wposl) * logf( 1.0F - curleft ); \
341 if( curright > CV_ENTROPY_THRESHOLD ) \
342 currerror -= wposr * logf( curright ); \
343 if( curright < 1.0F - CV_ENTROPY_THRESHOLD ) \
344 currerror -= (wr - wposr) * logf( 1.0F - curright ); \
347 /* least sum of squares error */
348 #define ICV_DEF_FIND_STUMP_THRESHOLD_SQ( suffix, type ) \
349 ICV_DEF_FIND_STUMP_THRESHOLD( sq_##suffix, type, \
350 /* calculate error (sum of squares) */ \
351 /* err = sum( w * (y - left(rigt)Val)^2 ) */ \
352 curlerror = wyyl + curleft * curleft * wl - 2.0F * curleft * wyl; \
353 currerror = (*sumwyy) - wyyl + curright * curright * wr - 2.0F * curright * wyr; \
356 ICV_DEF_FIND_STUMP_THRESHOLD_MISC( 16s, short )
358 ICV_DEF_FIND_STUMP_THRESHOLD_MISC( 32s, int )
360 ICV_DEF_FIND_STUMP_THRESHOLD_MISC( 32f, float )
363 ICV_DEF_FIND_STUMP_THRESHOLD_GINI( 16s, short )
365 ICV_DEF_FIND_STUMP_THRESHOLD_GINI( 32s, int )
367 ICV_DEF_FIND_STUMP_THRESHOLD_GINI( 32f, float )
370 ICV_DEF_FIND_STUMP_THRESHOLD_ENTROPY( 16s, short )
372 ICV_DEF_FIND_STUMP_THRESHOLD_ENTROPY( 32s, int )
374 ICV_DEF_FIND_STUMP_THRESHOLD_ENTROPY( 32f, float )
377 ICV_DEF_FIND_STUMP_THRESHOLD_SQ( 16s, short )
379 ICV_DEF_FIND_STUMP_THRESHOLD_SQ( 32s, int )
381 ICV_DEF_FIND_STUMP_THRESHOLD_SQ( 32f, float )
383 typedef int (*CvFindThresholdFunc)( uchar* data, size_t datastep,
384 uchar* wdata, size_t wstep,
385 uchar* ydata, size_t ystep,
386 uchar* idxdata, size_t idxstep, int num,
389 float* threshold, float* left, float* right,
390 float* sumw, float* sumwy, float* sumwyy );
392 CvFindThresholdFunc findStumpThreshold_16s[4] = {
393 icvFindStumpThreshold_misc_16s,
394 icvFindStumpThreshold_gini_16s,
395 icvFindStumpThreshold_entropy_16s,
396 icvFindStumpThreshold_sq_16s
399 CvFindThresholdFunc findStumpThreshold_32s[4] = {
400 icvFindStumpThreshold_misc_32s,
401 icvFindStumpThreshold_gini_32s,
402 icvFindStumpThreshold_entropy_32s,
403 icvFindStumpThreshold_sq_32s
406 CvFindThresholdFunc findStumpThreshold_32f[4] = {
407 icvFindStumpThreshold_misc_32f,
408 icvFindStumpThreshold_gini_32f,
409 icvFindStumpThreshold_entropy_32f,
410 icvFindStumpThreshold_sq_32f
414 CvClassifier* cvCreateStumpClassifier( CvMat* trainData,
418 CvMat* missedMeasurementsMask,
422 CvClassifierTrainParams* trainParams
425 CvStumpClassifier* stump = NULL;
426 int m = 0; /* number of samples */
427 int n = 0; /* number of components */
433 uchar* idxdata = NULL;
435 int l = 0; /* number of indices */
442 float sumw = FLT_MAX;
443 float sumw1 = FLT_MAX;
444 float sumw0 = FLT_MAX;
445 float sumwy = FLT_MAX;
446 float sumwyy = FLT_MAX;
448 assert( trainData != NULL );
449 assert( CV_MAT_TYPE( trainData->type ) == CV_32FC1 );
450 assert( trainClasses != NULL );
451 assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
452 assert( missedMeasurementsMask == NULL );
453 assert( compIdx == NULL );
454 assert( weights != NULL );
455 assert( CV_MAT_TYPE( weights->type ) == CV_32FC1 );
456 assert( trainParams != NULL );
458 data = trainData->data.ptr;
459 if( CV_IS_ROW_SAMPLE( flags ) )
461 cstep = CV_ELEM_SIZE( trainData->type );
462 sstep = trainData->step;
468 sstep = CV_ELEM_SIZE( trainData->type );
469 cstep = trainData->step;
474 ydata = trainClasses->data.ptr;
475 if( trainClasses->rows == 1 )
477 assert( trainClasses->cols == m );
478 ystep = CV_ELEM_SIZE( trainClasses->type );
482 assert( trainClasses->rows == m );
483 ystep = trainClasses->step;
486 wdata = weights->data.ptr;
487 if( weights->rows == 1 )
489 assert( weights->cols == m );
490 wstep = CV_ELEM_SIZE( weights->type );
494 assert( weights->rows == m );
495 wstep = weights->step;
499 if( sampleIdx != NULL )
501 assert( CV_MAT_TYPE( sampleIdx->type ) == CV_32FC1 );
503 idxdata = sampleIdx->data.ptr;
504 if( sampleIdx->rows == 1 )
507 idxstep = CV_ELEM_SIZE( sampleIdx->type );
512 idxstep = sampleIdx->step;
517 idx = (int*) cvAlloc( l * sizeof( int ) );
518 stump = (CvStumpClassifier*) cvAlloc( sizeof( CvStumpClassifier) );
521 memset( (void*) stump, 0, sizeof( CvStumpClassifier ) );
523 stump->eval = cvEvalStumpClassifier;
526 stump->release = cvReleaseStumpClassifier;
528 stump->lerror = FLT_MAX;
529 stump->rerror = FLT_MAX;
534 if( sampleIdx != NULL )
536 for( i = 0; i < l; i++ )
538 idx[i] = (int) *((float*) (idxdata + i*idxstep));
543 for( i = 0; i < l; i++ )
549 for( i = 0; i < n; i++ )
553 va.data = data + i * ((size_t) cstep);
555 icvSortIndexedValArray_32s( idx, l, &va );
556 if( findStumpThreshold_32s[(int) ((CvStumpTrainParams*) trainParams)->error]
557 ( data + i * ((size_t) cstep), sstep,
558 wdata, wstep, ydata, ystep, (uchar*) idx, sizeof( int ), l,
559 &(stump->lerror), &(stump->rerror),
560 &(stump->threshold), &(stump->left), &(stump->right),
561 &sumw, &sumwy, &sumwyy ) )
565 } /* for each component */
571 if( ((CvStumpTrainParams*) trainParams)->type == CV_CLASSIFICATION_CLASS )
573 stump->left = 2.0F * (stump->left >= 0.5F) - 1.0F;
574 stump->right = 2.0F * (stump->right >= 0.5F) - 1.0F;
577 return (CvClassifier*) stump;
581 * cvCreateMTStumpClassifier
583 * Multithreaded stump classifier constructor
584 * Includes huge train data support through callback function
587 CvClassifier* cvCreateMTStumpClassifier( CvMat* trainData,
591 CvMat* missedMeasurementsMask,
595 CvClassifierTrainParams* trainParams )
597 CvStumpClassifier* stump = NULL;
598 int m = 0; /* number of samples */
599 int n = 0; /* number of components */
603 int datan = 0; /* num components */
606 uchar* idxdata = NULL;
608 int l = 0; /* number of indices */
612 uchar* sorteddata = NULL;
614 size_t sortedcstep = 0; /* component step */
615 size_t sortedsstep = 0; /* sample step */
616 int sortedn = 0; /* num components */
617 int sortedm = 0; /* num samples */
626 /* private variables */
655 /* end private variables */
657 assert( trainParams != NULL );
658 assert( trainClasses != NULL );
659 assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
660 assert( missedMeasurementsMask == NULL );
661 assert( compIdx == NULL );
663 stumperror = (int) ((CvMTStumpTrainParams*) trainParams)->error;
665 ydata = trainClasses->data.ptr;
666 if( trainClasses->rows == 1 )
668 m = trainClasses->cols;
669 ystep = CV_ELEM_SIZE( trainClasses->type );
673 m = trainClasses->rows;
674 ystep = trainClasses->step;
677 wdata = weights->data.ptr;
678 if( weights->rows == 1 )
680 assert( weights->cols == m );
681 wstep = CV_ELEM_SIZE( weights->type );
685 assert( weights->rows == m );
686 wstep = weights->step;
689 if( ((CvMTStumpTrainParams*) trainParams)->sortedIdx != NULL )
692 CV_MAT_TYPE( ((CvMTStumpTrainParams*) trainParams)->sortedIdx->type );
693 assert( sortedtype == CV_16SC1 || sortedtype == CV_32SC1
694 || sortedtype == CV_32FC1 );
695 sorteddata = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->data.ptr;
696 sortedsstep = CV_ELEM_SIZE( sortedtype );
697 sortedcstep = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->step;
698 sortedn = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->rows;
699 sortedm = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->cols;
702 if( trainData == NULL )
704 assert( ((CvMTStumpTrainParams*) trainParams)->getTrainData != NULL );
705 n = ((CvMTStumpTrainParams*) trainParams)->numcomp;
710 assert( CV_MAT_TYPE( trainData->type ) == CV_32FC1 );
711 data = trainData->data.ptr;
712 if( CV_IS_ROW_SAMPLE( flags ) )
714 cstep = CV_ELEM_SIZE( trainData->type );
715 sstep = trainData->step;
716 assert( m == trainData->rows );
717 datan = n = trainData->cols;
721 sstep = CV_ELEM_SIZE( trainData->type );
722 cstep = trainData->step;
723 assert( m == trainData->cols );
724 datan = n = trainData->rows;
726 if( ((CvMTStumpTrainParams*) trainParams)->getTrainData != NULL )
728 n = ((CvMTStumpTrainParams*) trainParams)->numcomp;
731 assert( datan <= n );
733 if( sampleIdx != NULL )
735 assert( CV_MAT_TYPE( sampleIdx->type ) == CV_32FC1 );
736 idxdata = sampleIdx->data.ptr;
737 idxstep = ( sampleIdx->rows == 1 )
738 ? CV_ELEM_SIZE( sampleIdx->type ) : sampleIdx->step;
739 l = ( sampleIdx->rows == 1 ) ? sampleIdx->cols : sampleIdx->rows;
741 if( sorteddata != NULL )
743 filter = (char*) cvAlloc( sizeof( char ) * m );
744 memset( (void*) filter, 0, sizeof( char ) * m );
745 for( i = 0; i < l; i++ )
747 filter[(int) *((float*) (idxdata + i * idxstep))] = (char) 1;
756 stump = (CvStumpClassifier*) cvAlloc( sizeof( CvStumpClassifier) );
759 memset( (void*) stump, 0, sizeof( CvStumpClassifier ) );
761 portion = ((CvMTStumpTrainParams*)trainParams)->portion;
768 portion /= omp_get_max_threads();
772 stump->eval = cvEvalStumpClassifier;
775 stump->release = cvReleaseStumpClassifier;
777 stump->lerror = FLT_MAX;
778 stump->rerror = FLT_MAX;
784 #pragma omp parallel private(mat, va, lerror, rerror, left, right, threshold, \
785 optcompidx, sumw, sumwy, sumwyy, t_compidx, t_n, \
786 ti, tj, tk, t_data, t_cstep, t_sstep, matcstep, \
821 /* prepare matrix for callback */
822 if( CV_IS_ROW_SAMPLE( flags ) )
824 mat = cvMat( m, portion, CV_32FC1, 0 );
825 matcstep = CV_ELEM_SIZE( mat.type );
830 mat = cvMat( portion, m, CV_32FC1, 0 );
832 matsstep = CV_ELEM_SIZE( mat.type );
834 mat.data.ptr = (uchar*) cvAlloc( sizeof( float ) * mat.rows * mat.cols );
837 if( filter != NULL || sortedn < n )
839 t_idx = (int*) cvAlloc( sizeof( int ) * m );
840 if( sortedn == 0 || filter == NULL )
842 if( idxdata != NULL )
844 for( ti = 0; ti < l; ti++ )
846 t_idx[ti] = (int) *((float*) (idxdata + ti * idxstep));
851 for( ti = 0; ti < l; ti++ )
860 #pragma omp critical(c_compidx)
866 while( t_compidx < n )
869 if( t_compidx < datan )
871 t_n = ( t_n < (datan - t_compidx) ) ? t_n : (datan - t_compidx);
878 t_n = ( t_n < (n - t_compidx) ) ? t_n : (n - t_compidx);
881 t_data = mat.data.ptr - t_compidx * ((size_t) t_cstep );
883 /* calculate components */
884 ((CvMTStumpTrainParams*)trainParams)->getTrainData( &mat,
885 sampleIdx, compIdx, t_compidx, t_n,
886 ((CvMTStumpTrainParams*)trainParams)->userdata );
889 if( sorteddata != NULL )
893 /* have sorted indices and filter */
897 for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
900 for( tj = 0; tj < sortedm; tj++ )
902 int curidx = (int) ( *((short*) (sorteddata
903 + ti * sortedcstep + tj * sortedsstep)) );
904 if( filter[curidx] != 0 )
906 t_idx[tk++] = curidx;
909 if( findStumpThreshold_32s[stumperror](
910 t_data + ti * t_cstep, t_sstep,
911 wdata, wstep, ydata, ystep,
912 (uchar*) t_idx, sizeof( int ), tk,
914 &threshold, &left, &right,
915 &sumw, &sumwy, &sumwyy ) )
922 for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
925 for( tj = 0; tj < sortedm; tj++ )
927 int curidx = (int) ( *((int*) (sorteddata
928 + ti * sortedcstep + tj * sortedsstep)) );
929 if( filter[curidx] != 0 )
931 t_idx[tk++] = curidx;
934 if( findStumpThreshold_32s[stumperror](
935 t_data + ti * t_cstep, t_sstep,
936 wdata, wstep, ydata, ystep,
937 (uchar*) t_idx, sizeof( int ), tk,
939 &threshold, &left, &right,
940 &sumw, &sumwy, &sumwyy ) )
947 for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
950 for( tj = 0; tj < sortedm; tj++ )
952 int curidx = (int) ( *((float*) (sorteddata
953 + ti * sortedcstep + tj * sortedsstep)) );
954 if( filter[curidx] != 0 )
956 t_idx[tk++] = curidx;
959 if( findStumpThreshold_32s[stumperror](
960 t_data + ti * t_cstep, t_sstep,
961 wdata, wstep, ydata, ystep,
962 (uchar*) t_idx, sizeof( int ), tk,
964 &threshold, &left, &right,
965 &sumw, &sumwy, &sumwyy ) )
978 /* have sorted indices */
982 for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
984 if( findStumpThreshold_16s[stumperror](
985 t_data + ti * t_cstep, t_sstep,
986 wdata, wstep, ydata, ystep,
987 sorteddata + ti * sortedcstep, sortedsstep, sortedm,
989 &threshold, &left, &right,
990 &sumw, &sumwy, &sumwyy ) )
997 for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
999 if( findStumpThreshold_32s[stumperror](
1000 t_data + ti * t_cstep, t_sstep,
1001 wdata, wstep, ydata, ystep,
1002 sorteddata + ti * sortedcstep, sortedsstep, sortedm,
1004 &threshold, &left, &right,
1005 &sumw, &sumwy, &sumwyy ) )
1012 for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
1014 if( findStumpThreshold_32f[stumperror](
1015 t_data + ti * t_cstep, t_sstep,
1016 wdata, wstep, ydata, ystep,
1017 sorteddata + ti * sortedcstep, sortedsstep, sortedm,
1019 &threshold, &left, &right,
1020 &sumw, &sumwy, &sumwyy ) )
1033 ti = MAX( t_compidx, MIN( sortedn, t_compidx + t_n ) );
1034 for( ; ti < t_compidx + t_n; ti++ )
1036 va.data = t_data + ti * t_cstep;
1038 icvSortIndexedValArray_32s( t_idx, l, &va );
1039 if( findStumpThreshold_32s[stumperror](
1040 t_data + ti * t_cstep, t_sstep,
1041 wdata, wstep, ydata, ystep,
1042 (uchar*)t_idx, sizeof( int ), l,
1044 &threshold, &left, &right,
1045 &sumw, &sumwy, &sumwyy ) )
1051 #pragma omp critical(c_compidx)
1052 #endif /* _OPENMP */
1054 t_compidx = compidx;
1057 } /* while have training data */
1059 /* get the best classifier */
1061 #pragma omp critical(c_beststump)
1062 #endif /* _OPENMP */
1064 if( lerror + rerror < stump->lerror + stump->rerror )
1066 stump->lerror = lerror;
1067 stump->rerror = rerror;
1068 stump->compidx = optcompidx;
1069 stump->threshold = threshold;
1071 stump->right = right;
1075 /* free allocated memory */
1076 if( mat.data.ptr != NULL )
1078 cvFree( &(mat.data.ptr) );
1084 } /* end of parallel region */
1088 /* free allocated memory */
1089 if( filter != NULL )
1094 if( ((CvMTStumpTrainParams*) trainParams)->type == CV_CLASSIFICATION_CLASS )
1096 stump->left = 2.0F * (stump->left >= 0.5F) - 1.0F;
1097 stump->right = 2.0F * (stump->right >= 0.5F) - 1.0F;
1100 return (CvClassifier*) stump;
1104 float cvEvalCARTClassifier( CvClassifier* classifier, CvMat* sample )
1106 CV_FUNCNAME( "cvEvalCARTClassifier" );
1113 CV_ASSERT( classifier != NULL );
1114 CV_ASSERT( sample != NULL );
1115 CV_ASSERT( CV_MAT_TYPE( sample->type ) == CV_32FC1 );
1116 CV_ASSERT( sample->rows == 1 || sample->cols == 1 );
1119 if( sample->rows == 1 )
1123 if( (CV_MAT_ELEM( (*sample), float, 0,
1124 ((CvCARTClassifier*) classifier)->compidx[idx] )) <
1125 ((CvCARTClassifier*) classifier)->threshold[idx] )
1127 idx = ((CvCARTClassifier*) classifier)->left[idx];
1131 idx = ((CvCARTClassifier*) classifier)->right[idx];
1139 if( (CV_MAT_ELEM( (*sample), float,
1140 ((CvCARTClassifier*) classifier)->compidx[idx], 0 )) <
1141 ((CvCARTClassifier*) classifier)->threshold[idx] )
1143 idx = ((CvCARTClassifier*) classifier)->left[idx];
1147 idx = ((CvCARTClassifier*) classifier)->right[idx];
1154 return ((CvCARTClassifier*) classifier)->val[-idx];
1158 float cvEvalCARTClassifierIdx( CvClassifier* classifier, CvMat* sample )
1160 CV_FUNCNAME( "cvEvalCARTClassifierIdx" );
1167 CV_ASSERT( classifier != NULL );
1168 CV_ASSERT( sample != NULL );
1169 CV_ASSERT( CV_MAT_TYPE( sample->type ) == CV_32FC1 );
1170 CV_ASSERT( sample->rows == 1 || sample->cols == 1 );
1173 if( sample->rows == 1 )
1177 if( (CV_MAT_ELEM( (*sample), float, 0,
1178 ((CvCARTClassifier*) classifier)->compidx[idx] )) <
1179 ((CvCARTClassifier*) classifier)->threshold[idx] )
1181 idx = ((CvCARTClassifier*) classifier)->left[idx];
1185 idx = ((CvCARTClassifier*) classifier)->right[idx];
1193 if( (CV_MAT_ELEM( (*sample), float,
1194 ((CvCARTClassifier*) classifier)->compidx[idx], 0 )) <
1195 ((CvCARTClassifier*) classifier)->threshold[idx] )
1197 idx = ((CvCARTClassifier*) classifier)->left[idx];
1201 idx = ((CvCARTClassifier*) classifier)->right[idx];
1208 return (float) (-idx);
1212 void cvReleaseCARTClassifier( CvClassifier** classifier )
1214 cvFree( classifier );
1218 void CV_CDECL icvDefaultSplitIdx_R( int compidx, float threshold,
1219 CvMat* idx, CvMat** left, CvMat** right,
1222 CvMat* trainData = (CvMat*) userdata;
1225 *left = cvCreateMat( 1, trainData->rows, CV_32FC1 );
1226 *right = cvCreateMat( 1, trainData->rows, CV_32FC1 );
1227 (*left)->cols = (*right)->cols = 0;
1230 for( i = 0; i < trainData->rows; i++ )
1232 if( CV_MAT_ELEM( *trainData, float, i, compidx ) < threshold )
1234 (*left)->data.fl[(*left)->cols++] = (float) i;
1238 (*right)->data.fl[(*right)->cols++] = (float) i;
1249 idxdata = idx->data.ptr;
1250 idxnum = (idx->rows == 1) ? idx->cols : idx->rows;
1251 idxstep = (idx->rows == 1) ? CV_ELEM_SIZE( idx->type ) : idx->step;
1252 for( i = 0; i < idxnum; i++ )
1254 index = (int) *((float*) (idxdata + i * idxstep));
1255 if( CV_MAT_ELEM( *trainData, float, index, compidx ) < threshold )
1257 (*left)->data.fl[(*left)->cols++] = (float) index;
1261 (*right)->data.fl[(*right)->cols++] = (float) index;
1267 void CV_CDECL icvDefaultSplitIdx_C( int compidx, float threshold,
1268 CvMat* idx, CvMat** left, CvMat** right,
1271 CvMat* trainData = (CvMat*) userdata;
1274 *left = cvCreateMat( 1, trainData->cols, CV_32FC1 );
1275 *right = cvCreateMat( 1, trainData->cols, CV_32FC1 );
1276 (*left)->cols = (*right)->cols = 0;
1279 for( i = 0; i < trainData->cols; i++ )
1281 if( CV_MAT_ELEM( *trainData, float, compidx, i ) < threshold )
1283 (*left)->data.fl[(*left)->cols++] = (float) i;
1287 (*right)->data.fl[(*right)->cols++] = (float) i;
1298 idxdata = idx->data.ptr;
1299 idxnum = (idx->rows == 1) ? idx->cols : idx->rows;
1300 idxstep = (idx->rows == 1) ? CV_ELEM_SIZE( idx->type ) : idx->step;
1301 for( i = 0; i < idxnum; i++ )
1303 index = (int) *((float*) (idxdata + i * idxstep));
1304 if( CV_MAT_ELEM( *trainData, float, compidx, index ) < threshold )
1306 (*left)->data.fl[(*left)->cols++] = (float) index;
1310 (*right)->data.fl[(*right)->cols++] = (float) index;
1316 /* internal structure used in CART creation */
1317 typedef struct CvCARTNode
1320 CvStumpClassifier* stump;
1327 CvClassifier* cvCreateCARTClassifier( CvMat* trainData,
1329 CvMat* trainClasses,
1331 CvMat* missedMeasurementsMask,
1335 CvClassifierTrainParams* trainParams )
1337 CvCARTClassifier* cart = NULL;
1338 size_t datasize = 0;
1343 CvCARTNode* intnode = NULL;
1344 CvCARTNode* list = NULL;
1349 float maxerrdrop = 0.0F;
1352 void (*splitIdxCallback)( int compidx, float threshold,
1353 CvMat* idx, CvMat** left, CvMat** right,
1357 count = ((CvCARTTrainParams*) trainParams)->count;
1359 assert( count > 0 );
1361 datasize = sizeof( *cart ) + (sizeof( float ) + 3 * sizeof( int )) * count +
1362 sizeof( float ) * (count + 1);
1364 cart = (CvCARTClassifier*) cvAlloc( datasize );
1365 memset( cart, 0, datasize );
1367 cart->count = count;
1369 cart->eval = cvEvalCARTClassifier;
1371 cart->release = cvReleaseCARTClassifier;
1373 cart->compidx = (int*) (cart + 1);
1374 cart->threshold = (float*) (cart->compidx + count);
1375 cart->left = (int*) (cart->threshold + count);
1376 cart->right = (int*) (cart->left + count);
1377 cart->val = (float*) (cart->right + count);
1379 datasize = sizeof( CvCARTNode ) * (count + count);
1380 intnode = (CvCARTNode*) cvAlloc( datasize );
1381 memset( intnode, 0, datasize );
1382 list = (CvCARTNode*) (intnode + count);
1384 splitIdxCallback = ((CvCARTTrainParams*) trainParams)->splitIdx;
1385 userdata = ((CvCARTTrainParams*) trainParams)->userdata;
1386 if( splitIdxCallback == NULL )
1388 splitIdxCallback = ( CV_IS_ROW_SAMPLE( flags ) )
1389 ? icvDefaultSplitIdx_R : icvDefaultSplitIdx_C;
1390 userdata = trainData;
1393 /* create root of the tree */
1394 intnode[0].sampleIdx = sampleIdx;
1395 intnode[0].stump = (CvStumpClassifier*)
1396 ((CvCARTTrainParams*) trainParams)->stumpConstructor( trainData, flags,
1397 trainClasses, typeMask, missedMeasurementsMask, compIdx, sampleIdx, weights,
1398 ((CvCARTTrainParams*) trainParams)->stumpTrainParams );
1399 cart->left[0] = cart->right[0] = 0;
1403 for( i = 1; i < count; i++ )
1405 /* split last added node */
1406 splitIdxCallback( intnode[i-1].stump->compidx, intnode[i-1].stump->threshold,
1407 intnode[i-1].sampleIdx, &lidx, &ridx, userdata );
1409 if( intnode[i-1].stump->lerror != 0.0F )
1411 list[listcount].sampleIdx = lidx;
1412 list[listcount].stump = (CvStumpClassifier*)
1413 ((CvCARTTrainParams*) trainParams)->stumpConstructor( trainData, flags,
1414 trainClasses, typeMask, missedMeasurementsMask, compIdx,
1415 list[listcount].sampleIdx,
1416 weights, ((CvCARTTrainParams*) trainParams)->stumpTrainParams );
1417 list[listcount].errdrop = intnode[i-1].stump->lerror
1418 - (list[listcount].stump->lerror + list[listcount].stump->rerror);
1419 list[listcount].leftflag = 1;
1420 list[listcount].parent = i-1;
1425 cvReleaseMat( &lidx );
1427 if( intnode[i-1].stump->rerror != 0.0F )
1429 list[listcount].sampleIdx = ridx;
1430 list[listcount].stump = (CvStumpClassifier*)
1431 ((CvCARTTrainParams*) trainParams)->stumpConstructor( trainData, flags,
1432 trainClasses, typeMask, missedMeasurementsMask, compIdx,
1433 list[listcount].sampleIdx,
1434 weights, ((CvCARTTrainParams*) trainParams)->stumpTrainParams );
1435 list[listcount].errdrop = intnode[i-1].stump->rerror
1436 - (list[listcount].stump->lerror + list[listcount].stump->rerror);
1437 list[listcount].leftflag = 0;
1438 list[listcount].parent = i-1;
1443 cvReleaseMat( &ridx );
1446 if( listcount == 0 ) break;
1448 /* find the best node to be added to the tree */
1450 maxerrdrop = list[idx].errdrop;
1451 for( j = 1; j < listcount; j++ )
1453 if( list[j].errdrop > maxerrdrop )
1456 maxerrdrop = list[j].errdrop;
1459 intnode[i] = list[idx];
1460 if( list[idx].leftflag )
1462 cart->left[list[idx].parent] = i;
1466 cart->right[list[idx].parent] = i;
1468 if( idx != (listcount - 1) )
1470 list[idx] = list[listcount - 1];
1475 /* fill <cart> fields */
1478 for( i = 0; i < count && (intnode[i].stump != NULL); i++ )
1481 cart->compidx[i] = intnode[i].stump->compidx;
1482 cart->threshold[i] = intnode[i].stump->threshold;
1485 if( cart->left[i] <= 0 )
1488 cart->val[j] = intnode[i].stump->left;
1491 if( cart->right[i] <= 0 )
1493 cart->right[i] = -j;
1494 cart->val[j] = intnode[i].stump->right;
1500 for( i = 0; i < count && (intnode[i].stump != NULL); i++ )
1502 intnode[i].stump->release( (CvClassifier**) &(intnode[i].stump) );
1505 cvReleaseMat( &(intnode[i].sampleIdx) );
1508 for( i = 0; i < listcount; i++ )
1510 list[i].stump->release( (CvClassifier**) &(list[i].stump) );
1511 cvReleaseMat( &(list[i].sampleIdx) );
1516 return (CvClassifier*) cart;
1519 /****************************************************************************************\
1521 \****************************************************************************************/
1523 typedef struct CvBoostTrainer
1526 int count; /* (idx) ? number_of_indices : number_of_samples */
1532 * cvBoostStartTraining, cvBoostNextWeakClassifier, cvBoostEndTraining
1534 * These functions perform training of 2-class boosting classifier
1535 * using ANY appropriate weak classifier
1539 CvBoostTrainer* icvBoostStartTraining( CvMat* trainClasses,
1540 CvMat* weakTrainVals,
1555 CvBoostTrainer* ptr;
1561 assert( trainClasses != NULL );
1562 assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
1563 assert( weakTrainVals != NULL );
1564 assert( CV_MAT_TYPE( weakTrainVals->type ) == CV_32FC1 );
1566 CV_MAT2VEC( *trainClasses, ydata, ystep, m );
1567 CV_MAT2VEC( *weakTrainVals, traindata, trainstep, trainnum );
1569 assert( m == trainnum );
1576 CV_MAT2VEC( *sampleIdx, idxdata, idxstep, idxnum );
1579 datasize = sizeof( *ptr ) + sizeof( *ptr->idx ) * idxnum;
1580 ptr = (CvBoostTrainer*) cvAlloc( datasize );
1581 memset( ptr, 0, datasize );
1592 ptr->idx = (int*) (ptr + 1);
1593 ptr->count = idxnum;
1594 for( i = 0; i < ptr->count; i++ )
1596 cvRawDataToScalar( idxdata + i*idxstep, CV_MAT_TYPE( sampleIdx->type ), &s );
1597 ptr->idx[i] = (int) s.val[0];
1600 for( i = 0; i < ptr->count; i++ )
1602 idx = (ptr->idx) ? ptr->idx[i] : i;
1604 *((float*) (traindata + idx * trainstep)) =
1605 2.0F * (*((float*) (ydata + idx * ystep))) - 1.0F;
1613 * Discrete AdaBoost functions
1617 float icvBoostNextWeakClassifierDAB( CvMat* weakEvalVals,
1618 CvMat* trainClasses,
1619 CvMat* weakTrainVals,
1621 CvBoostTrainer* trainer )
1638 assert( weakEvalVals != NULL );
1639 assert( CV_MAT_TYPE( weakEvalVals->type ) == CV_32FC1 );
1640 assert( trainClasses != NULL );
1641 assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
1642 assert( weights != NULL );
1643 assert( CV_MAT_TYPE( weights ->type ) == CV_32FC1 );
1645 CV_MAT2VEC( *weakEvalVals, evaldata, evalstep, m );
1646 CV_MAT2VEC( *trainClasses, ydata, ystep, ynum );
1647 CV_MAT2VEC( *weights, wdata, wstep, wnum );
1649 assert( m == ynum );
1650 assert( m == wnum );
1654 for( i = 0; i < trainer->count; i++ )
1656 idx = (trainer->idx) ? trainer->idx[i] : i;
1658 sumw += *((float*) (wdata + idx*wstep));
1659 err += (*((float*) (wdata + idx*wstep))) *
1660 ( (*((float*) (evaldata + idx*evalstep))) !=
1661 2.0F * (*((float*) (ydata + idx*ystep))) - 1.0F );
1664 err = -cvLogRatio( err );
1666 for( i = 0; i < trainer->count; i++ )
1668 idx = (trainer->idx) ? trainer->idx[i] : i;
1670 *((float*) (wdata + idx*wstep)) *= expf( err *
1671 ((*((float*) (evaldata + idx*evalstep))) !=
1672 2.0F * (*((float*) (ydata + idx*ystep))) - 1.0F) );
1673 sumw += *((float*) (wdata + idx*wstep));
1675 for( i = 0; i < trainer->count; i++ )
1677 idx = (trainer->idx) ? trainer->idx[i] : i;
1679 *((float*) (wdata + idx * wstep)) /= sumw;
1687 * Real AdaBoost functions
1691 float icvBoostNextWeakClassifierRAB( CvMat* weakEvalVals,
1692 CvMat* trainClasses,
1693 CvMat* weakTrainVals,
1695 CvBoostTrainer* trainer )
1710 assert( weakEvalVals != NULL );
1711 assert( CV_MAT_TYPE( weakEvalVals->type ) == CV_32FC1 );
1712 assert( trainClasses != NULL );
1713 assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
1714 assert( weights != NULL );
1715 assert( CV_MAT_TYPE( weights ->type ) == CV_32FC1 );
1717 CV_MAT2VEC( *weakEvalVals, evaldata, evalstep, m );
1718 CV_MAT2VEC( *trainClasses, ydata, ystep, ynum );
1719 CV_MAT2VEC( *weights, wdata, wstep, wnum );
1721 assert( m == ynum );
1722 assert( m == wnum );
1726 for( i = 0; i < trainer->count; i++ )
1728 idx = (trainer->idx) ? trainer->idx[i] : i;
1730 *((float*) (wdata + idx*wstep)) *= expf( (-(*((float*) (ydata + idx*ystep))) + 0.5F)
1731 * cvLogRatio( *((float*) (evaldata + idx*evalstep)) ) );
1732 sumw += *((float*) (wdata + idx*wstep));
1734 for( i = 0; i < trainer->count; i++ )
1736 idx = (trainer->idx) ? trainer->idx[i] : i;
1738 *((float*) (wdata + idx*wstep)) /= sumw;
1746 * LogitBoost functions
1749 #define CV_LB_PROB_THRESH 0.01F
1750 #define CV_LB_WEIGHT_THRESHOLD 0.0001F
1753 void icvResponsesAndWeightsLB( int num, uchar* wdata, int wstep,
1754 uchar* ydata, int ystep,
1755 uchar* fdata, int fstep,
1756 uchar* traindata, int trainstep,
1762 for( i = 0; i < num; i++ )
1764 idx = (indices) ? indices[i] : i;
1766 p = 1.0F / (1.0F + expf( -(*((float*) (fdata + idx*fstep)))) );
1767 *((float*) (wdata + idx*wstep)) = MAX( p * (1.0F - p), CV_LB_WEIGHT_THRESHOLD );
1768 if( *((float*) (ydata + idx*ystep)) == 1.0F )
1770 *((float*) (traindata + idx*trainstep)) =
1771 1.0F / (MAX( p, CV_LB_PROB_THRESH ));
1775 *((float*) (traindata + idx*trainstep)) =
1776 -1.0F / (MAX( 1.0F - p, CV_LB_PROB_THRESH ));
1782 CvBoostTrainer* icvBoostStartTrainingLB( CvMat* trainClasses,
1783 CvMat* weakTrainVals,
1789 CvBoostTrainer* ptr;
1806 assert( trainClasses != NULL );
1807 assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
1808 assert( weakTrainVals != NULL );
1809 assert( CV_MAT_TYPE( weakTrainVals->type ) == CV_32FC1 );
1810 assert( weights != NULL );
1811 assert( CV_MAT_TYPE( weights->type ) == CV_32FC1 );
1813 CV_MAT2VEC( *trainClasses, ydata, ystep, m );
1814 CV_MAT2VEC( *weakTrainVals, traindata, trainstep, trainnum );
1815 CV_MAT2VEC( *weights, wdata, wstep, wnum );
1817 assert( m == trainnum );
1818 assert( m == wnum );
1826 CV_MAT2VEC( *sampleIdx, idxdata, idxstep, idxnum );
1829 datasize = sizeof( *ptr ) + sizeof( *ptr->F ) * m + sizeof( *ptr->idx ) * idxnum;
1830 ptr = (CvBoostTrainer*) cvAlloc( datasize );
1831 memset( ptr, 0, datasize );
1832 ptr->F = (float*) (ptr + 1);
1842 ptr->idx = (int*) (ptr->F + m);
1843 ptr->count = idxnum;
1844 for( i = 0; i < ptr->count; i++ )
1846 cvRawDataToScalar( idxdata + i*idxstep, CV_MAT_TYPE( sampleIdx->type ), &s );
1847 ptr->idx[i] = (int) s.val[0];
1851 for( i = 0; i < m; i++ )
1856 icvResponsesAndWeightsLB( ptr->count, wdata, wstep, ydata, ystep,
1857 (uchar*) ptr->F, sizeof( *ptr->F ),
1858 traindata, trainstep, ptr->idx );
1864 float icvBoostNextWeakClassifierLB( CvMat* weakEvalVals,
1865 CvMat* trainClasses,
1866 CvMat* weakTrainVals,
1868 CvBoostTrainer* trainer )
1884 assert( weakEvalVals != NULL );
1885 assert( CV_MAT_TYPE( weakEvalVals->type ) == CV_32FC1 );
1886 assert( trainClasses != NULL );
1887 assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
1888 assert( weakTrainVals != NULL );
1889 assert( CV_MAT_TYPE( weakTrainVals->type ) == CV_32FC1 );
1890 assert( weights != NULL );
1891 assert( CV_MAT_TYPE( weights ->type ) == CV_32FC1 );
1893 CV_MAT2VEC( *weakEvalVals, evaldata, evalstep, m );
1894 CV_MAT2VEC( *trainClasses, ydata, ystep, ynum );
1895 CV_MAT2VEC( *weakTrainVals, traindata, trainstep, trainnum );
1896 CV_MAT2VEC( *weights, wdata, wstep, wnum );
1898 assert( m == ynum );
1899 assert( m == wnum );
1900 assert( m == trainnum );
1901 //assert( m == trainer->count );
1903 for( i = 0; i < trainer->count; i++ )
1905 idx = (trainer->idx) ? trainer->idx[i] : i;
1907 trainer->F[idx] += *((float*) (evaldata + idx * evalstep));
1910 icvResponsesAndWeightsLB( trainer->count, wdata, wstep, ydata, ystep,
1911 (uchar*) trainer->F, sizeof( *trainer->F ),
1912 traindata, trainstep, trainer->idx );
1923 float icvBoostNextWeakClassifierGAB( CvMat* weakEvalVals,
1924 CvMat* trainClasses,
1925 CvMat* weakTrainVals,
1927 CvBoostTrainer* trainer )
1942 assert( weakEvalVals != NULL );
1943 assert( CV_MAT_TYPE( weakEvalVals->type ) == CV_32FC1 );
1944 assert( trainClasses != NULL );
1945 assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
1946 assert( weights != NULL );
1947 assert( CV_MAT_TYPE( weights->type ) == CV_32FC1 );
1949 CV_MAT2VEC( *weakEvalVals, evaldata, evalstep, m );
1950 CV_MAT2VEC( *trainClasses, ydata, ystep, ynum );
1951 CV_MAT2VEC( *weights, wdata, wstep, wnum );
1953 assert( m == ynum );
1954 assert( m == wnum );
1957 for( i = 0; i < trainer->count; i++ )
1959 idx = (trainer->idx) ? trainer->idx[i] : i;
1961 *((float*) (wdata + idx*wstep)) *=
1962 expf( -(*((float*) (evaldata + idx*evalstep)))
1963 * ( 2.0F * (*((float*) (ydata + idx*ystep))) - 1.0F ) );
1964 sumw += *((float*) (wdata + idx*wstep));
1967 for( i = 0; i < trainer->count; i++ )
1969 idx = (trainer->idx) ? trainer->idx[i] : i;
1971 *((float*) (wdata + idx*wstep)) /= sumw;
1977 typedef CvBoostTrainer* (*CvBoostStartTraining)( CvMat* trainClasses,
1978 CvMat* weakTrainVals,
1983 typedef float (*CvBoostNextWeakClassifier)( CvMat* weakEvalVals,
1984 CvMat* trainClasses,
1985 CvMat* weakTrainVals,
1987 CvBoostTrainer* data );
1989 CvBoostStartTraining startTraining[4] = {
1990 icvBoostStartTraining,
1991 icvBoostStartTraining,
1992 icvBoostStartTrainingLB,
1993 icvBoostStartTraining
1996 CvBoostNextWeakClassifier nextWeakClassifier[4] = {
1997 icvBoostNextWeakClassifierDAB,
1998 icvBoostNextWeakClassifierRAB,
1999 icvBoostNextWeakClassifierLB,
2000 icvBoostNextWeakClassifierGAB
2009 CvBoostTrainer* cvBoostStartTraining( CvMat* trainClasses,
2010 CvMat* weakTrainVals,
2015 return startTraining[type]( trainClasses, weakTrainVals, weights, sampleIdx, type );
2019 void cvBoostEndTraining( CvBoostTrainer** trainer )
2026 float cvBoostNextWeakClassifier( CvMat* weakEvalVals,
2027 CvMat* trainClasses,
2028 CvMat* weakTrainVals,
2030 CvBoostTrainer* trainer )
2032 return nextWeakClassifier[trainer->type]( weakEvalVals, trainClasses,
2033 weakTrainVals, weights, trainer );
2036 /****************************************************************************************\
2037 * Boosted tree models *
2038 \****************************************************************************************/
2040 typedef struct CvBtTrainer
2046 CvMat* trainClasses;
2059 CvMTStumpTrainParams stumpParams;
2060 CvCARTTrainParams cartParams;
2062 float* f; /* F_(m-1) */
2063 CvMat* y; /* yhat */
2065 CvBoostTrainer* boosttrainer;
2069 * cvBtStart, cvBtNext, cvBtEnd
2071 * These functions perform iterative training of
2072 * 2-class (CV_DABCLASS - CV_GABCLASS, CV_L2CLASS), K-class (CV_LKCLASS) classifier
2073 * or fit regression model (CV_LSREG, CV_LADREG, CV_MREG)
2074 * using decision tree as a weak classifier.
2077 typedef void (*CvZeroApproxFunc)( float* approx, CvBtTrainer* trainer );
2079 /* Mean zero approximation */
2080 void icvZeroApproxMean( float* approx, CvBtTrainer* trainer )
2086 for( i = 0; i < trainer->numsamples; i++ )
2088 idx = icvGetIdxAt( trainer->sampleIdx, i );
2089 approx[0] += *((float*) (trainer->ydata + idx * trainer->ystep));
2091 approx[0] /= (float) trainer->numsamples;
2095 * Median zero approximation
2097 void icvZeroApproxMed( float* approx, CvBtTrainer* trainer )
2102 for( i = 0; i < trainer->numsamples; i++ )
2104 idx = icvGetIdxAt( trainer->sampleIdx, i );
2105 trainer->f[i] = *((float*) (trainer->ydata + idx * trainer->ystep));
2108 icvSort_32f( trainer->f, trainer->numsamples, 0 );
2109 approx[0] = trainer->f[trainer->numsamples / 2];
2113 * 0.5 * log( mean(y) / (1 - mean(y)) ) where y in {0, 1}
2115 void icvZeroApproxLog( float* approx, CvBtTrainer* trainer )
2119 icvZeroApproxMean( &y_mean, trainer );
2120 approx[0] = 0.5F * cvLogRatio( y_mean );
2124 * 0 zero approximation
2126 void icvZeroApprox0( float* approx, CvBtTrainer* trainer )
2130 for( i = 0; i < trainer->numclasses; i++ )
2136 static CvZeroApproxFunc icvZeroApproxFunc[] =
2138 icvZeroApprox0, /* CV_DABCLASS */
2139 icvZeroApprox0, /* CV_RABCLASS */
2140 icvZeroApprox0, /* CV_LBCLASS */
2141 icvZeroApprox0, /* CV_GABCLASS */
2142 icvZeroApproxLog, /* CV_L2CLASS */
2143 icvZeroApprox0, /* CV_LKCLASS */
2144 icvZeroApproxMean, /* CV_LSREG */
2145 icvZeroApproxMed, /* CV_LADREG */
2146 icvZeroApproxMed, /* CV_MREG */
2150 void cvBtNext( CvCARTClassifier** trees, CvBtTrainer* trainer );
2153 CvBtTrainer* cvBtStart( CvCARTClassifier** trees,
2156 CvMat* trainClasses,
2165 CV_FUNCNAME( "cvBtStart" );
2176 CV_ERROR( CV_StsNullPtr, "Invalid trees parameter" );
2179 if( type < CV_DABCLASS || type > CV_MREG )
2181 CV_ERROR( CV_StsUnsupportedFormat, "Unsupported type parameter" );
2183 if( type == CV_LKCLASS )
2185 CV_ASSERT( numclasses >= 2 );
2192 m = MAX( trainClasses->rows, trainClasses->cols );
2194 data_size = sizeof( *ptr );
2195 if( type > CV_GABCLASS )
2197 data_size += m * numclasses * sizeof( *(ptr->f) );
2199 CV_CALL( ptr = (CvBtTrainer*) cvAlloc( data_size ) );
2200 memset( ptr, 0, data_size );
2201 ptr->f = (float*) (ptr + 1);
2203 ptr->trainData = trainData;
2205 ptr->trainClasses = trainClasses;
2206 CV_MAT2VEC( *trainClasses, ptr->ydata, ptr->ystep, ptr->m );
2208 memset( &(ptr->cartParams), 0, sizeof( ptr->cartParams ) );
2209 memset( &(ptr->stumpParams), 0, sizeof( ptr->stumpParams ) );
2214 ptr->stumpParams.error = CV_MISCLASSIFICATION;
2215 ptr->stumpParams.type = CV_CLASSIFICATION_CLASS;
2218 ptr->stumpParams.error = CV_GINI;
2219 ptr->stumpParams.type = CV_CLASSIFICATION;
2222 ptr->stumpParams.error = CV_SQUARE;
2223 ptr->stumpParams.type = CV_REGRESSION;
2225 ptr->cartParams.count = numsplits;
2226 ptr->cartParams.stumpTrainParams = (CvClassifierTrainParams*) &(ptr->stumpParams);
2227 ptr->cartParams.stumpConstructor = cvCreateMTStumpClassifier;
2229 ptr->param[0] = param[0];
2230 ptr->param[1] = param[1];
2232 ptr->numclasses = numclasses;
2234 CV_CALL( ptr->y = cvCreateMat( 1, m, CV_32FC1 ) );
2235 ptr->sampleIdx = sampleIdx;
2236 ptr->numsamples = ( sampleIdx == NULL ) ? ptr->m
2237 : MAX( sampleIdx->rows, sampleIdx->cols );
2239 ptr->weights = cvCreateMat( 1, m, CV_32FC1 );
2240 cvSet( ptr->weights, cvScalar( 1.0 ) );
2242 if( type <= CV_GABCLASS )
2244 ptr->boosttrainer = cvBoostStartTraining( ptr->trainClasses, ptr->y,
2245 ptr->weights, NULL, type );
2247 CV_CALL( cvBtNext( trees, ptr ) );
2251 data_size = sizeof( *zero_approx ) * numclasses;
2252 CV_CALL( zero_approx = (float*) cvAlloc( data_size ) );
2253 icvZeroApproxFunc[type]( zero_approx, ptr );
2254 for( i = 0; i < m; i++ )
2256 for( j = 0; j < numclasses; j++ )
2258 ptr->f[i * numclasses + j] = zero_approx[j];
2262 CV_CALL( cvBtNext( trees, ptr ) );
2264 for( i = 0; i < numclasses; i++ )
2266 for( j = 0; j <= trees[i]->count; j++ )
2268 trees[i]->val[j] += zero_approx[i];
2271 CV_CALL( cvFree( &zero_approx ) );
2279 void icvBtNext_LSREG( CvCARTClassifier** trees, CvBtTrainer* trainer )
2283 /* yhat_i = y_i - F_(m-1)(x_i) */
2284 for( i = 0; i < trainer->m; i++ )
2286 trainer->y->data.fl[i] =
2287 *((float*) (trainer->ydata + i * trainer->ystep)) - trainer->f[i];
2290 trees[0] = (CvCARTClassifier*) cvCreateCARTClassifier( trainer->trainData,
2292 trainer->y, NULL, NULL, NULL, trainer->sampleIdx, trainer->weights,
2293 (CvClassifierTrainParams*) &trainer->cartParams );
2297 void icvBtNext_LADREG( CvCARTClassifier** trees, CvBtTrainer* trainer )
2299 CvCARTClassifier* ptr;
2312 data_size = trainer->m * sizeof( *idx );
2313 idx = (int*) cvAlloc( data_size );
2314 data_size = trainer->m * sizeof( *resp );
2315 resp = (float*) cvAlloc( data_size );
2317 /* yhat_i = sign(y_i - F_(m-1)(x_i)) */
2318 for( i = 0; i < trainer->numsamples; i++ )
2320 index = icvGetIdxAt( trainer->sampleIdx, i );
2321 trainer->y->data.fl[index] = (float)
2322 CV_SIGN( *((float*) (trainer->ydata + index * trainer->ystep))
2323 - trainer->f[index] );
2326 ptr = (CvCARTClassifier*) cvCreateCARTClassifier( trainer->trainData, trainer->flags,
2327 trainer->y, NULL, NULL, NULL, trainer->sampleIdx, trainer->weights,
2328 (CvClassifierTrainParams*) &trainer->cartParams );
2330 CV_GET_SAMPLE( *trainer->trainData, trainer->flags, 0, sample );
2331 CV_GET_SAMPLE_STEP( *trainer->trainData, trainer->flags, sample_step );
2332 sample_data = sample.data.ptr;
2333 for( i = 0; i < trainer->numsamples; i++ )
2335 index = icvGetIdxAt( trainer->sampleIdx, i );
2336 sample.data.ptr = sample_data + index * sample_step;
2337 idx[index] = (int) cvEvalCARTClassifierIdx( (CvClassifier*) ptr, &sample );
2339 for( j = 0; j <= ptr->count; j++ )
2342 for( i = 0; i < trainer->numsamples; i++ )
2344 index = icvGetIdxAt( trainer->sampleIdx, i );
2345 if( idx[index] == j )
2347 resp[respnum++] = *((float*) (trainer->ydata + index * trainer->ystep))
2348 - trainer->f[index];
2353 icvSort_32f( resp, respnum, 0 );
2354 val = resp[respnum / 2];
2370 void icvBtNext_MREG( CvCARTClassifier** trees, CvBtTrainer* trainer )
2372 CvCARTClassifier* ptr;
2388 data_size = trainer->m * sizeof( *idx );
2389 idx = (int*) cvAlloc( data_size );
2390 data_size = trainer->m * sizeof( *resp );
2391 resp = (float*) cvAlloc( data_size );
2392 data_size = trainer->m * sizeof( *resid );
2393 resid = (float*) cvAlloc( data_size );
2395 /* resid_i = (y_i - F_(m-1)(x_i)) */
2396 for( i = 0; i < trainer->numsamples; i++ )
2398 index = icvGetIdxAt( trainer->sampleIdx, i );
2399 resid[index] = *((float*) (trainer->ydata + index * trainer->ystep))
2400 - trainer->f[index];
2402 resp[i] = (float) fabs( resid[index] );
2405 /* delta = quantile_alpha{abs(resid_i)} */
2406 icvSort_32f( resp, trainer->numsamples, 0 );
2407 delta = resp[(int)(trainer->param[1] * (trainer->numsamples - 1))];
2410 for( i = 0; i < trainer->numsamples; i++ )
2412 index = icvGetIdxAt( trainer->sampleIdx, i );
2413 trainer->y->data.fl[index] = MIN( delta, ((float) fabs( resid[index] )) ) *
2414 CV_SIGN( resid[index] );
2417 ptr = (CvCARTClassifier*) cvCreateCARTClassifier( trainer->trainData, trainer->flags,
2418 trainer->y, NULL, NULL, NULL, trainer->sampleIdx, trainer->weights,
2419 (CvClassifierTrainParams*) &trainer->cartParams );
2421 CV_GET_SAMPLE( *trainer->trainData, trainer->flags, 0, sample );
2422 CV_GET_SAMPLE_STEP( *trainer->trainData, trainer->flags, sample_step );
2423 sample_data = sample.data.ptr;
2424 for( i = 0; i < trainer->numsamples; i++ )
2426 index = icvGetIdxAt( trainer->sampleIdx, i );
2427 sample.data.ptr = sample_data + index * sample_step;
2428 idx[index] = (int) cvEvalCARTClassifierIdx( (CvClassifier*) ptr, &sample );
2430 for( j = 0; j <= ptr->count; j++ )
2434 for( i = 0; i < trainer->numsamples; i++ )
2436 index = icvGetIdxAt( trainer->sampleIdx, i );
2437 if( idx[index] == j )
2439 resp[respnum++] = *((float*) (trainer->ydata + index * trainer->ystep))
2440 - trainer->f[index];
2445 /* rhat = median(y_i - F_(m-1)(x_i)) */
2446 icvSort_32f( resp, respnum, 0 );
2447 rhat = resp[respnum / 2];
2449 /* val = sum{sign(r_i - rhat_i) * min(delta, abs(r_i - rhat_i)}
2450 * r_i = y_i - F_(m-1)(x_i)
2453 for( i = 0; i < respnum; i++ )
2455 val += CV_SIGN( resp[i] - rhat )
2456 * MIN( delta, (float) fabs( resp[i] - rhat ) );
2459 val = rhat + val / (float) respnum;
2477 //#define CV_VAL_MAX 1e304
2479 //#define CV_LOG_VAL_MAX 700.0
2481 #define CV_VAL_MAX 1e+8
2483 #define CV_LOG_VAL_MAX 18.0
2485 void icvBtNext_L2CLASS( CvCARTClassifier** trees, CvBtTrainer* trainer )
2487 CvCARTClassifier* ptr;
2501 float* sorted_weights;
2507 data_size = trainer->m * sizeof( *idx );
2508 idx = (int*) cvAlloc( data_size );
2510 data_size = trainer->m * sizeof( *weights );
2511 weights = (float*) cvAlloc( data_size );
2512 data_size = trainer->m * sizeof( *sorted_weights );
2513 sorted_weights = (float*) cvAlloc( data_size );
2515 /* yhat_i = (4 * y_i - 2) / ( 1 + exp( (4 * y_i - 2) * F_(m-1)(x_i) ) ).
2519 for( i = 0; i < trainer->numsamples; i++ )
2521 index = icvGetIdxAt( trainer->sampleIdx, i );
2522 val = 4.0F * (*((float*) (trainer->ydata + index * trainer->ystep))) - 2.0F;
2523 val_f = val * trainer->f[index];
2524 val_f = ( val_f < CV_LOG_VAL_MAX ) ? exp( val_f ) : CV_LOG_VAL_MAX;
2525 val = (float) ( (double) val / ( 1.0 + val_f ) );
2526 trainer->y->data.fl[index] = val;
2527 val = (float) fabs( val );
2528 weights[index] = val * (2.0F - val);
2529 sorted_weights[i] = weights[index];
2530 sum_weights += sorted_weights[i];
2534 sample_idx = trainer->sampleIdx;
2535 trimmed_num = trainer->numsamples;
2536 if( trainer->param[1] < 1.0F )
2538 /* perform weight trimming */
2543 icvSort_32f( sorted_weights, trainer->numsamples, 0 );
2545 sum_weights *= (1.0F - trainer->param[1]);
2548 do { sum_weights -= sorted_weights[++i]; }
2549 while( sum_weights > 0.0F && i < (trainer->numsamples - 1) );
2551 threshold = sorted_weights[i];
2553 while( i > 0 && sorted_weights[i-1] == threshold ) i--;
2557 trimmed_num = trainer->numsamples - i;
2558 trimmed_idx = cvCreateMat( 1, trimmed_num, CV_32FC1 );
2560 for( i = 0; i < trainer->numsamples; i++ )
2562 index = icvGetIdxAt( trainer->sampleIdx, i );
2563 if( weights[index] >= threshold )
2565 CV_MAT_ELEM( *trimmed_idx, float, 0, count ) = (float) index;
2570 assert( count == trimmed_num );
2572 sample_idx = trimmed_idx;
2574 printf( "Used samples %%: %g\n",
2575 (float) trimmed_num / (float) trainer->numsamples * 100.0F );
2579 ptr = (CvCARTClassifier*) cvCreateCARTClassifier( trainer->trainData, trainer->flags,
2580 trainer->y, NULL, NULL, NULL, sample_idx, trainer->weights,
2581 (CvClassifierTrainParams*) &trainer->cartParams );
2583 CV_GET_SAMPLE( *trainer->trainData, trainer->flags, 0, sample );
2584 CV_GET_SAMPLE_STEP( *trainer->trainData, trainer->flags, sample_step );
2585 sample_data = sample.data.ptr;
2586 for( i = 0; i < trimmed_num; i++ )
2588 index = icvGetIdxAt( sample_idx, i );
2589 sample.data.ptr = sample_data + index * sample_step;
2590 idx[index] = (int) cvEvalCARTClassifierIdx( (CvClassifier*) ptr, &sample );
2592 for( j = 0; j <= ptr->count; j++ )
2597 for( i = 0; i < trimmed_num; i++ )
2599 index = icvGetIdxAt( sample_idx, i );
2600 if( idx[index] == j )
2602 val += trainer->y->data.fl[index];
2603 sum_weights += weights[index];
2607 if( sum_weights > 0.0F )
2618 if( trimmed_idx != NULL ) cvReleaseMat( &trimmed_idx );
2619 cvFree( &sorted_weights );
2626 void icvBtNext_LKCLASS( CvCARTClassifier** trees, CvBtTrainer* trainer )
2628 int i, j, k, kk, num;
2640 float* sorted_weights;
2649 data_size = trainer->m * sizeof( *idx );
2650 idx = (int*) cvAlloc( data_size );
2651 data_size = trainer->m * sizeof( *weights );
2652 weights = (float*) cvAlloc( data_size );
2653 data_size = trainer->m * sizeof( *sorted_weights );
2654 sorted_weights = (float*) cvAlloc( data_size );
2655 trimmed_idx = cvCreateMat( 1, trainer->numsamples, CV_32FC1 );
2657 for( k = 0; k < trainer->numclasses; k++ )
2659 /* yhat_i = y_i - p_k(x_i), y_i in {0, 1} */
2660 /* p_k(x_i) = exp(f_k(x_i)) / (sum_exp_f(x_i)) */
2662 for( i = 0; i < trainer->numsamples; i++ )
2664 index = icvGetIdxAt( trainer->sampleIdx, i );
2665 /* p_k(x_i) = 1 / (1 + sum(exp(f_kk(x_i) - f_k(x_i)))), kk != k */
2666 num = index * trainer->numclasses;
2667 f_k = (double) trainer->f[num + k];
2669 for( kk = 0; kk < trainer->numclasses; kk++ )
2671 if( kk == k ) continue;
2672 exp_f = (double) trainer->f[num + kk] - f_k;
2673 exp_f = (exp_f < CV_LOG_VAL_MAX) ? exp( exp_f ) : CV_VAL_MAX;
2674 if( exp_f == CV_VAL_MAX || exp_f >= (CV_VAL_MAX - sum_exp_f) )
2676 sum_exp_f = CV_VAL_MAX;
2682 val = (float) ( (*((float*) (trainer->ydata + index * trainer->ystep)))
2684 val -= (float) ( (sum_exp_f == CV_VAL_MAX) ? 0.0 : ( 1.0 / sum_exp_f ) );
2686 assert( val >= -1.0F );
2687 assert( val <= 1.0F );
2689 trainer->y->data.fl[index] = val;
2690 val = (float) fabs( val );
2691 weights[index] = val * (1.0F - val);
2692 sorted_weights[i] = weights[index];
2693 sum_weights += sorted_weights[i];
2696 sample_idx = trainer->sampleIdx;
2697 trimmed_num = trainer->numsamples;
2698 if( trainer->param[1] < 1.0F )
2700 /* perform weight trimming */
2705 icvSort_32f( sorted_weights, trainer->numsamples, 0 );
2707 sum_weights *= (1.0F - trainer->param[1]);
2710 do { sum_weights -= sorted_weights[++i]; }
2711 while( sum_weights > 0.0F && i < (trainer->numsamples - 1) );
2713 threshold = sorted_weights[i];
2715 while( i > 0 && sorted_weights[i-1] == threshold ) i--;
2719 trimmed_num = trainer->numsamples - i;
2720 trimmed_idx->cols = trimmed_num;
2722 for( i = 0; i < trainer->numsamples; i++ )
2724 index = icvGetIdxAt( trainer->sampleIdx, i );
2725 if( weights[index] >= threshold )
2727 CV_MAT_ELEM( *trimmed_idx, float, 0, count ) = (float) index;
2732 assert( count == trimmed_num );
2734 sample_idx = trimmed_idx;
2736 printf( "k: %d Used samples %%: %g\n", k,
2737 (float) trimmed_num / (float) trainer->numsamples * 100.0F );
2739 } /* weight trimming */
2741 trees[k] = (CvCARTClassifier*) cvCreateCARTClassifier( trainer->trainData,
2742 trainer->flags, trainer->y, NULL, NULL, NULL, sample_idx, trainer->weights,
2743 (CvClassifierTrainParams*) &trainer->cartParams );
2745 CV_GET_SAMPLE( *trainer->trainData, trainer->flags, 0, sample );
2746 CV_GET_SAMPLE_STEP( *trainer->trainData, trainer->flags, sample_step );
2747 sample_data = sample.data.ptr;
2748 for( i = 0; i < trimmed_num; i++ )
2750 index = icvGetIdxAt( sample_idx, i );
2751 sample.data.ptr = sample_data + index * sample_step;
2752 idx[index] = (int) cvEvalCARTClassifierIdx( (CvClassifier*) trees[k],
2755 for( j = 0; j <= trees[k]->count; j++ )
2760 for( i = 0; i < trimmed_num; i++ )
2762 index = icvGetIdxAt( sample_idx, i );
2763 if( idx[index] == j )
2765 val += trainer->y->data.fl[index];
2766 sum_weights += weights[index];
2770 if( sum_weights > 0.0F )
2772 val = ((float) (trainer->numclasses - 1)) * val /
2773 ((float) (trainer->numclasses)) / sum_weights;
2779 trees[k]->val[j] = val;
2781 } /* for each class */
2783 cvReleaseMat( &trimmed_idx );
2784 cvFree( &sorted_weights );
2790 void icvBtNext_XXBCLASS( CvCARTClassifier** trees, CvBtTrainer* trainer )
2794 CvMat* weak_eval_vals;
2801 weak_eval_vals = cvCreateMat( 1, trainer->m, CV_32FC1 );
2803 sample_idx = cvTrimWeights( trainer->weights, trainer->sampleIdx,
2804 trainer->param[1] );
2805 num_samples = ( sample_idx == NULL )
2806 ? trainer->m : MAX( sample_idx->rows, sample_idx->cols );
2808 printf( "Used samples %%: %g\n",
2809 (float) num_samples / (float) trainer->numsamples * 100.0F );
2811 trees[0] = (CvCARTClassifier*) cvCreateCARTClassifier( trainer->trainData,
2812 trainer->flags, trainer->y, NULL, NULL, NULL,
2813 sample_idx, trainer->weights,
2814 (CvClassifierTrainParams*) &trainer->cartParams );
2816 /* evaluate samples */
2817 CV_GET_SAMPLE( *trainer->trainData, trainer->flags, 0, sample );
2818 CV_GET_SAMPLE_STEP( *trainer->trainData, trainer->flags, sample_step );
2819 sample_data = sample.data.ptr;
2821 for( i = 0; i < trainer->m; i++ )
2823 sample.data.ptr = sample_data + i * sample_step;
2824 weak_eval_vals->data.fl[i] = trees[0]->eval( (CvClassifier*) trees[0], &sample );
2827 alpha = cvBoostNextWeakClassifier( weak_eval_vals, trainer->trainClasses,
2828 trainer->y, trainer->weights, trainer->boosttrainer );
2830 /* multiply tree by alpha */
2831 for( i = 0; i <= trees[0]->count; i++ )
2833 trees[0]->val[i] *= alpha;
2835 if( trainer->type == CV_RABCLASS )
2837 for( i = 0; i <= trees[0]->count; i++ )
2839 trees[0]->val[i] = cvLogRatio( trees[0]->val[i] );
2843 if( sample_idx != NULL && sample_idx != trainer->sampleIdx )
2845 cvReleaseMat( &sample_idx );
2847 cvReleaseMat( &weak_eval_vals );
2850 typedef void (*CvBtNextFunc)( CvCARTClassifier** trees, CvBtTrainer* trainer );
2852 static CvBtNextFunc icvBtNextFunc[] =
2866 void cvBtNext( CvCARTClassifier** trees, CvBtTrainer* trainer )
2869 CV_FUNCNAME( "cvBtNext" );
2879 icvBtNextFunc[trainer->type]( trees, trainer );
2882 if( trainer->param[0] != 1.0F )
2884 for( j = 0; j < trainer->numclasses; j++ )
2886 for( i = 0; i <= trees[j]->count; i++ )
2888 trees[j]->val[i] *= trainer->param[0];
2893 if( trainer->type > CV_GABCLASS )
2895 /* update F_(m-1) */
2896 CV_GET_SAMPLE( *(trainer->trainData), trainer->flags, 0, sample );
2897 CV_GET_SAMPLE_STEP( *(trainer->trainData), trainer->flags, sample_step );
2898 sample_data = sample.data.ptr;
2899 for( i = 0; i < trainer->numsamples; i++ )
2901 index = icvGetIdxAt( trainer->sampleIdx, i );
2902 sample.data.ptr = sample_data + index * sample_step;
2903 for( j = 0; j < trainer->numclasses; j++ )
2905 trainer->f[index * trainer->numclasses + j] +=
2906 trees[j]->eval( (CvClassifier*) (trees[j]), &sample );
2915 void cvBtEnd( CvBtTrainer** trainer )
2917 CV_FUNCNAME( "cvBtEnd" );
2921 if( trainer == NULL || (*trainer) == NULL )
2923 CV_ERROR( CV_StsNullPtr, "Invalid trainer parameter" );
2926 if( (*trainer)->y != NULL )
2928 CV_CALL( cvReleaseMat( &((*trainer)->y) ) );
2930 if( (*trainer)->weights != NULL )
2932 CV_CALL( cvReleaseMat( &((*trainer)->weights) ) );
2934 if( (*trainer)->boosttrainer != NULL )
2936 CV_CALL( cvBoostEndTraining( &((*trainer)->boosttrainer) ) );
2938 CV_CALL( cvFree( trainer ) );
2943 /****************************************************************************************\
2944 * Boosted tree model as a classifier *
2945 \****************************************************************************************/
2948 float cvEvalBtClassifier( CvClassifier* classifier, CvMat* sample )
2952 CV_FUNCNAME( "cvEvalBtClassifier" );
2959 if( CV_IS_TUNABLE( classifier->flags ) )
2962 CvCARTClassifier* tree;
2964 CV_CALL( cvStartReadSeq( ((CvBtClassifier*) classifier)->seq, &reader ) );
2965 for( i = 0; i < ((CvBtClassifier*) classifier)->numiter; i++ )
2967 CV_READ_SEQ_ELEM( tree, reader );
2968 val += tree->eval( (CvClassifier*) tree, sample );
2973 CvCARTClassifier** ptree;
2975 ptree = ((CvBtClassifier*) classifier)->trees;
2976 for( i = 0; i < ((CvBtClassifier*) classifier)->numiter; i++ )
2978 val += (*ptree)->eval( (CvClassifier*) (*ptree), sample );
2989 float cvEvalBtClassifier2( CvClassifier* classifier, CvMat* sample )
2993 CV_FUNCNAME( "cvEvalBtClassifier2" );
2997 CV_CALL( val = cvEvalBtClassifier( classifier, sample ) );
3001 return (float) (val >= 0.0F);
3005 float cvEvalBtClassifierK( CvClassifier* classifier, CvMat* sample )
3009 CV_FUNCNAME( "cvEvalBtClassifierK" );
3020 numclasses = ((CvBtClassifier*) classifier)->numclasses;
3021 data_size = sizeof( *vals ) * numclasses;
3022 CV_CALL( vals = (float*) cvAlloc( data_size ) );
3023 memset( vals, 0, data_size );
3025 if( CV_IS_TUNABLE( classifier->flags ) )
3028 CvCARTClassifier* tree;
3030 CV_CALL( cvStartReadSeq( ((CvBtClassifier*) classifier)->seq, &reader ) );
3031 for( i = 0; i < ((CvBtClassifier*) classifier)->numiter; i++ )
3033 for( k = 0; k < numclasses; k++ )
3035 CV_READ_SEQ_ELEM( tree, reader );
3036 vals[k] += tree->eval( (CvClassifier*) tree, sample );
3043 CvCARTClassifier** ptree;
3045 ptree = ((CvBtClassifier*) classifier)->trees;
3046 for( i = 0; i < ((CvBtClassifier*) classifier)->numiter; i++ )
3048 for( k = 0; k < numclasses; k++ )
3050 vals[k] += (*ptree)->eval( (CvClassifier*) (*ptree), sample );
3057 max_val = vals[cls];
3058 for( k = 1; k < numclasses; k++ )
3060 if( vals[k] > max_val )
3067 CV_CALL( cvFree( &vals ) );
3074 typedef float (*CvEvalBtClassifier)( CvClassifier* classifier, CvMat* sample );
3076 static CvEvalBtClassifier icvEvalBtClassifier[] =
3078 cvEvalBtClassifier2,
3079 cvEvalBtClassifier2,
3080 cvEvalBtClassifier2,
3081 cvEvalBtClassifier2,
3082 cvEvalBtClassifier2,
3083 cvEvalBtClassifierK,
3090 int cvSaveBtClassifier( CvClassifier* classifier, const char* filename )
3092 CV_FUNCNAME( "cvSaveBtClassifier" );
3099 CvCARTClassifier* tree;
3101 CV_ASSERT( classifier );
3102 CV_ASSERT( filename );
3104 if( !icvMkDir( filename ) || !(file = fopen( filename, "w" )) )
3106 CV_ERROR( CV_StsError, "Unable to create file" );
3109 if( CV_IS_TUNABLE( classifier->flags ) )
3111 CV_CALL( cvStartReadSeq( ((CvBtClassifier*) classifier)->seq, &reader ) );
3113 fprintf( file, "%d %d\n%d\n%d\n", (int) ((CvBtClassifier*) classifier)->type,
3114 ((CvBtClassifier*) classifier)->numclasses,
3115 ((CvBtClassifier*) classifier)->numfeatures,
3116 ((CvBtClassifier*) classifier)->numiter );
3118 for( i = 0; i < ((CvBtClassifier*) classifier)->numclasses *
3119 ((CvBtClassifier*) classifier)->numiter; i++ )
3121 if( CV_IS_TUNABLE( classifier->flags ) )
3123 CV_READ_SEQ_ELEM( tree, reader );
3127 tree = ((CvBtClassifier*) classifier)->trees[i];
3130 fprintf( file, "%d\n", tree->count );
3131 for( j = 0; j < tree->count; j++ )
3133 fprintf( file, "%d %g %d %d\n", tree->compidx[j],
3138 for( j = 0; j <= tree->count; j++ )
3140 fprintf( file, "%g ", tree->val[j] );
3142 fprintf( file, "\n" );
3154 void cvReleaseBtClassifier( CvClassifier** ptr )
3156 CV_FUNCNAME( "cvReleaseBtClassifier" );
3162 if( ptr == NULL || *ptr == NULL )
3164 CV_ERROR( CV_StsNullPtr, "" );
3166 if( CV_IS_TUNABLE( (*ptr)->flags ) )
3169 CvCARTClassifier* tree;
3171 CV_CALL( cvStartReadSeq( ((CvBtClassifier*) *ptr)->seq, &reader ) );
3172 for( i = 0; i < ((CvBtClassifier*) *ptr)->numclasses *
3173 ((CvBtClassifier*) *ptr)->numiter; i++ )
3175 CV_READ_SEQ_ELEM( tree, reader );
3176 tree->release( (CvClassifier**) (&tree) );
3178 CV_CALL( cvReleaseMemStorage( &(((CvBtClassifier*) *ptr)->seq->storage) ) );
3182 CvCARTClassifier** ptree;
3184 ptree = ((CvBtClassifier*) *ptr)->trees;
3185 for( i = 0; i < ((CvBtClassifier*) *ptr)->numclasses *
3186 ((CvBtClassifier*) *ptr)->numiter; i++ )
3188 (*ptree)->release( (CvClassifier**) ptree );
3193 CV_CALL( cvFree( ptr ) );
3199 void cvTuneBtClassifier( CvClassifier* classifier, CvMat*, int flags,
3200 CvMat*, CvMat* , CvMat*, CvMat*, CvMat* )
3202 CV_FUNCNAME( "cvTuneBtClassifier" );
3208 if( CV_IS_TUNABLE( flags ) )
3210 if( !CV_IS_TUNABLE( classifier->flags ) )
3212 CV_ERROR( CV_StsUnsupportedFormat,
3213 "Classifier does not support tune function" );
3217 /* tune classifier */
3218 CvCARTClassifier** trees;
3220 printf( "Iteration %d\n", ((CvBtClassifier*) classifier)->numiter + 1 );
3222 data_size = sizeof( *trees ) * ((CvBtClassifier*) classifier)->numclasses;
3223 CV_CALL( trees = (CvCARTClassifier**) cvAlloc( data_size ) );
3224 CV_CALL( cvBtNext( trees,
3225 (CvBtTrainer*) ((CvBtClassifier*) classifier)->trainer ) );
3226 CV_CALL( cvSeqPushMulti( ((CvBtClassifier*) classifier)->seq,
3227 trees, ((CvBtClassifier*) classifier)->numclasses ) );
3228 CV_CALL( cvFree( &trees ) );
3229 ((CvBtClassifier*) classifier)->numiter++;
3234 if( CV_IS_TUNABLE( classifier->flags ) )
3239 assert( ((CvBtClassifier*) classifier)->seq->total ==
3240 ((CvBtClassifier*) classifier)->numiter *
3241 ((CvBtClassifier*) classifier)->numclasses );
3243 data_size = sizeof( ((CvBtClassifier*) classifier)->trees[0] ) *
3244 ((CvBtClassifier*) classifier)->seq->total;
3245 CV_CALL( ptr = cvAlloc( data_size ) );
3246 CV_CALL( cvCvtSeqToArray( ((CvBtClassifier*) classifier)->seq, ptr ) );
3247 CV_CALL( cvReleaseMemStorage(
3248 &(((CvBtClassifier*) classifier)->seq->storage) ) );
3249 ((CvBtClassifier*) classifier)->trees = (CvCARTClassifier**) ptr;
3250 classifier->flags &= ~CV_TUNABLE;
3251 CV_CALL( cvBtEnd( (CvBtTrainer**)
3252 &(((CvBtClassifier*) classifier)->trainer )) );
3253 ((CvBtClassifier*) classifier)->trainer = NULL;
3260 CvBtClassifier* icvAllocBtClassifier( CvBoostType type, int flags, int numclasses,
3263 CvBtClassifier* ptr;
3266 assert( numclasses >= 1 );
3267 assert( numiter >= 0 );
3268 assert( ( numclasses == 1 ) || (type == CV_LKCLASS) );
3270 data_size = sizeof( *ptr );
3271 ptr = (CvBtClassifier*) cvAlloc( data_size );
3272 memset( ptr, 0, data_size );
3274 if( CV_IS_TUNABLE( flags ) )
3276 ptr->seq = cvCreateSeq( 0, sizeof( *(ptr->seq) ), sizeof( *(ptr->trees) ),
3277 cvCreateMemStorage() );
3282 data_size = numclasses * numiter * sizeof( *(ptr->trees) );
3283 ptr->trees = (CvCARTClassifier**) cvAlloc( data_size );
3284 memset( ptr->trees, 0, data_size );
3286 ptr->numiter = numiter;
3290 ptr->numclasses = numclasses;
3293 ptr->eval = icvEvalBtClassifier[(int) type];
3294 ptr->tune = cvTuneBtClassifier;
3295 ptr->save = cvSaveBtClassifier;
3296 ptr->release = cvReleaseBtClassifier;
3302 CvClassifier* cvCreateBtClassifier( CvMat* trainData,
3304 CvMat* trainClasses,
3306 CvMat* missedMeasurementsMask,
3310 CvClassifierTrainParams* trainParams )
3312 CvBtClassifier* ptr;
3314 CV_FUNCNAME( "cvCreateBtClassifier" );
3321 CvCARTClassifier** trees;
3324 CV_ASSERT( trainData != NULL );
3325 CV_ASSERT( trainClasses != NULL );
3326 CV_ASSERT( typeMask == NULL );
3327 CV_ASSERT( missedMeasurementsMask == NULL );
3328 CV_ASSERT( compIdx == NULL );
3329 CV_ASSERT( weights == NULL );
3330 CV_ASSERT( trainParams != NULL );
3332 type = ((CvBtClassifierTrainParams*) trainParams)->type;
3334 if( type >= CV_DABCLASS && type <= CV_GABCLASS && sampleIdx )
3336 CV_ERROR( CV_StsBadArg, "Sample indices are not supported for this type" );
3339 if( type == CV_LKCLASS )
3344 cvMinMaxLoc( trainClasses, &min_val, &max_val );
3345 num_classes = (int) (max_val + 1.0);
3347 CV_ASSERT( num_classes >= 2 );
3353 num_iter = ((CvBtClassifierTrainParams*) trainParams)->numiter;
3355 CV_ASSERT( num_iter > 0 );
3357 ptr = icvAllocBtClassifier( type, CV_TUNABLE | flags, num_classes, num_iter );
3358 ptr->numfeatures = (CV_IS_ROW_SAMPLE( flags )) ? trainData->cols : trainData->rows;
3362 printf( "Iteration %d\n", 1 );
3364 data_size = sizeof( *trees ) * ptr->numclasses;
3365 CV_CALL( trees = (CvCARTClassifier**) cvAlloc( data_size ) );
3367 CV_CALL( ptr->trainer = cvBtStart( trees, trainData, flags, trainClasses, sampleIdx,
3368 ((CvBtClassifierTrainParams*) trainParams)->numsplits, type, num_classes,
3369 &(((CvBtClassifierTrainParams*) trainParams)->param[0]) ) );
3371 CV_CALL( cvSeqPushMulti( ptr->seq, trees, ptr->numclasses ) );
3372 CV_CALL( cvFree( &trees ) );
3375 for( i = 1; i < num_iter; i++ )
3377 ptr->tune( (CvClassifier*) ptr, NULL, CV_TUNABLE, NULL, NULL, NULL, NULL, NULL );
3379 if( !CV_IS_TUNABLE( flags ) )
3382 ptr->tune( (CvClassifier*) ptr, NULL, 0, NULL, NULL, NULL, NULL, NULL );
3387 return (CvClassifier*) ptr;
3391 CvClassifier* cvCreateBtClassifierFromFile( const char* filename )
3393 CvBtClassifier* ptr;
3395 CV_FUNCNAME( "cvCreateBtClassifierFromFile" );
3402 int num_classifiers;
3407 CV_ASSERT( filename != NULL );
3410 file = fopen( filename, "r" );
3413 CV_ERROR( CV_StsError, "Unable to open file" );
3416 fscanf( file, "%d %d %d %d", &type, &num_classes, &num_features, &num_classifiers );
3418 CV_ASSERT( type >= (int) CV_DABCLASS && type <= (int) CV_MREG );
3419 CV_ASSERT( num_features > 0 );
3420 CV_ASSERT( num_classifiers > 0 );
3422 if( (CvBoostType) type != CV_LKCLASS )
3426 ptr = icvAllocBtClassifier( (CvBoostType) type, 0, num_classes, num_classifiers );
3427 ptr->numfeatures = num_features;
3429 for( i = 0; i < num_classes * num_classifiers; i++ )
3432 CvCARTClassifier* tree;
3434 fscanf( file, "%d", &count );
3436 data_size = sizeof( *tree )
3437 + count * ( sizeof( *(tree->compidx) ) + sizeof( *(tree->threshold) ) +
3438 sizeof( *(tree->right) ) + sizeof( *(tree->left) ) )
3439 + (count + 1) * ( sizeof( *(tree->val) ) );
3440 CV_CALL( tree = (CvCARTClassifier*) cvAlloc( data_size ) );
3441 memset( tree, 0, data_size );
3442 tree->eval = cvEvalCARTClassifier;
3445 tree->release = cvReleaseCARTClassifier;
3446 tree->compidx = (int*) ( tree + 1 );
3447 tree->threshold = (float*) ( tree->compidx + count );
3448 tree->left = (int*) ( tree->threshold + count );
3449 tree->right = (int*) ( tree->left + count );
3450 tree->val = (float*) ( tree->right + count );
3452 tree->count = count;
3453 for( j = 0; j < tree->count; j++ )
3455 fscanf( file, "%d %g %d %d", &(tree->compidx[j]),
3456 &(tree->threshold[j]),
3458 &(tree->right[j]) );
3460 for( j = 0; j <= tree->count; j++ )
3462 fscanf( file, "%g", &(tree->val[j]) );
3464 ptr->trees[i] = tree;
3471 return (CvClassifier*) ptr;
3474 /****************************************************************************************\
3475 * Utility functions *
3476 \****************************************************************************************/
3479 CvMat* cvTrimWeights( CvMat* weights, CvMat* idx, float factor )
3483 CV_FUNCNAME( "cvTrimWeights" );
3492 float* sorted_weights;
3494 CV_ASSERT( CV_MAT_TYPE( weights->type ) == CV_32FC1 );
3497 sorted_weights = NULL;
3499 if( factor > 0.0F && factor < 1.0F )
3503 CV_MAT2VEC( *weights, wdata, wstep, wnum );
3504 num = ( idx == NULL ) ? wnum : MAX( idx->rows, idx->cols );
3506 data_size = num * sizeof( *sorted_weights );
3507 sorted_weights = (float*) cvAlloc( data_size );
3508 memset( sorted_weights, 0, data_size );
3511 for( i = 0; i < num; i++ )
3513 index = icvGetIdxAt( idx, i );
3514 sorted_weights[i] = *((float*) (wdata + index * wstep));
3515 sum_weights += sorted_weights[i];
3518 icvSort_32f( sorted_weights, num, 0 );
3520 sum_weights *= (1.0F - factor);
3523 do { sum_weights -= sorted_weights[++i]; }
3524 while( sum_weights > 0.0F && i < (num - 1) );
3526 threshold = sorted_weights[i];
3528 while( i > 0 && sorted_weights[i-1] == threshold ) i--;
3530 if( i > 0 || ( idx != NULL && CV_MAT_TYPE( idx->type ) != CV_32FC1 ) )
3532 CV_CALL( ptr = cvCreateMat( 1, num - i, CV_32FC1 ) );
3534 for( i = 0; i < num; i++ )
3536 index = icvGetIdxAt( idx, i );
3537 if( *((float*) (wdata + index * wstep)) >= threshold )
3539 CV_MAT_ELEM( *ptr, float, 0, count ) = (float) index;
3544 assert( count == ptr->cols );
3546 cvFree( &sorted_weights );
3556 void cvReadTrainData( const char* filename, int flags,
3558 CvMat** trainClasses )
3561 CV_FUNCNAME( "cvReadTrainData" );
3570 if( filename == NULL )
3572 CV_ERROR( CV_StsNullPtr, "filename must be specified" );
3574 if( trainData == NULL )
3576 CV_ERROR( CV_StsNullPtr, "trainData must be not NULL" );
3578 if( trainClasses == NULL )
3580 CV_ERROR( CV_StsNullPtr, "trainClasses must be not NULL" );
3584 *trainClasses = NULL;
3585 file = fopen( filename, "r" );
3588 CV_ERROR( CV_StsError, "Unable to open file" );
3591 fscanf( file, "%d %d", &m, &n );
3593 if( CV_IS_ROW_SAMPLE( flags ) )
3595 CV_CALL( *trainData = cvCreateMat( m, n, CV_32FC1 ) );
3599 CV_CALL( *trainData = cvCreateMat( n, m, CV_32FC1 ) );
3602 CV_CALL( *trainClasses = cvCreateMat( 1, m, CV_32FC1 ) );
3604 for( i = 0; i < m; i++ )
3606 for( j = 0; j < n; j++ )
3608 fscanf( file, "%f", &val );
3609 if( CV_IS_ROW_SAMPLE( flags ) )
3611 CV_MAT_ELEM( **trainData, float, i, j ) = val;
3615 CV_MAT_ELEM( **trainData, float, j, i ) = val;
3618 fscanf( file, "%f", &val );
3619 CV_MAT_ELEM( **trainClasses, float, 0, i ) = val;
3629 void cvWriteTrainData( const char* filename, int flags,
3630 CvMat* trainData, CvMat* trainClasses, CvMat* sampleIdx )
3632 CV_FUNCNAME( "cvWriteTrainData" );
3644 if( filename == NULL )
3646 CV_ERROR( CV_StsNullPtr, "filename must be specified" );
3648 if( trainData == NULL || CV_MAT_TYPE( trainData->type ) != CV_32FC1 )
3650 CV_ERROR( CV_StsUnsupportedFormat, "Invalid trainData" );
3652 if( CV_IS_ROW_SAMPLE( flags ) )
3654 m = trainData->rows;
3655 n = trainData->cols;
3659 n = trainData->rows;
3660 m = trainData->cols;
3662 if( trainClasses == NULL || CV_MAT_TYPE( trainClasses->type ) != CV_32FC1 ||
3663 MIN( trainClasses->rows, trainClasses->cols ) != 1 )
3665 CV_ERROR( CV_StsUnsupportedFormat, "Invalid trainClasses" );
3667 clsrow = (trainClasses->rows == 1);
3668 if( m != ( (clsrow) ? trainClasses->cols : trainClasses->rows ) )
3670 CV_ERROR( CV_StsUnmatchedSizes, "Incorrect trainData and trainClasses sizes" );
3673 if( sampleIdx != NULL )
3675 count = (sampleIdx->rows == 1) ? sampleIdx->cols : sampleIdx->rows;
3683 file = fopen( filename, "w" );
3686 CV_ERROR( CV_StsError, "Unable to create file" );
3689 fprintf( file, "%d %d\n", count, n );
3691 for( i = 0; i < count; i++ )
3695 if( sampleIdx->rows == 1 )
3697 sc = cvGet2D( sampleIdx, 0, i );
3701 sc = cvGet2D( sampleIdx, i, 0 );
3703 idx = (int) sc.val[0];
3709 for( j = 0; j < n; j++ )
3711 fprintf( file, "%g ", ( (CV_IS_ROW_SAMPLE( flags ))
3712 ? CV_MAT_ELEM( *trainData, float, idx, j )
3713 : CV_MAT_ELEM( *trainData, float, j, idx ) ) );
3715 fprintf( file, "%g\n", ( (clsrow)
3716 ? CV_MAT_ELEM( *trainClasses, float, 0, idx )
3717 : CV_MAT_ELEM( *trainClasses, float, idx, 0 ) ) );
3726 #define ICV_RAND_SHUFFLE( suffix, type ) \
3727 void icvRandShuffle_##suffix( uchar* data, size_t step, int num ) \
3729 CvRandState state; \
3737 cvRandInit( &state, (double) 0, (double) 0, (int)seed ); \
3738 for( i = 0; i < (num-1); i++ ) \
3740 rn = ((float) cvRandNext( &state )) / (1.0F + UINT_MAX); \
3741 CV_SWAP( *((type*)(data + i * step)), \
3742 *((type*)(data + ( i + (int)( rn * (num - i ) ) )* step)), \
3747 ICV_RAND_SHUFFLE( 8U, uchar )
3749 ICV_RAND_SHUFFLE( 16S, short )
3751 ICV_RAND_SHUFFLE( 32S, int )
3753 ICV_RAND_SHUFFLE( 32F, float )
3756 void cvRandShuffleVec( CvMat* mat )
3758 CV_FUNCNAME( "cvRandShuffle" );
3766 if( (mat == NULL) || !CV_IS_MAT( mat ) || MIN( mat->rows, mat->cols ) != 1 )
3768 CV_ERROR( CV_StsUnsupportedFormat, "" );
3771 CV_MAT2VEC( *mat, data, step, num );
3772 switch( CV_MAT_TYPE( mat->type ) )
3775 icvRandShuffle_8U( data, step, num);
3778 icvRandShuffle_16S( data, step, num);
3781 icvRandShuffle_32S( data, step, num);
3784 icvRandShuffle_32F( data, step, num);
3787 CV_ERROR( CV_StsUnsupportedFormat, "" );