Update the changelog
[opencv] / ml / src / mltree.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
14 //
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
17 //
18 //   * Redistribution's of source code must retain the above copyright notice,
19 //     this list of conditions and the following disclaimer.
20 //
21 //   * Redistribution's in binary form must reproduce the above copyright notice,
22 //     this list of conditions and the following disclaimer in the documentation
23 //     and/or other materials provided with the distribution.
24 //
25 //   * The name of Intel Corporation may not be used to endorse or promote products
26 //     derived from this software without specific prior written permission.
27 //
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
38 //
39 //M*/
40
41 #include "_ml.h"
42
43 static const float ord_nan = FLT_MAX*0.5f;
44 static const int min_block_size = 1 << 16;
45 static const int block_size_delta = 1 << 10;
46
47 CvDTreeTrainData::CvDTreeTrainData()
48 {
49     var_idx = var_type = cat_count = cat_ofs = cat_map =
50         priors = priors_mult = counts = buf = direction = split_buf = 0;
51     tree_storage = temp_storage = 0;
52
53     clear();
54 }
55
56
57 CvDTreeTrainData::CvDTreeTrainData( const CvMat* _train_data, int _tflag,
58                       const CvMat* _responses, const CvMat* _var_idx,
59                       const CvMat* _sample_idx, const CvMat* _var_type,
60                       const CvMat* _missing_mask, const CvDTreeParams& _params,
61                       bool _shared, bool _add_labels )
62 {
63     var_idx = var_type = cat_count = cat_ofs = cat_map =
64         priors = priors_mult = counts = buf = direction = split_buf = 0;
65     tree_storage = temp_storage = 0;
66     
67     set_data( _train_data, _tflag, _responses, _var_idx, _sample_idx,
68               _var_type, _missing_mask, _params, _shared, _add_labels );
69 }
70
71
72 CvDTreeTrainData::~CvDTreeTrainData()
73 {
74     clear();
75 }
76
77
78 bool CvDTreeTrainData::set_params( const CvDTreeParams& _params )
79 {
80     bool ok = false;
81     
82     CV_FUNCNAME( "CvDTreeTrainData::set_params" );
83
84     __BEGIN__;
85
86     // set parameters
87     params = _params;
88
89     if( params.max_categories < 2 )
90         CV_ERROR( CV_StsOutOfRange, "params.max_categories should be >= 2" );
91     params.max_categories = MIN( params.max_categories, 15 );
92
93     if( params.max_depth < 0 )
94         CV_ERROR( CV_StsOutOfRange, "params.max_depth should be >= 0" );
95     params.max_depth = MIN( params.max_depth, 25 );
96
97     params.min_sample_count = MAX(params.min_sample_count,1);
98
99     if( params.cv_folds < 0 )
100         CV_ERROR( CV_StsOutOfRange,
101         "params.cv_folds should be =0 (the tree is not pruned) "
102         "or n>0 (tree is pruned using n-fold cross-validation)" );
103
104     if( params.cv_folds == 1 )
105         params.cv_folds = 0;
106
107     if( params.regression_accuracy < 0 )
108         CV_ERROR( CV_StsOutOfRange, "params.regression_accuracy should be >= 0" );
109
110     ok = true;
111
112     __END__;
113
114     return ok;
115 }
116
117
118 #define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))
119 static CV_IMPLEMENT_QSORT_EX( icvSortIntPtr, int*, CV_CMP_NUM_PTR, int )
120 static CV_IMPLEMENT_QSORT_EX( icvSortDblPtr, double*, CV_CMP_NUM_PTR, int )
121
122 #define CV_CMP_PAIRS(a,b) ((a).val < (b).val)
123 static CV_IMPLEMENT_QSORT_EX( icvSortPairs, CvPair32s32f, CV_CMP_PAIRS, int )
124
125 void CvDTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
126     const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,
127     const CvMat* _var_type, const CvMat* _missing_mask, const CvDTreeParams& _params,
128     bool _shared, bool _add_labels, bool _update_data )
129 {
130     CvMat* sample_idx = 0;
131     CvMat* var_type0 = 0;
132     CvMat* tmp_map = 0;
133     int** int_ptr = 0;
134     CvDTreeTrainData* data = 0;
135
136     CV_FUNCNAME( "CvDTreeTrainData::set_data" );
137
138     __BEGIN__;
139
140     int sample_all = 0, r_type = 0, cv_n;
141     int total_c_count = 0;
142     int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;
143     int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step
144     int vi, i;
145     char err[100];
146     const int *sidx = 0, *vidx = 0;
147
148     if( _update_data && data_root )
149     {
150         data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx,
151             _sample_idx, _var_type, _missing_mask, _params, _shared, _add_labels );
152         
153         // compare new and old train data
154         if( !(data->var_count == var_count &&
155             cvNorm( data->var_type, var_type, CV_C ) < FLT_EPSILON &&
156             cvNorm( data->cat_count, cat_count, CV_C ) < FLT_EPSILON &&
157             cvNorm( data->cat_map, cat_map, CV_C ) < FLT_EPSILON) )
158             CV_ERROR( CV_StsBadArg,
159             "The new training data must have the same types and the input and output variables "
160             "and the same categories for categorical variables" );
161
162         cvReleaseMat( &priors );
163         cvReleaseMat( &priors_mult );
164         cvReleaseMat( &buf );
165         cvReleaseMat( &direction );
166         cvReleaseMat( &split_buf );
167         cvReleaseMemStorage( &temp_storage );
168
169         priors = data->priors; data->priors = 0;
170         priors_mult = data->priors_mult; data->priors_mult = 0;
171         buf = data->buf; data->buf = 0;
172         buf_count = data->buf_count; buf_size = data->buf_size;
173         sample_count = data->sample_count;
174
175         direction = data->direction; data->direction = 0;
176         split_buf = data->split_buf; data->split_buf = 0;
177         temp_storage = data->temp_storage; data->temp_storage = 0;
178         nv_heap = data->nv_heap; cv_heap = data->cv_heap;
179
180         data_root = new_node( 0, sample_count, 0, 0 );
181         EXIT;
182     }
183
184     clear();
185
186     var_all = 0;
187     rng = cvRNG(-1);
188
189     CV_CALL( set_params( _params ));
190
191     // check parameter types and sizes
192     CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all ));
193     if( _tflag == CV_ROW_SAMPLE )
194     {
195         ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
196         dv_step = 1;
197         if( _missing_mask )
198             ms_step = _missing_mask->step, mv_step = 1;
199     }
200     else
201     {
202         dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
203         ds_step = 1;
204         if( _missing_mask )
205             mv_step = _missing_mask->step, ms_step = 1;
206     }
207
208     sample_count = sample_all;
209     var_count = var_all;
210
211     if( _sample_idx )
212     {
213         CV_CALL( sample_idx = cvPreprocessIndexArray( _sample_idx, sample_all ));
214         sidx = sample_idx->data.i;
215         sample_count = sample_idx->rows + sample_idx->cols - 1;
216     }
217
218     if( _var_idx )
219     {
220         CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all ));
221         vidx = var_idx->data.i;
222         var_count = var_idx->rows + var_idx->cols - 1;
223     }
224
225     if( !CV_IS_MAT(_responses) ||
226         (CV_MAT_TYPE(_responses->type) != CV_32SC1 &&
227          CV_MAT_TYPE(_responses->type) != CV_32FC1) ||
228         _responses->rows != 1 && _responses->cols != 1 ||
229         _responses->rows + _responses->cols - 1 != sample_all )
230         CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "
231                   "floating-point vector containing as many elements as "
232                   "the total number of samples in the training data matrix" );
233
234     CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_all, &r_type ));
235     CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 ));
236
237     cat_var_count = 0;
238     ord_var_count = -1;
239
240     is_classifier = r_type == CV_VAR_CATEGORICAL;
241
242     // step 0. calc the number of categorical vars
243     for( vi = 0; vi < var_count; vi++ )
244     {
245         var_type->data.i[vi] = var_type0->data.ptr[vi] == CV_VAR_CATEGORICAL ?
246             cat_var_count++ : ord_var_count--;
247     }
248
249     ord_var_count = ~ord_var_count;
250     cv_n = params.cv_folds;
251     // set the two last elements of var_type array to be able
252     // to locate responses and cross-validation labels using
253     // the corresponding get_* functions.
254     var_type->data.i[var_count] = cat_var_count;
255     var_type->data.i[var_count+1] = cat_var_count+1;
256
257     // in case of single ordered predictor we need dummy cv_labels
258     // for safe split_node_data() operation
259     have_labels = cv_n > 0 || ord_var_count == 1 && cat_var_count == 0 || _add_labels;
260
261     buf_size = (ord_var_count + get_work_var_count())*sample_count + 2;
262     shared = _shared;
263     buf_count = shared ? 3 : 2;
264     CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_32SC1 ));
265     CV_CALL( cat_count = cvCreateMat( 1, cat_var_count+1, CV_32SC1 ));
266     CV_CALL( cat_ofs = cvCreateMat( 1, cat_count->cols+1, CV_32SC1 ));
267     CV_CALL( cat_map = cvCreateMat( 1, cat_count->cols*10 + 128, CV_32SC1 ));
268
269     // now calculate the maximum size of split,
270     // create memory storage that will keep nodes and splits of the decision tree
271     // allocate root node and the buffer for the whole training data
272     max_split_size = cvAlign(sizeof(CvDTreeSplit) +
273         (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*));
274     tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
275     tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
276     CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
277     CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage ));
278
279     nv_size = var_count*sizeof(int);
280     nv_size = MAX( nv_size, (int)sizeof(CvSetElem) );
281
282     temp_block_size = nv_size;
283     
284     if( cv_n )
285     {
286         if( sample_count < cv_n*MAX(params.min_sample_count,10) )
287             CV_ERROR( CV_StsOutOfRange,
288                 "The many folds in cross-validation for such a small dataset" );
289
290         cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) );
291         temp_block_size = MAX(temp_block_size, cv_size);
292     }
293
294     temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size );
295     CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size ));
296     CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage ));
297     if( cv_size )
298         CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage ));
299
300     CV_CALL( data_root = new_node( 0, sample_count, 0, 0 ));
301     CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) ));
302
303     max_c_count = 1;
304
305     // transform the training data to convenient representation
306     for( vi = 0; vi <= var_count; vi++ )
307     {
308         int ci;
309         const uchar* mask = 0;
310         int m_step = 0, step;
311         const int* idata = 0;
312         const float* fdata = 0;
313         int num_valid = 0;
314
315         if( vi < var_count ) // analyze i-th input variable
316         {
317             int vi0 = vidx ? vidx[vi] : vi;
318             ci = get_var_type(vi);
319             step = ds_step; m_step = ms_step;
320             if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 )
321                 idata = _train_data->data.i + vi0*dv_step;
322             else
323                 fdata = _train_data->data.fl + vi0*dv_step;
324             if( _missing_mask )
325                 mask = _missing_mask->data.ptr + vi0*mv_step;
326         }
327         else // analyze _responses
328         {
329             ci = cat_var_count;
330             step = CV_IS_MAT_CONT(_responses->type) ?
331                 1 : _responses->step / CV_ELEM_SIZE(_responses->type);
332             if( CV_MAT_TYPE(_responses->type) == CV_32SC1 )
333                 idata = _responses->data.i;
334             else
335                 fdata = _responses->data.fl;
336         }
337
338         if( vi < var_count && ci >= 0 ||
339             vi == var_count && is_classifier ) // process categorical variable or response
340         {
341             int c_count, prev_label;
342             int* c_map, *dst = get_cat_var_data( data_root, vi );
343
344             // copy data
345             for( i = 0; i < sample_count; i++ )
346             {
347                 int val = INT_MAX, si = sidx ? sidx[i] : i;
348                 if( !mask || !mask[si*m_step] )
349                 {
350                     if( idata )
351                         val = idata[si*step];
352                     else
353                     {
354                         float t = fdata[si*step];
355                         val = cvRound(t);
356                         if( val != t )
357                         {
358                             sprintf( err, "%d-th value of %d-th (categorical) "
359                                 "variable is not an integer", i, vi );
360                             CV_ERROR( CV_StsBadArg, err );
361                         }
362                     }
363
364                     if( val == INT_MAX )
365                     {
366                         sprintf( err, "%d-th value of %d-th (categorical) "
367                             "variable is too large", i, vi );
368                         CV_ERROR( CV_StsBadArg, err );
369                     }
370                     num_valid++;
371                 }
372                 dst[i] = val;
373                 int_ptr[i] = dst + i;
374             }
375
376             // sort all the values, including the missing measurements
377             // that should all move to the end
378             icvSortIntPtr( int_ptr, sample_count, 0 );
379             //qsort( int_ptr, sample_count, sizeof(int_ptr[0]), icvCmpIntPtr );
380
381             c_count = num_valid > 0;
382
383             // count the categories
384             for( i = 1; i < num_valid; i++ )
385                 c_count += *int_ptr[i] != *int_ptr[i-1];
386
387             if( vi > 0 )
388                 max_c_count = MAX( max_c_count, c_count );
389             cat_count->data.i[ci] = c_count;
390             cat_ofs->data.i[ci] = total_c_count;
391
392             // resize cat_map, if need
393             if( cat_map->cols < total_c_count + c_count )
394             {
395                 tmp_map = cat_map;
396                 CV_CALL( cat_map = cvCreateMat( 1,
397                     MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 ));
398                 for( i = 0; i < total_c_count; i++ )
399                     cat_map->data.i[i] = tmp_map->data.i[i];
400                 cvReleaseMat( &tmp_map );
401             }
402
403             c_map = cat_map->data.i + total_c_count;
404             total_c_count += c_count;
405
406             // compact the class indices and build the map
407             prev_label = ~*int_ptr[0];
408             c_count = -1;
409
410             for( i = 0; i < num_valid; i++ )
411             {
412                 int cur_label = *int_ptr[i];
413                 if( cur_label != prev_label )
414                     c_map[++c_count] = prev_label = cur_label;
415                 *int_ptr[i] = c_count;
416             }
417
418             // replace labels for missing values with -1
419             for( ; i < sample_count; i++ )
420                 *int_ptr[i] = -1;
421         }
422         else if( ci < 0 ) // process ordered variable
423         {
424             CvPair32s32f* dst = get_ord_var_data( data_root, vi );
425
426             for( i = 0; i < sample_count; i++ )
427             {
428                 float val = ord_nan;
429                 int si = sidx ? sidx[i] : i;
430                 if( !mask || !mask[si*m_step] )
431                 {
432                     if( idata )
433                         val = (float)idata[si*step];
434                     else
435                         val = fdata[si*step];
436
437                     if( fabs(val) >= ord_nan )
438                     {
439                         sprintf( err, "%d-th value of %d-th (ordered) "
440                             "variable (=%g) is too large", i, vi, val );
441                         CV_ERROR( CV_StsBadArg, err );
442                     }
443                     num_valid++;
444                 }
445                 dst[i].i = i;
446                 dst[i].val = val;
447             }
448
449             icvSortPairs( dst, sample_count, 0 );
450         }
451         else // special case: process ordered response,
452              // it will be stored similarly to categorical vars (i.e. no pairs)
453         {
454             float* dst = get_ord_responses( data_root );
455
456             for( i = 0; i < sample_count; i++ )
457             {
458                 float val = ord_nan;
459                 int si = sidx ? sidx[i] : i;
460                 if( idata )
461                     val = (float)idata[si*step];
462                 else
463                     val = fdata[si*step];
464
465                 if( fabs(val) >= ord_nan )
466                 {
467                     sprintf( err, "%d-th value of %d-th (ordered) "
468                         "variable (=%g) is out of range", i, vi, val );
469                     CV_ERROR( CV_StsBadArg, err );
470                 }
471                 dst[i] = val;
472             }
473
474             cat_count->data.i[cat_var_count] = 0;
475             cat_ofs->data.i[cat_var_count] = total_c_count;
476             num_valid = sample_count;
477         }
478
479         if( vi < var_count )
480             data_root->set_num_valid(vi, num_valid);
481     }
482
483     if( cv_n )
484     {
485         int* dst = get_labels(data_root);
486         CvRNG* r = &rng;
487
488         for( i = vi = 0; i < sample_count; i++ )
489         {
490             dst[i] = vi++;
491             vi &= vi < cv_n ? -1 : 0;
492         }
493
494         for( i = 0; i < sample_count; i++ )
495         {
496             int a = cvRandInt(r) % sample_count;
497             int b = cvRandInt(r) % sample_count;
498             CV_SWAP( dst[a], dst[b], vi );
499         }
500     }
501
502     cat_map->cols = MAX( total_c_count, 1 );
503
504     max_split_size = cvAlign(sizeof(CvDTreeSplit) +
505         (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
506     CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage ));
507
508     have_priors = is_classifier && params.priors;
509     if( is_classifier )
510     {
511         int m = get_num_classes();
512         double sum = 0;
513         CV_CALL( priors = cvCreateMat( 1, m, CV_64F ));
514         for( i = 0; i < m; i++ )
515         {
516             double val = have_priors ? params.priors[i] : 1.;
517             if( val <= 0 )
518                 CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" );
519             priors->data.db[i] = val;
520             sum += val;
521         }
522         
523         // normalize weights
524         if( have_priors )
525             cvScale( priors, priors, 1./sum );
526         
527         CV_CALL( priors_mult = cvCloneMat( priors ));
528         CV_CALL( counts = cvCreateMat( 1, m, CV_32SC1 ));
529     }
530
531     CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 ));
532     CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 ));
533
534     __END__;
535
536     if( data )
537         delete data;
538
539     cvFree( &int_ptr );
540     cvReleaseMat( &sample_idx );
541     cvReleaseMat( &var_type0 );
542     cvReleaseMat( &tmp_map );
543 }
544
545
546 CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx )
547 {
548     CvDTreeNode* root = 0;
549     CvMat* isubsample_idx = 0;
550     CvMat* subsample_co = 0;
551     
552     CV_FUNCNAME( "CvDTreeTrainData::subsample_data" );
553
554     __BEGIN__;
555
556     if( !data_root )
557         CV_ERROR( CV_StsError, "No training data has been set" );
558     
559     if( _subsample_idx )
560         CV_CALL( isubsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
561
562     if( !isubsample_idx )
563     {
564         // make a copy of the root node
565         CvDTreeNode temp;
566         int i;
567         root = new_node( 0, 1, 0, 0 );
568         temp = *root;
569         *root = *data_root;
570         root->num_valid = temp.num_valid;
571         if( root->num_valid )
572         {
573             for( i = 0; i < var_count; i++ )
574                 root->num_valid[i] = data_root->num_valid[i];
575         }
576         root->cv_Tn = temp.cv_Tn;
577         root->cv_node_risk = temp.cv_node_risk;
578         root->cv_node_error = temp.cv_node_error;
579     }
580     else
581     {
582         int* sidx = isubsample_idx->data.i;
583         // co - array of count/offset pairs (to handle duplicated values in _subsample_idx)
584         int* co, cur_ofs = 0;
585         int vi, i, total = data_root->sample_count;
586         int count = isubsample_idx->rows + isubsample_idx->cols - 1;
587         int work_var_count = get_work_var_count();
588         root = new_node( 0, count, 1, 0 );
589
590         CV_CALL( subsample_co = cvCreateMat( 1, total*2, CV_32SC1 ));
591         cvZero( subsample_co );
592         co = subsample_co->data.i;
593         for( i = 0; i < count; i++ )
594             co[sidx[i]*2]++;
595         for( i = 0; i < total; i++ )
596         {
597             if( co[i*2] )
598             {
599                 co[i*2+1] = cur_ofs;
600                 cur_ofs += co[i*2];
601             }
602             else
603                 co[i*2+1] = -1;
604         }
605
606         for( vi = 0; vi < work_var_count; vi++ )
607         {
608             int ci = get_var_type(vi);
609
610             if( ci >= 0 || vi >= var_count )
611             {
612                 const int* src = get_cat_var_data( data_root, vi );
613                 int* dst = get_cat_var_data( root, vi );
614                 int num_valid = 0;
615
616                 for( i = 0; i < count; i++ )
617                 {
618                     int val = src[sidx[i]];
619                     dst[i] = val;
620                     num_valid += val >= 0;
621                 }
622
623                 if( vi < var_count )
624                     root->set_num_valid(vi, num_valid);
625             }
626             else
627             {
628                 const CvPair32s32f* src = get_ord_var_data( data_root, vi );
629                 CvPair32s32f* dst = get_ord_var_data( root, vi );
630                 int j = 0, idx, count_i;
631                 int num_valid = data_root->get_num_valid(vi);
632
633                 for( i = 0; i < num_valid; i++ )
634                 {
635                     idx = src[i].i;
636                     count_i = co[idx*2];
637                     if( count_i )
638                     {
639                         float val = src[i].val;
640                         for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
641                         {
642                             dst[j].val = val;
643                             dst[j].i = cur_ofs;
644                         }
645                     }
646                 }
647
648                 root->set_num_valid(vi, j);
649
650                 for( ; i < total; i++ )
651                 {
652                     idx = src[i].i;
653                     count_i = co[idx*2];
654                     if( count_i )
655                     {
656                         float val = src[i].val;
657                         for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
658                         {
659                             dst[j].val = val;
660                             dst[j].i = cur_ofs;
661                         }
662                     }
663                 }
664             }
665         }
666     }
667
668     __END__;
669
670     cvReleaseMat( &isubsample_idx );
671     cvReleaseMat( &subsample_co );
672
673     return root;
674 }
675
676
677 void CvDTreeTrainData::get_vectors( const CvMat* _subsample_idx,
678                                     float* values, uchar* missing,
679                                     float* responses, bool get_class_idx )
680 {
681     CvMat* subsample_idx = 0;
682     CvMat* subsample_co = 0;
683     
684     CV_FUNCNAME( "CvDTreeTrainData::get_vectors" );
685
686     __BEGIN__;
687
688     int i, vi, total = sample_count, count = total, cur_ofs = 0;
689     int* sidx = 0;
690     int* co = 0;
691
692     if( _subsample_idx )
693     {
694         CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
695         sidx = subsample_idx->data.i;
696         CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
697         co = subsample_co->data.i;
698         cvZero( subsample_co );
699         count = subsample_idx->cols + subsample_idx->rows - 1;
700         for( i = 0; i < count; i++ )
701             co[sidx[i]*2]++;
702         for( i = 0; i < total; i++ )
703         {
704             int count_i = co[i*2];
705             if( count_i )
706             {
707                 co[i*2+1] = cur_ofs*var_count;
708                 cur_ofs += count_i;
709             }
710         }
711     }
712
713     if( missing )
714         memset( missing, 1, count*var_count );
715
716     for( vi = 0; vi < var_count; vi++ )
717     {
718         int ci = get_var_type(vi);
719         if( ci >= 0 ) // categorical
720         {
721             float* dst = values + vi;
722             uchar* m = missing ? missing + vi : 0;
723             const int* src = get_cat_var_data(data_root, vi);
724
725             for( i = 0; i < count; i++, dst += var_count )
726             {
727                 int idx = sidx ? sidx[i] : i;
728                 int val = src[idx];
729                 *dst = (float)val;
730                 if( m )
731                 {
732                     *m = val < 0;
733                     m += var_count;
734                 }
735             }
736         }
737         else // ordered
738         {
739             float* dst = values + vi;
740             uchar* m = missing ? missing + vi : 0;
741             const CvPair32s32f* src = get_ord_var_data(data_root, vi);
742             int count1 = data_root->get_num_valid(vi);
743
744             for( i = 0; i < count1; i++ )
745             {
746                 int idx = src[i].i;
747                 int count_i = 1;
748                 if( co )
749                 {
750                     count_i = co[idx*2];
751                     cur_ofs = co[idx*2+1];
752                 }
753                 else
754                     cur_ofs = idx*var_count;
755                 if( count_i )
756                 {
757                     float val = src[i].val;
758                     for( ; count_i > 0; count_i--, cur_ofs += var_count )
759                     {
760                         dst[cur_ofs] = val;
761                         if( m )
762                             m[cur_ofs] = 0;
763                     }
764                 }
765             }
766         }
767     }
768
769     // copy responses
770     if( responses )
771     {
772         if( is_classifier )
773         {
774             const int* src = get_class_labels(data_root);
775             for( i = 0; i < count; i++ )
776             {
777                 int idx = sidx ? sidx[i] : i;
778                 int val = get_class_idx ? src[idx] :
779                     cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];
780                 responses[i] = (float)val;
781             }
782         }
783         else
784         {
785             const float* src = get_ord_responses(data_root);
786             for( i = 0; i < count; i++ )
787             {
788                 int idx = sidx ? sidx[i] : i;
789                 responses[i] = src[idx];
790             }
791         }
792     }
793
794     __END__;
795
796     cvReleaseMat( &subsample_idx );
797     cvReleaseMat( &subsample_co );
798 }
799
800
801 CvDTreeNode* CvDTreeTrainData::new_node( CvDTreeNode* parent, int count,
802                                          int storage_idx, int offset )
803 {
804     CvDTreeNode* node = (CvDTreeNode*)cvSetNew( node_heap );
805
806     node->sample_count = count;
807     node->depth = parent ? parent->depth + 1 : 0;
808     node->parent = parent;
809     node->left = node->right = 0;
810     node->split = 0;
811     node->value = 0;
812     node->class_idx = 0;
813     node->maxlr = 0.;
814
815     node->buf_idx = storage_idx;
816     node->offset = offset;
817     if( nv_heap )
818         node->num_valid = (int*)cvSetNew( nv_heap );
819     else
820         node->num_valid = 0;
821     node->alpha = node->node_risk = node->tree_risk = node->tree_error = 0.;
822     node->complexity = 0;
823
824     if( params.cv_folds > 0 && cv_heap )
825     {
826         int cv_n = params.cv_folds;
827         node->Tn = INT_MAX;
828         node->cv_Tn = (int*)cvSetNew( cv_heap );
829         node->cv_node_risk = (double*)cvAlignPtr(node->cv_Tn + cv_n, sizeof(double));
830         node->cv_node_error = node->cv_node_risk + cv_n;
831     }
832     else
833     {
834         node->Tn = 0;
835         node->cv_Tn = 0;
836         node->cv_node_risk = 0;
837         node->cv_node_error = 0;
838     }
839
840     return node;
841 }
842
843
844 CvDTreeSplit* CvDTreeTrainData::new_split_ord( int vi, float cmp_val,
845                 int split_point, int inversed, float quality )
846 {
847     CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
848     split->var_idx = vi;
849     split->ord.c = cmp_val;
850     split->ord.split_point = split_point;
851     split->inversed = inversed;
852     split->quality = quality;
853     split->next = 0;
854
855     return split;
856 }
857
858
859 CvDTreeSplit* CvDTreeTrainData::new_split_cat( int vi, float quality )
860 {
861     CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
862     int i, n = (max_c_count + 31)/32;
863
864     split->var_idx = vi;
865     split->inversed = 0;
866     split->quality = quality;
867     for( i = 0; i < n; i++ )
868         split->subset[i] = 0;
869     split->next = 0;
870
871     return split;
872 }
873
874
875 void CvDTreeTrainData::free_node( CvDTreeNode* node )
876 {
877     CvDTreeSplit* split = node->split;
878     free_node_data( node );
879     while( split )
880     {
881         CvDTreeSplit* next = split->next;
882         cvSetRemoveByPtr( split_heap, split );
883         split = next;
884     }
885     node->split = 0;
886     cvSetRemoveByPtr( node_heap, node );
887 }
888
889
890 void CvDTreeTrainData::free_node_data( CvDTreeNode* node )
891 {
892     if( node->num_valid )
893     {
894         cvSetRemoveByPtr( nv_heap, node->num_valid );
895         node->num_valid = 0;
896     }
897     // do not free cv_* fields, as all the cross-validation related data is released at once.
898 }
899
900
901 void CvDTreeTrainData::free_train_data()
902 {
903     cvReleaseMat( &counts );
904     cvReleaseMat( &buf );
905     cvReleaseMat( &direction );
906     cvReleaseMat( &split_buf );
907     cvReleaseMemStorage( &temp_storage );
908     cv_heap = nv_heap = 0;
909 }
910
911
912 void CvDTreeTrainData::clear()
913 {
914     free_train_data();
915
916     cvReleaseMemStorage( &tree_storage );
917
918     cvReleaseMat( &var_idx );
919     cvReleaseMat( &var_type );
920     cvReleaseMat( &cat_count );
921     cvReleaseMat( &cat_ofs );
922     cvReleaseMat( &cat_map );
923     cvReleaseMat( &priors );
924     cvReleaseMat( &priors_mult );
925
926     node_heap = split_heap = 0;
927
928     sample_count = var_all = var_count = max_c_count = ord_var_count = cat_var_count = 0;
929     have_labels = have_priors = is_classifier = false;
930
931     buf_count = buf_size = 0;
932     shared = false;
933
934     data_root = 0;
935
936     rng = cvRNG(-1);
937 }
938
939
940 int CvDTreeTrainData::get_num_classes() const
941 {
942     return is_classifier ? cat_count->data.i[cat_var_count] : 0;
943 }
944
945
946 int CvDTreeTrainData::get_var_type(int vi) const
947 {
948     return var_type->data.i[vi];
949 }
950
951
952 int CvDTreeTrainData::get_work_var_count() const
953 {
954     return var_count + 1 + (have_labels ? 1 : 0);
955 }
956
957 CvPair32s32f* CvDTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi )
958 {
959     int oi = ~get_var_type(vi);
960     assert( 0 <= oi && oi < ord_var_count );
961     return (CvPair32s32f*)(buf->data.i + n->buf_idx*buf->cols +
962                            n->offset + oi*n->sample_count*2);
963 }
964
965
966 int* CvDTreeTrainData::get_class_labels( CvDTreeNode* n )
967 {
968     return get_cat_var_data( n, var_count );
969 }
970
971
972 float* CvDTreeTrainData::get_ord_responses( CvDTreeNode* n )
973 {
974     return (float*)get_cat_var_data( n, var_count );
975 }
976
977
978 int* CvDTreeTrainData::get_labels( CvDTreeNode* n )
979 {
980     return have_labels ? get_cat_var_data( n, var_count + 1 ) : 0;
981 }
982
983
984 int* CvDTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi )
985 {
986     int ci = get_var_type(vi);
987     assert( 0 <= ci && ci <= cat_var_count + 1 );
988     return buf->data.i + n->buf_idx*buf->cols + n->offset +
989            (ord_var_count*2 + ci)*n->sample_count;
990 }
991
992
993 int CvDTreeTrainData::get_child_buf_idx( CvDTreeNode* n )
994 {
995     int idx = n->buf_idx + 1;
996     if( idx >= buf_count )
997         idx = shared ? 1 : 0;
998     return idx;
999 }
1000
1001
1002 void CvDTreeTrainData::write_params( CvFileStorage* fs )
1003 {
1004     CV_FUNCNAME( "CvDTreeTrainData::write_params" );
1005
1006     __BEGIN__;
1007
1008     int vi, vcount = var_count;
1009
1010     cvWriteInt( fs, "is_classifier", is_classifier ? 1 : 0 );
1011     cvWriteInt( fs, "var_all", var_all );
1012     cvWriteInt( fs, "var_count", var_count );
1013     cvWriteInt( fs, "ord_var_count", ord_var_count );
1014     cvWriteInt( fs, "cat_var_count", cat_var_count );
1015
1016     cvStartWriteStruct( fs, "training_params", CV_NODE_MAP );
1017     cvWriteInt( fs, "use_surrogates", params.use_surrogates ? 1 : 0 );
1018
1019     if( is_classifier )
1020     {
1021         cvWriteInt( fs, "max_categories", params.max_categories );
1022     }
1023     else
1024     {
1025         cvWriteReal( fs, "regression_accuracy", params.regression_accuracy );
1026     }
1027
1028     cvWriteInt( fs, "max_depth", params.max_depth );
1029     cvWriteInt( fs, "min_sample_count", params.min_sample_count );
1030     cvWriteInt( fs, "cross_validation_folds", params.cv_folds );
1031     
1032     if( params.cv_folds > 1 )
1033     {
1034         cvWriteInt( fs, "use_1se_rule", params.use_1se_rule ? 1 : 0 );
1035         cvWriteInt( fs, "truncate_pruned_tree", params.truncate_pruned_tree ? 1 : 0 );
1036     }
1037
1038     if( priors )
1039         cvWrite( fs, "priors", priors );
1040
1041     cvEndWriteStruct( fs );
1042
1043     if( var_idx )
1044         cvWrite( fs, "var_idx", var_idx );
1045     
1046     cvStartWriteStruct( fs, "var_type", CV_NODE_SEQ+CV_NODE_FLOW );
1047
1048     for( vi = 0; vi < vcount; vi++ )
1049         cvWriteInt( fs, 0, var_type->data.i[vi] >= 0 );
1050
1051     cvEndWriteStruct( fs );
1052
1053     if( cat_count && (cat_var_count > 0 || is_classifier) )
1054     {
1055         CV_ASSERT( cat_count != 0 );
1056         cvWrite( fs, "cat_count", cat_count );
1057         cvWrite( fs, "cat_map", cat_map );
1058     }
1059
1060     __END__;
1061 }
1062
1063
1064 void CvDTreeTrainData::read_params( CvFileStorage* fs, CvFileNode* node )
1065 {
1066     CV_FUNCNAME( "CvDTreeTrainData::read_params" );
1067
1068     __BEGIN__;
1069     
1070     CvFileNode *tparams_node, *vartype_node;
1071     CvSeqReader reader;
1072     int vi, max_split_size, tree_block_size;
1073
1074     is_classifier = (cvReadIntByName( fs, node, "is_classifier" ) != 0);
1075     var_all = cvReadIntByName( fs, node, "var_all" );
1076     var_count = cvReadIntByName( fs, node, "var_count", var_all );
1077     cat_var_count = cvReadIntByName( fs, node, "cat_var_count" );
1078     ord_var_count = cvReadIntByName( fs, node, "ord_var_count" );
1079
1080     tparams_node = cvGetFileNodeByName( fs, node, "training_params" );
1081
1082     if( tparams_node ) // training parameters are not necessary
1083     {
1084         params.use_surrogates = cvReadIntByName( fs, tparams_node, "use_surrogates", 1 ) != 0;
1085
1086         if( is_classifier )
1087         {
1088             params.max_categories = cvReadIntByName( fs, tparams_node, "max_categories" );
1089         }
1090         else
1091         {
1092             params.regression_accuracy =
1093                 (float)cvReadRealByName( fs, tparams_node, "regression_accuracy" );
1094         }
1095
1096         params.max_depth = cvReadIntByName( fs, tparams_node, "max_depth" );
1097         params.min_sample_count = cvReadIntByName( fs, tparams_node, "min_sample_count" );
1098         params.cv_folds = cvReadIntByName( fs, tparams_node, "cross_validation_folds" );
1099     
1100         if( params.cv_folds > 1 )
1101         {
1102             params.use_1se_rule = cvReadIntByName( fs, tparams_node, "use_1se_rule" ) != 0;
1103             params.truncate_pruned_tree =
1104                 cvReadIntByName( fs, tparams_node, "truncate_pruned_tree" ) != 0;
1105         }
1106
1107         priors = (CvMat*)cvReadByName( fs, tparams_node, "priors" );
1108         if( priors )
1109         {
1110             if( !CV_IS_MAT(priors) )
1111                 CV_ERROR( CV_StsParseError, "priors must stored as a matrix" );
1112             priors_mult = cvCloneMat( priors );
1113         }
1114     }
1115
1116     CV_CALL( var_idx = (CvMat*)cvReadByName( fs, node, "var_idx" ));
1117     if( var_idx )
1118     {
1119         if( !CV_IS_MAT(var_idx) ||
1120             var_idx->cols != 1 && var_idx->rows != 1 ||
1121             var_idx->cols + var_idx->rows - 1 != var_count ||
1122             CV_MAT_TYPE(var_idx->type) != CV_32SC1 )
1123             CV_ERROR( CV_StsParseError,
1124                 "var_idx (if exist) must be valid 1d integer vector containing <var_count> elements" );
1125
1126         for( vi = 0; vi < var_count; vi++ )
1127             if( (unsigned)var_idx->data.i[vi] >= (unsigned)var_all )
1128                 CV_ERROR( CV_StsOutOfRange, "some of var_idx elements are out of range" );
1129     }
1130     
1131     ////// read var type
1132     CV_CALL( var_type = cvCreateMat( 1, var_count + 2, CV_32SC1 ));
1133
1134     cat_var_count = 0;
1135     ord_var_count = -1;
1136     vartype_node = cvGetFileNodeByName( fs, node, "var_type" );
1137
1138     if( vartype_node && CV_NODE_TYPE(vartype_node->tag) == CV_NODE_INT && var_count == 1 )
1139         var_type->data.i[0] = vartype_node->data.i ? cat_var_count++ : ord_var_count--;
1140     else
1141     {
1142         if( !vartype_node || CV_NODE_TYPE(vartype_node->tag) != CV_NODE_SEQ ||
1143             vartype_node->data.seq->total != var_count )
1144             CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
1145
1146         cvStartReadSeq( vartype_node->data.seq, &reader );
1147     
1148         for( vi = 0; vi < var_count; vi++ )
1149         {
1150             CvFileNode* n = (CvFileNode*)reader.ptr;
1151             if( CV_NODE_TYPE(n->tag) != CV_NODE_INT || (n->data.i & ~1) )
1152                 CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
1153             var_type->data.i[vi] = n->data.i ? cat_var_count++ : ord_var_count--;
1154             CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
1155         }
1156     }
1157     var_type->data.i[var_count] = cat_var_count;
1158
1159     ord_var_count = ~ord_var_count;
1160     if( cat_var_count != cat_var_count || ord_var_count != ord_var_count )
1161         CV_ERROR( CV_StsParseError, "var_type is inconsistent with cat_var_count and ord_var_count" );
1162     //////
1163
1164     if( cat_var_count > 0 || is_classifier )
1165     {
1166         int ccount, total_c_count = 0;
1167         CV_CALL( cat_count = (CvMat*)cvReadByName( fs, node, "cat_count" ));
1168         CV_CALL( cat_map = (CvMat*)cvReadByName( fs, node, "cat_map" ));
1169
1170         if( !CV_IS_MAT(cat_count) || !CV_IS_MAT(cat_map) ||
1171             cat_count->cols != 1 && cat_count->rows != 1 ||
1172             CV_MAT_TYPE(cat_count->type) != CV_32SC1 ||
1173             cat_count->cols + cat_count->rows - 1 != cat_var_count + is_classifier ||
1174             cat_map->cols != 1 && cat_map->rows != 1 ||
1175             CV_MAT_TYPE(cat_map->type) != CV_32SC1 )
1176             CV_ERROR( CV_StsParseError,
1177             "Both cat_count and cat_map must exist and be valid 1d integer vectors of an appropriate size" );
1178
1179         ccount = cat_var_count + is_classifier;
1180
1181         CV_CALL( cat_ofs = cvCreateMat( 1, ccount + 1, CV_32SC1 ));
1182         cat_ofs->data.i[0] = 0;
1183         max_c_count = 1;
1184
1185         for( vi = 0; vi < ccount; vi++ )
1186         {
1187             int val = cat_count->data.i[vi];
1188             if( val <= 0 )
1189                 CV_ERROR( CV_StsOutOfRange, "some of cat_count elements are out of range" );
1190             max_c_count = MAX( max_c_count, val );
1191             cat_ofs->data.i[vi+1] = total_c_count += val;
1192         }
1193
1194         if( cat_map->cols + cat_map->rows - 1 != total_c_count )
1195             CV_ERROR( CV_StsBadSize,
1196             "cat_map vector length is not equal to the total number of categories in all categorical vars" );
1197     }
1198
1199     max_split_size = cvAlign(sizeof(CvDTreeSplit) +
1200         (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
1201
1202     tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
1203     tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
1204     CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
1205     CV_CALL( node_heap = cvCreateSet( 0, sizeof(node_heap[0]),
1206             sizeof(CvDTreeNode), tree_storage ));
1207     CV_CALL( split_heap = cvCreateSet( 0, sizeof(split_heap[0]),
1208             max_split_size, tree_storage ));
1209
1210     __END__;
1211 }
1212
1213
1214 /////////////////////// Decision Tree /////////////////////////
1215
1216 CvDTree::CvDTree()
1217 {
1218     data = 0;
1219     var_importance = 0;
1220     default_model_name = "my_tree";
1221
1222     clear();
1223 }
1224
1225
1226 void CvDTree::clear()
1227 {
1228     cvReleaseMat( &var_importance );
1229     if( data )
1230     {
1231         if( !data->shared )
1232             delete data;
1233         else
1234             free_tree();
1235         data = 0;
1236     }
1237     root = 0;
1238     pruned_tree_idx = -1;
1239 }
1240
1241
1242 CvDTree::~CvDTree()
1243 {
1244     clear();
1245 }
1246
1247
1248 const CvDTreeNode* CvDTree::get_root() const
1249 {
1250     return root;
1251 }
1252
1253
1254 int CvDTree::get_pruned_tree_idx() const
1255 {
1256     return pruned_tree_idx;
1257 }
1258
1259
1260 CvDTreeTrainData* CvDTree::get_data()
1261 {
1262     return data;
1263 }
1264
1265
1266 bool CvDTree::train( const CvMat* _train_data, int _tflag,
1267                      const CvMat* _responses, const CvMat* _var_idx,
1268                      const CvMat* _sample_idx, const CvMat* _var_type,
1269                      const CvMat* _missing_mask, CvDTreeParams _params )
1270 {
1271     bool result = false;
1272
1273     CV_FUNCNAME( "CvDTree::train" );
1274
1275     __BEGIN__;
1276
1277     clear();
1278     data = new CvDTreeTrainData( _train_data, _tflag, _responses,
1279                                  _var_idx, _sample_idx, _var_type,
1280                                  _missing_mask, _params, false );
1281     CV_CALL( result = do_train(0));
1282
1283     __END__;
1284
1285     return result;
1286 }
1287
1288
1289 bool CvDTree::train( CvDTreeTrainData* _data, const CvMat* _subsample_idx )
1290 {
1291     bool result = false;
1292
1293     CV_FUNCNAME( "CvDTree::train" );
1294
1295     __BEGIN__;
1296
1297     clear();
1298     data = _data;
1299     data->shared = true;
1300     CV_CALL( result = do_train(_subsample_idx));
1301
1302     __END__;
1303
1304     return result;
1305 }
1306
1307
1308 bool CvDTree::do_train( const CvMat* _subsample_idx )
1309 {
1310     bool result = false;
1311
1312     CV_FUNCNAME( "CvDTree::do_train" );
1313
1314     __BEGIN__;
1315
1316     root = data->subsample_data( _subsample_idx );
1317
1318     CV_CALL( try_split_node(root));
1319     
1320     if( data->params.cv_folds > 0 )
1321         CV_CALL( prune_cv());
1322
1323     if( !data->shared )
1324         data->free_train_data();
1325
1326     result = true;
1327
1328     __END__;
1329
1330     return result;
1331 }
1332
1333
1334 void CvDTree::try_split_node( CvDTreeNode* node )
1335 {
1336     CvDTreeSplit* best_split = 0;
1337     int i, n = node->sample_count, vi;
1338     bool can_split = true;
1339     double quality_scale;
1340
1341     calc_node_value( node );
1342
1343     if( node->sample_count <= data->params.min_sample_count ||
1344         node->depth >= data->params.max_depth )
1345         can_split = false;
1346
1347     if( can_split && data->is_classifier )
1348     {
1349         // check if we have a "pure" node,
1350         // we assume that cls_count is filled by calc_node_value()
1351         int* cls_count = data->counts->data.i;
1352         int nz = 0, m = data->get_num_classes();
1353         for( i = 0; i < m; i++ )
1354             nz += cls_count[i] != 0;
1355         if( nz == 1 ) // there is only one class
1356             can_split = false;
1357     }
1358     else if( can_split )
1359     {
1360         if( sqrt(node->node_risk)/n < data->params.regression_accuracy )
1361             can_split = false;
1362     }
1363
1364     if( can_split )
1365     {
1366         best_split = find_best_split(node);
1367         // TODO: check the split quality ...
1368         node->split = best_split;
1369     }
1370
1371     if( !can_split || !best_split )
1372     {
1373         data->free_node_data(node);
1374         return;
1375     }
1376
1377     quality_scale = calc_node_dir( node );
1378
1379     if( data->params.use_surrogates )
1380     {
1381         // find all the surrogate splits
1382         // and sort them by their similarity to the primary one
1383         for( vi = 0; vi < data->var_count; vi++ )
1384         {
1385             CvDTreeSplit* split;
1386             int ci = data->get_var_type(vi);
1387
1388             if( vi == best_split->var_idx )
1389                 continue;
1390
1391             if( ci >= 0 )
1392                 split = find_surrogate_split_cat( node, vi );
1393             else
1394                 split = find_surrogate_split_ord( node, vi );
1395
1396             if( split )
1397             {
1398                 // insert the split
1399                 CvDTreeSplit* prev_split = node->split;
1400                 split->quality = (float)(split->quality*quality_scale);
1401
1402                 while( prev_split->next &&
1403                        prev_split->next->quality > split->quality )
1404                     prev_split = prev_split->next;
1405                 split->next = prev_split->next;
1406                 prev_split->next = split;
1407             }
1408         }
1409     }
1410
1411     split_node_data( node );
1412     try_split_node( node->left );
1413     try_split_node( node->right );
1414 }
1415
1416
1417 // calculate direction (left(-1),right(1),missing(0))
1418 // for each sample using the best split
1419 // the function returns scale coefficients for surrogate split quality factors.
1420 // the scale is applied to normalize surrogate split quality relatively to the
1421 // best (primary) split quality. That is, if a surrogate split is absolutely
1422 // identical to the primary split, its quality will be set to the maximum value = 
1423 // quality of the primary split; otherwise, it will be lower.
1424 // besides, the function compute node->maxlr,
1425 // minimum possible quality (w/o considering the above mentioned scale)
1426 // for a surrogate split. Surrogate splits with quality less than node->maxlr
1427 // are not discarded.
1428 double CvDTree::calc_node_dir( CvDTreeNode* node )
1429 {
1430     char* dir = (char*)data->direction->data.ptr;
1431     int i, n = node->sample_count, vi = node->split->var_idx;
1432     double L, R;
1433
1434     assert( !node->split->inversed );
1435
1436     if( data->get_var_type(vi) >= 0 ) // split on categorical var
1437     {
1438         const int* labels = data->get_cat_var_data(node,vi);
1439         const int* subset = node->split->subset;
1440
1441         if( !data->have_priors )
1442         {
1443             int sum = 0, sum_abs = 0;
1444
1445             for( i = 0; i < n; i++ )
1446             {
1447                 int idx = labels[i];
1448                 int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
1449                 sum += d; sum_abs += d & 1;
1450                 dir[i] = (char)d;
1451             }
1452
1453             R = (sum_abs + sum) >> 1;
1454             L = (sum_abs - sum) >> 1;
1455         }
1456         else
1457         {
1458             const int* responses = data->get_class_labels(node);
1459             const double* priors = data->priors_mult->data.db;
1460             double sum = 0, sum_abs = 0;
1461
1462             for( i = 0; i < n; i++ )
1463             {
1464                 int idx = labels[i];
1465                 double w = priors[responses[i]];
1466                 int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
1467                 sum += d*w; sum_abs += (d & 1)*w;
1468                 dir[i] = (char)d;
1469             }
1470
1471             R = (sum_abs + sum) * 0.5;
1472             L = (sum_abs - sum) * 0.5;
1473         }
1474     }
1475     else // split on ordered var
1476     {
1477         const CvPair32s32f* sorted = data->get_ord_var_data(node,vi);
1478         int split_point = node->split->ord.split_point;
1479         int n1 = node->get_num_valid(vi);
1480
1481         assert( 0 <= split_point && split_point < n1-1 );
1482
1483         if( !data->have_priors )
1484         {
1485             for( i = 0; i <= split_point; i++ )
1486                 dir[sorted[i].i] = (char)-1;
1487             for( ; i < n1; i++ )
1488                 dir[sorted[i].i] = (char)1;
1489             for( ; i < n; i++ )
1490                 dir[sorted[i].i] = (char)0;
1491
1492             L = split_point-1;
1493             R = n1 - split_point + 1;
1494         }
1495         else
1496         {
1497             const int* responses = data->get_class_labels(node);
1498             const double* priors = data->priors_mult->data.db;
1499             L = R = 0;
1500
1501             for( i = 0; i <= split_point; i++ )
1502             {
1503                 int idx = sorted[i].i;
1504                 double w = priors[responses[idx]];
1505                 dir[idx] = (char)-1;
1506                 L += w;
1507             }
1508
1509             for( ; i < n1; i++ )
1510             {
1511                 int idx = sorted[i].i;
1512                 double w = priors[responses[idx]];
1513                 dir[idx] = (char)1;
1514                 R += w;
1515             }
1516
1517             for( ; i < n; i++ )
1518                 dir[sorted[i].i] = (char)0;
1519         }
1520     }
1521
1522     node->maxlr = MAX( L, R );
1523     return node->split->quality/(L + R);
1524 }
1525
1526
1527 CvDTreeSplit* CvDTree::find_best_split( CvDTreeNode* node )
1528 {
1529     int vi;
1530     CvDTreeSplit *best_split = 0, *split = 0, *t;
1531
1532     for( vi = 0; vi < data->var_count; vi++ )
1533     {
1534         int ci = data->get_var_type(vi);
1535         if( node->get_num_valid(vi) <= 1 )
1536             continue;
1537
1538         if( data->is_classifier )
1539         {
1540             if( ci >= 0 )
1541                 split = find_split_cat_class( node, vi );
1542             else
1543                 split = find_split_ord_class( node, vi );
1544         }
1545         else
1546         {
1547             if( ci >= 0 )
1548                 split = find_split_cat_reg( node, vi );
1549             else
1550                 split = find_split_ord_reg( node, vi );
1551         }
1552
1553         if( split )
1554         {
1555             if( !best_split || best_split->quality < split->quality )
1556                 CV_SWAP( best_split, split, t );
1557             if( split )
1558                 cvSetRemoveByPtr( data->split_heap, split );
1559         }
1560     }
1561
1562     return best_split;
1563 }
1564
1565
1566 CvDTreeSplit* CvDTree::find_split_ord_class( CvDTreeNode* node, int vi )
1567 {
1568     const float epsilon = FLT_EPSILON*2;
1569     const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
1570     const int* responses = data->get_class_labels(node);
1571     int n = node->sample_count;
1572     int n1 = node->get_num_valid(vi);
1573     int m = data->get_num_classes();
1574     const int* rc0 = data->counts->data.i;
1575     int* lc = (int*)cvStackAlloc(m*sizeof(lc[0]));
1576     int* rc = (int*)cvStackAlloc(m*sizeof(rc[0]));
1577     int i, best_i = -1;
1578     double lsum2 = 0, rsum2 = 0, best_val = 0;
1579     const double* priors = data->have_priors ? data->priors_mult->data.db : 0;
1580
1581     // init arrays of class instance counters on both sides of the split
1582     for( i = 0; i < m; i++ )
1583     {
1584         lc[i] = 0;
1585         rc[i] = rc0[i];
1586     }
1587
1588     // compensate for missing values
1589     for( i = n1; i < n; i++ )
1590         rc[responses[sorted[i].i]]--;
1591
1592     if( !priors )
1593     {
1594         int L = 0, R = n1;
1595
1596         for( i = 0; i < m; i++ )
1597             rsum2 += (double)rc[i]*rc[i];
1598
1599         for( i = 0; i < n1 - 1; i++ )
1600         {
1601             int idx = responses[sorted[i].i];
1602             int lv, rv;
1603             L++; R--;
1604             lv = lc[idx]; rv = rc[idx];
1605             lsum2 += lv*2 + 1;
1606             rsum2 -= rv*2 - 1;
1607             lc[idx] = lv + 1; rc[idx] = rv - 1;
1608
1609             if( sorted[i].val + epsilon < sorted[i+1].val )
1610             {
1611                 double val = (lsum2*R + rsum2*L)/((double)L*R);
1612                 if( best_val < val )
1613                 {
1614                     best_val = val;
1615                     best_i = i;
1616                 }
1617             }
1618         }
1619     }
1620     else
1621     {
1622         double L = 0, R = 0;
1623         for( i = 0; i < m; i++ )
1624         {
1625             double wv = rc[i]*priors[i];
1626             R += wv;
1627             rsum2 += wv*wv;
1628         }
1629
1630         for( i = 0; i < n1 - 1; i++ )
1631         {
1632             int idx = responses[sorted[i].i];
1633             int lv, rv;
1634             double p = priors[idx], p2 = p*p;
1635             L += p; R -= p;
1636             lv = lc[idx]; rv = rc[idx];
1637             lsum2 += p2*(lv*2 + 1);
1638             rsum2 -= p2*(rv*2 - 1);
1639             lc[idx] = lv + 1; rc[idx] = rv - 1;
1640
1641             if( sorted[i].val + epsilon < sorted[i+1].val )
1642             {
1643                 double val = (lsum2*R + rsum2*L)/((double)L*R);
1644                 if( best_val < val )
1645                 {
1646                     best_val = val;
1647                     best_i = i;
1648                 }
1649             }
1650         }
1651     }
1652
1653     return best_i >= 0 ? data->new_split_ord( vi,
1654         (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
1655         0, (float)best_val ) : 0;
1656 }
1657
1658
1659 void CvDTree::cluster_categories( const int* vectors, int n, int m,
1660                                 int* csums, int k, int* labels )
1661 {
1662     // TODO: consider adding priors (class weights) and sample weights to the clustering algorithm
1663     int iters = 0, max_iters = 100;
1664     int i, j, idx;
1665     double* buf = (double*)cvStackAlloc( (n + k)*sizeof(buf[0]) );
1666     double *v_weights = buf, *c_weights = buf + k;
1667     bool modified = true;
1668     CvRNG* r = &data->rng;
1669
1670     // assign labels randomly
1671     for( i = idx = 0; i < n; i++ )
1672     {
1673         int sum = 0;
1674         const int* v = vectors + i*m;
1675         labels[i] = idx++;
1676         idx &= idx < k ? -1 : 0;
1677
1678         // compute weight of each vector
1679         for( j = 0; j < m; j++ )
1680             sum += v[j];
1681         v_weights[i] = sum ? 1./sum : 0.;
1682     }
1683
1684     for( i = 0; i < n; i++ )
1685     {
1686         int i1 = cvRandInt(r) % n;
1687         int i2 = cvRandInt(r) % n;
1688         CV_SWAP( labels[i1], labels[i2], j );
1689     }
1690
1691     for( iters = 0; iters <= max_iters; iters++ )
1692     {
1693         // calculate csums
1694         for( i = 0; i < k; i++ )
1695         {
1696             for( j = 0; j < m; j++ )
1697                 csums[i*m + j] = 0;
1698         }
1699
1700         for( i = 0; i < n; i++ )
1701         {
1702             const int* v = vectors + i*m;
1703             int* s = csums + labels[i]*m;
1704             for( j = 0; j < m; j++ )
1705                 s[j] += v[j];
1706         }
1707
1708         // exit the loop here, when we have up-to-date csums
1709         if( iters == max_iters || !modified )
1710             break;
1711
1712         modified = false;
1713
1714         // calculate weight of each cluster
1715         for( i = 0; i < k; i++ )
1716         {
1717             const int* s = csums + i*m;
1718             int sum = 0;
1719             for( j = 0; j < m; j++ )
1720                 sum += s[j];
1721             c_weights[i] = sum ? 1./sum : 0;
1722         }
1723
1724         // now for each vector determine the closest cluster
1725         for( i = 0; i < n; i++ )
1726         {
1727             const int* v = vectors + i*m;
1728             double alpha = v_weights[i];
1729             double min_dist2 = DBL_MAX;
1730             int min_idx = -1;
1731
1732             for( idx = 0; idx < k; idx++ )
1733             {
1734                 const int* s = csums + idx*m;
1735                 double dist2 = 0., beta = c_weights[idx];
1736                 for( j = 0; j < m; j++ )
1737                 {
1738                     double t = v[j]*alpha - s[j]*beta;
1739                     dist2 += t*t;
1740                 }
1741                 if( min_dist2 > dist2 )
1742                 {
1743                     min_dist2 = dist2;
1744                     min_idx = idx;
1745                 }
1746             }
1747
1748             if( min_idx != labels[i] )
1749                 modified = true;
1750             labels[i] = min_idx;
1751         }
1752     }
1753 }
1754
1755
1756 CvDTreeSplit* CvDTree::find_split_cat_class( CvDTreeNode* node, int vi )
1757 {
1758     CvDTreeSplit* split;
1759     const int* labels = data->get_cat_var_data(node, vi);
1760     const int* responses = data->get_class_labels(node);
1761     int ci = data->get_var_type(vi);
1762     int n = node->sample_count;
1763     int m = data->get_num_classes();
1764     int _mi = data->cat_count->data.i[ci], mi = _mi;
1765     int* lc = (int*)cvStackAlloc(m*sizeof(lc[0]));
1766     int* rc = (int*)cvStackAlloc(m*sizeof(rc[0]));
1767     int* _cjk = (int*)cvStackAlloc(m*(mi+1)*sizeof(_cjk[0]))+m, *cjk = _cjk;
1768     double* c_weights = (double*)cvStackAlloc( mi*sizeof(c_weights[0]) );
1769     int* cluster_labels = 0;
1770     int** int_ptr = 0;
1771     int i, j, k, idx;
1772     double L = 0, R = 0;
1773     double best_val = 0;
1774     int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;
1775     const double* priors = data->priors_mult->data.db;
1776
1777     // init array of counters:
1778     // c_{jk} - number of samples that have vi-th input variable = j and response = k.
1779     for( j = -1; j < mi; j++ )
1780         for( k = 0; k < m; k++ )
1781             cjk[j*m + k] = 0;
1782
1783     for( i = 0; i < n; i++ )
1784     {
1785         j = labels[i];
1786         k = responses[i];
1787         cjk[j*m + k]++;
1788     }
1789
1790     if( m > 2 )
1791     {
1792         if( mi > data->params.max_categories )
1793         {
1794             mi = MIN(data->params.max_categories, n);
1795             cjk += _mi*m;
1796             cluster_labels = (int*)cvStackAlloc(mi*sizeof(cluster_labels[0]));
1797             cluster_categories( _cjk, _mi, m, cjk, mi, cluster_labels );
1798         }
1799         subset_i = 1;
1800         subset_n = 1 << mi;
1801     }
1802     else
1803     {
1804         assert( m == 2 );
1805         int_ptr = (int**)cvStackAlloc( mi*sizeof(int_ptr[0]) );
1806         for( j = 0; j < mi; j++ )
1807             int_ptr[j] = cjk + j*2 + 1;
1808         icvSortIntPtr( int_ptr, mi, 0 );
1809         subset_i = 0;
1810         subset_n = mi;
1811     }
1812
1813     for( k = 0; k < m; k++ )
1814     {
1815         int sum = 0;
1816         for( j = 0; j < mi; j++ )
1817             sum += cjk[j*m + k];
1818         rc[k] = sum;
1819         lc[k] = 0;
1820     }
1821
1822     for( j = 0; j < mi; j++ )
1823     {
1824         double sum = 0;
1825         for( k = 0; k < m; k++ )
1826             sum += cjk[j*m + k]*priors[k];
1827         c_weights[j] = sum;
1828         R += c_weights[j];
1829     }
1830
1831     for( ; subset_i < subset_n; subset_i++ )
1832     {
1833         double weight;
1834         int* crow;
1835         double lsum2 = 0, rsum2 = 0;
1836
1837         if( m == 2 )
1838             idx = (int)(int_ptr[subset_i] - cjk)/2;
1839         else
1840         {
1841             int graycode = (subset_i>>1)^subset_i;
1842             int diff = graycode ^ prevcode;
1843
1844             // determine index of the changed bit.
1845             Cv32suf u;
1846             idx = diff >= (1 << 16) ? 16 : 0;
1847             u.f = (float)(((diff >> 16) | diff) & 65535);
1848             idx += (u.i >> 23) - 127;
1849             subtract = graycode < prevcode;
1850             prevcode = graycode;
1851         }
1852
1853         crow = cjk + idx*m;
1854         weight = c_weights[idx];
1855         if( weight < FLT_EPSILON )
1856             continue;
1857
1858         if( !subtract )
1859         {
1860             for( k = 0; k < m; k++ )
1861             {
1862                 int t = crow[k];
1863                 int lval = lc[k] + t;
1864                 int rval = rc[k] - t;
1865                 double p = priors[k], p2 = p*p;
1866                 lsum2 += p2*lval*lval;
1867                 rsum2 += p2*rval*rval;
1868                 lc[k] = lval; rc[k] = rval;
1869             }
1870             L += weight;
1871             R -= weight;
1872         }
1873         else
1874         {
1875             for( k = 0; k < m; k++ )
1876             {
1877                 int t = crow[k];
1878                 int lval = lc[k] - t;
1879                 int rval = rc[k] + t;
1880                 double p = priors[k], p2 = p*p;
1881                 lsum2 += p2*lval*lval;
1882                 rsum2 += p2*rval*rval;
1883                 lc[k] = lval; rc[k] = rval;
1884             }
1885             L -= weight;
1886             R += weight;
1887         }
1888
1889         if( L > FLT_EPSILON && R > FLT_EPSILON )
1890         {
1891             double val = (lsum2*R + rsum2*L)/((double)L*R);
1892             if( best_val < val )
1893             {
1894                 best_val = val;
1895                 best_subset = subset_i;
1896             }
1897         }
1898     }
1899
1900     if( best_subset < 0 )
1901         return 0;
1902
1903     split = data->new_split_cat( vi, (float)best_val );
1904
1905     if( m == 2 )
1906     {
1907         for( i = 0; i <= best_subset; i++ )
1908         {
1909             idx = (int)(int_ptr[i] - cjk) >> 1;
1910             split->subset[idx >> 5] |= 1 << (idx & 31);
1911         }
1912     }
1913     else
1914     {
1915         for( i = 0; i < _mi; i++ )
1916         {
1917             idx = cluster_labels ? cluster_labels[i] : i;
1918             if( best_subset & (1 << idx) )
1919                 split->subset[i >> 5] |= 1 << (i & 31);
1920         }
1921     }
1922
1923     return split;
1924 }
1925
1926
1927 CvDTreeSplit* CvDTree::find_split_ord_reg( CvDTreeNode* node, int vi )
1928 {
1929     const float epsilon = FLT_EPSILON*2;
1930     const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
1931     const float* responses = data->get_ord_responses(node);
1932     int n = node->sample_count;
1933     int n1 = node->get_num_valid(vi);
1934     int i, best_i = -1;
1935     double best_val = 0, lsum = 0, rsum = node->value*n;
1936     int L = 0, R = n1;
1937
1938     // compensate for missing values
1939     for( i = n1; i < n; i++ )
1940         rsum -= responses[sorted[i].i];
1941
1942     // find the optimal split
1943     for( i = 0; i < n1 - 1; i++ )
1944     {
1945         float t = responses[sorted[i].i];
1946         L++; R--;
1947         lsum += t;
1948         rsum -= t;
1949
1950         if( sorted[i].val + epsilon < sorted[i+1].val )
1951         {
1952             double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
1953             if( best_val < val )
1954             {
1955                 best_val = val;
1956                 best_i = i;
1957             }
1958         }
1959     }
1960
1961     return best_i >= 0 ? data->new_split_ord( vi,
1962         (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
1963         0, (float)best_val ) : 0;
1964 }
1965
1966
1967 CvDTreeSplit* CvDTree::find_split_cat_reg( CvDTreeNode* node, int vi )
1968 {
1969     CvDTreeSplit* split;
1970     const int* labels = data->get_cat_var_data(node, vi);
1971     const float* responses = data->get_ord_responses(node);
1972     int ci = data->get_var_type(vi);
1973     int n = node->sample_count;
1974     int mi = data->cat_count->data.i[ci];
1975     double* sum = (double*)cvStackAlloc( (mi+1)*sizeof(sum[0]) ) + 1;
1976     int* counts = (int*)cvStackAlloc( (mi+1)*sizeof(counts[0]) ) + 1;
1977     double** sum_ptr = 0;
1978     int i, L = 0, R = 0;
1979     double best_val = 0, lsum = 0, rsum = 0;
1980     int best_subset = -1, subset_i;
1981
1982     for( i = -1; i < mi; i++ )
1983         sum[i] = counts[i] = 0;
1984
1985     // calculate sum response and weight of each category of the input var
1986     for( i = 0; i < n; i++ )
1987     {
1988         int idx = labels[i];
1989         double s = sum[idx] + responses[i];
1990         int nc = counts[idx] + 1;
1991         sum[idx] = s;
1992         counts[idx] = nc;
1993     }
1994
1995     // calculate average response in each category
1996     for( i = 0; i < mi; i++ )
1997     {
1998         R += counts[i];
1999         rsum += sum[i];
2000         sum[i] /= MAX(counts[i],1);
2001         sum_ptr[i] = sum + i;
2002     }
2003
2004     icvSortDblPtr( sum_ptr, mi, 0 );
2005
2006     // revert back to unnormalized sums
2007     // (there should be a very little loss of accuracy)
2008     for( i = 0; i < mi; i++ )
2009         sum[i] *= counts[i];
2010
2011     for( subset_i = 0; subset_i < mi-1; subset_i++ )
2012     {
2013         int idx = (int)(sum_ptr[subset_i] - sum);
2014         int ni = counts[idx];
2015
2016         if( ni )
2017         {
2018             double s = sum[idx];
2019             lsum += s; L += ni;
2020             rsum -= s; R -= ni;
2021             
2022             if( L && R )
2023             {
2024                 double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
2025                 if( best_val < val )
2026                 {
2027                     best_val = val;
2028                     best_subset = subset_i;
2029                 }
2030             }
2031         }
2032     }
2033
2034     if( best_subset < 0 )
2035         return 0;
2036
2037     split = data->new_split_cat( vi, (float)best_val );
2038     for( i = 0; i <= best_subset; i++ )
2039     {
2040         int idx = (int)(sum_ptr[i] - sum);
2041         split->subset[idx >> 5] |= 1 << (idx & 31);
2042     }
2043
2044     return split;
2045 }
2046
2047
2048 CvDTreeSplit* CvDTree::find_surrogate_split_ord( CvDTreeNode* node, int vi )
2049 {
2050     const float epsilon = FLT_EPSILON*2;
2051     const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
2052     const char* dir = (char*)data->direction->data.ptr;
2053     int n1 = node->get_num_valid(vi);
2054     // LL - number of samples that both the primary and the surrogate splits send to the left
2055     // LR - ... primary split sends to the left and the surrogate split sends to the right
2056     // RL - ... primary split sends to the right and the surrogate split sends to the left
2057     // RR - ... both send to the right
2058     int i, best_i = -1, best_inversed = 0;
2059     double best_val; 
2060
2061     if( !data->have_priors )
2062     {
2063         int LL = 0, RL = 0, LR, RR;
2064         int worst_val = cvFloor(node->maxlr), _best_val = worst_val;
2065         int sum = 0, sum_abs = 0;
2066         
2067         for( i = 0; i < n1; i++ )
2068         {
2069             int d = dir[sorted[i].i];
2070             sum += d; sum_abs += d & 1;
2071         }
2072
2073         // sum_abs = R + L; sum = R - L
2074         RR = (sum_abs + sum) >> 1;
2075         LR = (sum_abs - sum) >> 1;
2076
2077         // initially all the samples are sent to the right by the surrogate split,
2078         // LR of them are sent to the left by primary split, and RR - to the right.
2079         // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
2080         for( i = 0; i < n1 - 1; i++ )
2081         {
2082             int d = dir[sorted[i].i];
2083
2084             if( d < 0 )
2085             {
2086                 LL++; LR--;
2087                 if( LL + RR > _best_val && sorted[i].val + epsilon < sorted[i+1].val )
2088                 {
2089                     best_val = LL + RR;
2090                     best_i = i; best_inversed = 0;
2091                 }
2092             }
2093             else if( d > 0 )
2094             {
2095                 RL++; RR--;
2096                 if( RL + LR > _best_val && sorted[i].val + epsilon < sorted[i+1].val )
2097                 {
2098                     best_val = RL + LR;
2099                     best_i = i; best_inversed = 1;
2100                 }
2101             }
2102         }
2103         best_val = _best_val;
2104     }
2105     else
2106     {
2107         double LL = 0, RL = 0, LR, RR;
2108         double worst_val = node->maxlr;
2109         double sum = 0, sum_abs = 0;
2110         const double* priors = data->priors_mult->data.db;
2111         const int* responses = data->get_class_labels(node);
2112         best_val = worst_val;
2113         
2114         for( i = 0; i < n1; i++ )
2115         {
2116             int idx = sorted[i].i;
2117             double w = priors[responses[idx]];
2118             int d = dir[idx];
2119             sum += d*w; sum_abs += (d & 1)*w;
2120         }
2121
2122         // sum_abs = R + L; sum = R - L
2123         RR = (sum_abs + sum)*0.5;
2124         LR = (sum_abs - sum)*0.5;
2125
2126         // initially all the samples are sent to the right by the surrogate split,
2127         // LR of them are sent to the left by primary split, and RR - to the right.
2128         // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
2129         for( i = 0; i < n1 - 1; i++ )
2130         {
2131             int idx = sorted[i].i;
2132             double w = priors[responses[idx]];
2133             int d = dir[idx];
2134
2135             if( d < 0 )
2136             {
2137                 LL += w; LR -= w;
2138                 if( LL + RR > best_val && sorted[i].val + epsilon < sorted[i+1].val )
2139                 {
2140                     best_val = LL + RR;
2141                     best_i = i; best_inversed = 0;
2142                 }
2143             }
2144             else if( d > 0 )
2145             {
2146                 RL += w; RR -= w;
2147                 if( RL + LR > best_val && sorted[i].val + epsilon < sorted[i+1].val )
2148                 {
2149                     best_val = RL + LR;
2150                     best_i = i; best_inversed = 1;
2151                 }
2152             }
2153         }
2154     }
2155
2156     return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi,
2157         (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
2158         best_inversed, (float)best_val ) : 0;
2159 }
2160
2161
2162 CvDTreeSplit* CvDTree::find_surrogate_split_cat( CvDTreeNode* node, int vi )
2163 {
2164     const int* labels = data->get_cat_var_data(node, vi);
2165     const char* dir = (char*)data->direction->data.ptr;
2166     int n = node->sample_count;
2167     // LL - number of samples that both the primary and the surrogate splits send to the left
2168     // LR - ... primary split sends to the left and the surrogate split sends to the right
2169     // RL - ... primary split sends to the right and the surrogate split sends to the left
2170     // RR - ... both send to the right
2171     CvDTreeSplit* split = data->new_split_cat( vi, 0 );
2172     int i, mi = data->cat_count->data.i[data->get_var_type(vi)], l_win = 0;
2173     double best_val = 0;
2174     double* lc = (double*)cvStackAlloc( (mi+1)*2*sizeof(lc[0]) ) + 1;
2175     double* rc = lc + mi + 1;
2176     
2177     for( i = -1; i < mi; i++ )
2178         lc[i] = rc[i] = 0;
2179
2180     // for each category calculate the weight of samples
2181     // sent to the left (lc) and to the right (rc) by the primary split
2182     if( !data->have_priors )
2183     {
2184         int* _lc = (int*)cvStackAlloc((mi+2)*2*sizeof(_lc[0])) + 1;
2185         int* _rc = _lc + mi + 1;
2186
2187         for( i = -1; i < mi; i++ )
2188             _lc[i] = _rc[i] = 0;
2189
2190         for( i = 0; i < n; i++ )
2191         {
2192             int idx = labels[i];
2193             int d = dir[i];
2194             int sum = _lc[idx] + d;
2195             int sum_abs = _rc[idx] + (d & 1);
2196             _lc[idx] = sum; _rc[idx] = sum_abs;
2197         }
2198
2199         for( i = 0; i < mi; i++ )
2200         {
2201             int sum = _lc[i];
2202             int sum_abs = _rc[i];
2203             lc[i] = (sum_abs - sum) >> 1;
2204             rc[i] = (sum_abs + sum) >> 1;
2205         }
2206     }
2207     else
2208     {
2209         const double* priors = data->priors_mult->data.db;
2210         const int* responses = data->get_class_labels(node);
2211
2212         for( i = 0; i < n; i++ )
2213         {
2214             int idx = labels[i];
2215             double w = priors[responses[i]];
2216             int d = dir[i];
2217             double sum = lc[idx] + d*w;
2218             double sum_abs = rc[idx] + (d & 1)*w;
2219             lc[idx] = sum; rc[idx] = sum_abs;
2220         }
2221
2222         for( i = 0; i < mi; i++ )
2223         {
2224             double sum = lc[i];
2225             double sum_abs = rc[i];
2226             lc[i] = (sum_abs - sum) * 0.5;
2227             rc[i] = (sum_abs + sum) * 0.5;
2228         }
2229     }
2230
2231     // 2. now form the split.
2232     // in each category send all the samples to the same direction as majority
2233     for( i = 0; i < mi; i++ )
2234     {
2235         double lval = lc[i], rval = rc[i];
2236         if( lval > rval )
2237         {
2238             split->subset[i >> 5] |= 1 << (i & 31);
2239             best_val += lval;
2240             l_win++;
2241         }
2242         else
2243             best_val += rval;
2244     }
2245
2246     split->quality = (float)best_val;
2247     if( split->quality <= node->maxlr || l_win == 0 || l_win == mi )
2248         cvSetRemoveByPtr( data->split_heap, split ), split = 0;
2249
2250     return split;
2251 }
2252
2253
2254 void CvDTree::calc_node_value( CvDTreeNode* node )
2255 {
2256     int i, j, k, n = node->sample_count, cv_n = data->params.cv_folds;
2257     const int* cv_labels = data->get_labels(node);
2258
2259     if( data->is_classifier )
2260     {
2261         // in case of classification tree:
2262         //  * node value is the label of the class that has the largest weight in the node.
2263         //  * node risk is the weighted number of misclassified samples,
2264         //  * j-th cross-validation fold value and risk are calculated as above,
2265         //    but using the samples with cv_labels(*)!=j.
2266         //  * j-th cross-validation fold error is calculated as the weighted number of
2267         //    misclassified samples with cv_labels(*)==j.
2268
2269         // compute the number of instances of each class
2270         int* cls_count = data->counts->data.i;
2271         const int* responses = data->get_class_labels(node);
2272         int m = data->get_num_classes();
2273         int* cv_cls_count = (int*)cvStackAlloc(m*cv_n*sizeof(cv_cls_count[0]));
2274         double max_val = -1, total_weight = 0;
2275         int max_k = -1;
2276         double* priors = data->priors_mult->data.db;
2277
2278         for( k = 0; k < m; k++ )
2279             cls_count[k] = 0;
2280
2281         if( cv_n == 0 )
2282         {
2283             for( i = 0; i < n; i++ )
2284                 cls_count[responses[i]]++;
2285         }
2286         else
2287         {
2288             for( j = 0; j < cv_n; j++ )
2289                 for( k = 0; k < m; k++ )
2290                     cv_cls_count[j*m + k] = 0;
2291
2292             for( i = 0; i < n; i++ )
2293             {
2294                 j = cv_labels[i]; k = responses[i];
2295                 cv_cls_count[j*m + k]++;
2296             }
2297
2298             for( j = 0; j < cv_n; j++ )
2299                 for( k = 0; k < m; k++ )
2300                     cls_count[k] += cv_cls_count[j*m + k];
2301         }
2302
2303         if( data->have_priors && node->parent == 0 )
2304         {
2305             // compute priors_mult from priors, take the sample ratio into account.
2306             double sum = 0;
2307             for( k = 0; k < m; k++ )
2308             {
2309                 int n_k = cls_count[k];
2310                 priors[k] = data->priors->data.db[k]*(n_k ? 1./n_k : 0.);
2311                 sum += priors[k];
2312             }
2313             sum = 1./sum;
2314             for( k = 0; k < m; k++ )
2315                 priors[k] *= sum;
2316         }
2317
2318         for( k = 0; k < m; k++ )
2319         {
2320             double val = cls_count[k]*priors[k];
2321             total_weight += val;
2322             if( max_val < val )
2323             {
2324                 max_val = val;
2325                 max_k = k;
2326             }
2327         }
2328
2329         node->class_idx = max_k;
2330         node->value = data->cat_map->data.i[
2331             data->cat_ofs->data.i[data->cat_var_count] + max_k];
2332         node->node_risk = total_weight - max_val;
2333
2334         for( j = 0; j < cv_n; j++ )
2335         {
2336             double sum_k = 0, sum = 0, max_val_k = 0;
2337             max_val = -1; max_k = -1;
2338
2339             for( k = 0; k < m; k++ )
2340             {
2341                 double w = priors[k];
2342                 double val_k = cv_cls_count[j*m + k]*w;
2343                 double val = cls_count[k]*w - val_k;
2344                 sum_k += val_k;
2345                 sum += val;
2346                 if( max_val < val )
2347                 {
2348                     max_val = val;
2349                     max_val_k = val_k;
2350                     max_k = k;
2351                 }
2352             }
2353
2354             node->cv_Tn[j] = INT_MAX;
2355             node->cv_node_risk[j] = sum - max_val;
2356             node->cv_node_error[j] = sum_k - max_val_k;
2357         }
2358     }
2359     else
2360     {
2361         // in case of regression tree:
2362         //  * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
2363         //    n is the number of samples in the node.
2364         //  * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
2365         //  * j-th cross-validation fold value and risk are calculated as above,
2366         //    but using the samples with cv_labels(*)!=j.
2367         //  * j-th cross-validation fold error is calculated
2368         //    using samples with cv_labels(*)==j as the test subset:
2369         //    error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2),
2370         //    where node_value_j is the node value calculated
2371         //    as described in the previous bullet, and summation is done
2372         //    over the samples with cv_labels(*)==j.
2373
2374         double sum = 0, sum2 = 0;
2375         const float* values = data->get_ord_responses(node);
2376         double *cv_sum = 0, *cv_sum2 = 0;
2377         int* cv_count = 0;
2378         
2379         if( cv_n == 0 )
2380         {
2381             for( i = 0; i < n; i++ )
2382             {
2383                 double t = values[i];
2384                 sum += t;
2385                 sum2 += t*t;
2386             }
2387         }
2388         else
2389         {
2390             cv_sum = (double*)cvStackAlloc( cv_n*sizeof(cv_sum[0]) );
2391             cv_sum2 = (double*)cvStackAlloc( cv_n*sizeof(cv_sum2[0]) );
2392             cv_count = (int*)cvStackAlloc( cv_n*sizeof(cv_count[0]) );
2393
2394             for( j = 0; j < cv_n; j++ )
2395             {
2396                 cv_sum[j] = cv_sum2[j] = 0.;
2397                 cv_count[j] = 0;
2398             }
2399
2400             for( i = 0; i < n; i++ )
2401             {
2402                 j = cv_labels[i];
2403                 double t = values[i];
2404                 double s = cv_sum[j] + t;
2405                 double s2 = cv_sum2[j] + t*t;
2406                 int nc = cv_count[j] + 1;
2407                 cv_sum[j] = s;
2408                 cv_sum2[j] = s2;
2409                 cv_count[j] = nc;
2410             }
2411
2412             for( j = 0; j < cv_n; j++ )
2413             {
2414                 sum += cv_sum[j];
2415                 sum2 += cv_sum2[j];
2416             }
2417         }
2418
2419         node->node_risk = sum2 - (sum/n)*sum;
2420         node->value = sum/n;
2421
2422         for( j = 0; j < cv_n; j++ )
2423         {
2424             double s = cv_sum[j], si = sum - s;
2425             double s2 = cv_sum2[j], s2i = sum2 - s2;
2426             int c = cv_count[j], ci = n - c;
2427             double r = si/MAX(ci,1);
2428             node->cv_node_risk[j] = s2i - r*r*ci;
2429             node->cv_node_error[j] = s2 - 2*r*s + c*r*r;
2430             node->cv_Tn[j] = INT_MAX;
2431         }
2432     }
2433 }
2434
2435
2436 void CvDTree::complete_node_dir( CvDTreeNode* node )
2437 {
2438     int vi, i, n = node->sample_count, nl, nr, d0 = 0, d1 = -1;
2439     int nz = n - node->get_num_valid(node->split->var_idx);
2440     char* dir = (char*)data->direction->data.ptr;
2441
2442     // try to complete direction using surrogate splits
2443     if( nz && data->params.use_surrogates )
2444     {
2445         CvDTreeSplit* split = node->split->next;
2446         for( ; split != 0 && nz; split = split->next )
2447         {
2448             int inversed_mask = split->inversed ? -1 : 0;
2449             vi = split->var_idx;
2450
2451             if( data->get_var_type(vi) >= 0 ) // split on categorical var
2452             {
2453                 const int* labels = data->get_cat_var_data(node, vi);
2454                 const int* subset = split->subset;
2455
2456                 for( i = 0; i < n; i++ )
2457                 {
2458                     int idx;
2459                     if( !dir[i] && (idx = labels[i]) >= 0 )
2460                     {
2461                         int d = CV_DTREE_CAT_DIR(idx,subset);
2462                         dir[i] = (char)((d ^ inversed_mask) - inversed_mask);
2463                         if( --nz )
2464                             break;
2465                     }
2466                 }
2467             }
2468             else // split on ordered var
2469             {
2470                 const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
2471                 int split_point = split->ord.split_point;
2472                 int n1 = node->get_num_valid(vi);
2473
2474                 assert( 0 <= split_point && split_point < n-1 );
2475
2476                 for( i = 0; i < n1; i++ )
2477                 {
2478                     int idx = sorted[i].i;
2479                     if( !dir[idx] )
2480                     {
2481                         int d = i <= split_point ? -1 : 1;
2482                         dir[idx] = (char)((d ^ inversed_mask) - inversed_mask);
2483                         if( --nz )
2484                             break;
2485                     }
2486                 }
2487             }
2488         }
2489     }
2490
2491     // find the default direction for the rest
2492     if( nz )
2493     {
2494         for( i = nr = 0; i < n; i++ )
2495             nr += dir[i] > 0;
2496         nl = n - nr - nz;
2497         d0 = nl > nr ? -1 : nr > nl;
2498     }
2499
2500     // make sure that every sample is directed either to the left or to the right
2501     for( i = 0; i < n; i++ )
2502     {
2503         int d = dir[i];
2504         if( !d )
2505         {
2506             d = d0;
2507             if( !d )
2508                 d = d1, d1 = -d1;
2509         }
2510         d = d > 0;
2511         dir[i] = (char)d; // remap (-1,1) to (0,1)
2512     }
2513 }
2514
2515
2516 void CvDTree::split_node_data( CvDTreeNode* node )
2517 {
2518     int vi, i, n = node->sample_count, nl, nr;
2519     char* dir = (char*)data->direction->data.ptr;
2520     CvDTreeNode *left = 0, *right = 0;
2521     int* new_idx = data->split_buf->data.i;
2522     int new_buf_idx = data->get_child_buf_idx( node );
2523     int work_var_count = data->get_work_var_count();
2524
2525     // speedup things a little, especially for tree ensembles with a lots of small trees:
2526     //   do not physically split the input data between the left and right child nodes
2527     //   when we are not going to split them further,
2528     //   as calc_node_value() does not requires input features anyway.
2529     bool split_input_data;
2530
2531     complete_node_dir(node);
2532
2533     for( i = nl = nr = 0; i < n; i++ )
2534     {
2535         int d = dir[i];
2536         // initialize new indices for splitting ordered variables
2537         new_idx[i] = (nl & (d-1)) | (nr & -d); // d ? ri : li
2538         nr += d;
2539         nl += d^1;
2540     }
2541
2542     node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );
2543     node->right = right = data->new_node( node, nr, new_buf_idx, node->offset +
2544         (data->ord_var_count + work_var_count)*nl );
2545
2546     split_input_data = node->depth + 1 < data->params.max_depth &&
2547         (node->left->sample_count > data->params.min_sample_count ||
2548         node->right->sample_count > data->params.min_sample_count);
2549
2550     // split ordered variables, keep both halves sorted.
2551     for( vi = 0; vi < data->var_count; vi++ )
2552     {
2553         int ci = data->get_var_type(vi);
2554         int n1 = node->get_num_valid(vi);
2555         CvPair32s32f *src, *ldst0, *rdst0, *ldst, *rdst;
2556         CvPair32s32f tl, tr;
2557
2558         if( ci >= 0 || !split_input_data )
2559             continue;
2560
2561         src = data->get_ord_var_data(node, vi);
2562         ldst0 = ldst = data->get_ord_var_data(left, vi);
2563         rdst0 = rdst = data->get_ord_var_data(right, vi);
2564         tl = ldst0[nl]; tr = rdst0[nr];
2565
2566         // split sorted
2567         for( i = 0; i < n1; i++ )
2568         {
2569             int idx = src[i].i;
2570             float val = src[i].val;
2571             int d = dir[idx];
2572             idx = new_idx[idx];
2573             ldst->i = rdst->i = idx;
2574             ldst->val = rdst->val = val;
2575             ldst += d^1;
2576             rdst += d;
2577         }
2578
2579         left->set_num_valid(vi, (int)(ldst - ldst0));
2580         right->set_num_valid(vi, (int)(rdst - rdst0));
2581
2582         // split missing
2583         for( ; i < n; i++ )
2584         {
2585             int idx = src[i].i;
2586             int d = dir[idx];
2587             idx = new_idx[idx];
2588             ldst->i = rdst->i = idx;
2589             ldst->val = rdst->val = ord_nan;
2590             ldst += d^1;
2591             rdst += d;
2592         }
2593
2594         ldst0[nl] = tl; rdst0[nr] = tr;
2595     }
2596
2597     // split categorical vars, responses and cv_labels using new_idx relocation table
2598     for( vi = 0; vi < work_var_count; vi++ )
2599     {
2600         int ci = data->get_var_type(vi);
2601         int n1 = node->get_num_valid(vi), nr1 = 0;
2602         int *src, *ldst0, *rdst0, *ldst, *rdst;
2603         int tl, tr;
2604
2605         if( ci < 0 || (vi < data->var_count && !split_input_data) )
2606             continue;
2607
2608         src = data->get_cat_var_data(node, vi);
2609         ldst0 = ldst = data->get_cat_var_data(left, vi);
2610         rdst0 = rdst = data->get_cat_var_data(right, vi);
2611         tl = ldst0[nl]; tr = rdst0[nr];
2612
2613         for( i = 0; i < n; i++ )
2614         {
2615             int d = dir[i];
2616             int val = src[i];
2617             *ldst = *rdst = val;
2618             ldst += d^1;
2619             rdst += d;
2620             nr1 += (val >= 0)&d;
2621         }
2622
2623         if( vi < data->var_count )
2624         {
2625             left->set_num_valid(vi, n1 - nr1);
2626             right->set_num_valid(vi, nr1);
2627         }
2628
2629         ldst0[nl] = tl; rdst0[nr] = tr;
2630     }
2631
2632     // deallocate the parent node data that is not needed anymore
2633     data->free_node_data(node);
2634 }
2635
2636
2637 void CvDTree::prune_cv()
2638 {
2639     CvMat* ab = 0;
2640     CvMat* temp = 0;
2641     CvMat* err_jk = 0;
2642     
2643     // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}.
2644     // 2. choose the best tree index (if need, apply 1SE rule).
2645     // 3. store the best index and cut the branches.
2646
2647     CV_FUNCNAME( "CvDTree::prune_cv" );
2648
2649     __BEGIN__;
2650
2651     int ti, j, tree_count = 0, cv_n = data->params.cv_folds, n = root->sample_count;
2652     // currently, 1SE for regression is not implemented
2653     bool use_1se = data->params.use_1se_rule != 0 && data->is_classifier;
2654     double* err;
2655     double min_err = 0, min_err_se = 0;
2656     int min_idx = -1;
2657     
2658     CV_CALL( ab = cvCreateMat( 1, 256, CV_64F ));
2659
2660     // build the main tree sequence, calculate alpha's
2661     for(;;tree_count++)
2662     {
2663         double min_alpha = update_tree_rnc(tree_count, -1);
2664         if( cut_tree(tree_count, -1, min_alpha) )
2665             break;
2666
2667         if( ab->cols <= tree_count )
2668         {
2669             CV_CALL( temp = cvCreateMat( 1, ab->cols*3/2, CV_64F ));
2670             for( ti = 0; ti < ab->cols; ti++ )
2671                 temp->data.db[ti] = ab->data.db[ti];
2672             cvReleaseMat( &ab );
2673             ab = temp;
2674             temp = 0;
2675         }
2676
2677         ab->data.db[tree_count] = min_alpha;
2678     }
2679
2680     ab->data.db[0] = 0.;
2681     
2682     if( tree_count > 0 )
2683     {
2684         for( ti = 1; ti < tree_count-1; ti++ )
2685             ab->data.db[ti] = sqrt(ab->data.db[ti]*ab->data.db[ti+1]);
2686         ab->data.db[tree_count-1] = DBL_MAX*0.5;
2687
2688         CV_CALL( err_jk = cvCreateMat( cv_n, tree_count, CV_64F ));
2689         err = err_jk->data.db;
2690
2691         for( j = 0; j < cv_n; j++ )
2692         {
2693             int tj = 0, tk = 0;
2694             for( ; tk < tree_count; tj++ )
2695             {
2696                 double min_alpha = update_tree_rnc(tj, j);
2697                 if( cut_tree(tj, j, min_alpha) )
2698                     min_alpha = DBL_MAX;
2699
2700                 for( ; tk < tree_count; tk++ )
2701                 {
2702                     if( ab->data.db[tk] > min_alpha )
2703                         break;
2704                     err[j*tree_count + tk] = root->tree_error;
2705                 }
2706             }
2707         }
2708
2709         for( ti = 0; ti < tree_count; ti++ )
2710         {
2711             double sum_err = 0;
2712             for( j = 0; j < cv_n; j++ )
2713                 sum_err += err[j*tree_count + ti];
2714             if( ti == 0 || sum_err < min_err )
2715             {
2716                 min_err = sum_err;
2717                 min_idx = ti;
2718                 if( use_1se )
2719                     min_err_se = sqrt( sum_err*(n - sum_err) );
2720             }
2721             else if( sum_err < min_err + min_err_se )
2722                 min_idx = ti;
2723         }
2724     }
2725
2726     pruned_tree_idx = min_idx;
2727     free_prune_data(data->params.truncate_pruned_tree != 0);
2728
2729     __END__;
2730
2731     cvReleaseMat( &err_jk );
2732     cvReleaseMat( &ab );
2733     cvReleaseMat( &temp );
2734 }
2735
2736
2737 double CvDTree::update_tree_rnc( int T, int fold )
2738 {
2739     CvDTreeNode* node = root;
2740     double min_alpha = DBL_MAX;
2741     
2742     for(;;)
2743     {
2744         CvDTreeNode* parent;
2745         for(;;)
2746         {
2747             int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
2748             if( t <= T || !node->left )
2749             {
2750                 node->complexity = 1;
2751                 node->tree_risk = node->node_risk;
2752                 node->tree_error = 0.;
2753                 if( fold >= 0 )
2754                 {
2755                     node->tree_risk = node->cv_node_risk[fold];
2756                     node->tree_error = node->cv_node_error[fold];
2757                 }
2758                 break;
2759             }
2760             node = node->left;
2761         }
2762         
2763         for( parent = node->parent; parent && parent->right == node;
2764             node = parent, parent = parent->parent )
2765         {
2766             parent->complexity += node->complexity;
2767             parent->tree_risk += node->tree_risk;
2768             parent->tree_error += node->tree_error;
2769
2770             parent->alpha = ((fold >= 0 ? parent->cv_node_risk[fold] : parent->node_risk)
2771                 - parent->tree_risk)/(parent->complexity - 1);
2772             min_alpha = MIN( min_alpha, parent->alpha );
2773         }
2774
2775         if( !parent )
2776             break;
2777
2778         parent->complexity = node->complexity;
2779         parent->tree_risk = node->tree_risk;
2780         parent->tree_error = node->tree_error;
2781         node = parent->right;
2782     }
2783
2784     return min_alpha;
2785 }
2786
2787
2788 int CvDTree::cut_tree( int T, int fold, double min_alpha )
2789 {
2790     CvDTreeNode* node = root;
2791     if( !node->left )
2792         return 1;
2793
2794     for(;;)
2795     {
2796         CvDTreeNode* parent;
2797         for(;;)
2798         {
2799             int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
2800             if( t <= T || !node->left )
2801                 break;
2802             if( node->alpha <= min_alpha + FLT_EPSILON )
2803             {
2804                 if( fold >= 0 )
2805                     node->cv_Tn[fold] = T;
2806                 else
2807                     node->Tn = T;
2808                 if( node == root )
2809                     return 1;
2810                 break;
2811             }
2812             node = node->left;
2813         }
2814         
2815         for( parent = node->parent; parent && parent->right == node;
2816             node = parent, parent = parent->parent )
2817             ;
2818
2819         if( !parent )
2820             break;
2821
2822         node = parent->right;
2823     }
2824
2825     return 0;
2826 }
2827
2828
2829 void CvDTree::free_prune_data(bool cut_tree)
2830 {
2831     CvDTreeNode* node = root;
2832     
2833     for(;;)
2834     {
2835         CvDTreeNode* parent;
2836         for(;;)
2837         {
2838             // do not call cvSetRemoveByPtr( cv_heap, node->cv_Tn )
2839             // as we will clear the whole cross-validation heap at the end
2840             node->cv_Tn = 0;
2841             node->cv_node_error = node->cv_node_risk = 0;
2842             if( !node->left )
2843                 break;
2844             node = node->left;
2845         }
2846         
2847         for( parent = node->parent; parent && parent->right == node;
2848             node = parent, parent = parent->parent )
2849         {
2850             if( cut_tree && parent->Tn <= pruned_tree_idx )
2851             {
2852                 data->free_node( parent->left );
2853                 data->free_node( parent->right );
2854                 parent->left = parent->right = 0;
2855             }
2856         }
2857
2858         if( !parent )
2859             break;
2860
2861         node = parent->right;
2862     }
2863
2864     if( data->cv_heap )
2865         cvClearSet( data->cv_heap );
2866 }
2867
2868
2869 void CvDTree::free_tree()
2870 {
2871     if( root && data && data->shared )
2872     {
2873         pruned_tree_idx = INT_MIN;
2874         free_prune_data(true);
2875         data->free_node(root);
2876         root = 0;
2877     }
2878 }
2879
2880
2881 CvDTreeNode* CvDTree::predict( const CvMat* _sample,
2882     const CvMat* _missing, bool preprocessed_input ) const
2883 {
2884     CvDTreeNode* result = 0;
2885     int* catbuf = 0;
2886
2887     CV_FUNCNAME( "CvDTree::predict" );
2888
2889     __BEGIN__;
2890
2891     int i, step, mstep = 0;
2892     const float* sample;
2893     const uchar* m = 0;
2894     CvDTreeNode* node = root;
2895     const int* vtype;
2896     const int* vidx;
2897     const int* cmap;
2898     const int* cofs;
2899
2900     if( !node )
2901         CV_ERROR( CV_StsError, "The tree has not been trained yet" );
2902
2903     if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||
2904         _sample->cols != 1 && _sample->rows != 1 ||
2905         _sample->cols + _sample->rows - 1 != data->var_all && !preprocessed_input ||
2906         _sample->cols + _sample->rows - 1 != data->var_count && preprocessed_input )
2907             CV_ERROR( CV_StsBadArg,
2908         "the input sample must be 1d floating-point vector with the same "
2909         "number of elements as the total number of variables used for training" );
2910
2911     sample = _sample->data.fl;
2912     step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(sample[0]);
2913
2914     if( data->cat_count && !preprocessed_input ) // cache for categorical variables
2915     {
2916         int n = data->cat_count->cols;
2917         catbuf = (int*)cvStackAlloc(n*sizeof(catbuf[0]));
2918         for( i = 0; i < n; i++ )
2919             catbuf[i] = -1;
2920     }
2921
2922     if( _missing )
2923     {
2924         if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||
2925         !CV_ARE_SIZES_EQ(_missing, _sample) )
2926             CV_ERROR( CV_StsBadArg,
2927         "the missing data mask must be 8-bit vector of the same size as input sample" );
2928         m = _missing->data.ptr;
2929         mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step/sizeof(m[0]);
2930     }
2931
2932     vtype = data->var_type->data.i;
2933     vidx = data->var_idx && !preprocessed_input ? data->var_idx->data.i : 0;
2934     cmap = data->cat_map ? data->cat_map->data.i : 0;
2935     cofs = data->cat_ofs ? data->cat_ofs->data.i : 0;
2936
2937     while( node->Tn > pruned_tree_idx && node->left )
2938     {
2939         CvDTreeSplit* split = node->split;
2940         int dir = 0;
2941         for( ; !dir && split != 0; split = split->next )
2942         {
2943             int vi = split->var_idx;
2944             int ci = vtype[vi];
2945             i = vidx ? vidx[vi] : vi;
2946             float val = sample[i*step];
2947             if( m && m[i*mstep] )
2948                 continue;
2949             if( ci < 0 ) // ordered
2950                 dir = val <= split->ord.c ? -1 : 1;
2951             else // categorical
2952             {
2953                 int c;
2954                 if( preprocessed_input )
2955                     c = cvRound(val);
2956                 else
2957                 {
2958                     c = catbuf[ci];
2959                     if( c < 0 )
2960                     {
2961                         int a = c = cofs[ci];
2962                         int b = cofs[ci+1];
2963                         int ival = cvRound(val);
2964                         if( ival != val )
2965                             CV_ERROR( CV_StsBadArg,
2966                             "one of input categorical variable is not an integer" );
2967
2968                         while( a < b )
2969                         {
2970                             c = (a + b) >> 1;
2971                             if( ival < cmap[c] )
2972                                 b = c;
2973                             else if( ival > cmap[c] )
2974                                 a = c+1;
2975                             else
2976                                 break;
2977                         }
2978
2979                         if( c < 0 || ival != cmap[c] )
2980                             continue;
2981
2982                         catbuf[ci] = c -= cofs[ci];
2983                     }
2984                 }
2985                 dir = CV_DTREE_CAT_DIR(c, split->subset);
2986             }
2987
2988             if( split->inversed )
2989                 dir = -dir;
2990         }
2991
2992         if( !dir )
2993         {
2994             double diff = node->right->sample_count - node->left->sample_count;
2995             dir = diff < 0 ? -1 : 1;
2996         }
2997         node = dir < 0 ? node->left : node->right;
2998     }
2999
3000     result = node;
3001
3002     __END__;
3003
3004     return result;
3005 }
3006
3007
3008 const CvMat* CvDTree::get_var_importance()
3009 {
3010     if( !var_importance )
3011     {
3012         CvDTreeNode* node = root;
3013         double* importance;
3014         if( !node )
3015             return 0;
3016         var_importance = cvCreateMat( 1, data->var_count, CV_64F );
3017         cvZero( var_importance );
3018         importance = var_importance->data.db;
3019
3020         for(;;)
3021         {
3022             CvDTreeNode* parent;
3023             for( ;; node = node->left )
3024             {
3025                 CvDTreeSplit* split = node->split;
3026                 
3027                 if( !node->left || node->Tn <= pruned_tree_idx )
3028                     break;
3029                 
3030                 for( ; split != 0; split = split->next )
3031                     importance[split->var_idx] += split->quality;
3032             }
3033
3034             for( parent = node->parent; parent && parent->right == node;
3035                 node = parent, parent = parent->parent )
3036                 ;
3037
3038             if( !parent )
3039                 break;
3040
3041             node = parent->right;
3042         }
3043
3044         cvNormalize( var_importance, var_importance, 1., 0, CV_L1 );
3045     }
3046
3047     return var_importance;
3048 }
3049
3050
3051 void CvDTree::write_split( CvFileStorage* fs, CvDTreeSplit* split )
3052 {
3053     int ci;
3054     
3055     cvStartWriteStruct( fs, 0, CV_NODE_MAP + CV_NODE_FLOW );
3056     cvWriteInt( fs, "var", split->var_idx );
3057     cvWriteReal( fs, "quality", split->quality );
3058
3059     ci = data->get_var_type(split->var_idx);
3060     if( ci >= 0 ) // split on a categorical var
3061     {
3062         int i, n = data->cat_count->data.i[ci], to_right = 0, default_dir;
3063         for( i = 0; i < n; i++ )
3064             to_right += CV_DTREE_CAT_DIR(i,split->subset) > 0;
3065
3066         // ad-hoc rule when to use inverse categorical split notation
3067         // to achieve more compact and clear representation
3068         default_dir = to_right <= 1 || to_right <= MIN(3, n/2) || to_right <= n/3 ? -1 : 1;
3069         
3070         cvStartWriteStruct( fs, default_dir*(split->inversed ? -1 : 1) > 0 ?
3071                             "in" : "not_in", CV_NODE_SEQ+CV_NODE_FLOW );
3072
3073         for( i = 0; i < n; i++ )
3074         {
3075             int dir = CV_DTREE_CAT_DIR(i,split->subset);
3076             if( dir*default_dir < 0 )
3077                 cvWriteInt( fs, 0, i );
3078         }
3079         cvEndWriteStruct( fs );
3080     }
3081     else
3082         cvWriteReal( fs, !split->inversed ? "le" : "gt", split->ord.c );
3083
3084     cvEndWriteStruct( fs );
3085 }
3086
3087
3088 void CvDTree::write_node( CvFileStorage* fs, CvDTreeNode* node )
3089 {
3090     CvDTreeSplit* split;
3091     
3092     cvStartWriteStruct( fs, 0, CV_NODE_MAP );
3093
3094     cvWriteInt( fs, "depth", node->depth );
3095     cvWriteInt( fs, "sample_count", node->sample_count );
3096     cvWriteReal( fs, "value", node->value );
3097     
3098     if( data->is_classifier )
3099         cvWriteInt( fs, "norm_class_idx", node->class_idx );
3100
3101     cvWriteInt( fs, "Tn", node->Tn );
3102     cvWriteInt( fs, "complexity", node->complexity );
3103     cvWriteReal( fs, "alpha", node->alpha );
3104     cvWriteReal( fs, "node_risk", node->node_risk );
3105     cvWriteReal( fs, "tree_risk", node->tree_risk );
3106     cvWriteReal( fs, "tree_error", node->tree_error );
3107
3108     if( node->left )
3109     {
3110         cvStartWriteStruct( fs, "splits", CV_NODE_SEQ );
3111
3112         for( split = node->split; split != 0; split = split->next )
3113             write_split( fs, split );
3114
3115         cvEndWriteStruct( fs );
3116     }
3117
3118     cvEndWriteStruct( fs );
3119 }
3120
3121
3122 void CvDTree::write_tree_nodes( CvFileStorage* fs )
3123 {
3124     //CV_FUNCNAME( "CvDTree::write_tree_nodes" );
3125
3126     __BEGIN__;
3127
3128     CvDTreeNode* node = root;
3129
3130     // traverse the tree and save all the nodes in depth-first order
3131     for(;;)
3132     {
3133         CvDTreeNode* parent;
3134         for(;;)
3135         {
3136             write_node( fs, node );
3137             if( !node->left )
3138                 break;
3139             node = node->left;
3140         }
3141         
3142         for( parent = node->parent; parent && parent->right == node;
3143             node = parent, parent = parent->parent )
3144             ;
3145
3146         if( !parent )
3147             break;
3148
3149         node = parent->right;
3150     }
3151
3152     __END__;
3153 }
3154
3155
3156 void CvDTree::write( CvFileStorage* fs, const char* name )
3157 {
3158     //CV_FUNCNAME( "CvDTree::write" );
3159
3160     __BEGIN__;
3161
3162     cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_TREE );
3163
3164     get_var_importance();
3165     data->write_params( fs );
3166     if( var_importance )
3167         cvWrite( fs, "var_importance", var_importance );
3168     write( fs );
3169
3170     cvEndWriteStruct( fs );
3171
3172     __END__;
3173 }
3174
3175
3176 void CvDTree::write( CvFileStorage* fs )
3177 {
3178     //CV_FUNCNAME( "CvDTree::write" );
3179
3180     __BEGIN__;
3181
3182     cvWriteInt( fs, "best_tree_idx", pruned_tree_idx );
3183
3184     cvStartWriteStruct( fs, "nodes", CV_NODE_SEQ );
3185     write_tree_nodes( fs );
3186     cvEndWriteStruct( fs );
3187
3188     __END__;
3189 }
3190
3191
3192 CvDTreeSplit* CvDTree::read_split( CvFileStorage* fs, CvFileNode* fnode )
3193 {
3194     CvDTreeSplit* split = 0;
3195     
3196     CV_FUNCNAME( "CvDTree::read_split" );
3197
3198     __BEGIN__;
3199
3200     int vi, ci;
3201     
3202     if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
3203         CV_ERROR( CV_StsParseError, "some of the splits are not stored properly" );
3204
3205     vi = cvReadIntByName( fs, fnode, "var", -1 );
3206     if( (unsigned)vi >= (unsigned)data->var_count )
3207         CV_ERROR( CV_StsOutOfRange, "Split variable index is out of range" );
3208
3209     ci = data->get_var_type(vi);
3210     if( ci >= 0 ) // split on categorical var
3211     {
3212         int i, n = data->cat_count->data.i[ci], inversed = 0, val;
3213         CvSeqReader reader;
3214         CvFileNode* inseq;
3215         split = data->new_split_cat( vi, 0 );
3216         inseq = cvGetFileNodeByName( fs, fnode, "in" );
3217         if( !inseq )
3218         {
3219             inseq = cvGetFileNodeByName( fs, fnode, "not_in" );
3220             inversed = 1;
3221         }
3222         if( !inseq ||
3223             (CV_NODE_TYPE(inseq->tag) != CV_NODE_SEQ && CV_NODE_TYPE(inseq->tag) != CV_NODE_INT))
3224             CV_ERROR( CV_StsParseError,
3225             "Either 'in' or 'not_in' tags should be inside a categorical split data" );
3226
3227         if( CV_NODE_TYPE(inseq->tag) == CV_NODE_INT )
3228         {
3229             val = inseq->data.i;
3230             if( (unsigned)val >= (unsigned)n )
3231                 CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
3232
3233             split->subset[val >> 5] |= 1 << (val & 31);
3234         }
3235         else
3236         {
3237             cvStartReadSeq( inseq->data.seq, &reader );
3238
3239             for( i = 0; i < reader.seq->total; i++ )
3240             {
3241                 CvFileNode* inode = (CvFileNode*)reader.ptr;
3242                 val = inode->data.i;
3243                 if( CV_NODE_TYPE(inode->tag) != CV_NODE_INT || (unsigned)val >= (unsigned)n )
3244                     CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
3245
3246                 split->subset[val >> 5] |= 1 << (val & 31);
3247                 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3248             }
3249         }
3250
3251         // for categorical splits we do not use inversed splits,
3252         // instead we inverse the variable set in the split
3253         if( inversed )
3254             for( i = 0; i < (n + 31) >> 5; i++ )
3255                 split->subset[i] ^= -1;
3256     }
3257     else
3258     {
3259         CvFileNode* cmp_node;
3260         split = data->new_split_ord( vi, 0, 0, 0, 0 );
3261
3262         cmp_node = cvGetFileNodeByName( fs, fnode, "le" );
3263         if( !cmp_node )
3264         {
3265             cmp_node = cvGetFileNodeByName( fs, fnode, "gt" );
3266             split->inversed = 1;
3267         }
3268
3269         split->ord.c = (float)cvReadReal( cmp_node );
3270     }
3271         
3272     split->quality = (float)cvReadRealByName( fs, fnode, "quality" );
3273
3274     __END__;
3275     
3276     return split;
3277 }
3278
3279
3280 CvDTreeNode* CvDTree::read_node( CvFileStorage* fs, CvFileNode* fnode, CvDTreeNode* parent )
3281 {
3282     CvDTreeNode* node = 0;
3283     
3284     CV_FUNCNAME( "CvDTree::read_node" );
3285
3286     __BEGIN__;
3287
3288     CvFileNode* splits;
3289     int i, depth;
3290
3291     if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
3292         CV_ERROR( CV_StsParseError, "some of the tree elements are not stored properly" );
3293
3294     CV_CALL( node = data->new_node( parent, 0, 0, 0 ));
3295     depth = cvReadIntByName( fs, fnode, "depth", -1 );
3296     if( depth != node->depth )
3297         CV_ERROR( CV_StsParseError, "incorrect node depth" );
3298
3299     node->sample_count = cvReadIntByName( fs, fnode, "sample_count" );
3300     node->value = cvReadRealByName( fs, fnode, "value" );
3301     if( data->is_classifier )
3302         node->class_idx = cvReadIntByName( fs, fnode, "norm_class_idx" );
3303
3304     node->Tn = cvReadIntByName( fs, fnode, "Tn" );
3305     node->complexity = cvReadIntByName( fs, fnode, "complexity" );
3306     node->alpha = cvReadRealByName( fs, fnode, "alpha" );
3307     node->node_risk = cvReadRealByName( fs, fnode, "node_risk" );
3308     node->tree_risk = cvReadRealByName( fs, fnode, "tree_risk" );
3309     node->tree_error = cvReadRealByName( fs, fnode, "tree_error" );
3310
3311     splits = cvGetFileNodeByName( fs, fnode, "splits" );
3312     if( splits )
3313     {
3314         CvSeqReader reader;
3315         CvDTreeSplit* last_split = 0;
3316
3317         if( CV_NODE_TYPE(splits->tag) != CV_NODE_SEQ )
3318             CV_ERROR( CV_StsParseError, "splits tag must stored as a sequence" );
3319
3320         cvStartReadSeq( splits->data.seq, &reader );
3321         for( i = 0; i < reader.seq->total; i++ )
3322         {
3323             CvDTreeSplit* split;
3324             CV_CALL( split = read_split( fs, (CvFileNode*)reader.ptr ));
3325             if( !last_split )
3326                 node->split = last_split = split;
3327             else
3328                 last_split = last_split->next = split;
3329
3330             CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3331         }
3332     }
3333
3334     __END__;
3335     
3336     return node;
3337 }
3338
3339
3340 void CvDTree::read_tree_nodes( CvFileStorage* fs, CvFileNode* fnode )
3341 {
3342     CV_FUNCNAME( "CvDTree::read_tree_nodes" );
3343
3344     __BEGIN__;
3345
3346     CvSeqReader reader;
3347     CvDTreeNode _root;
3348     CvDTreeNode* parent = &_root;
3349     int i;
3350     parent->left = parent->right = parent->parent = 0;
3351
3352     cvStartReadSeq( fnode->data.seq, &reader );
3353
3354     for( i = 0; i < reader.seq->total; i++ )
3355     {
3356         CvDTreeNode* node;
3357         
3358         CV_CALL( node = read_node( fs, (CvFileNode*)reader.ptr, parent != &_root ? parent : 0 ));
3359         if( !parent->left )
3360             parent->left = node;
3361         else
3362             parent->right = node;
3363         if( node->split )
3364             parent = node;
3365         else
3366         {
3367             while( parent && parent->right )
3368                 parent = parent->parent;
3369         }
3370
3371         CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3372     }
3373
3374     root = _root.left;
3375
3376     __END__;
3377 }
3378
3379
3380 void CvDTree::read( CvFileStorage* fs, CvFileNode* fnode )
3381 {
3382     CvDTreeTrainData* _data = new CvDTreeTrainData();
3383     _data->read_params( fs, fnode );
3384
3385     read( fs, fnode, _data );
3386     get_var_importance();
3387 }
3388
3389
3390 // a special entry point for reading weak decision trees from the tree ensembles
3391 void CvDTree::read( CvFileStorage* fs, CvFileNode* node, CvDTreeTrainData* _data )
3392 {
3393     CV_FUNCNAME( "CvDTree::read" );
3394
3395     __BEGIN__;
3396
3397     CvFileNode* tree_nodes;
3398
3399     clear();
3400     data = _data;
3401
3402     tree_nodes = cvGetFileNodeByName( fs, node, "nodes" );
3403     if( !tree_nodes || CV_NODE_TYPE(tree_nodes->tag) != CV_NODE_SEQ )
3404         CV_ERROR( CV_StsParseError, "nodes tag is missing" );
3405
3406     pruned_tree_idx = cvReadIntByName( fs, node, "best_tree_idx", -1 );
3407     read_tree_nodes( fs, tree_nodes );
3408
3409     __END__;
3410 }
3411
3412 /* End of file. */