--- /dev/null
+/*M///////////////////////////////////////////////////////////////////////////////////////
+
+ IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
+
+ By downloading, copying, installing or using the software you agree to this license.
+ If you do not agree to this license, do not download, install,
+ copy or use the software.
+
+
+ Intel License Agreement
+
+ Copyright (C) 2000, Intel Corporation, all rights reserved.
+ Third party copyrights are property of their respective owners.
+
+ Redistribution and use in source and binary forms, with or without modification,
+ are permitted provided that the following conditions are met:
+
+ * Redistribution's of source code must retain the above copyright notice,
+ this list of conditions and the following disclaimer.
+
+ * Redistribution's in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+ * The name of Intel Corporation may not be used to endorse or promote products
+ derived from this software without specific prior written permission.
+
+ This software is provided by the copyright holders and contributors "as is" and
+ any express or implied warranties, including, but not limited to, the implied
+ warranties of merchantability and fitness for a particular purpose are disclaimed.
+ In no event shall the Intel Corporation or contributors be liable for any direct,
+ indirect, incidental, special, exemplary, or consequential damages
+ (including, but not limited to, procurement of substitute goods or services;
+ loss of use, data, or profits; or business interruption) however caused
+ and on any theory of liability, whether in contract, strict liability,
+ or tort (including negligence or otherwise) arising in any way out of
+ the use of this software, even if advised of the possibility of such damage.
+
+M*/
+
+#include "_ml.h"
+
+static const float ord_nan = FLT_MAX*0.5f;
+static const int min_block_size = 1 << 16;
+static const int block_size_delta = 1 << 10;
+
+#define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))
+static CV_IMPLEMENT_QSORT_EX( icvSortIntPtr, int*, CV_CMP_NUM_PTR, int )
+
+#define CV_CMP_PAIRS(a,b) (*((a).i) < *((b).i))
+static CV_IMPLEMENT_QSORT_EX( icvSortPairs, CvPair16u32s, CV_CMP_PAIRS, int )
+
+///
+
+void CvERTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
+ const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,
+ const CvMat* _var_type, const CvMat* _missing_mask, const CvDTreeParams& _params,
+ bool _shared, bool _add_labels, bool _update_data )
+{
+ CvMat* sample_indices = 0;
+ CvMat* var_type0 = 0;
+ CvMat* tmp_map = 0;
+ int** int_ptr = 0;
+ CvPair16u32s* pair16u32s_ptr = 0;
+ CvDTreeTrainData* data = 0;
+ float *_fdst = 0;
+ int *_idst = 0;
+ unsigned short* udst = 0;
+ int* idst = 0;
+
+ CV_FUNCNAME( "CvERTreeTrainData::set_data" );
+
+ __BEGIN__;
+
+ int sample_all = 0, r_type = 0, cv_n;
+ int total_c_count = 0;
+ int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;
+ int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step
+ int vi, i, size;
+ char err[100];
+ const int *sidx = 0, *vidx = 0;
+
+ if ( _params.use_surrogates )
+ CV_ERROR(CV_StsBadArg, "CvERTrees do not support surrogate splits");
+
+ if( _update_data && data_root )
+ {
+ CV_ERROR(CV_StsBadArg, "CvERTrees do not support data update");
+ }
+
+ clear();
+
+ var_all = 0;
+ rng = cvRNG(-1);
+
+ CV_CALL( set_params( _params ));
+
+ // check parameter types and sizes
+ CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all ));
+
+ train_data = _train_data;
+ responses = _responses;
+ missing_mask = _missing_mask;
+
+ if( _tflag == CV_ROW_SAMPLE )
+ {
+ ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
+ dv_step = 1;
+ if( _missing_mask )
+ ms_step = _missing_mask->step, mv_step = 1;
+ }
+ else
+ {
+ dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
+ ds_step = 1;
+ if( _missing_mask )
+ mv_step = _missing_mask->step, ms_step = 1;
+ }
+ tflag = _tflag;
+
+ sample_count = sample_all;
+ var_count = var_all;
+
+ if( _sample_idx )
+ {
+ CV_CALL( sample_indices = cvPreprocessIndexArray( _sample_idx, sample_all ));
+ sidx = sample_indices->data.i;
+ sample_count = sample_indices->rows + sample_indices->cols - 1;
+ }
+
+ if( _var_idx )
+ {
+ CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all ));
+ vidx = var_idx->data.i;
+ var_count = var_idx->rows + var_idx->cols - 1;
+ }
+
+ if( !CV_IS_MAT(_responses) ||
+ (CV_MAT_TYPE(_responses->type) != CV_32SC1 &&
+ CV_MAT_TYPE(_responses->type) != CV_32FC1) ||
+ (_responses->rows != 1 && _responses->cols != 1) ||
+ _responses->rows + _responses->cols - 1 != sample_all )
+ CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "
+ "floating-point vector containing as many elements as "
+ "the total number of samples in the training data matrix" );
+
+ is_buf_16u = false;
+ if ( sample_count < 65536 )
+ is_buf_16u = true;
+
+
+ CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_count, &r_type ));
+
+ CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 ));
+
+
+ cat_var_count = 0;
+ ord_var_count = -1;
+
+ is_classifier = r_type == CV_VAR_CATEGORICAL;
+
+ // step 0. calc the number of categorical vars
+ for( vi = 0; vi < var_count; vi++ )
+ {
+ var_type->data.i[vi] = var_type0->data.ptr[vi] == CV_VAR_CATEGORICAL ?
+ cat_var_count++ : ord_var_count--;
+ }
+
+ ord_var_count = ~ord_var_count;
+ cv_n = params.cv_folds;
+ // set the two last elements of var_type array to be able
+ // to locate responses and cross-validation labels using
+ // the corresponding get_* functions.
+ var_type->data.i[var_count] = cat_var_count;
+ var_type->data.i[var_count+1] = cat_var_count+1;
+
+ // in case of single ordered predictor we need dummy cv_labels
+ // for safe split_node_data() operation
+ have_labels = cv_n > 0 || (ord_var_count == 1 && cat_var_count == 0) || _add_labels;
+
+ work_var_count = cat_var_count + (is_classifier ? 1 : 0) + (have_labels ? 1 : 0);
+ buf_size = (work_var_count + 1)*sample_count;
+ shared = _shared;
+ buf_count = shared ? 2 : 1;
+
+ if ( is_buf_16u )
+ {
+ CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_16UC1 ));
+ CV_CALL( pair16u32s_ptr = (CvPair16u32s*)cvAlloc( sample_count*sizeof(pair16u32s_ptr[0]) ));
+ }
+ else
+ {
+ CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_32SC1 ));
+ CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) ));
+ }
+
+ size = is_classifier ? cat_var_count+1 : cat_var_count;
+ size = !size ? 1 : size;
+ CV_CALL( cat_count = cvCreateMat( 1, size, CV_32SC1 ));
+ CV_CALL( cat_ofs = cvCreateMat( 1, size, CV_32SC1 ));
+
+ size = is_classifier ? (cat_var_count + 1)*params.max_categories : cat_var_count*params.max_categories;
+ size = !size ? 1 : size;
+ CV_CALL( cat_map = cvCreateMat( 1, size, CV_32SC1 ));
+
+ // now calculate the maximum size of split,
+ // create memory storage that will keep nodes and splits of the decision tree
+ // allocate root node and the buffer for the whole training data
+ max_split_size = cvAlign(sizeof(CvDTreeSplit) +
+ (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*));
+ tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
+ tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
+ CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
+ CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage ));
+
+ nv_size = var_count*sizeof(int);
+ nv_size = cvAlign(MAX( nv_size, (int)sizeof(CvSetElem) ), sizeof(void*));
+
+ temp_block_size = nv_size;
+
+ if( cv_n )
+ {
+ if( sample_count < cv_n*MAX(params.min_sample_count,10) )
+ CV_ERROR( CV_StsOutOfRange,
+ "The many folds in cross-validation for such a small dataset" );
+
+ cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) );
+ temp_block_size = MAX(temp_block_size, cv_size);
+ }
+
+ temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size );
+ CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size ));
+ CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage ));
+ if( cv_size )
+ CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage ));
+
+ CV_CALL( data_root = new_node( 0, sample_count, 0, 0 ));
+
+ max_c_count = 1;
+
+ _fdst = 0;
+ _idst = 0;
+ if (ord_var_count)
+ _fdst = (float*)cvAlloc(sample_count*sizeof(_fdst[0]));
+ if (is_buf_16u && (cat_var_count || is_classifier))
+ _idst = (int*)cvAlloc(sample_count*sizeof(_idst[0]));
+
+ // transform the training data to convenient representation
+ for( vi = 0; vi <= var_count; vi++ )
+ {
+ int ci;
+ const uchar* mask = 0;
+ int m_step = 0, step;
+ const int* idata = 0;
+ const float* fdata = 0;
+ int num_valid = 0;
+
+ if( vi < var_count ) // analyze i-th input variable
+ {
+ int vi0 = vidx ? vidx[vi] : vi;
+ ci = get_var_type(vi);
+ step = ds_step; m_step = ms_step;
+ if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 )
+ idata = _train_data->data.i + vi0*dv_step;
+ else
+ fdata = _train_data->data.fl + vi0*dv_step;
+ if( _missing_mask )
+ mask = _missing_mask->data.ptr + vi0*mv_step;
+ }
+ else // analyze _responses
+ {
+ ci = cat_var_count;
+ step = CV_IS_MAT_CONT(_responses->type) ?
+ 1 : _responses->step / CV_ELEM_SIZE(_responses->type);
+ if( CV_MAT_TYPE(_responses->type) == CV_32SC1 )
+ idata = _responses->data.i;
+ else
+ fdata = _responses->data.fl;
+ }
+
+ if( (vi < var_count && ci>=0) ||
+ (vi == var_count && is_classifier) ) // process categorical variable or response
+ {
+ int c_count, prev_label;
+ int* c_map;
+
+ if (is_buf_16u)
+ udst = (unsigned short*)(buf->data.s + ci*sample_count);
+ else
+ idst = buf->data.i + ci*sample_count;
+
+ // copy data
+ for( i = 0; i < sample_count; i++ )
+ {
+ int val = INT_MAX, si = sidx ? sidx[i] : i;
+ if( !mask || !mask[si*m_step] )
+ {
+ if( idata )
+ val = idata[si*step];
+ else
+ {
+ float t = fdata[si*step];
+ val = cvRound(t);
+ if( val != t )
+ {
+ sprintf( err, "%d-th value of %d-th (categorical) "
+ "variable is not an integer", i, vi );
+ CV_ERROR( CV_StsBadArg, err );
+ }
+ }
+
+ if( val == INT_MAX )
+ {
+ sprintf( err, "%d-th value of %d-th (categorical) "
+ "variable is too large", i, vi );
+ CV_ERROR( CV_StsBadArg, err );
+ }
+ num_valid++;
+ }
+ if (is_buf_16u)
+ {
+ _idst[i] = val;
+ pair16u32s_ptr[i].u = udst + i;
+ pair16u32s_ptr[i].i = _idst + i;
+ }
+ else
+ {
+ idst[i] = val;
+ int_ptr[i] = idst + i;
+ }
+ }
+
+ c_count = num_valid > 0;
+
+ if (is_buf_16u)
+ {
+ icvSortPairs( pair16u32s_ptr, sample_count, 0 );
+ // count the categories
+ for( i = 1; i < num_valid; i++ )
+ if (*pair16u32s_ptr[i].i != *pair16u32s_ptr[i-1].i)
+ c_count ++ ;
+ }
+ else
+ {
+ icvSortIntPtr( int_ptr, sample_count, 0 );
+ // count the categories
+ for( i = 1; i < num_valid; i++ )
+ c_count += *int_ptr[i] != *int_ptr[i-1];
+ }
+
+ if( vi > 0 )
+ max_c_count = MAX( max_c_count, c_count );
+ cat_count->data.i[ci] = c_count;
+ cat_ofs->data.i[ci] = total_c_count;
+
+ // resize cat_map, if need
+ if( cat_map->cols < total_c_count + c_count )
+ {
+ tmp_map = cat_map;
+ CV_CALL( cat_map = cvCreateMat( 1,
+ MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 ));
+ for( i = 0; i < total_c_count; i++ )
+ cat_map->data.i[i] = tmp_map->data.i[i];
+ cvReleaseMat( &tmp_map );
+ }
+
+ c_map = cat_map->data.i + total_c_count;
+ total_c_count += c_count;
+
+ c_count = -1;
+ if (is_buf_16u)
+ {
+ // compact the class indices and build the map
+ prev_label = ~*pair16u32s_ptr[0].i;
+ for( i = 0; i < num_valid; i++ )
+ {
+ int cur_label = *pair16u32s_ptr[i].i;
+ if( cur_label != prev_label )
+ c_map[++c_count] = prev_label = cur_label;
+ *pair16u32s_ptr[i].u = (unsigned short)c_count;
+ }
+ // replace labels for missing values with 65535
+ for( ; i < sample_count; i++ )
+ *pair16u32s_ptr[i].u = 65535;
+ }
+ else
+ {
+ // compact the class indices and build the map
+ prev_label = ~*int_ptr[0];
+ for( i = 0; i < num_valid; i++ )
+ {
+ int cur_label = *int_ptr[i];
+ if( cur_label != prev_label )
+ c_map[++c_count] = prev_label = cur_label;
+ *int_ptr[i] = c_count;
+ }
+ // replace labels for missing values with -1
+ for( ; i < sample_count; i++ )
+ *int_ptr[i] = -1;
+ }
+ }
+ else if( ci < 0 ) // process ordered variable
+ {
+ for( i = 0; i < sample_count; i++ )
+ {
+ float val = ord_nan;
+ int si = sidx ? sidx[i] : i;
+ if( !mask || !mask[si*m_step] )
+ {
+ if( idata )
+ val = (float)idata[si*step];
+ else
+ val = fdata[si*step];
+
+ if( fabs(val) >= ord_nan )
+ {
+ sprintf( err, "%d-th value of %d-th (ordered) "
+ "variable (=%g) is too large", i, vi, val );
+ CV_ERROR( CV_StsBadArg, err );
+ }
+ }
+ num_valid++;
+ }
+ }
+ if( vi < var_count )
+ data_root->set_num_valid(vi, num_valid);
+ }
+
+ // set sample labels
+ if (is_buf_16u)
+ udst = (unsigned short*)(buf->data.s + get_work_var_count()*sample_count);
+ else
+ idst = buf->data.i + get_work_var_count()*sample_count;
+
+ for (i = 0; i < sample_count; i++)
+ {
+ if (udst)
+ udst[i] = sidx ? (unsigned short)sidx[i] : (unsigned short)i;
+ else
+ idst[i] = sidx ? sidx[i] : i;
+ }
+
+ if( cv_n )
+ {
+ unsigned short* udst = 0;
+ int* idst = 0;
+ CvRNG* r = &rng;
+
+ if (is_buf_16u)
+ {
+ udst = (unsigned short*)(buf->data.s + (get_work_var_count()-1)*sample_count);
+ for( i = vi = 0; i < sample_count; i++ )
+ {
+ udst[i] = (unsigned short)vi++;
+ vi &= vi < cv_n ? -1 : 0;
+ }
+
+ for( i = 0; i < sample_count; i++ )
+ {
+ int a = cvRandInt(r) % sample_count;
+ int b = cvRandInt(r) % sample_count;
+ unsigned short unsh = (unsigned short)vi;
+ CV_SWAP( udst[a], udst[b], unsh );
+ }
+ }
+ else
+ {
+ idst = buf->data.i + (get_work_var_count()-1)*sample_count;
+ for( i = vi = 0; i < sample_count; i++ )
+ {
+ idst[i] = vi++;
+ vi &= vi < cv_n ? -1 : 0;
+ }
+
+ for( i = 0; i < sample_count; i++ )
+ {
+ int a = cvRandInt(r) % sample_count;
+ int b = cvRandInt(r) % sample_count;
+ CV_SWAP( idst[a], idst[b], vi );
+ }
+ }
+ }
+
+ if ( cat_map )
+ cat_map->cols = MAX( total_c_count, 1 );
+
+ max_split_size = cvAlign(sizeof(CvDTreeSplit) +
+ (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
+ CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage ));
+
+ have_priors = is_classifier && params.priors;
+ if( is_classifier )
+ {
+ int m = get_num_classes();
+ double sum = 0;
+ CV_CALL( priors = cvCreateMat( 1, m, CV_64F ));
+ for( i = 0; i < m; i++ )
+ {
+ double val = have_priors ? params.priors[i] : 1.;
+ if( val <= 0 )
+ CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" );
+ priors->data.db[i] = val;
+ sum += val;
+ }
+
+ // normalize weights
+ if( have_priors )
+ cvScale( priors, priors, 1./sum );
+
+ CV_CALL( priors_mult = cvCloneMat( priors ));
+ CV_CALL( counts = cvCreateMat( 1, m, CV_32SC1 ));
+ }
+
+ CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 ));
+ CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 ));
+
+ {
+ int maxNumThreads = 1;
+#ifdef _OPENMP
+ maxNumThreads = cv::getNumThreads();
+#endif
+ pred_float_buf.resize(maxNumThreads);
+ pred_int_buf.resize(maxNumThreads);
+ resp_float_buf.resize(maxNumThreads);
+ resp_int_buf.resize(maxNumThreads);
+ cv_lables_buf.resize(maxNumThreads);
+ sample_idx_buf.resize(maxNumThreads);
+ for( int ti = 0; ti < maxNumThreads; ti++ )
+ {
+ pred_float_buf[ti].resize(sample_count);
+ pred_int_buf[ti].resize(sample_count);
+ resp_float_buf[ti].resize(sample_count);
+ resp_int_buf[ti].resize(sample_count);
+ cv_lables_buf[ti].resize(sample_count);
+ sample_idx_buf[ti].resize(sample_count);
+ }
+ }
+
+ __END__;
+
+ if( data )
+ delete data;
+
+ if (_fdst)
+ cvFree( &_fdst );
+ if (_idst)
+ cvFree( &_idst );
+ cvFree( &int_ptr );
+ cvReleaseMat( &var_type0 );
+ cvReleaseMat( &sample_indices );
+ cvReleaseMat( &tmp_map );
+}
+
+int CvERTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* missing_buf, const float** ord_values, const int** missing )
+{
+ int vidx = var_idx ? var_idx->data.i[vi] : vi;
+ int node_sample_count = n->sample_count;
+ int* sample_indices_buf = get_sample_idx_buf();
+ const int* sample_indices = 0;
+
+ get_sample_indices(n, sample_indices_buf, &sample_indices);
+
+ int td_step = train_data->step/CV_ELEM_SIZE(train_data->type);
+ int m_step = missing_mask ? missing_mask->step/CV_ELEM_SIZE(missing_mask->type) : 1;
+ if( tflag == CV_ROW_SAMPLE )
+ {
+ for( int i = 0; i < node_sample_count; i++ )
+ {
+ int idx = sample_indices[i];
+ missing_buf[i] = missing_mask ? *(missing_mask->data.ptr + idx * m_step + vi) : 0;
+ ord_values_buf[i] = *(train_data->data.fl + idx * td_step + vidx);
+ }
+ }
+ else
+ for( int i = 0; i < node_sample_count; i++ )
+ {
+ int idx = sample_indices[i];
+ missing_buf[i] = missing_mask ? *(missing_mask->data.ptr + vi* m_step + idx) : 0;
+ ord_values_buf[i] = *(train_data->data.fl + vidx* td_step + idx);
+ }
+ *ord_values = ord_values_buf;
+ *missing = missing_buf;
+ return 0; //TODO: return the number of non-missing values
+}
+
+
+void CvERTreeTrainData::get_sample_indices( CvDTreeNode* n, int* indices_buf, const int** indices )
+{
+ get_cat_var_data( n, var_count + (is_classifier ? 1 : 0) + (have_labels ? 1 : 0), indices_buf, indices );
+}
+
+
+void CvERTreeTrainData::get_cv_labels( CvDTreeNode* n, int* labels_buf, const int** labels )
+{
+ if (have_labels)
+ get_cat_var_data( n, var_count + (is_classifier ? 1 : 0), labels_buf, labels );
+}
+
+
+int CvERTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf, const int** cat_values )
+{
+ int ci = get_var_type( vi);
+ if( !is_buf_16u )
+ *cat_values = buf->data.i + n->buf_idx*buf->cols +
+ ci*sample_count + n->offset;
+ else {
+ const unsigned short* short_values = (const unsigned short*)(buf->data.s + n->buf_idx*buf->cols +
+ ci*sample_count + n->offset);
+ for( int i = 0; i < n->sample_count; i++ )
+ cat_values_buf[i] = short_values[i];
+ *cat_values = cat_values_buf;
+ }
+
+ return 0; //TODO: return the number of non-missing values
+}
+
+void CvERTreeTrainData::get_vectors( const CvMat* _subsample_idx,
+ float* values, uchar* missing,
+ float* responses, bool get_class_idx )
+{
+ CvMat* subsample_idx = 0;
+ CvMat* subsample_co = 0;
+
+ CV_FUNCNAME( "CvERTreeTrainData::get_vectors" );
+
+ __BEGIN__;
+
+ int i, vi, total = sample_count, count = total, cur_ofs = 0;
+ int* sidx = 0;
+ int* co = 0;
+
+ if( _subsample_idx )
+ {
+ CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
+ sidx = subsample_idx->data.i;
+ CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
+ co = subsample_co->data.i;
+ cvZero( subsample_co );
+ count = subsample_idx->cols + subsample_idx->rows - 1;
+ for( i = 0; i < count; i++ )
+ co[sidx[i]*2]++;
+ for( i = 0; i < total; i++ )
+ {
+ int count_i = co[i*2];
+ if( count_i )
+ {
+ co[i*2+1] = cur_ofs*var_count;
+ cur_ofs += count_i;
+ }
+ }
+ }
+
+ if( missing )
+ memset( missing, 1, count*var_count );
+
+ for( vi = 0; vi < var_count; vi++ )
+ {
+ int ci = get_var_type(vi);
+ if( ci >= 0 ) // categorical
+ {
+ float* dst = values + vi;
+ uchar* m = missing ? missing + vi : 0;
+ int* src_buf = get_pred_int_buf();
+ const int* src = 0;
+ get_cat_var_data(data_root, vi, src_buf, &src);
+
+ for( i = 0; i < count; i++, dst += var_count )
+ {
+ int idx = sidx ? sidx[i] : i;
+ int val = src[idx];
+ *dst = (float)val;
+ if( m )
+ {
+ *m = (!is_buf_16u && val < 0) || (is_buf_16u && (val == 65535));
+ m += var_count;
+ }
+ }
+ }
+ else // ordered
+ {
+ float* dst_buf = values + vi;
+ int* m_buf = get_pred_int_buf();
+ const float *dst = 0;
+ const int* m = 0;
+ get_ord_var_data(data_root, vi, dst_buf, m_buf, &dst, &m);
+ for (int si = 0; si < total; si++)
+ *(missing + vi + si) = m[si] == 0 ? 0 : 1;
+ }
+ }
+
+ // copy responses
+ if( responses )
+ {
+ if( is_classifier )
+ {
+ int* src_buf = get_resp_int_buf();
+ const int* src = 0;
+ get_class_labels(data_root, src_buf, &src);
+ for( i = 0; i < count; i++ )
+ {
+ int idx = sidx ? sidx[i] : i;
+ int val = get_class_idx ? src[idx] :
+ cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];
+ responses[i] = (float)val;
+ }
+ }
+ else
+ {
+ float *_values_buf = get_resp_float_buf();
+ const float* _values = 0;
+ get_ord_responses(data_root, _values_buf, &_values);
+ for( i = 0; i < count; i++ )
+ {
+ int idx = sidx ? sidx[i] : i;
+ responses[i] = _values[idx];
+ }
+ }
+ }
+
+ __END__;
+
+ cvReleaseMat( &subsample_idx );
+ cvReleaseMat( &subsample_co );
+}
+
+CvDTreeNode* CvERTreeTrainData::subsample_data( const CvMat* _subsample_idx )
+{
+ CvDTreeNode* root = 0;
+
+ CV_FUNCNAME( "CvERTreeTrainData::subsample_data" );
+
+ __BEGIN__;
+
+ if( !data_root )
+ CV_ERROR( CV_StsError, "No training data has been set" );
+
+ if( !_subsample_idx )
+ {
+ // make a copy of the root node
+ CvDTreeNode temp;
+ int i;
+ root = new_node( 0, 1, 0, 0 );
+ temp = *root;
+ *root = *data_root;
+ root->num_valid = temp.num_valid;
+ if( root->num_valid )
+ {
+ for( i = 0; i < var_count; i++ )
+ root->num_valid[i] = data_root->num_valid[i];
+ }
+ root->cv_Tn = temp.cv_Tn;
+ root->cv_node_risk = temp.cv_node_risk;
+ root->cv_node_error = temp.cv_node_error;
+ }
+ else
+ CV_ERROR( CV_StsError, "_subsample_idx must be null for extra-trees" );
+ __END__;
+
+ return root;
+}
+
+double CvForestERTree::calc_node_dir( CvDTreeNode* node )
+{
+ char* dir = (char*)data->direction->data.ptr;
+ int i, n = node->sample_count, vi = node->split->var_idx;
+ double L, R;
+
+ assert( !node->split->inversed );
+
+ if( data->get_var_type(vi) >= 0 ) // split on categorical var
+ {
+ int* labels_buf = data->get_pred_int_buf();
+ const int* labels = 0;
+ const int* subset = node->split->subset;
+ data->get_cat_var_data( node, vi, labels_buf, &labels );
+ if( !data->have_priors )
+ {
+ int sum = 0, sum_abs = 0;
+
+ for( i = 0; i < n; i++ )
+ {
+ int idx = labels[i];
+ int d = ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ) ?
+ CV_DTREE_CAT_DIR(idx,subset) : 0;
+ sum += d; sum_abs += d & 1;
+ dir[i] = (char)d;
+ }
+
+ R = (sum_abs + sum) >> 1;
+ L = (sum_abs - sum) >> 1;
+ }
+ else
+ {
+ const double* priors = data->priors_mult->data.db;
+ double sum = 0, sum_abs = 0;
+ int *responses_buf = data->get_resp_int_buf();
+ const int* responses;
+ data->get_class_labels(node, responses_buf, &responses);
+
+ for( i = 0; i < n; i++ )
+ {
+ int idx = labels[i];
+ double w = priors[responses[i]];
+ int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
+ sum += d*w; sum_abs += (d & 1)*w;
+ dir[i] = (char)d;
+ }
+
+ R = (sum_abs + sum) * 0.5;
+ L = (sum_abs - sum) * 0.5;
+ }
+ }
+ else // split on ordered var
+ {
+ float split_val = node->split->ord.c;
+ float* val_buf = data->get_pred_float_buf();
+ const float* val = 0;
+ int* missing_buf = data->get_pred_int_buf();
+ const int* missing = 0;
+ data->get_ord_var_data( node, vi, val_buf, missing_buf, &val, &missing );
+
+ if( !data->have_priors )
+ {
+ L = R = 0;
+ for( i = 0; i < n; i++ )
+ {
+ if ( missing[i] )
+ dir[i] = (char)0;
+ else
+ {
+ if ( val[i] < split_val)
+ {
+ dir[i] = (char)-1;
+ L++;
+ }
+ else
+ {
+ dir[i] = (char)1;
+ R++;
+ }
+ }
+ }
+ }
+ else
+ {
+ const double* priors = data->priors_mult->data.db;
+ int* responses_buf = data->get_resp_int_buf();
+ const int* responses = 0;
+ data->get_class_labels(node, responses_buf, &responses);
+ L = R = 0;
+ for( i = 0; i < n; i++ )
+ {
+ if ( missing[i] )
+ dir[i] = (char)0;
+ else
+ {
+ double w = priors[responses[i]];
+ if ( val[i] < split_val)
+ {
+ dir[i] = (char)-1;
+ L += w;
+ }
+ else
+ {
+ dir[i] = (char)1;
+ R += w;
+ }
+ }
+ }
+ }
+ }
+
+ node->maxlr = MAX( L, R );
+ return node->split->quality/(L + R);
+}
+
+CvDTreeSplit* CvForestERTree::find_split_ord_class( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split )
+{
+ const float epsilon = FLT_EPSILON*2;
+ const float split_delta = (1 + FLT_EPSILON) * FLT_EPSILON;
+
+ int n = node->sample_count;
+ int m = data->get_num_classes();
+
+ float* values_buf = data->get_pred_float_buf();
+ const float* values = 0;
+ int* missing_buf = data->get_pred_int_buf();
+ const int* missing = 0;
+ data->get_ord_var_data( node, vi, values_buf, missing_buf, &values, &missing );
+ int* responses_buf = data->get_resp_int_buf();
+ const int* responses = 0;
+ data->get_class_labels( node, responses_buf, &responses );
+
+ double lbest_val = 0, rbest_val = 0, best_val = init_quality, split_val = 0;
+
+ int i;
+
+ const double* priors = data->have_priors ? data->priors_mult->data.db : 0;
+
+ bool is_find_split = false;
+
+ float pmin, pmax;
+ int smpi = 0;
+ while ( missing[smpi] && (smpi < n) )
+ smpi++;
+ assert(smpi < n);
+
+ pmin = values[smpi];
+ pmax = pmin;
+ for (; smpi < n; smpi++)
+ {
+ float ptemp = values[smpi];
+ int m = missing[smpi];
+ if (m) continue;
+ if ( ptemp < pmin)
+ pmin = ptemp;
+ if ( ptemp > pmax)
+ pmax = ptemp;
+ }
+ float fdiff = pmax-pmin;
+ if (fdiff > epsilon)
+ {
+ is_find_split = true;
+ CvRNG* rng = &data->rng;
+ split_val = pmin + cvRandReal(rng) * fdiff ;
+ if (split_val - pmin <= FLT_EPSILON)
+ split_val = pmin + split_delta;
+ if (pmax - split_val <= FLT_EPSILON)
+ split_val = pmax - split_delta;
+
+ // calculate Gini index
+ if ( !priors )
+ {
+ int* lc = (int*)cvStackAlloc(m*sizeof(lc[0]));
+ int* rc = (int*)cvStackAlloc(m*sizeof(rc[0]));
+ int L = 0, R = 0;
+
+ // init arrays of class instance counters on both sides of the split
+ for( i = 0; i < m; i++ )
+ {
+ lc[i] = 0;
+ rc[i] = 0;
+ }
+ for( int si = 0; si < n; si++ )
+ {
+ int r = responses[si];
+ float val = values[si];
+ int m = missing[si];
+ if (m) continue;
+ if ( val < split_val )
+ {
+ lc[r]++;
+ L++;
+ }
+ else
+ {
+ rc[r]++;
+ R++;
+ }
+ }
+ for (int i = 0; i < m; i++)
+ {
+ lbest_val += lc[i]*lc[i];
+ rbest_val += rc[i]*rc[i];
+ }
+ best_val = (lbest_val*R + rbest_val*L) / ((double)(L*R));
+ }
+ else
+ {
+ double* lc = (double*)cvStackAlloc(m*sizeof(lc[0]));
+ double* rc = (double*)cvStackAlloc(m*sizeof(rc[0]));
+ double L = 0, R = 0;
+
+ // init arrays of class instance counters on both sides of the split
+ for( i = 0; i < m; i++ )
+ {
+ lc[i] = 0;
+ rc[i] = 0;
+ }
+ for( int si = 0; si < n; si++ )
+ {
+ int r = responses[si];
+ float val = values[si];
+ int m = missing[si];
+ double p = priors[si];
+ if (m) continue;
+ if ( val < split_val )
+ {
+ lc[r] += p;
+ L += p;
+ }
+ else
+ {
+ rc[r] += p;
+ R += p;
+ }
+ }
+ for (int i = 0; i < m; i++)
+ {
+ lbest_val += lc[i]*lc[i];
+ rbest_val += rc[i]*rc[i];
+ }
+ best_val = (lbest_val*R + rbest_val*L) / (L*R);
+ }
+
+ }
+
+ CvDTreeSplit* split = 0;
+ if( is_find_split )
+ {
+ split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
+ split->var_idx = vi;
+ split->ord.c = (float)split_val;
+ split->ord.split_point = -1;
+ split->inversed = 0;
+ split->quality = (float)best_val;
+ }
+ return split;
+}
+
+CvDTreeSplit* CvForestERTree::find_split_cat_class( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split )
+{
+ int ci = data->get_var_type(vi);
+ int n = node->sample_count;
+ int cm = data->get_num_classes();
+ int vm = data->cat_count->data.i[ci];
+ double best_val = init_quality;
+ CvDTreeSplit *split = 0;
+
+ if ( vm > 1 )
+ {
+ int* labels_buf = data->get_pred_int_buf();
+ const int* labels = 0;
+ data->get_cat_var_data( node, vi, labels_buf, &labels );
+
+ int* responses_buf = data->get_resp_int_buf();
+ const int* responses = 0;
+ data->get_class_labels( node, responses_buf, &responses );
+
+ const double* priors = data->have_priors ? data->priors_mult->data.db : 0;
+
+ // create random class mask
+ int *valid_cidx = (int*)cvStackAlloc(vm*sizeof(valid_cidx[0]));
+ for (int i = 0; i < vm; i++)
+ {
+ valid_cidx[i] = -1;
+ }
+ for (int si = 0; si < n; si++)
+ {
+ int c = labels[si];
+ if ( ((c == 65535) && data->is_buf_16u) || ((c<0) && (!data->is_buf_16u)) )
+ continue;
+ valid_cidx[c]++;
+ }
+
+ int valid_ccount = 0;
+ for (int i = 0; i < vm; i++)
+ if (valid_cidx[i] >= 0)
+ {
+ valid_cidx[i] = valid_ccount;
+ valid_ccount++;
+ }
+ if (valid_ccount > 1)
+ {
+ CvRNG* rng = forest->get_rng();
+ int l_cval_count = 1 + cvRandInt(rng) % (valid_ccount-1);
+
+ CvMat* var_class_mask = cvCreateMat( 1, valid_ccount, CV_8UC1 );
+ CvMat submask;
+ memset(var_class_mask->data.ptr, 0, valid_ccount*CV_ELEM_SIZE(var_class_mask->type));
+ cvGetCols( var_class_mask, &submask, 0, l_cval_count );
+ cvSet( &submask, cvScalar(1) );
+ for (int i = 0; i < valid_ccount; i++)
+ {
+ uchar temp;
+ int i1 = cvRandInt( rng ) % valid_ccount;
+ int i2 = cvRandInt( rng ) % valid_ccount;
+ CV_SWAP( var_class_mask->data.ptr[i1], var_class_mask->data.ptr[i2], temp );
+ }
+
+ split = _split ? _split : data->new_split_cat( 0, -1.0f );
+ split->var_idx = vi;
+ memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
+
+ // calculate Gini index
+ double lbest_val = 0, rbest_val = 0;
+ if( !priors )
+ {
+ int* lc = (int*)cvStackAlloc(cm*sizeof(lc[0]));
+ int* rc = (int*)cvStackAlloc(cm*sizeof(rc[0]));
+ int L = 0, R = 0;
+ // init arrays of class instance counters on both sides of the split
+ for(int i = 0; i < cm; i++ )
+ {
+ lc[i] = 0;
+ rc[i] = 0;
+ }
+ for( int si = 0; si < n; si++ )
+ {
+ int r = responses[si];
+ int var_class_idx = labels[si];
+ if ( ((var_class_idx == 65535) && data->is_buf_16u) || ((var_class_idx<0) && (!data->is_buf_16u)) )
+ continue;
+ int mask_class_idx = valid_cidx[var_class_idx];
+ if (var_class_mask->data.ptr[mask_class_idx])
+ {
+ lc[r]++;
+ L++;
+ split->subset[var_class_idx >> 5] |= 1 << (var_class_idx & 31);
+ }
+ else
+ {
+ rc[r]++;
+ R++;
+ }
+ }
+ for (int i = 0; i < cm; i++)
+ {
+ lbest_val += lc[i]*lc[i];
+ rbest_val += rc[i]*rc[i];
+ }
+ best_val = (lbest_val*R + rbest_val*L) / ((double)(L*R));
+ }
+ else
+ {
+ double* lc = (double*)cvStackAlloc(cm*sizeof(lc[0]));
+ double* rc = (double*)cvStackAlloc(cm*sizeof(rc[0]));
+ double L = 0, R = 0;
+ // init arrays of class instance counters on both sides of the split
+ for(int i = 0; i < cm; i++ )
+ {
+ lc[i] = 0;
+ rc[i] = 0;
+ }
+ for( int si = 0; si < n; si++ )
+ {
+ int r = responses[si];
+ int var_class_idx = labels[si];
+ if ( ((var_class_idx == 65535) && data->is_buf_16u) || ((var_class_idx<0) && (!data->is_buf_16u)) )
+ continue;
+ double p = priors[si];
+ int mask_class_idx = valid_cidx[var_class_idx];
+
+ if (var_class_mask->data.ptr[mask_class_idx])
+ {
+ lc[r]+=p;
+ L+=p;
+ split->subset[var_class_idx >> 5] |= 1 << (var_class_idx & 31);
+ }
+ else
+ {
+ rc[r]+=p;
+ R+=p;
+ }
+ }
+ for (int i = 0; i < cm; i++)
+ {
+ lbest_val += lc[i]*lc[i];
+ rbest_val += rc[i]*rc[i];
+ }
+ best_val = (lbest_val*R + rbest_val*L) / (L*R);
+ }
+ split->quality = (float)best_val;
+
+ cvReleaseMat(&var_class_mask);
+ }
+ }
+
+ return split;
+}
+
+CvDTreeSplit* CvForestERTree::find_split_ord_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split )
+{
+ const float epsilon = FLT_EPSILON*2;
+ const float split_delta = (1 + FLT_EPSILON) * FLT_EPSILON;
+ int n = node->sample_count;
+ float* values_buf = data->get_pred_float_buf();
+ const float* values = 0;
+ int* missing_buf = data->get_pred_int_buf();
+ const int* missing = 0;
+ data->get_ord_var_data( node, vi, values_buf, missing_buf, &values, &missing );
+ float* responses_buf = data->get_resp_float_buf();
+ const float* responses = 0;
+ data->get_ord_responses( node, responses_buf, &responses );
+
+ double best_val = init_quality, split_val = 0, lsum = 0, rsum = 0;
+ int L = 0, R = 0;
+
+ bool is_find_split = false;
+ float pmin, pmax;
+ int smpi = 0;
+ while ( missing[smpi] && (smpi < n) )
+ smpi++;
+
+ assert(smpi < n);
+
+ pmin = values[smpi];
+ pmax = pmin;
+ for (; smpi < n; smpi++)
+ {
+ float ptemp = values[smpi];
+ int m = missing[smpi];
+ if (m) continue;
+ if ( ptemp < pmin)
+ pmin = ptemp;
+ if ( ptemp > pmax)
+ pmax = ptemp;
+ }
+ float fdiff = pmax-pmin;
+ if (fdiff > epsilon)
+ {
+ is_find_split = true;
+ CvRNG* rng = &data->rng;
+ split_val = pmin + cvRandReal(rng) * fdiff ;
+ if (split_val - pmin <= FLT_EPSILON)
+ split_val = pmin + split_delta;
+ if (pmax - split_val <= FLT_EPSILON)
+ split_val = pmax - split_delta;
+
+ for (int si = 0; si < n; si++)
+ {
+ float r = responses[si];
+ float val = values[si];
+ int m = missing[si];
+ if (m) continue;
+ if (val < split_val)
+ {
+ lsum += r;
+ L++;
+ }
+ else
+ {
+ rsum += r;
+ R++;
+ }
+ }
+ best_val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
+ }
+
+ CvDTreeSplit* split = 0;
+ if( is_find_split )
+ {
+ split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
+ split->var_idx = vi;
+ split->ord.c = (float)split_val;
+ split->ord.split_point = -1;
+ split->inversed = 0;
+ split->quality = (float)best_val;
+ }
+ return split;
+}
+
+CvDTreeSplit* CvForestERTree::find_split_cat_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split )
+{
+ int ci = data->get_var_type(vi);
+ int n = node->sample_count;
+ int vm = data->cat_count->data.i[ci];
+ double best_val = init_quality;
+ CvDTreeSplit *split = 0;
+ float lsum = 0, rsum = 0;
+
+ if ( vm > 1 )
+ {
+ int* labels_buf = data->get_pred_int_buf();
+ const int* labels = 0;
+ data->get_cat_var_data( node, vi, labels_buf, &labels );
+
+ float* responses_buf = data->get_resp_float_buf();
+ const float* responses = 0;
+ data->get_ord_responses( node, responses_buf, &responses );
+
+ // create random class mask
+ int *valid_cidx = (int*)cvStackAlloc(vm*sizeof(valid_cidx[0]));
+ for (int i = 0; i < vm; i++)
+ {
+ valid_cidx[i] = -1;
+ }
+ for (int si = 0; si < n; si++)
+ {
+ int c = labels[si];
+ if ( ((c == 65535) && data->is_buf_16u) || ((c<0) && (!data->is_buf_16u)) )
+ continue;
+ valid_cidx[c]++;
+ }
+
+ int valid_ccount = 0;
+ for (int i = 0; i < vm; i++)
+ if (valid_cidx[i] >= 0)
+ {
+ valid_cidx[i] = valid_ccount;
+ valid_ccount++;
+ }
+ if (valid_ccount > 1)
+ {
+ CvRNG* rng = forest->get_rng();
+ int l_cval_count = 1 + cvRandInt(rng) % (valid_ccount-1);
+
+ CvMat* var_class_mask = cvCreateMat( 1, valid_ccount, CV_8UC1 );
+ CvMat submask;
+ memset(var_class_mask->data.ptr, 0, valid_ccount*CV_ELEM_SIZE(var_class_mask->type));
+ cvGetCols( var_class_mask, &submask, 0, l_cval_count );
+ cvSet( &submask, cvScalar(1) );
+ for (int i = 0; i < valid_ccount; i++)
+ {
+ uchar temp;
+ int i1 = cvRandInt( rng ) % valid_ccount;
+ int i2 = cvRandInt( rng ) % valid_ccount;
+ CV_SWAP( var_class_mask->data.ptr[i1], var_class_mask->data.ptr[i2], temp );
+ }
+
+ split = _split ? _split : data->new_split_cat( 0, -1.0f);
+ split->var_idx = vi;
+ memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
+
+ int L = 0, R = 0;
+ for( int si = 0; si < n; si++ )
+ {
+ float r = responses[si];
+ int var_class_idx = labels[si];
+ if ( ((var_class_idx == 65535) && data->is_buf_16u) || ((var_class_idx<0) && (!data->is_buf_16u)) )
+ continue;
+ int mask_class_idx = valid_cidx[var_class_idx];
+ if (var_class_mask->data.ptr[mask_class_idx])
+ {
+ lsum += r;
+ L++;
+ split->subset[var_class_idx >> 5] |= 1 << (var_class_idx & 31);
+ }
+ else
+ {
+ rsum += r;
+ R++;
+ }
+ }
+ best_val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
+
+ split->quality = (float)best_val;
+
+ cvReleaseMat(&var_class_mask);
+ }
+ }
+
+ return split;
+}
+
+//void CvForestERTree::complete_node_dir( CvDTreeNode* node )
+//{
+// int vi, i, n = node->sample_count, nl, nr, d0 = 0, d1 = -1;
+// int nz = n - node->get_num_valid(node->split->var_idx);
+// char* dir = (char*)data->direction->data.ptr;
+//
+// // try to complete direction using surrogate splits
+// if( nz && data->params.use_surrogates )
+// {
+// CvDTreeSplit* split = node->split->next;
+// for( ; split != 0 && nz; split = split->next )
+// {
+// int inversed_mask = split->inversed ? -1 : 0;
+// vi = split->var_idx;
+//
+// if( data->get_var_type(vi) >= 0 ) // split on categorical var
+// {
+// int* labels_buf = data->pred_int_buf;
+// const int* labels = 0;
+// data->get_cat_var_data(node, vi, labels_buf, &labels);
+// const int* subset = split->subset;
+//
+// for( i = 0; i < n; i++ )
+// {
+// int idx = labels[i];
+// if( !dir[i] && ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ))
+//
+// {
+// int d = CV_DTREE_CAT_DIR(idx,subset);
+// dir[i] = (char)((d ^ inversed_mask) - inversed_mask);
+// if( --nz )
+// break;
+// }
+// }
+// }
+// else // split on ordered var
+// {
+// float* values_buf = data->pred_float_buf;
+// const float* values = 0;
+// uchar* missing_buf = (uchar*)data->pred_int_buf;
+// const uchar* missing = 0;
+// data->get_ord_var_data( node, vi, values_buf, missing_buf, &values, &missing );
+// float split_val = node->split->ord.c;
+//
+// for( i = 0; i < n; i++ )
+// {
+// if( !dir[i] && !missing[i])
+// {
+// int d = values[i] <= split_val ? -1 : 1;
+// dir[i] = (char)((d ^ inversed_mask) - inversed_mask);
+// if( --nz )
+// break;
+// }
+// }
+// }
+// }
+// }
+//
+// // find the default direction for the rest
+// if( nz )
+// {
+// for( i = nr = 0; i < n; i++ )
+// nr += dir[i] > 0;
+// nl = n - nr - nz;
+// d0 = nl > nr ? -1 : nr > nl;
+// }
+//
+// // make sure that every sample is directed either to the left or to the right
+// for( i = 0; i < n; i++ )
+// {
+// int d = dir[i];
+// if( !d )
+// {
+// d = d0;
+// if( !d )
+// d = d1, d1 = -d1;
+// }
+// d = d > 0;
+// dir[i] = (char)d; // remap (-1,1) to (0,1)
+// }
+//}
+
+void CvForestERTree::split_node_data( CvDTreeNode* node )
+{
+ int vi, i, n = node->sample_count, nl, nr, scount = data->sample_count;
+ char* dir = (char*)data->direction->data.ptr;
+ CvDTreeNode *left = 0, *right = 0;
+ int new_buf_idx = data->get_child_buf_idx( node );
+ CvMat* buf = data->buf;
+ int* temp_buf = (int*)cvStackAlloc(n*sizeof(temp_buf[0]));
+
+ complete_node_dir(node);
+
+ for( i = nl = nr = 0; i < n; i++ )
+ {
+ int d = dir[i];
+ nr += d;
+ nl += d^1;
+ }
+
+ bool split_input_data;
+ node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );
+ node->right = right = data->new_node( node, nr, new_buf_idx, node->offset + nl );
+
+ split_input_data = node->depth + 1 < data->params.max_depth &&
+ (node->left->sample_count > data->params.min_sample_count ||
+ node->right->sample_count > data->params.min_sample_count);
+
+ // split ordered vars
+ for( vi = 0; vi < data->var_count; vi++ )
+ {
+ int ci = data->get_var_type(vi);
+ if (ci >= 0) continue;
+
+ int n1 = node->get_num_valid(vi), nr1 = 0;
+
+ float* values_buf = data->get_pred_float_buf();
+ const float* values = 0;
+ int* missing_buf = data->get_pred_int_buf();
+ const int* missing = 0;
+ data->get_ord_var_data( node, vi, values_buf, missing_buf, &values, &missing );
+
+ for( i = 0; i < n; i++ )
+ nr1 += (!missing[i] & dir[i]);
+ left->set_num_valid(vi, n1 - nr1);
+ right->set_num_valid(vi, nr1);
+ }
+ // split categorical vars, responses and cv_labels using new_idx relocation table
+ for( vi = 0; vi < data->get_work_var_count() + data->ord_var_count; vi++ )
+ {
+ int ci = data->get_var_type(vi);
+ if (ci < 0) continue;
+
+ int n1 = node->get_num_valid(vi), nr1 = 0;
+
+ int *src_lbls_buf = data->get_pred_int_buf();
+ const int* src_lbls = 0;
+ data->get_cat_var_data(node, vi, src_lbls_buf, &src_lbls);
+
+ for(i = 0; i < n; i++)
+ temp_buf[i] = src_lbls[i];
+
+ if (data->is_buf_16u)
+ {
+ unsigned short *ldst = (unsigned short *)(buf->data.s + left->buf_idx*buf->cols +
+ ci*scount + left->offset);
+ unsigned short *rdst = (unsigned short *)(buf->data.s + right->buf_idx*buf->cols +
+ ci*scount + right->offset);
+
+ for( i = 0; i < n; i++ )
+ {
+ int d = dir[i];
+ int idx = temp_buf[i];
+ if (d)
+ {
+ *rdst = (unsigned short)idx;
+ rdst++;
+ nr1 += (idx != 65535);
+ }
+ else
+ {
+ *ldst = (unsigned short)idx;
+ ldst++;
+ }
+ }
+
+ if( vi < data->var_count )
+ {
+ left->set_num_valid(vi, n1 - nr1);
+ right->set_num_valid(vi, nr1);
+ }
+ }
+ else
+ {
+ int *ldst = buf->data.i + left->buf_idx*buf->cols +
+ ci*scount + left->offset;
+ int *rdst = buf->data.i + right->buf_idx*buf->cols +
+ ci*scount + right->offset;
+
+ for( i = 0; i < n; i++ )
+ {
+ int d = dir[i];
+ int idx = temp_buf[i];
+ if (d)
+ {
+ *rdst = idx;
+ rdst++;
+ nr1 += (idx >= 0);
+ }
+ else
+ {
+ *ldst = idx;
+ ldst++;
+ }
+
+ }
+
+ if( vi < data->var_count )
+ {
+ left->set_num_valid(vi, n1 - nr1);
+ right->set_num_valid(vi, nr1);
+ }
+ }
+ }
+
+
+ // split sample indices
+ int *sample_idx_src_buf = data->get_sample_idx_buf();
+ const int* sample_idx_src = 0;
+ if (split_input_data)
+ {
+ data->get_sample_indices(node, sample_idx_src_buf, &sample_idx_src);
+
+ for(i = 0; i < n; i++)
+ temp_buf[i] = sample_idx_src[i];
+
+ int pos = data->get_work_var_count();
+
+ if (data->is_buf_16u)
+ {
+ unsigned short* ldst = (unsigned short*)(buf->data.s + left->buf_idx*buf->cols +
+ pos*scount + left->offset);
+ unsigned short* rdst = (unsigned short*)(buf->data.s + right->buf_idx*buf->cols +
+ pos*scount + right->offset);
+
+ for (i = 0; i < n; i++)
+ {
+ int d = dir[i];
+ unsigned short idx = (unsigned short)temp_buf[i];
+ if (d)
+ {
+ *rdst = idx;
+ rdst++;
+ }
+ else
+ {
+ *ldst = idx;
+ ldst++;
+ }
+ }
+ }
+ else
+ {
+ int* ldst = buf->data.i + left->buf_idx*buf->cols +
+ pos*scount + left->offset;
+ int* rdst = buf->data.i + right->buf_idx*buf->cols +
+ pos*scount + right->offset;
+ for (i = 0; i < n; i++)
+ {
+ int d = dir[i];
+ int idx = temp_buf[i];
+ if (d)
+ {
+ *rdst = idx;
+ rdst++;
+ }
+ else
+ {
+ *ldst = idx;
+ ldst++;
+ }
+ }
+ }
+ }
+
+ // deallocate the parent node data that is not needed anymore
+ data->free_node_data(node);
+}
+
+CvERTrees::CvERTrees()
+{
+}
+
+CvERTrees::~CvERTrees()
+{
+}
+
+bool CvERTrees::train( const CvMat* _train_data, int _tflag,
+ const CvMat* _responses, const CvMat* _var_idx,
+ const CvMat* _sample_idx, const CvMat* _var_type,
+ const CvMat* _missing_mask, CvRTParams params )
+{
+ bool result = false;
+
+ CV_FUNCNAME("CvERTrees::train");
+ __BEGIN__
+ int var_count = 0;
+
+ clear();
+
+ CvDTreeParams tree_params( params.max_depth, params.min_sample_count,
+ params.regression_accuracy, params.use_surrogates, params.max_categories,
+ params.cv_folds, params.use_1se_rule, false, params.priors );
+
+ data = new CvERTreeTrainData();
+ CV_CALL(data->set_data( _train_data, _tflag, _responses, _var_idx,
+ _sample_idx, _var_type, _missing_mask, tree_params, true));
+
+ var_count = data->var_count;
+ if( params.nactive_vars > var_count )
+ params.nactive_vars = var_count;
+ else if( params.nactive_vars == 0 )
+ params.nactive_vars = (int)sqrt((double)var_count);
+ else if( params.nactive_vars < 0 )
+ CV_ERROR( CV_StsBadArg, "<nactive_vars> must be non-negative" );
+
+ // Create mask of active variables at the tree nodes
+ CV_CALL(active_var_mask = cvCreateMat( 1, var_count, CV_8UC1 ));
+ if( params.calc_var_importance )
+ {
+ CV_CALL(var_importance = cvCreateMat( 1, var_count, CV_32FC1 ));
+ cvZero(var_importance);
+ }
+ { // initialize active variables mask
+ CvMat submask1, submask2;
+ cvGetCols( active_var_mask, &submask1, 0, params.nactive_vars );
+ cvGetCols( active_var_mask, &submask2, params.nactive_vars, var_count );
+ cvSet( &submask1, cvScalar(1) );
+ cvZero( &submask2 );
+ }
+
+ CV_CALL(result = grow_forest( params.term_crit ));
+
+ result = true;
+
+ __END__
+ return result;
+
+}
+
+bool CvERTrees::train( CvMLData* data, CvRTParams params)
+{
+ bool result = false;
+
+ CV_FUNCNAME( "CvERTrees::train" );
+
+ __BEGIN__;
+
+ CV_CALL( result = CvRTrees::train( data, params) );
+
+ __END__;
+
+ return result;
+}
+
+bool CvERTrees::grow_forest( const CvTermCriteria term_crit )
+{
+ bool result = false;
+
+ CvMat* sample_idx_for_tree = 0;
+
+ CV_FUNCNAME("CvERTrees::grow_forest");
+ __BEGIN__;
+
+ const int max_ntrees = term_crit.max_iter;
+ const double max_oob_err = term_crit.epsilon;
+
+ const int dims = data->var_count;
+ float maximal_response = 0;
+
+ CvMat* oob_sample_votes = 0;
+ CvMat* oob_responses = 0;
+
+ float* oob_samples_perm_ptr= 0;
+
+ float* samples_ptr = 0;
+ uchar* missing_ptr = 0;
+ float* true_resp_ptr = 0;
+ bool is_oob_or_vimportance = ((max_oob_err > 0) && (term_crit.type != CV_TERMCRIT_ITER)) || var_importance;
+
+ // oob_predictions_sum[i] = sum of predicted values for the i-th sample
+ // oob_num_of_predictions[i] = number of summands
+ // (number of predictions for the i-th sample)
+ // initialize these variable to avoid warning C4701
+ CvMat oob_predictions_sum = cvMat( 1, 1, CV_32FC1 );
+ CvMat oob_num_of_predictions = cvMat( 1, 1, CV_32FC1 );
+
+ nsamples = data->sample_count;
+ nclasses = data->get_num_classes();
+
+ if ( is_oob_or_vimportance )
+ {
+ if( data->is_classifier )
+ {
+ CV_CALL(oob_sample_votes = cvCreateMat( nsamples, nclasses, CV_32SC1 ));
+ cvZero(oob_sample_votes);
+ }
+ else
+ {
+ // oob_responses[0,i] = oob_predictions_sum[i]
+ // = sum of predicted values for the i-th sample
+ // oob_responses[1,i] = oob_num_of_predictions[i]
+ // = number of summands (number of predictions for the i-th sample)
+ CV_CALL(oob_responses = cvCreateMat( 2, nsamples, CV_32FC1 ));
+ cvZero(oob_responses);
+ cvGetRow( oob_responses, &oob_predictions_sum, 0 );
+ cvGetRow( oob_responses, &oob_num_of_predictions, 1 );
+ }
+
+ CV_CALL(oob_samples_perm_ptr = (float*)cvAlloc( sizeof(float)*nsamples*dims ));
+ CV_CALL(samples_ptr = (float*)cvAlloc( sizeof(float)*nsamples*dims ));
+ CV_CALL(missing_ptr = (uchar*)cvAlloc( sizeof(uchar)*nsamples*dims ));
+ CV_CALL(true_resp_ptr = (float*)cvAlloc( sizeof(float)*nsamples ));
+
+ CV_CALL(data->get_vectors( 0, samples_ptr, missing_ptr, true_resp_ptr ));
+ {
+ double minval, maxval;
+ CvMat responses = cvMat(1, nsamples, CV_32FC1, true_resp_ptr);
+ cvMinMaxLoc( &responses, &minval, &maxval );
+ maximal_response = (float)MAX( MAX( fabs(minval), fabs(maxval) ), 0 );
+ }
+ }
+
+ trees = (CvForestTree**)cvAlloc( sizeof(trees[0])*max_ntrees );
+ memset( trees, 0, sizeof(trees[0])*max_ntrees );
+
+ CV_CALL(sample_idx_for_tree = cvCreateMat( 1, nsamples, CV_32SC1 ));
+
+ for (int i = 0; i < nsamples; i++)
+ sample_idx_for_tree->data.i[i] = i;
+ ntrees = 0;
+ while( ntrees < max_ntrees )
+ {
+ int i, oob_samples_count = 0;
+ double ncorrect_responses = 0; // used for estimation of variable importance
+ CvForestTree* tree = 0;
+
+ trees[ntrees] = new CvForestERTree();
+ tree = (CvForestERTree*)trees[ntrees];
+ CV_CALL(tree->train( data, 0, this ));
+
+ if ( is_oob_or_vimportance )
+ {
+ CvMat sample, missing;
+ // form array of OOB samples indices and get these samples
+ sample = cvMat( 1, dims, CV_32FC1, samples_ptr );
+ missing = cvMat( 1, dims, CV_8UC1, missing_ptr );
+
+ oob_error = 0;
+ for( i = 0; i < nsamples; i++,
+ sample.data.fl += dims, missing.data.ptr += dims )
+ {
+ CvDTreeNode* predicted_node = 0;
+
+ // predict oob samples
+ if( !predicted_node )
+ CV_CALL(predicted_node = tree->predict(&sample, &missing, true));
+
+ if( !data->is_classifier ) //regression
+ {
+ double avg_resp, resp = predicted_node->value;
+ oob_predictions_sum.data.fl[i] += (float)resp;
+ oob_num_of_predictions.data.fl[i] += 1;
+
+ // compute oob error
+ avg_resp = oob_predictions_sum.data.fl[i]/oob_num_of_predictions.data.fl[i];
+ avg_resp -= true_resp_ptr[i];
+ oob_error += avg_resp*avg_resp;
+ resp = (resp - true_resp_ptr[i])/maximal_response;
+ ncorrect_responses += exp( -resp*resp );
+ }
+ else //classification
+ {
+ double prdct_resp;
+ CvPoint max_loc;
+ CvMat votes;
+
+ cvGetRow(oob_sample_votes, &votes, i);
+ votes.data.i[predicted_node->class_idx]++;
+
+ // compute oob error
+ cvMinMaxLoc( &votes, 0, 0, 0, &max_loc );
+
+ prdct_resp = data->cat_map->data.i[max_loc.x];
+ oob_error += (fabs(prdct_resp - true_resp_ptr[i]) < FLT_EPSILON) ? 0 : 1;
+
+ ncorrect_responses += cvRound(predicted_node->value - true_resp_ptr[i]) == 0;
+ }
+ oob_samples_count++;
+ }
+ if( oob_samples_count > 0 )
+ oob_error /= (double)oob_samples_count;
+
+ // estimate variable importance
+ if( var_importance && oob_samples_count > 0 )
+ {
+ int m;
+
+ memcpy( oob_samples_perm_ptr, samples_ptr, dims*nsamples*sizeof(float));
+ for( m = 0; m < dims; m++ )
+ {
+ double ncorrect_responses_permuted = 0;
+ // randomly permute values of the m-th variable in the oob samples
+ float* mth_var_ptr = oob_samples_perm_ptr + m;
+
+ for( i = 0; i < nsamples; i++ )
+ {
+ int i1, i2;
+ float temp;
+
+ i1 = cvRandInt( &rng ) % nsamples;
+ i2 = cvRandInt( &rng ) % nsamples;
+ CV_SWAP( mth_var_ptr[i1*dims], mth_var_ptr[i2*dims], temp );
+
+ // turn values of (m-1)-th variable, that were permuted
+ // at the previous iteration, untouched
+ if( m > 1 )
+ oob_samples_perm_ptr[i*dims+m-1] = samples_ptr[i*dims+m-1];
+ }
+
+ // predict "permuted" cases and calculate the number of votes for the
+ // correct class in the variable-m-permuted oob data
+ sample = cvMat( 1, dims, CV_32FC1, oob_samples_perm_ptr );
+ missing = cvMat( 1, dims, CV_8UC1, missing_ptr );
+ for( i = 0; i < nsamples; i++,
+ sample.data.fl += dims, missing.data.ptr += dims )
+ {
+ double predct_resp, true_resp;
+
+ predct_resp = tree->predict(&sample, &missing, true)->value;
+ true_resp = true_resp_ptr[i];
+ if( data->is_classifier )
+ ncorrect_responses_permuted += cvRound(true_resp - predct_resp) == 0;
+ else
+ {
+ true_resp = (true_resp - predct_resp)/maximal_response;
+ ncorrect_responses_permuted += exp( -true_resp*true_resp );
+ }
+ }
+ var_importance->data.fl[m] += (float)(ncorrect_responses
+ - ncorrect_responses_permuted);
+ }
+ }
+ }
+ ntrees++;
+ if( term_crit.type != CV_TERMCRIT_ITER && oob_error < max_oob_err )
+ break;
+ }
+ if( var_importance )
+ {
+ for ( int vi = 0; vi < var_importance->cols; vi++ )
+ var_importance->data.fl[vi] = ( var_importance->data.fl[vi] > 0 ) ?
+ var_importance->data.fl[vi] : 0;
+ cvNormalize( var_importance, var_importance, 1., 0, CV_L1 );
+ }
+
+ result = true;
+
+ cvFree( &oob_samples_perm_ptr );
+ cvFree( &samples_ptr );
+ cvFree( &missing_ptr );
+ cvFree( &true_resp_ptr );
+
+ cvReleaseMat( &sample_idx_for_tree );
+
+ cvReleaseMat( &oob_sample_votes );
+ cvReleaseMat( &oob_responses );
+
+ __END__;
+
+ return result;
+}
+
+using namespace cv;
+
+bool CvERTrees::train( const Mat& _train_data, int _tflag,
+ const Mat& _responses, const Mat& _var_idx,
+ const Mat& _sample_idx, const Mat& _var_type,
+ const Mat& _missing_mask, CvRTParams params )
+{
+ CvMat tdata = _train_data, responses = _responses, vidx = _var_idx,
+ sidx = _sample_idx, vtype = _var_type, mmask = _missing_mask;
+ return train(&tdata, _tflag, &responses, vidx.data.ptr ? &vidx : 0,
+ sidx.data.ptr ? &sidx : 0, vtype.data.ptr ? &vtype : 0,
+ mmask.data.ptr ? &mmask : 0, params);
+}
+
+// End of file.
+