Update to 2.0.0 tree from current Fremantle build
[opencv] / src / ml / mlrtrees.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
14 //
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
17 //
18 //   * Redistribution's of source code must retain the above copyright notice,
19 //     this list of conditions and the following disclaimer.
20 //
21 //   * Redistribution's in binary form must reproduce the above copyright notice,
22 //     this list of conditions and the following disclaimer in the documentation
23 //     and/or other materials provided with the distribution.
24 //
25 //   * The name of Intel Corporation may not be used to endorse or promote products
26 //     derived from this software without specific prior written permission.
27 //
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
38 //
39 //M*/
40
41 #include "_ml.h"
42
43 CvForestTree::CvForestTree()
44 {
45     forest = NULL;
46 }
47
48
49 CvForestTree::~CvForestTree()
50 {
51     clear();
52 }
53
54
55 bool CvForestTree::train( CvDTreeTrainData* _data,
56                           const CvMat* _subsample_idx,
57                           CvRTrees* _forest )
58 {
59     clear();
60     forest = _forest;
61
62     data = _data;
63     data->shared = true;
64     return do_train(_subsample_idx);
65 }
66
67
68 bool
69 CvForestTree::train( const CvMat*, int, const CvMat*, const CvMat*,
70                     const CvMat*, const CvMat*, const CvMat*, CvDTreeParams )
71 {
72     assert(0);
73     return false;
74 }
75
76
77 bool
78 CvForestTree::train( CvDTreeTrainData*, const CvMat* )
79 {
80     assert(0);
81     return false;
82 }
83
84
85 CvDTreeSplit* CvForestTree::find_best_split( CvDTreeNode* node )
86 {
87     int vi;
88
89     CvDTreeSplit *best_split = 0;
90
91     CvMat* active_var_mask = 0;
92     if( forest )
93     {
94         int var_count;
95         CvRNG* rng = forest->get_rng();
96
97         active_var_mask = forest->get_active_var_mask();
98         var_count = active_var_mask->cols;
99
100         CV_Assert( var_count == data->var_count );
101
102         for( vi = 0; vi < var_count; vi++ )
103         {
104             uchar temp;
105             int i1 = cvRandInt(rng) % var_count;
106             int i2 = cvRandInt(rng) % var_count;
107             CV_SWAP( active_var_mask->data.ptr[i1],
108                 active_var_mask->data.ptr[i2], temp );
109         }
110     }
111     int maxNumThreads = 1;
112 #ifdef _OPENMP
113     maxNumThreads = cv::getNumThreads();
114 #endif
115     vector<CvDTreeSplit*> splits(maxNumThreads);
116     vector<CvDTreeSplit*> bestSplits(maxNumThreads);
117     vector<int> canSplit(maxNumThreads);
118     CvDTreeSplit **splitsPtr = &splits[0], ** bestSplitsPtr = &bestSplits[0];
119     int* canSplitPtr = &canSplit[0];
120     for (int i = 0; i < maxNumThreads; i++)
121     {
122         splits[i] = data->new_split_cat( 0, -1.0f );
123         bestSplits[i] = data->new_split_cat( 0, -1.0f );
124         canSplitPtr[i] = 0;
125     }
126
127 #ifdef _OPENMP
128 #pragma omp parallel for num_threads(maxNumThreads) schedule(dynamic)
129 #endif
130     for( vi = 0; vi < data->var_count; vi++ )
131     {
132         CvDTreeSplit *res, *t;
133         int threadIdx = cv::getThreadNum();
134         int ci = data->var_type->data.i[vi];
135         if( node->num_valid[vi] <= 1
136             || (active_var_mask && !active_var_mask->data.ptr[vi]) )
137             continue;
138
139         if( data->is_classifier )
140         {
141             if( ci >= 0 )
142                 res = find_split_cat_class( node, vi, bestSplitsPtr[threadIdx]->quality, splitsPtr[threadIdx] );
143             else
144                 res = find_split_ord_class( node, vi, bestSplitsPtr[threadIdx]->quality, splitsPtr[threadIdx] );
145         }
146         else
147         {
148             if( ci >= 0 )
149                 res = find_split_cat_reg( node, vi, bestSplitsPtr[threadIdx]->quality, splitsPtr[threadIdx] );
150             else
151                 res = find_split_ord_reg( node, vi, bestSplitsPtr[threadIdx]->quality, splitsPtr[threadIdx] );
152         }
153
154         if( res )
155         {
156             canSplitPtr[threadIdx] = 1;
157             if( bestSplits[threadIdx]->quality < splits[threadIdx]->quality )
158                 CV_SWAP( bestSplits[threadIdx], splits[threadIdx], t );
159         }
160     }
161     int ti = 0;
162     for( ; ti < maxNumThreads; ti++ )
163     {
164         if( canSplitPtr[ti] )
165         {
166             best_split = bestSplitsPtr[ti];
167             break;
168         }
169     }
170     for( ; ti < maxNumThreads; ti++ )
171     {
172         if( best_split->quality < bestSplitsPtr[ti]->quality )
173             best_split = bestSplitsPtr[ti];
174     }
175     for(int i = 0; i < maxNumThreads; i++)
176     {
177         cvSetRemoveByPtr( data->split_heap, splits[i] );
178         if( bestSplits[i] != best_split )
179             cvSetRemoveByPtr( data->split_heap, bestSplits[i] );
180     }
181     return best_split;
182 }
183
184
185 void CvForestTree::read( CvFileStorage* fs, CvFileNode* fnode, CvRTrees* _forest, CvDTreeTrainData* _data )
186 {
187     CvDTree::read( fs, fnode, _data );
188     forest = _forest;
189 }
190
191
192 void CvForestTree::read( CvFileStorage*, CvFileNode* )
193 {
194     assert(0);
195 }
196
197 void CvForestTree::read( CvFileStorage* _fs, CvFileNode* _node,
198                          CvDTreeTrainData* _data )
199 {
200     CvDTree::read( _fs, _node, _data );
201 }
202
203
204 //////////////////////////////////////////////////////////////////////////////////////////
205 //                                  Random trees                                        //
206 //////////////////////////////////////////////////////////////////////////////////////////
207
208 CvRTrees::CvRTrees()
209 {
210     nclasses         = 0;
211     oob_error        = 0;
212     ntrees           = 0;
213     trees            = NULL;
214     data             = NULL;
215     active_var_mask  = NULL;
216     var_importance   = NULL;
217     rng = cvRNG(0xffffffff);
218     default_model_name = "my_random_trees";
219 }
220
221
222 void CvRTrees::clear()
223 {
224     int k;
225     for( k = 0; k < ntrees; k++ )
226         delete trees[k];
227     cvFree( &trees );
228
229     delete data;
230     data = 0;
231
232     cvReleaseMat( &active_var_mask );
233     cvReleaseMat( &var_importance );
234     ntrees = 0;
235 }
236
237
238 CvRTrees::~CvRTrees()
239 {
240     clear();
241 }
242
243
244 CvMat* CvRTrees::get_active_var_mask()
245 {
246     return active_var_mask;
247 }
248
249
250 CvRNG* CvRTrees::get_rng()
251 {
252     return &rng;
253 }
254
255 bool CvRTrees::train( const CvMat* _train_data, int _tflag,
256                         const CvMat* _responses, const CvMat* _var_idx,
257                         const CvMat* _sample_idx, const CvMat* _var_type,
258                         const CvMat* _missing_mask, CvRTParams params )
259 {
260     clear();
261
262     CvDTreeParams tree_params( params.max_depth, params.min_sample_count,
263         params.regression_accuracy, params.use_surrogates, params.max_categories,
264         params.cv_folds, params.use_1se_rule, false, params.priors );
265
266     data = new CvDTreeTrainData();
267     data->set_data( _train_data, _tflag, _responses, _var_idx,
268         _sample_idx, _var_type, _missing_mask, tree_params, true);
269
270     int var_count = data->var_count;
271     if( params.nactive_vars > var_count )
272         params.nactive_vars = var_count;
273     else if( params.nactive_vars == 0 )
274         params.nactive_vars = (int)sqrt((double)var_count);
275     else if( params.nactive_vars < 0 )
276         CV_Error( CV_StsBadArg, "<nactive_vars> must be non-negative" );
277
278     // Create mask of active variables at the tree nodes
279     active_var_mask = cvCreateMat( 1, var_count, CV_8UC1 );
280     if( params.calc_var_importance )
281     {
282         var_importance  = cvCreateMat( 1, var_count, CV_32FC1 );
283         cvZero(var_importance);
284     }
285     { // initialize active variables mask
286         CvMat submask1, submask2;
287         cvGetCols( active_var_mask, &submask1, 0, params.nactive_vars );
288         cvGetCols( active_var_mask, &submask2, params.nactive_vars, var_count );
289         cvSet( &submask1, cvScalar(1) );
290         cvZero( &submask2 );
291     }
292
293     return grow_forest( params.term_crit );
294 }
295
296 bool CvRTrees::train( CvMLData* data, CvRTParams params )
297 {
298     const CvMat* values = data->get_values();
299     const CvMat* response = data->get_responses();
300     const CvMat* missing = data->get_missing();
301     const CvMat* var_types = data->get_var_types();
302     const CvMat* train_sidx = data->get_train_sample_idx();
303     const CvMat* var_idx = data->get_var_idx();
304
305     return train( values, CV_ROW_SAMPLE, response, var_idx,
306                   train_sidx, var_types, missing, params );
307 }
308
309 bool CvRTrees::grow_forest( const CvTermCriteria term_crit )
310 {
311     CvMat* sample_idx_mask_for_tree = 0;
312     CvMat* sample_idx_for_tree      = 0;
313
314     const int max_ntrees = term_crit.max_iter;
315     const double max_oob_err = term_crit.epsilon;
316
317     const int dims = data->var_count;
318     float maximal_response = 0;
319
320     CvMat* oob_sample_votes        = 0;
321     CvMat* oob_responses       = 0;
322
323     float* oob_samples_perm_ptr= 0;
324
325     float* samples_ptr     = 0;
326     uchar* missing_ptr     = 0;
327     float* true_resp_ptr   = 0;
328     bool is_oob_or_vimportance = (max_oob_err > 0) && (term_crit.type != CV_TERMCRIT_ITER) || var_importance;
329
330     // oob_predictions_sum[i] = sum of predicted values for the i-th sample
331     // oob_num_of_predictions[i] = number of summands
332     //                            (number of predictions for the i-th sample)
333     // initialize these variable to avoid warning C4701
334     CvMat oob_predictions_sum = cvMat( 1, 1, CV_32FC1 );
335     CvMat oob_num_of_predictions = cvMat( 1, 1, CV_32FC1 );
336      
337     nsamples = data->sample_count;
338     nclasses = data->get_num_classes();
339
340     if ( is_oob_or_vimportance )
341     {
342         if( data->is_classifier )
343         {
344             oob_sample_votes = cvCreateMat( nsamples, nclasses, CV_32SC1 );
345             cvZero(oob_sample_votes);
346         }
347         else
348         {
349             // oob_responses[0,i] = oob_predictions_sum[i]
350             //    = sum of predicted values for the i-th sample
351             // oob_responses[1,i] = oob_num_of_predictions[i]
352             //    = number of summands (number of predictions for the i-th sample)
353             oob_responses = cvCreateMat( 2, nsamples, CV_32FC1 );
354             cvZero(oob_responses);
355             cvGetRow( oob_responses, &oob_predictions_sum, 0 );
356             cvGetRow( oob_responses, &oob_num_of_predictions, 1 );
357         }
358         
359         oob_samples_perm_ptr     = (float*)cvAlloc( sizeof(float)*nsamples*dims );
360         samples_ptr              = (float*)cvAlloc( sizeof(float)*nsamples*dims );
361         missing_ptr              = (uchar*)cvAlloc( sizeof(uchar)*nsamples*dims );
362         true_resp_ptr            = (float*)cvAlloc( sizeof(float)*nsamples );            
363
364         data->get_vectors( 0, samples_ptr, missing_ptr, true_resp_ptr );
365         
366         double minval, maxval;
367         CvMat responses = cvMat(1, nsamples, CV_32FC1, true_resp_ptr);
368         cvMinMaxLoc( &responses, &minval, &maxval );
369         maximal_response = (float)MAX( MAX( fabs(minval), fabs(maxval) ), 0 );
370     }
371
372     trees = (CvForestTree**)cvAlloc( sizeof(trees[0])*max_ntrees );
373     memset( trees, 0, sizeof(trees[0])*max_ntrees );
374
375     sample_idx_mask_for_tree = cvCreateMat( 1, nsamples, CV_8UC1 );
376     sample_idx_for_tree      = cvCreateMat( 1, nsamples, CV_32SC1 );
377
378     ntrees = 0;
379     while( ntrees < max_ntrees )
380     {
381         int i, oob_samples_count = 0;
382         double ncorrect_responses = 0; // used for estimation of variable importance
383         CvForestTree* tree = 0;
384
385         cvZero( sample_idx_mask_for_tree );
386         for(i = 0; i < nsamples; i++ ) //form sample for creation one tree
387         {
388             int idx = cvRandInt( &rng ) % nsamples;
389             sample_idx_for_tree->data.i[i] = idx;
390             sample_idx_mask_for_tree->data.ptr[idx] = 0xFF;
391         }
392
393         trees[ntrees] = new CvForestTree();
394         tree = trees[ntrees];
395         tree->train( data, sample_idx_for_tree, this );
396
397         if ( is_oob_or_vimportance )
398         {
399             CvMat sample, missing;
400             // form array of OOB samples indices and get these samples
401             sample   = cvMat( 1, dims, CV_32FC1, samples_ptr );
402             missing  = cvMat( 1, dims, CV_8UC1,  missing_ptr );
403
404             oob_error = 0;
405             for( i = 0; i < nsamples; i++,
406                 sample.data.fl += dims, missing.data.ptr += dims )
407             {
408                 CvDTreeNode* predicted_node = 0;
409                 // check if the sample is OOB
410                 if( sample_idx_mask_for_tree->data.ptr[i] )
411                     continue;
412
413                 // predict oob samples
414                 if( !predicted_node )
415                     predicted_node = tree->predict(&sample, &missing, true);
416
417                 if( !data->is_classifier ) //regression
418                 {
419                     double avg_resp, resp = predicted_node->value;
420                     oob_predictions_sum.data.fl[i] += (float)resp;
421                     oob_num_of_predictions.data.fl[i] += 1;
422
423                     // compute oob error
424                     avg_resp = oob_predictions_sum.data.fl[i]/oob_num_of_predictions.data.fl[i];
425                     avg_resp -= true_resp_ptr[i];
426                     oob_error += avg_resp*avg_resp;
427                     resp = (resp - true_resp_ptr[i])/maximal_response;
428                     ncorrect_responses += exp( -resp*resp );
429                 }
430                 else //classification
431                 {
432                     double prdct_resp;
433                     CvPoint max_loc;
434                     CvMat votes;
435
436                     cvGetRow(oob_sample_votes, &votes, i);
437                     votes.data.i[predicted_node->class_idx]++;
438
439                     // compute oob error
440                     cvMinMaxLoc( &votes, 0, 0, 0, &max_loc );
441
442                     prdct_resp = data->cat_map->data.i[max_loc.x];
443                     oob_error += (fabs(prdct_resp - true_resp_ptr[i]) < FLT_EPSILON) ? 0 : 1;
444
445                     ncorrect_responses += cvRound(predicted_node->value - true_resp_ptr[i]) == 0;
446                 }
447                 oob_samples_count++;
448             }
449             if( oob_samples_count > 0 )
450                 oob_error /= (double)oob_samples_count;
451
452             // estimate variable importance
453             if( var_importance && oob_samples_count > 0 )
454             {
455                 int m;
456
457                 memcpy( oob_samples_perm_ptr, samples_ptr, dims*nsamples*sizeof(float));
458                 for( m = 0; m < dims; m++ )
459                 {
460                     double ncorrect_responses_permuted = 0;
461                     // randomly permute values of the m-th variable in the oob samples
462                     float* mth_var_ptr = oob_samples_perm_ptr + m;
463
464                     for( i = 0; i < nsamples; i++ )
465                     {
466                         int i1, i2;
467                         float temp;
468
469                         if( sample_idx_mask_for_tree->data.ptr[i] ) //the sample is not OOB
470                             continue;
471                         i1 = cvRandInt( &rng ) % nsamples;
472                         i2 = cvRandInt( &rng ) % nsamples;
473                         CV_SWAP( mth_var_ptr[i1*dims], mth_var_ptr[i2*dims], temp );
474
475                         // turn values of (m-1)-th variable, that were permuted
476                         // at the previous iteration, untouched
477                         if( m > 1 )
478                             oob_samples_perm_ptr[i*dims+m-1] = samples_ptr[i*dims+m-1];
479                     }
480
481                     // predict "permuted" cases and calculate the number of votes for the
482                     // correct class in the variable-m-permuted oob data
483                     sample  = cvMat( 1, dims, CV_32FC1, oob_samples_perm_ptr );
484                     missing = cvMat( 1, dims, CV_8UC1, missing_ptr );
485                     for( i = 0; i < nsamples; i++,
486                         sample.data.fl += dims, missing.data.ptr += dims )
487                     {
488                         double predct_resp, true_resp;
489
490                         if( sample_idx_mask_for_tree->data.ptr[i] ) //the sample is not OOB
491                             continue;
492
493                         predct_resp = tree->predict(&sample, &missing, true)->value;
494                         true_resp   = true_resp_ptr[i];
495                         if( data->is_classifier )
496                             ncorrect_responses_permuted += cvRound(true_resp - predct_resp) == 0;
497                         else
498                         {
499                             true_resp = (true_resp - predct_resp)/maximal_response;
500                             ncorrect_responses_permuted += exp( -true_resp*true_resp );
501                         }
502                     }
503                     var_importance->data.fl[m] += (float)(ncorrect_responses
504                         - ncorrect_responses_permuted);
505                 }
506             }
507         }
508         ntrees++;
509         if( term_crit.type != CV_TERMCRIT_ITER && oob_error < max_oob_err )
510             break;
511     }
512
513     if( var_importance )
514     {
515         for ( int vi = 0; vi < var_importance->cols; vi++ )
516                 var_importance->data.fl[vi] = ( var_importance->data.fl[vi] > 0 ) ?
517                     var_importance->data.fl[vi] : 0;
518         cvNormalize( var_importance, var_importance, 1., 0, CV_L1 );
519     }
520
521     cvFree( &oob_samples_perm_ptr );
522     cvFree( &samples_ptr );
523     cvFree( &missing_ptr );
524     cvFree( &true_resp_ptr );
525     
526     cvReleaseMat( &sample_idx_mask_for_tree );
527     cvReleaseMat( &sample_idx_for_tree );
528
529     cvReleaseMat( &oob_sample_votes );
530     cvReleaseMat( &oob_responses );
531
532     return true;
533 }
534
535
536 const CvMat* CvRTrees::get_var_importance()
537 {
538     return var_importance;
539 }
540
541
542 float CvRTrees::get_proximity( const CvMat* sample1, const CvMat* sample2,
543                               const CvMat* missing1, const CvMat* missing2 ) const
544 {
545     float result = 0;
546
547     for( int i = 0; i < ntrees; i++ )
548         result += trees[i]->predict( sample1, missing1 ) ==
549         trees[i]->predict( sample2, missing2 ) ?  1 : 0;
550     result = result/(float)ntrees;
551
552     return result;
553 }
554
555 float CvRTrees::calc_error( CvMLData* _data, int type , vector<float> *resp )
556 {
557     float err = 0;
558     const CvMat* values = _data->get_values();
559     const CvMat* response = _data->get_responses();
560     const CvMat* missing = _data->get_missing();
561     const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();
562     const CvMat* var_types = _data->get_var_types();
563     int* sidx = sample_idx ? sample_idx->data.i : 0;
564     int r_step = CV_IS_MAT_CONT(response->type) ?
565                 1 : response->step / CV_ELEM_SIZE(response->type);
566     bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL;
567     int sample_count = sample_idx ? sample_idx->cols : 0;
568     sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? values->rows : sample_count;
569     float* pred_resp = 0;
570     if( resp && (sample_count > 0) )
571     {
572         resp->resize( sample_count );
573         pred_resp = &((*resp)[0]);
574     }
575     if ( is_classifier )
576     {
577         for( int i = 0; i < sample_count; i++ )
578         {
579             CvMat sample, miss;
580             int si = sidx ? sidx[i] : i;
581             cvGetRow( values, &sample, si ); 
582             if( missing ) 
583                 cvGetRow( missing, &miss, si );             
584             float r = (float)predict( &sample, missing ? &miss : 0 );
585             if( pred_resp )
586                 pred_resp[i] = r;
587             int d = fabs((double)r - response->data.fl[si*r_step]) <= FLT_EPSILON ? 0 : 1;
588             err += d;
589         }
590         err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
591     }
592     else
593     {
594         for( int i = 0; i < sample_count; i++ )
595         {
596             CvMat sample, miss;
597             int si = sidx ? sidx[i] : i;
598             cvGetRow( values, &sample, si );
599             if( missing ) 
600                 cvGetRow( missing, &miss, si );             
601             float r = (float)predict( &sample, missing ? &miss : 0 );
602             if( pred_resp )
603                 pred_resp[i] = r;
604             float d = r - response->data.fl[si*r_step];
605             err += d*d;
606         }
607         err = sample_count ? err / (float)sample_count : -FLT_MAX;    
608     }
609     return err;
610 }
611
612 float CvRTrees::get_train_error()
613 {
614     float err = -1;
615
616     int sample_count = data->sample_count;
617     int var_count = data->var_count;
618
619     float *values_ptr = (float*)cvAlloc( sizeof(float)*sample_count*var_count );
620     uchar *missing_ptr = (uchar*)cvAlloc( sizeof(uchar)*sample_count*var_count );
621     float *responses_ptr = (float*)cvAlloc( sizeof(float)*sample_count );
622
623     data->get_vectors( 0, values_ptr, missing_ptr, responses_ptr);
624     
625     if (data->is_classifier)
626     {
627         int err_count = 0;
628         float *vp = values_ptr;
629         uchar *mp = missing_ptr;    
630         for (int si = 0; si < sample_count; si++, vp += var_count, mp += var_count)
631         {
632             CvMat sample = cvMat( 1, var_count, CV_32FC1, vp );
633             CvMat missing = cvMat( 1, var_count, CV_8UC1,  mp );
634             float r = predict( &sample, &missing );
635             if (fabs(r - responses_ptr[si]) >= FLT_EPSILON)
636                 err_count++;
637         }
638         err = (float)err_count / (float)sample_count;
639     }
640     else
641         CV_Error( CV_StsBadArg, "This method is not supported for regression problems" );
642     
643     cvFree( &values_ptr );
644     cvFree( &missing_ptr );
645     cvFree( &responses_ptr ); 
646
647     return err;
648 }
649
650
651 float CvRTrees::predict( const CvMat* sample, const CvMat* missing ) const
652 {
653     double result = -1;
654     int k;
655
656     if( nclasses > 0 ) //classification
657     {
658         int max_nvotes = 0;
659         int* votes = (int*)alloca( sizeof(int)*nclasses );
660         memset( votes, 0, sizeof(*votes)*nclasses );
661         for( k = 0; k < ntrees; k++ )
662         {
663             CvDTreeNode* predicted_node = trees[k]->predict( sample, missing );
664             int nvotes;
665             int class_idx = predicted_node->class_idx;
666             CV_Assert( 0 <= class_idx && class_idx < nclasses );
667
668             nvotes = ++votes[class_idx];
669             if( nvotes > max_nvotes )
670             {
671                 max_nvotes = nvotes;
672                 result = predicted_node->value;
673             }
674         }
675     }
676     else // regression
677     {
678         result = 0;
679         for( k = 0; k < ntrees; k++ )
680             result += trees[k]->predict( sample, missing )->value;
681         result /= (double)ntrees;
682     }
683
684     return (float)result;
685 }
686
687 float CvRTrees::predict_prob( const CvMat* sample, const CvMat* missing) const
688 {
689     double result = -1;
690     int k;
691         
692         if( nclasses == 2 ) //classification
693     {
694         int max_nvotes = 0;
695         int* votes = (int*)alloca( sizeof(int)*nclasses );
696         memset( votes, 0, sizeof(*votes)*nclasses );
697         for( k = 0; k < ntrees; k++ )
698         {
699             CvDTreeNode* predicted_node = trees[k]->predict( sample, missing );
700             int nvotes;
701             int class_idx = predicted_node->class_idx;
702             CV_Assert( 0 <= class_idx && class_idx < nclasses );
703                         
704             nvotes = ++votes[class_idx];
705             if( nvotes > max_nvotes )
706             {
707                 max_nvotes = nvotes;
708                 result = predicted_node->value;
709             }
710         }
711                 
712                 return float(votes[1])/ntrees;
713     }
714     else // regression
715                 CV_Error(CV_StsBadArg, "This function works for binary classification problems only...");
716         
717     return -1;
718 }
719
720 void CvRTrees::write( CvFileStorage* fs, const char* name ) const
721 {
722     int k;
723
724     if( ntrees < 1 || !trees || nsamples < 1 )
725         CV_Error( CV_StsBadArg, "Invalid CvRTrees object" );
726
727     cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_RTREES );
728
729     cvWriteInt( fs, "nclasses", nclasses );
730     cvWriteInt( fs, "nsamples", nsamples );
731     cvWriteInt( fs, "nactive_vars", (int)cvSum(active_var_mask).val[0] );
732     cvWriteReal( fs, "oob_error", oob_error );
733
734     if( var_importance )
735         cvWrite( fs, "var_importance", var_importance );
736
737     cvWriteInt( fs, "ntrees", ntrees );
738
739     data->write_params( fs );
740
741     cvStartWriteStruct( fs, "trees", CV_NODE_SEQ );
742
743     for( k = 0; k < ntrees; k++ )
744     {
745         cvStartWriteStruct( fs, 0, CV_NODE_MAP );
746         trees[k]->write( fs );
747         cvEndWriteStruct( fs );
748     }
749
750     cvEndWriteStruct( fs ); //trees
751     cvEndWriteStruct( fs ); //CV_TYPE_NAME_ML_RTREES
752 }
753
754
755 void CvRTrees::read( CvFileStorage* fs, CvFileNode* fnode )
756 {
757     int nactive_vars, var_count, k;
758     CvSeqReader reader;
759     CvFileNode* trees_fnode = 0;
760
761     clear();
762
763     nclasses     = cvReadIntByName( fs, fnode, "nclasses", -1 );
764     nsamples     = cvReadIntByName( fs, fnode, "nsamples" );
765     nactive_vars = cvReadIntByName( fs, fnode, "nactive_vars", -1 );
766     oob_error    = cvReadRealByName(fs, fnode, "oob_error", -1 );
767     ntrees       = cvReadIntByName( fs, fnode, "ntrees", -1 );
768
769     var_importance = (CvMat*)cvReadByName( fs, fnode, "var_importance" );
770
771     if( nclasses < 0 || nsamples <= 0 || nactive_vars < 0 || oob_error < 0 || ntrees <= 0)
772         CV_Error( CV_StsParseError, "Some <nclasses>, <nsamples>, <var_count>, "
773         "<nactive_vars>, <oob_error>, <ntrees> of tags are missing" );
774
775     rng = CvRNG( -1 );
776
777     trees = (CvForestTree**)cvAlloc( sizeof(trees[0])*ntrees );
778     memset( trees, 0, sizeof(trees[0])*ntrees );
779
780     data = new CvDTreeTrainData();
781     data->read_params( fs, fnode );
782     data->shared = true;
783
784     trees_fnode = cvGetFileNodeByName( fs, fnode, "trees" );
785     if( !trees_fnode || !CV_NODE_IS_SEQ(trees_fnode->tag) )
786         CV_Error( CV_StsParseError, "<trees> tag is missing" );
787
788     cvStartReadSeq( trees_fnode->data.seq, &reader );
789     if( reader.seq->total != ntrees )
790         CV_Error( CV_StsParseError,
791         "<ntrees> is not equal to the number of trees saved in file" );
792
793     for( k = 0; k < ntrees; k++ )
794     {
795         trees[k] = new CvForestTree();
796         trees[k]->read( fs, (CvFileNode*)reader.ptr, this, data );
797         CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
798     }
799
800     var_count = data->var_count;
801     active_var_mask = cvCreateMat( 1, var_count, CV_8UC1 );
802     {
803         // initialize active variables mask
804         CvMat submask1, submask2;
805         cvGetCols( active_var_mask, &submask1, 0, nactive_vars );
806         cvGetCols( active_var_mask, &submask2, nactive_vars, var_count );
807         cvSet( &submask1, cvScalar(1) );
808         cvZero( &submask2 );
809     }
810 }
811
812
813 int CvRTrees::get_tree_count() const
814 {
815     return ntrees;
816 }
817
818 CvForestTree* CvRTrees::get_tree(int i) const
819 {
820     return (unsigned)i < (unsigned)ntrees ? trees[i] : 0;
821 }
822
823 using namespace cv;
824
825 bool CvRTrees::train( const Mat& _train_data, int _tflag,
826                      const Mat& _responses, const Mat& _var_idx,
827                      const Mat& _sample_idx, const Mat& _var_type,
828                      const Mat& _missing_mask, CvRTParams _params )
829 {
830     CvMat tdata = _train_data, responses = _responses, vidx = _var_idx,
831     sidx = _sample_idx, vtype = _var_type, mmask = _missing_mask;
832     return train(&tdata, _tflag, &responses, vidx.data.ptr ? &vidx : 0,
833                  sidx.data.ptr ? &sidx : 0, vtype.data.ptr ? &vtype : 0,
834                  mmask.data.ptr ? &mmask : 0, _params);
835 }
836
837
838 float CvRTrees::predict( const Mat& _sample, const Mat& _missing ) const
839 {
840     CvMat sample = _sample, mmask = _missing;
841     return predict(&sample, mmask.data.ptr ? &mmask : 0);
842 }
843
844 float CvRTrees::predict_prob( const Mat& _sample, const Mat& _missing) const
845 {
846     CvMat sample = _sample, mmask = _missing;
847     return predict_prob(&sample, mmask.data.ptr ? &mmask : 0);
848 }
849
850
851 // End of file.