1 /*M///////////////////////////////////////////////////////////////////////////////////////
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
10 // Intel License Agreement
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
18 // * Redistribution's of source code must retain the above copyright notice,
19 // this list of conditions and the following disclaimer.
21 // * Redistribution's in binary form must reproduce the above copyright notice,
22 // this list of conditions and the following disclaimer in the documentation
23 // and/or other materials provided with the distribution.
25 // * The name of Intel Corporation may not be used to endorse or promote products
26 // derived from this software without specific prior written permission.
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
43 CvForestTree::CvForestTree()
49 CvForestTree::~CvForestTree()
55 bool CvForestTree::train( CvDTreeTrainData* _data,
56 const CvMat* _subsample_idx,
64 return do_train(_subsample_idx);
69 CvForestTree::train( const CvMat*, int, const CvMat*, const CvMat*,
70 const CvMat*, const CvMat*, const CvMat*, CvDTreeParams )
78 CvForestTree::train( CvDTreeTrainData*, const CvMat* )
85 CvDTreeSplit* CvForestTree::find_best_split( CvDTreeNode* node )
89 CvDTreeSplit *best_split = 0;
91 CvMat* active_var_mask = 0;
95 CvRNG* rng = forest->get_rng();
97 active_var_mask = forest->get_active_var_mask();
98 var_count = active_var_mask->cols;
100 CV_Assert( var_count == data->var_count );
102 for( vi = 0; vi < var_count; vi++ )
105 int i1 = cvRandInt(rng) % var_count;
106 int i2 = cvRandInt(rng) % var_count;
107 CV_SWAP( active_var_mask->data.ptr[i1],
108 active_var_mask->data.ptr[i2], temp );
111 int maxNumThreads = 1;
113 maxNumThreads = cv::getNumThreads();
115 vector<CvDTreeSplit*> splits(maxNumThreads);
116 vector<CvDTreeSplit*> bestSplits(maxNumThreads);
117 vector<int> canSplit(maxNumThreads);
118 CvDTreeSplit **splitsPtr = &splits[0], ** bestSplitsPtr = &bestSplits[0];
119 int* canSplitPtr = &canSplit[0];
120 for (int i = 0; i < maxNumThreads; i++)
122 splits[i] = data->new_split_cat( 0, -1.0f );
123 bestSplits[i] = data->new_split_cat( 0, -1.0f );
128 #pragma omp parallel for num_threads(maxNumThreads) schedule(dynamic)
130 for( vi = 0; vi < data->var_count; vi++ )
132 CvDTreeSplit *res, *t;
133 int threadIdx = cv::getThreadNum();
134 int ci = data->var_type->data.i[vi];
135 if( node->num_valid[vi] <= 1
136 || (active_var_mask && !active_var_mask->data.ptr[vi]) )
139 if( data->is_classifier )
142 res = find_split_cat_class( node, vi, bestSplitsPtr[threadIdx]->quality, splitsPtr[threadIdx] );
144 res = find_split_ord_class( node, vi, bestSplitsPtr[threadIdx]->quality, splitsPtr[threadIdx] );
149 res = find_split_cat_reg( node, vi, bestSplitsPtr[threadIdx]->quality, splitsPtr[threadIdx] );
151 res = find_split_ord_reg( node, vi, bestSplitsPtr[threadIdx]->quality, splitsPtr[threadIdx] );
156 canSplitPtr[threadIdx] = 1;
157 if( bestSplits[threadIdx]->quality < splits[threadIdx]->quality )
158 CV_SWAP( bestSplits[threadIdx], splits[threadIdx], t );
162 for( ; ti < maxNumThreads; ti++ )
164 if( canSplitPtr[ti] )
166 best_split = bestSplitsPtr[ti];
170 for( ; ti < maxNumThreads; ti++ )
172 if( best_split->quality < bestSplitsPtr[ti]->quality )
173 best_split = bestSplitsPtr[ti];
175 for(int i = 0; i < maxNumThreads; i++)
177 cvSetRemoveByPtr( data->split_heap, splits[i] );
178 if( bestSplits[i] != best_split )
179 cvSetRemoveByPtr( data->split_heap, bestSplits[i] );
185 void CvForestTree::read( CvFileStorage* fs, CvFileNode* fnode, CvRTrees* _forest, CvDTreeTrainData* _data )
187 CvDTree::read( fs, fnode, _data );
192 void CvForestTree::read( CvFileStorage*, CvFileNode* )
197 void CvForestTree::read( CvFileStorage* _fs, CvFileNode* _node,
198 CvDTreeTrainData* _data )
200 CvDTree::read( _fs, _node, _data );
204 //////////////////////////////////////////////////////////////////////////////////////////
206 //////////////////////////////////////////////////////////////////////////////////////////
215 active_var_mask = NULL;
216 var_importance = NULL;
217 rng = cvRNG(0xffffffff);
218 default_model_name = "my_random_trees";
222 void CvRTrees::clear()
225 for( k = 0; k < ntrees; k++ )
232 cvReleaseMat( &active_var_mask );
233 cvReleaseMat( &var_importance );
238 CvRTrees::~CvRTrees()
244 CvMat* CvRTrees::get_active_var_mask()
246 return active_var_mask;
250 CvRNG* CvRTrees::get_rng()
255 bool CvRTrees::train( const CvMat* _train_data, int _tflag,
256 const CvMat* _responses, const CvMat* _var_idx,
257 const CvMat* _sample_idx, const CvMat* _var_type,
258 const CvMat* _missing_mask, CvRTParams params )
262 CvDTreeParams tree_params( params.max_depth, params.min_sample_count,
263 params.regression_accuracy, params.use_surrogates, params.max_categories,
264 params.cv_folds, params.use_1se_rule, false, params.priors );
266 data = new CvDTreeTrainData();
267 data->set_data( _train_data, _tflag, _responses, _var_idx,
268 _sample_idx, _var_type, _missing_mask, tree_params, true);
270 int var_count = data->var_count;
271 if( params.nactive_vars > var_count )
272 params.nactive_vars = var_count;
273 else if( params.nactive_vars == 0 )
274 params.nactive_vars = (int)sqrt((double)var_count);
275 else if( params.nactive_vars < 0 )
276 CV_Error( CV_StsBadArg, "<nactive_vars> must be non-negative" );
278 // Create mask of active variables at the tree nodes
279 active_var_mask = cvCreateMat( 1, var_count, CV_8UC1 );
280 if( params.calc_var_importance )
282 var_importance = cvCreateMat( 1, var_count, CV_32FC1 );
283 cvZero(var_importance);
285 { // initialize active variables mask
286 CvMat submask1, submask2;
287 cvGetCols( active_var_mask, &submask1, 0, params.nactive_vars );
288 cvGetCols( active_var_mask, &submask2, params.nactive_vars, var_count );
289 cvSet( &submask1, cvScalar(1) );
293 return grow_forest( params.term_crit );
296 bool CvRTrees::train( CvMLData* data, CvRTParams params )
298 const CvMat* values = data->get_values();
299 const CvMat* response = data->get_responses();
300 const CvMat* missing = data->get_missing();
301 const CvMat* var_types = data->get_var_types();
302 const CvMat* train_sidx = data->get_train_sample_idx();
303 const CvMat* var_idx = data->get_var_idx();
305 return train( values, CV_ROW_SAMPLE, response, var_idx,
306 train_sidx, var_types, missing, params );
309 bool CvRTrees::grow_forest( const CvTermCriteria term_crit )
311 CvMat* sample_idx_mask_for_tree = 0;
312 CvMat* sample_idx_for_tree = 0;
314 const int max_ntrees = term_crit.max_iter;
315 const double max_oob_err = term_crit.epsilon;
317 const int dims = data->var_count;
318 float maximal_response = 0;
320 CvMat* oob_sample_votes = 0;
321 CvMat* oob_responses = 0;
323 float* oob_samples_perm_ptr= 0;
325 float* samples_ptr = 0;
326 uchar* missing_ptr = 0;
327 float* true_resp_ptr = 0;
328 bool is_oob_or_vimportance = (max_oob_err > 0) && (term_crit.type != CV_TERMCRIT_ITER) || var_importance;
330 // oob_predictions_sum[i] = sum of predicted values for the i-th sample
331 // oob_num_of_predictions[i] = number of summands
332 // (number of predictions for the i-th sample)
333 // initialize these variable to avoid warning C4701
334 CvMat oob_predictions_sum = cvMat( 1, 1, CV_32FC1 );
335 CvMat oob_num_of_predictions = cvMat( 1, 1, CV_32FC1 );
337 nsamples = data->sample_count;
338 nclasses = data->get_num_classes();
340 if ( is_oob_or_vimportance )
342 if( data->is_classifier )
344 oob_sample_votes = cvCreateMat( nsamples, nclasses, CV_32SC1 );
345 cvZero(oob_sample_votes);
349 // oob_responses[0,i] = oob_predictions_sum[i]
350 // = sum of predicted values for the i-th sample
351 // oob_responses[1,i] = oob_num_of_predictions[i]
352 // = number of summands (number of predictions for the i-th sample)
353 oob_responses = cvCreateMat( 2, nsamples, CV_32FC1 );
354 cvZero(oob_responses);
355 cvGetRow( oob_responses, &oob_predictions_sum, 0 );
356 cvGetRow( oob_responses, &oob_num_of_predictions, 1 );
359 oob_samples_perm_ptr = (float*)cvAlloc( sizeof(float)*nsamples*dims );
360 samples_ptr = (float*)cvAlloc( sizeof(float)*nsamples*dims );
361 missing_ptr = (uchar*)cvAlloc( sizeof(uchar)*nsamples*dims );
362 true_resp_ptr = (float*)cvAlloc( sizeof(float)*nsamples );
364 data->get_vectors( 0, samples_ptr, missing_ptr, true_resp_ptr );
366 double minval, maxval;
367 CvMat responses = cvMat(1, nsamples, CV_32FC1, true_resp_ptr);
368 cvMinMaxLoc( &responses, &minval, &maxval );
369 maximal_response = (float)MAX( MAX( fabs(minval), fabs(maxval) ), 0 );
372 trees = (CvForestTree**)cvAlloc( sizeof(trees[0])*max_ntrees );
373 memset( trees, 0, sizeof(trees[0])*max_ntrees );
375 sample_idx_mask_for_tree = cvCreateMat( 1, nsamples, CV_8UC1 );
376 sample_idx_for_tree = cvCreateMat( 1, nsamples, CV_32SC1 );
379 while( ntrees < max_ntrees )
381 int i, oob_samples_count = 0;
382 double ncorrect_responses = 0; // used for estimation of variable importance
383 CvForestTree* tree = 0;
385 cvZero( sample_idx_mask_for_tree );
386 for(i = 0; i < nsamples; i++ ) //form sample for creation one tree
388 int idx = cvRandInt( &rng ) % nsamples;
389 sample_idx_for_tree->data.i[i] = idx;
390 sample_idx_mask_for_tree->data.ptr[idx] = 0xFF;
393 trees[ntrees] = new CvForestTree();
394 tree = trees[ntrees];
395 tree->train( data, sample_idx_for_tree, this );
397 if ( is_oob_or_vimportance )
399 CvMat sample, missing;
400 // form array of OOB samples indices and get these samples
401 sample = cvMat( 1, dims, CV_32FC1, samples_ptr );
402 missing = cvMat( 1, dims, CV_8UC1, missing_ptr );
405 for( i = 0; i < nsamples; i++,
406 sample.data.fl += dims, missing.data.ptr += dims )
408 CvDTreeNode* predicted_node = 0;
409 // check if the sample is OOB
410 if( sample_idx_mask_for_tree->data.ptr[i] )
413 // predict oob samples
414 if( !predicted_node )
415 predicted_node = tree->predict(&sample, &missing, true);
417 if( !data->is_classifier ) //regression
419 double avg_resp, resp = predicted_node->value;
420 oob_predictions_sum.data.fl[i] += (float)resp;
421 oob_num_of_predictions.data.fl[i] += 1;
424 avg_resp = oob_predictions_sum.data.fl[i]/oob_num_of_predictions.data.fl[i];
425 avg_resp -= true_resp_ptr[i];
426 oob_error += avg_resp*avg_resp;
427 resp = (resp - true_resp_ptr[i])/maximal_response;
428 ncorrect_responses += exp( -resp*resp );
430 else //classification
436 cvGetRow(oob_sample_votes, &votes, i);
437 votes.data.i[predicted_node->class_idx]++;
440 cvMinMaxLoc( &votes, 0, 0, 0, &max_loc );
442 prdct_resp = data->cat_map->data.i[max_loc.x];
443 oob_error += (fabs(prdct_resp - true_resp_ptr[i]) < FLT_EPSILON) ? 0 : 1;
445 ncorrect_responses += cvRound(predicted_node->value - true_resp_ptr[i]) == 0;
449 if( oob_samples_count > 0 )
450 oob_error /= (double)oob_samples_count;
452 // estimate variable importance
453 if( var_importance && oob_samples_count > 0 )
457 memcpy( oob_samples_perm_ptr, samples_ptr, dims*nsamples*sizeof(float));
458 for( m = 0; m < dims; m++ )
460 double ncorrect_responses_permuted = 0;
461 // randomly permute values of the m-th variable in the oob samples
462 float* mth_var_ptr = oob_samples_perm_ptr + m;
464 for( i = 0; i < nsamples; i++ )
469 if( sample_idx_mask_for_tree->data.ptr[i] ) //the sample is not OOB
471 i1 = cvRandInt( &rng ) % nsamples;
472 i2 = cvRandInt( &rng ) % nsamples;
473 CV_SWAP( mth_var_ptr[i1*dims], mth_var_ptr[i2*dims], temp );
475 // turn values of (m-1)-th variable, that were permuted
476 // at the previous iteration, untouched
478 oob_samples_perm_ptr[i*dims+m-1] = samples_ptr[i*dims+m-1];
481 // predict "permuted" cases and calculate the number of votes for the
482 // correct class in the variable-m-permuted oob data
483 sample = cvMat( 1, dims, CV_32FC1, oob_samples_perm_ptr );
484 missing = cvMat( 1, dims, CV_8UC1, missing_ptr );
485 for( i = 0; i < nsamples; i++,
486 sample.data.fl += dims, missing.data.ptr += dims )
488 double predct_resp, true_resp;
490 if( sample_idx_mask_for_tree->data.ptr[i] ) //the sample is not OOB
493 predct_resp = tree->predict(&sample, &missing, true)->value;
494 true_resp = true_resp_ptr[i];
495 if( data->is_classifier )
496 ncorrect_responses_permuted += cvRound(true_resp - predct_resp) == 0;
499 true_resp = (true_resp - predct_resp)/maximal_response;
500 ncorrect_responses_permuted += exp( -true_resp*true_resp );
503 var_importance->data.fl[m] += (float)(ncorrect_responses
504 - ncorrect_responses_permuted);
509 if( term_crit.type != CV_TERMCRIT_ITER && oob_error < max_oob_err )
515 for ( int vi = 0; vi < var_importance->cols; vi++ )
516 var_importance->data.fl[vi] = ( var_importance->data.fl[vi] > 0 ) ?
517 var_importance->data.fl[vi] : 0;
518 cvNormalize( var_importance, var_importance, 1., 0, CV_L1 );
521 cvFree( &oob_samples_perm_ptr );
522 cvFree( &samples_ptr );
523 cvFree( &missing_ptr );
524 cvFree( &true_resp_ptr );
526 cvReleaseMat( &sample_idx_mask_for_tree );
527 cvReleaseMat( &sample_idx_for_tree );
529 cvReleaseMat( &oob_sample_votes );
530 cvReleaseMat( &oob_responses );
536 const CvMat* CvRTrees::get_var_importance()
538 return var_importance;
542 float CvRTrees::get_proximity( const CvMat* sample1, const CvMat* sample2,
543 const CvMat* missing1, const CvMat* missing2 ) const
547 for( int i = 0; i < ntrees; i++ )
548 result += trees[i]->predict( sample1, missing1 ) ==
549 trees[i]->predict( sample2, missing2 ) ? 1 : 0;
550 result = result/(float)ntrees;
555 float CvRTrees::calc_error( CvMLData* _data, int type , vector<float> *resp )
558 const CvMat* values = _data->get_values();
559 const CvMat* response = _data->get_responses();
560 const CvMat* missing = _data->get_missing();
561 const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();
562 const CvMat* var_types = _data->get_var_types();
563 int* sidx = sample_idx ? sample_idx->data.i : 0;
564 int r_step = CV_IS_MAT_CONT(response->type) ?
565 1 : response->step / CV_ELEM_SIZE(response->type);
566 bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL;
567 int sample_count = sample_idx ? sample_idx->cols : 0;
568 sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? values->rows : sample_count;
569 float* pred_resp = 0;
570 if( resp && (sample_count > 0) )
572 resp->resize( sample_count );
573 pred_resp = &((*resp)[0]);
577 for( int i = 0; i < sample_count; i++ )
580 int si = sidx ? sidx[i] : i;
581 cvGetRow( values, &sample, si );
583 cvGetRow( missing, &miss, si );
584 float r = (float)predict( &sample, missing ? &miss : 0 );
587 int d = fabs((double)r - response->data.fl[si*r_step]) <= FLT_EPSILON ? 0 : 1;
590 err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
594 for( int i = 0; i < sample_count; i++ )
597 int si = sidx ? sidx[i] : i;
598 cvGetRow( values, &sample, si );
600 cvGetRow( missing, &miss, si );
601 float r = (float)predict( &sample, missing ? &miss : 0 );
604 float d = r - response->data.fl[si*r_step];
607 err = sample_count ? err / (float)sample_count : -FLT_MAX;
612 float CvRTrees::get_train_error()
616 int sample_count = data->sample_count;
617 int var_count = data->var_count;
619 float *values_ptr = (float*)cvAlloc( sizeof(float)*sample_count*var_count );
620 uchar *missing_ptr = (uchar*)cvAlloc( sizeof(uchar)*sample_count*var_count );
621 float *responses_ptr = (float*)cvAlloc( sizeof(float)*sample_count );
623 data->get_vectors( 0, values_ptr, missing_ptr, responses_ptr);
625 if (data->is_classifier)
628 float *vp = values_ptr;
629 uchar *mp = missing_ptr;
630 for (int si = 0; si < sample_count; si++, vp += var_count, mp += var_count)
632 CvMat sample = cvMat( 1, var_count, CV_32FC1, vp );
633 CvMat missing = cvMat( 1, var_count, CV_8UC1, mp );
634 float r = predict( &sample, &missing );
635 if (fabs(r - responses_ptr[si]) >= FLT_EPSILON)
638 err = (float)err_count / (float)sample_count;
641 CV_Error( CV_StsBadArg, "This method is not supported for regression problems" );
643 cvFree( &values_ptr );
644 cvFree( &missing_ptr );
645 cvFree( &responses_ptr );
651 float CvRTrees::predict( const CvMat* sample, const CvMat* missing ) const
656 if( nclasses > 0 ) //classification
659 int* votes = (int*)alloca( sizeof(int)*nclasses );
660 memset( votes, 0, sizeof(*votes)*nclasses );
661 for( k = 0; k < ntrees; k++ )
663 CvDTreeNode* predicted_node = trees[k]->predict( sample, missing );
665 int class_idx = predicted_node->class_idx;
666 CV_Assert( 0 <= class_idx && class_idx < nclasses );
668 nvotes = ++votes[class_idx];
669 if( nvotes > max_nvotes )
672 result = predicted_node->value;
679 for( k = 0; k < ntrees; k++ )
680 result += trees[k]->predict( sample, missing )->value;
681 result /= (double)ntrees;
684 return (float)result;
687 float CvRTrees::predict_prob( const CvMat* sample, const CvMat* missing) const
692 if( nclasses == 2 ) //classification
695 int* votes = (int*)alloca( sizeof(int)*nclasses );
696 memset( votes, 0, sizeof(*votes)*nclasses );
697 for( k = 0; k < ntrees; k++ )
699 CvDTreeNode* predicted_node = trees[k]->predict( sample, missing );
701 int class_idx = predicted_node->class_idx;
702 CV_Assert( 0 <= class_idx && class_idx < nclasses );
704 nvotes = ++votes[class_idx];
705 if( nvotes > max_nvotes )
708 result = predicted_node->value;
712 return float(votes[1])/ntrees;
715 CV_Error(CV_StsBadArg, "This function works for binary classification problems only...");
720 void CvRTrees::write( CvFileStorage* fs, const char* name ) const
724 if( ntrees < 1 || !trees || nsamples < 1 )
725 CV_Error( CV_StsBadArg, "Invalid CvRTrees object" );
727 cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_RTREES );
729 cvWriteInt( fs, "nclasses", nclasses );
730 cvWriteInt( fs, "nsamples", nsamples );
731 cvWriteInt( fs, "nactive_vars", (int)cvSum(active_var_mask).val[0] );
732 cvWriteReal( fs, "oob_error", oob_error );
735 cvWrite( fs, "var_importance", var_importance );
737 cvWriteInt( fs, "ntrees", ntrees );
739 data->write_params( fs );
741 cvStartWriteStruct( fs, "trees", CV_NODE_SEQ );
743 for( k = 0; k < ntrees; k++ )
745 cvStartWriteStruct( fs, 0, CV_NODE_MAP );
746 trees[k]->write( fs );
747 cvEndWriteStruct( fs );
750 cvEndWriteStruct( fs ); //trees
751 cvEndWriteStruct( fs ); //CV_TYPE_NAME_ML_RTREES
755 void CvRTrees::read( CvFileStorage* fs, CvFileNode* fnode )
757 int nactive_vars, var_count, k;
759 CvFileNode* trees_fnode = 0;
763 nclasses = cvReadIntByName( fs, fnode, "nclasses", -1 );
764 nsamples = cvReadIntByName( fs, fnode, "nsamples" );
765 nactive_vars = cvReadIntByName( fs, fnode, "nactive_vars", -1 );
766 oob_error = cvReadRealByName(fs, fnode, "oob_error", -1 );
767 ntrees = cvReadIntByName( fs, fnode, "ntrees", -1 );
769 var_importance = (CvMat*)cvReadByName( fs, fnode, "var_importance" );
771 if( nclasses < 0 || nsamples <= 0 || nactive_vars < 0 || oob_error < 0 || ntrees <= 0)
772 CV_Error( CV_StsParseError, "Some <nclasses>, <nsamples>, <var_count>, "
773 "<nactive_vars>, <oob_error>, <ntrees> of tags are missing" );
777 trees = (CvForestTree**)cvAlloc( sizeof(trees[0])*ntrees );
778 memset( trees, 0, sizeof(trees[0])*ntrees );
780 data = new CvDTreeTrainData();
781 data->read_params( fs, fnode );
784 trees_fnode = cvGetFileNodeByName( fs, fnode, "trees" );
785 if( !trees_fnode || !CV_NODE_IS_SEQ(trees_fnode->tag) )
786 CV_Error( CV_StsParseError, "<trees> tag is missing" );
788 cvStartReadSeq( trees_fnode->data.seq, &reader );
789 if( reader.seq->total != ntrees )
790 CV_Error( CV_StsParseError,
791 "<ntrees> is not equal to the number of trees saved in file" );
793 for( k = 0; k < ntrees; k++ )
795 trees[k] = new CvForestTree();
796 trees[k]->read( fs, (CvFileNode*)reader.ptr, this, data );
797 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
800 var_count = data->var_count;
801 active_var_mask = cvCreateMat( 1, var_count, CV_8UC1 );
803 // initialize active variables mask
804 CvMat submask1, submask2;
805 cvGetCols( active_var_mask, &submask1, 0, nactive_vars );
806 cvGetCols( active_var_mask, &submask2, nactive_vars, var_count );
807 cvSet( &submask1, cvScalar(1) );
813 int CvRTrees::get_tree_count() const
818 CvForestTree* CvRTrees::get_tree(int i) const
820 return (unsigned)i < (unsigned)ntrees ? trees[i] : 0;
825 bool CvRTrees::train( const Mat& _train_data, int _tflag,
826 const Mat& _responses, const Mat& _var_idx,
827 const Mat& _sample_idx, const Mat& _var_type,
828 const Mat& _missing_mask, CvRTParams _params )
830 CvMat tdata = _train_data, responses = _responses, vidx = _var_idx,
831 sidx = _sample_idx, vtype = _var_type, mmask = _missing_mask;
832 return train(&tdata, _tflag, &responses, vidx.data.ptr ? &vidx : 0,
833 sidx.data.ptr ? &sidx : 0, vtype.data.ptr ? &vtype : 0,
834 mmask.data.ptr ? &mmask : 0, _params);
838 float CvRTrees::predict( const Mat& _sample, const Mat& _missing ) const
840 CvMat sample = _sample, mmask = _missing;
841 return predict(&sample, mmask.data.ptr ? &mmask : 0);
844 float CvRTrees::predict_prob( const Mat& _sample, const Mat& _missing) const
846 CvMat sample = _sample, mmask = _missing;
847 return predict_prob(&sample, mmask.data.ptr ? &mmask : 0);