1 /*M///////////////////////////////////////////////////////////////////////////////////////
\r
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
\r
5 // By downloading, copying, installing or using the software you agree to this license.
\r
6 // If you do not agree to this license, do not download, install,
\r
7 // copy or use the software.
\r
10 // Intel License Agreement
\r
11 // For Open Source Computer Vision Library
\r
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
\r
14 // Third party copyrights are property of their respective owners.
\r
16 // Redistribution and use in source and binary forms, with or without modification,
\r
17 // are permitted provided that the following conditions are met:
\r
19 // * Redistribution's of source code must retain the above copyright notice,
\r
20 // this list of conditions and the following disclaimer.
\r
22 // * Redistribution's in binary form must reproduce the above copyright notice,
\r
23 // this list of conditions and the following disclaimer in the documentation
\r
24 // and/or other materials provided with the distribution.
\r
26 // * The name of Intel Corporation may not be used to endorse or promote products
\r
27 // derived from this software without specific prior written permission.
\r
29 // This software is provided by the copyright holders and contributors "as is" and
\r
30 // any express or implied warranties, including, but not limited to, the implied
\r
31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
\r
32 // In no event shall the Intel Corporation or contributors be liable for any direct,
\r
33 // indirect, incidental, special, exemplary, or consequential damages
\r
34 // (including, but not limited to, procurement of substitute goods or services;
\r
35 // loss of use, data, or profits; or business interruption) however caused
\r
36 // and on any theory of liability, whether in contract, strict liability,
\r
37 // or tort (including negligence or otherwise) arising in any way out of
\r
38 // the use of this software, even if advised of the possibility of such damage.
\r
44 // auxiliary functions
\r
46 void nbayes_check_data( CvMLData* _data )
48 if( _data->get_missing() )
49 CV_Error( CV_StsBadArg, "missing values are not supported" );
50 const CvMat* var_types = _data->get_var_types();
51 bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL;
52 if( ( fabs( cvNorm( var_types, 0, CV_L1 ) -
\r
53 (var_types->rows + var_types->cols - 2)*CV_VAR_ORDERED - CV_VAR_CATEGORICAL ) > FLT_EPSILON ) ||
\r
55 CV_Error( CV_StsBadArg, "incorrect types of predictors or responses" );
57 bool nbayes_train( CvNormalBayesClassifier* nbayes, CvMLData* _data )
59 nbayes_check_data( _data );
\r
60 const CvMat* values = _data->get_values();
\r
61 const CvMat* responses = _data->get_responses();
\r
62 const CvMat* train_sidx = _data->get_train_sample_idx();
\r
63 const CvMat* var_idx = _data->get_var_idx();
\r
64 return nbayes->train( values, responses, var_idx, train_sidx );
\r
66 float nbayes_calc_error( CvNormalBayesClassifier* nbayes, CvMLData* _data, int type, vector<float> *resp )
\r
69 nbayes_check_data( _data );
\r
70 const CvMat* values = _data->get_values();
71 const CvMat* response = _data->get_responses();
72 const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();
\r
73 int* sidx = sample_idx ? sample_idx->data.i : 0;
\r
74 int r_step = CV_IS_MAT_CONT(response->type) ?
\r
75 1 : response->step / CV_ELEM_SIZE(response->type);
\r
76 int sample_count = sample_idx ? sample_idx->cols : 0;
\r
77 sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? values->rows : sample_count;
\r
78 float* pred_resp = 0;
\r
79 if( resp && (sample_count > 0) )
\r
81 resp->resize( sample_count );
\r
82 pred_resp = &((*resp)[0]);
\r
85 for( int i = 0; i < sample_count; i++ )
\r
88 int si = sidx ? sidx[i] : i;
\r
89 cvGetRow( values, &sample, si );
\r
90 float r = (float)nbayes->predict( &sample, 0 );
\r
93 int d = fabs((double)r - response->data.fl[si*r_step]) <= FLT_EPSILON ? 0 : 1;
\r
96 err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
\r
101 void knearest_check_data_and_get_predictors( CvMLData* _data, CvMat* _predictors )
103 const CvMat* values = _data->get_values();
104 const CvMat* var_idx = _data->get_var_idx();
105 if( var_idx->cols + var_idx->rows != values->cols )
\r
106 CV_Error( CV_StsBadArg, "var_idx is not supported" );
\r
107 if( _data->get_missing() )
\r
108 CV_Error( CV_StsBadArg, "missing values are not supported" );
\r
109 int resp_idx = _data->get_response_idx();
\r
111 cvGetCols( values, _predictors, 1, values->cols );
\r
112 else if( resp_idx == values->cols - 1 )
\r
113 cvGetCols( values, _predictors, 0, values->cols - 1 );
\r
115 CV_Error( CV_StsBadArg, "responses must be in the first or last column; other cases are not supported" );
\r
117 bool knearest_train( CvKNearest* knearest, CvMLData* _data )
119 const CvMat* responses = _data->get_responses();
\r
120 const CvMat* train_sidx = _data->get_train_sample_idx();
\r
121 bool is_regression = _data->get_var_type( _data->get_response_idx() ) == CV_VAR_ORDERED;
\r
123 knearest_check_data_and_get_predictors( _data, &predictors );
\r
124 return knearest->train( &predictors, responses, train_sidx, is_regression );
126 float knearest_calc_error( CvKNearest* knearest, CvMLData* _data, int k, int type, vector<float> *resp )
129 const CvMat* response = _data->get_responses();
\r
130 const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();
\r
131 int* sidx = sample_idx ? sample_idx->data.i : 0;
\r
132 int r_step = CV_IS_MAT_CONT(response->type) ?
\r
133 1 : response->step / CV_ELEM_SIZE(response->type);
\r
134 bool is_regression = _data->get_var_type( _data->get_response_idx() ) == CV_VAR_ORDERED;
\r
136 knearest_check_data_and_get_predictors( _data, &predictors );
\r
137 int sample_count = sample_idx ? sample_idx->cols : 0;
\r
138 sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? predictors.rows : sample_count;
\r
139 float* pred_resp = 0;
\r
140 if( resp && (sample_count > 0) )
\r
142 resp->resize( sample_count );
\r
143 pred_resp = &((*resp)[0]);
\r
145 if ( !is_regression )
\r
147 for( int i = 0; i < sample_count; i++ )
\r
150 int si = sidx ? sidx[i] : i;
\r
151 cvGetRow( &predictors, &sample, si );
\r
152 float r = knearest->find_nearest( &sample, k );
\r
155 int d = fabs((double)r - response->data.fl[si*r_step]) <= FLT_EPSILON ? 0 : 1;
\r
158 err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
\r
162 for( int i = 0; i < sample_count; i++ )
\r
165 int si = sidx ? sidx[i] : i;
\r
166 cvGetRow( &predictors, &sample, si );
\r
167 float r = knearest->find_nearest( &sample, k );
\r
170 float d = r - response->data.fl[si*r_step];
\r
173 err = sample_count ? err / (float)sample_count : -FLT_MAX;
\r
179 int str_to_svm_type(string& str)
\r
181 if( !str.compare("C_SVC") )
\r
182 return CvSVM::C_SVC;
\r
183 if( !str.compare("NU_SVC") )
\r
184 return CvSVM::NU_SVC;
\r
185 if( !str.compare("ONE_CLASS") )
\r
186 return CvSVM::ONE_CLASS;
\r
187 if( !str.compare("EPS_SVR") )
\r
188 return CvSVM::EPS_SVR;
\r
189 if( !str.compare("NU_SVR") )
\r
190 return CvSVM::NU_SVR;
\r
191 CV_Error( CV_StsBadArg, "incorrect svm type string" );
\r
194 int str_to_svm_kernel_type( string& str )
\r
196 if( !str.compare("LINEAR") )
\r
197 return CvSVM::LINEAR;
\r
198 if( !str.compare("POLY") )
\r
199 return CvSVM::POLY;
\r
200 if( !str.compare("RBF") )
\r
202 if( !str.compare("SIGMOID") )
\r
203 return CvSVM::SIGMOID;
\r
204 CV_Error( CV_StsBadArg, "incorrect svm type string" );
\r
207 void svm_check_data( CvMLData* _data )
209 if( _data->get_missing() )
210 CV_Error( CV_StsBadArg, "missing values are not supported" );
211 const CvMat* var_types = _data->get_var_types();
212 for( int i = 0; i < var_types->cols-1; i++ )
213 if (var_types->data.ptr[i] == CV_VAR_CATEGORICAL)
216 sprintf( msg, "incorrect type of %d-predictor", i );
217 CV_Error( CV_StsBadArg, msg );
220 bool svm_train( CvSVM* svm, CvMLData* _data, CvSVMParams _params )
222 svm_check_data(_data);
223 const CvMat* _train_data = _data->get_values();
224 const CvMat* _responses = _data->get_responses();
225 const CvMat* _var_idx = _data->get_var_idx();
226 const CvMat* _sample_idx = _data->get_train_sample_idx();
227 return svm->train( _train_data, _responses, _var_idx, _sample_idx, _params );
229 bool svm_train_auto( CvSVM* svm, CvMLData* _data, CvSVMParams _params,
230 int k_fold, CvParamGrid C_grid, CvParamGrid gamma_grid,
231 CvParamGrid p_grid, CvParamGrid nu_grid, CvParamGrid coef_grid,
232 CvParamGrid degree_grid )
234 svm_check_data(_data);
235 const CvMat* _train_data = _data->get_values();
236 const CvMat* _responses = _data->get_responses();
237 const CvMat* _var_idx = _data->get_var_idx();
238 const CvMat* _sample_idx = _data->get_train_sample_idx();
239 return svm->train_auto( _train_data, _responses, _var_idx,
240 _sample_idx, _params, k_fold, C_grid, gamma_grid, p_grid, nu_grid, coef_grid, degree_grid );
242 float svm_calc_error( CvSVM* svm, CvMLData* _data, int type, vector<float> *resp )
244 svm_check_data(_data);
246 const CvMat* values = _data->get_values();
\r
247 const CvMat* response = _data->get_responses();
\r
248 const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();
\r
249 const CvMat* var_types = _data->get_var_types();
\r
250 int* sidx = sample_idx ? sample_idx->data.i : 0;
\r
251 int r_step = CV_IS_MAT_CONT(response->type) ?
\r
252 1 : response->step / CV_ELEM_SIZE(response->type);
\r
253 bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL;
\r
254 int sample_count = sample_idx ? sample_idx->cols : 0;
\r
255 sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? values->rows : sample_count;
\r
256 float* pred_resp = 0;
\r
257 if( resp && (sample_count > 0) )
\r
259 resp->resize( sample_count );
\r
260 pred_resp = &((*resp)[0]);
\r
262 if ( is_classifier )
\r
264 for( int i = 0; i < sample_count; i++ )
\r
267 int si = sidx ? sidx[i] : i;
\r
268 cvGetRow( values, &sample, si );
\r
269 float r = svm->predict( &sample );
\r
272 int d = fabs((double)r - response->data.fl[si*r_step]) <= FLT_EPSILON ? 0 : 1;
\r
275 err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
\r
279 for( int i = 0; i < sample_count; i++ )
\r
282 int si = sidx ? sidx[i] : i;
\r
283 cvGetRow( values, &sample, si );
\r
284 float r = svm->predict( &sample );
\r
287 float d = r - response->data.fl[si*r_step];
\r
290 err = sample_count ? err / (float)sample_count : -FLT_MAX;
\r
297 int str_to_ann_train_method( string& str )
\r
299 if( !str.compare("BACKPROP") )
\r
300 return CvANN_MLP_TrainParams::BACKPROP;
\r
301 if( !str.compare("RPROP") )
\r
302 return CvANN_MLP_TrainParams::RPROP;
\r
303 CV_Error( CV_StsBadArg, "incorrect ann train method string" );
\r
306 void ann_check_data_and_get_predictors( CvMLData* _data, CvMat* _inputs )
308 const CvMat* values = _data->get_values();
309 const CvMat* var_idx = _data->get_var_idx();
310 if( var_idx->cols + var_idx->rows != values->cols )
\r
311 CV_Error( CV_StsBadArg, "var_idx is not supported" );
\r
312 if( _data->get_missing() )
\r
313 CV_Error( CV_StsBadArg, "missing values are not supported" );
\r
314 int resp_idx = _data->get_response_idx();
\r
316 cvGetCols( values, _inputs, 1, values->cols );
\r
317 else if( resp_idx == values->cols - 1 )
\r
318 cvGetCols( values, _inputs, 0, values->cols - 1 );
\r
320 CV_Error( CV_StsBadArg, "outputs must be in the first or last column; other cases are not supported" );
\r
322 void ann_get_new_responses( CvMLData* _data, Mat& new_responses, map<int, int>& cls_map )
324 const CvMat* train_sidx = _data->get_train_sample_idx();
325 int* train_sidx_ptr = train_sidx->data.i;
\r
326 const CvMat* responses = _data->get_responses();
\r
327 float* responses_ptr = responses->data.fl;
\r
328 int r_step = CV_IS_MAT_CONT(responses->type) ?
\r
329 1 : responses->step / CV_ELEM_SIZE(responses->type);
\r
331 // construct cls_map
\r
333 for( int si = 0; si < train_sidx->cols; si++ )
\r
335 int sidx = train_sidx_ptr[si];
\r
336 int r = cvRound(responses_ptr[sidx*r_step]);
\r
337 CV_DbgAssert( fabs(responses_ptr[sidx*r_step]-r) < FLT_EPSILON );
\r
338 int cls_map_size = (int)cls_map.size();
\r
340 if ( (int)cls_map.size() > cls_map_size )
\r
341 cls_map[r] = cls_count++;
\r
343 new_responses.create( responses->rows, cls_count, CV_32F );
\r
344 new_responses.setTo( 0 );
\r
345 for( int si = 0; si < train_sidx->cols; si++ )
\r
347 int sidx = train_sidx_ptr[si];
\r
348 int r = cvRound(responses_ptr[sidx*r_step]);
\r
349 int cidx = cls_map[r];
\r
350 new_responses.ptr<float>(sidx)[cidx] = 1;
\r
353 int ann_train( CvANN_MLP* ann, CvMLData* _data, Mat& new_responses, CvANN_MLP_TrainParams _params, int flags = 0 )
355 const CvMat* train_sidx = _data->get_train_sample_idx();
\r
357 ann_check_data_and_get_predictors( _data, &predictors );
\r
358 CvMat _new_responses = CvMat( new_responses );
\r
359 return ann->train( &predictors, &_new_responses, 0, train_sidx, _params, flags );
361 float ann_calc_error( CvANN_MLP* ann, CvMLData* _data, map<int, int>& cls_map, int type , vector<float> *resp_labels )
364 const CvMat* responses = _data->get_responses();
\r
365 const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();
\r
366 int* sidx = sample_idx ? sample_idx->data.i : 0;
\r
367 int r_step = CV_IS_MAT_CONT(responses->type) ?
\r
368 1 : responses->step / CV_ELEM_SIZE(responses->type);
\r
370 ann_check_data_and_get_predictors( _data, &predictors );
\r
371 int sample_count = sample_idx ? sample_idx->cols : 0;
\r
372 sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? predictors.rows : sample_count;
\r
373 float* pred_resp = 0;
\r
374 vector<float> innresp;
\r
375 if( sample_count > 0 )
\r
379 resp_labels->resize( sample_count );
\r
380 pred_resp = &((*resp_labels)[0]);
\r
384 innresp.resize( sample_count );
\r
385 pred_resp = &(innresp[0]);
\r
388 int cls_count = (int)cls_map.size();
\r
389 Mat output( 1, cls_count, CV_32FC1 );
\r
390 CvMat _output = CvMat(output);
\r
391 map<int, int>::iterator b_it = cls_map.begin();
\r
392 for( int i = 0; i < sample_count; i++ )
\r
395 int si = sidx ? sidx[i] : i;
\r
396 cvGetRow( &predictors, &sample, si );
\r
397 ann->predict( &sample, &_output );
\r
398 CvPoint best_cls = {0,0};
\r
399 cvMinMaxLoc( &_output, 0, 0, 0, &best_cls, 0 );
400 int r = cvRound(responses->data.fl[si*r_step]);
401 CV_DbgAssert( fabs(responses->data.fl[si*r_step]-r) < FLT_EPSILON );
403 int d = best_cls.x == r ? 0 : 1;
405 pred_resp[i] = (float)best_cls.x;
\r
407 err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
\r
413 int str_to_boost_type( string& str )
\r
415 if ( !str.compare("DISCRETE") )
\r
416 return CvBoost::DISCRETE;
\r
417 if ( !str.compare("REAL") )
\r
418 return CvBoost::REAL;
\r
419 if ( !str.compare("LOGIT") )
\r
420 return CvBoost::LOGIT;
\r
421 if ( !str.compare("GENTLE") )
\r
422 return CvBoost::GENTLE;
\r
423 CV_Error( CV_StsBadArg, "incorrect boost type string" );
\r
430 // ---------------------------------- MLBaseTest ---------------------------------------------------
\r
432 CV_MLBaseTest::CV_MLBaseTest( const char* _modelName, const char* _testName, const char* _testFuncs ) :
433 CvTest( _testName, _testFuncs )
435 modelName = _modelName;
445 if( !modelName.compare(CV_NBAYES) )
446 nbayes = new CvNormalBayesClassifier;
447 else if( !modelName.compare(CV_KNEAREST) )
448 knearest = new CvKNearest;
449 else if( !modelName.compare(CV_SVM) )
451 else if( !modelName.compare(CV_EM) )
453 else if( !modelName.compare(CV_ANN) )
455 else if( !modelName.compare(CV_DTREE) )
457 else if( !modelName.compare(CV_BOOST) )
459 else if( !modelName.compare(CV_RTREES) )
460 rtrees = new CvRTrees;
461 else if( !modelName.compare(CV_ERTREES) )
462 ertrees = new CvERTrees;
465 int CV_MLBaseTest::init( CvTS* system )
\r
470 string filename = ts->get_data_path();
471 filename += get_validation_filename();
472 validationFS.open( filename, FileStorage::READ );
473 return read_params( *validationFS );
\r
476 CV_MLBaseTest::~CV_MLBaseTest()
\r
478 if( validationFS.isOpened() )
\r
479 validationFS.release();
\r
500 int CV_MLBaseTest::read_params( CvFileStorage* _fs )
\r
503 test_case_count = -1;
\r
506 CvFileNode* fn = cvGetRootFileNode( _fs, 0 );
\r
507 fn = (CvFileNode*)cvGetSeqElem( fn->data.seq, 0 );
\r
508 fn = cvGetFileNodeByName( _fs, fn, "run_params" );
\r
509 CvSeq* dataSetNamesSeq = cvGetFileNodeByName( _fs, fn, modelName.c_str() )->data.seq;
\r
510 test_case_count = dataSetNamesSeq ? dataSetNamesSeq->total : -1;
\r
511 if( test_case_count > 0 )
\r
513 dataSetNames.resize( test_case_count );
\r
514 vector<string>::iterator it = dataSetNames.begin();
\r
515 for( int i = 0; i < test_case_count; i++, it++ )
\r
516 *it = ((CvFileNode*)cvGetSeqElem( dataSetNamesSeq, i ))->data.str.ptr;
\r
522 void CV_MLBaseTest::run( int start_from )
526 for (int i = 0; i < test_case_count; i++)
528 int temp_code = run_test_case( i );
529 if (temp_code == CvTS::OK)
530 temp_code = validate_test_results( i );
531 if (temp_code != CvTS::OK)
534 if ( test_case_count <= 0)
536 ts->printf( CvTS::LOG, "validation file is not determined or not correct" );
537 code = CvTS::FAIL_INVALID_TEST_DATA;
\r
539 ts->set_failed_test_info( code );
542 int CV_MLBaseTest::prepare_test_case( int test_case_idx )
544 int trainSampleCount, respIdx;
548 string dataPath = ts->get_data_path();
549 if ( dataPath.empty() )
\r
551 ts->printf( CvTS::LOG, "data path is empty" );
552 return CvTS::FAIL_INVALID_TEST_DATA;
\r
555 string dataName = dataSetNames[test_case_idx],
556 filename = dataPath + dataName + ".data";
557 if ( data.read_csv( filename.c_str() ) != 0)
560 sprintf( msg, "file %s can not be read", filename.c_str() );
\r
561 ts->printf( CvTS::LOG, msg );
562 return CvTS::FAIL_INVALID_TEST_DATA;
\r
565 FileNode dataParamsNode = validationFS.getFirstTopLevelNode()["validation"][modelName][dataName]["data_params"];
566 CV_DbgAssert( !dataParamsNode.empty() );
568 CV_DbgAssert( !dataParamsNode["LS"].empty() );
569 dataParamsNode["LS"] >> trainSampleCount;
570 CvTrainTestSplit spl( trainSampleCount );
571 data.set_train_test_split( &spl );
573 CV_DbgAssert( !dataParamsNode["resp_idx"].empty() );
574 dataParamsNode["resp_idx"] >> respIdx;
575 data.set_response_idx( respIdx );
577 CV_DbgAssert( !dataParamsNode["types"].empty() );
578 dataParamsNode["types"] >> varTypes;
579 data.set_var_types( varTypes.c_str() );
584 string& CV_MLBaseTest::get_validation_filename()
589 int CV_MLBaseTest::train( int testCaseIdx )
\r
591 bool is_trained = false;
\r
592 FileNode modelParamsNode =
593 validationFS.getFirstTopLevelNode()["validation"][modelName][dataSetNames[testCaseIdx]]["model_params"];
595 if( !modelName.compare(CV_NBAYES) )
596 is_trained = nbayes_train( nbayes, &data );
597 else if( !modelName.compare(CV_KNEAREST) )
600 //is_trained = knearest->train( &data );
602 else if( !modelName.compare(CV_SVM) )
604 string svm_type_str, kernel_type_str;
605 modelParamsNode["svm_type"] >> svm_type_str;
606 modelParamsNode["kernel_type"] >> kernel_type_str;
608 params.svm_type = str_to_svm_type( svm_type_str );
609 params.kernel_type = str_to_svm_kernel_type( kernel_type_str );
610 modelParamsNode["degree"] >> params.degree;
611 modelParamsNode["gamma"] >> params.gamma;
612 modelParamsNode["coef0"] >> params.coef0;
613 modelParamsNode["C"] >> params.C;
614 modelParamsNode["nu"] >> params.nu;
615 modelParamsNode["p"] >> params.p;
616 is_trained = svm_train( svm, &data, params );
618 else if( !modelName.compare(CV_EM) )
622 else if( !modelName.compare(CV_ANN) )
624 string train_method_str;
625 double param1, param2;
626 modelParamsNode["train_method"] >> train_method_str;
627 modelParamsNode["param1"] >> param1;
628 modelParamsNode["param2"] >> param2;
630 ann_get_new_responses( &data, new_responses, cls_map );
631 int layer_sz[] = { data.get_values()->cols - 1, 100, 100, (int)cls_map.size() };
633 cvMat( 1, (int)(sizeof(layer_sz)/sizeof(layer_sz[0])), CV_32S, layer_sz );
634 ann->create( &layer_sizes );
635 is_trained = ann_train( ann, &data, new_responses, CvANN_MLP_TrainParams(cvTermCriteria(CV_TERMCRIT_ITER,300,0.01),
636 str_to_ann_train_method(train_method_str), param1, param2) ) >= 0;
638 else if( !modelName.compare(CV_DTREE) )
640 int MAX_DEPTH, MIN_SAMPLE_COUNT, MAX_CATEGORIES, CV_FOLDS;
\r
641 float REG_ACCURACY = 0;
\r
642 bool USE_SURROGATE, IS_PRUNED;
643 modelParamsNode["max_depth"] >> MAX_DEPTH;
644 modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT;
645 modelParamsNode["use_surrogate"] >> USE_SURROGATE;
646 modelParamsNode["max_categories"] >> MAX_CATEGORIES;
647 modelParamsNode["cv_folds"] >> CV_FOLDS;
648 modelParamsNode["is_pruned"] >> IS_PRUNED;
649 is_trained = dtree->train( &data,
\r
650 CvDTreeParams(MAX_DEPTH, MIN_SAMPLE_COUNT, REG_ACCURACY, USE_SURROGATE,
\r
651 MAX_CATEGORIES, CV_FOLDS, false, IS_PRUNED, 0 )) != 0;
\r
653 else if( !modelName.compare(CV_BOOST) )
655 int BOOST_TYPE, WEAK_COUNT, MAX_DEPTH;
\r
656 float WEIGHT_TRIM_RATE;
\r
659 modelParamsNode["type"] >> typeStr;
660 BOOST_TYPE = str_to_boost_type( typeStr );
661 modelParamsNode["weak_count"] >> WEAK_COUNT;
\r
662 modelParamsNode["weight_trim_rate"] >> WEIGHT_TRIM_RATE;
\r
663 modelParamsNode["max_depth"] >> MAX_DEPTH;
\r
664 modelParamsNode["use_surrogate"] >> USE_SURROGATE;
\r
665 is_trained = boost->train( &data,
\r
666 CvBoostParams(BOOST_TYPE, WEAK_COUNT, WEIGHT_TRIM_RATE, MAX_DEPTH, USE_SURROGATE, 0) ) != 0;
668 else if( !modelName.compare(CV_RTREES) )
670 int MAX_DEPTH, MIN_SAMPLE_COUNT, MAX_CATEGORIES, CV_FOLDS, NACTIVE_VARS, MAX_TREES_NUM;
\r
671 float REG_ACCURACY = 0, OOB_EPS = 0.0;
\r
672 bool USE_SURROGATE, IS_PRUNED;
673 modelParamsNode["max_depth"] >> MAX_DEPTH;
674 modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT;
675 modelParamsNode["use_surrogate"] >> USE_SURROGATE;
676 modelParamsNode["max_categories"] >> MAX_CATEGORIES;
677 modelParamsNode["cv_folds"] >> CV_FOLDS;
678 modelParamsNode["is_pruned"] >> IS_PRUNED;
679 modelParamsNode["nactive_vars"] >> NACTIVE_VARS;
680 modelParamsNode["max_trees_num"] >> MAX_TREES_NUM;
681 is_trained = rtrees->train( &data, CvRTParams( MAX_DEPTH, MIN_SAMPLE_COUNT, REG_ACCURACY,
\r
682 USE_SURROGATE, MAX_CATEGORIES, 0, true, // (calc_var_importance == true) <=> RF processes variable importance
\r
683 NACTIVE_VARS, MAX_TREES_NUM, OOB_EPS, CV_TERMCRIT_ITER)) != 0;
685 else if( !modelName.compare(CV_ERTREES) )
687 int MAX_DEPTH, MIN_SAMPLE_COUNT, MAX_CATEGORIES, CV_FOLDS, NACTIVE_VARS, MAX_TREES_NUM;
\r
688 float REG_ACCURACY = 0, OOB_EPS = 0.0;
\r
689 bool USE_SURROGATE, IS_PRUNED;
\r
690 modelParamsNode["max_depth"] >> MAX_DEPTH;
691 modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT;
692 modelParamsNode["use_surrogate"] >> USE_SURROGATE;
693 modelParamsNode["max_categories"] >> MAX_CATEGORIES;
694 modelParamsNode["cv_folds"] >> CV_FOLDS;
695 modelParamsNode["is_pruned"] >> IS_PRUNED;
696 modelParamsNode["nactive_vars"] >> NACTIVE_VARS;
697 modelParamsNode["max_trees_num"] >> MAX_TREES_NUM;
698 is_trained = ertrees->train( &data, CvRTParams( MAX_DEPTH, MIN_SAMPLE_COUNT, REG_ACCURACY,
\r
699 USE_SURROGATE, MAX_CATEGORIES, 0, false, // (calc_var_importance == true) <=> RF processes variable importance
\r
700 NACTIVE_VARS, MAX_TREES_NUM, OOB_EPS, CV_TERMCRIT_ITER)) != 0;
\r
705 ts->printf( CvTS::LOG, "in test case %d model training was failed", testCaseIdx );
\r
706 return CvTS::FAIL_INVALID_OUTPUT;
\r
711 float CV_MLBaseTest::get_error( int testCaseIdx, int type, vector<float> *resp )
\r
714 if( !modelName.compare(CV_NBAYES) )
715 err = nbayes_calc_error( nbayes, &data, type, resp );
716 else if( !modelName.compare(CV_KNEAREST) )
721 validationFS.getFirstTopLevelNode()["validation"][modelName][dataSetNames[testCaseIdx]]["model_params"]["k"] >> k;
722 err = knearest->calc_error( &data, k, type, resp );*/
724 else if( !modelName.compare(CV_SVM) )
725 err = svm_calc_error( svm, &data, type, resp );
726 else if( !modelName.compare(CV_EM) )
728 else if( !modelName.compare(CV_ANN) )
729 err = ann_calc_error( ann, &data, cls_map, type, resp );
730 else if( !modelName.compare(CV_DTREE) )
731 err = dtree->calc_error( &data, type, resp );
732 else if( !modelName.compare(CV_BOOST) )
733 err = boost->calc_error( &data, type, resp );
734 else if( !modelName.compare(CV_RTREES) )
735 err = rtrees->calc_error( &data, type, resp );
736 else if( !modelName.compare(CV_ERTREES) )
737 err = ertrees->calc_error( &data, type, resp );
\r
741 void CV_MLBaseTest::save( const char* filename )
\r
743 if( !modelName.compare(CV_NBAYES) )
744 nbayes->save( filename );
745 else if( !modelName.compare(CV_KNEAREST) )
746 knearest->save( filename );
747 else if( !modelName.compare(CV_SVM) )
748 svm->save( filename );
749 else if( !modelName.compare(CV_EM) )
750 em->save( filename );
751 else if( !modelName.compare(CV_ANN) )
752 ann->save( filename );
753 else if( !modelName.compare(CV_DTREE) )
754 dtree->save( filename );
755 else if( !modelName.compare(CV_BOOST) )
756 boost->save( filename );
757 else if( !modelName.compare(CV_RTREES) )
758 rtrees->save( filename );
759 else if( !modelName.compare(CV_ERTREES) )
760 ertrees->save( filename );
\r
763 void CV_MLBaseTest::load( const char* filename )
\r
765 if( !modelName.compare(CV_NBAYES) )
766 nbayes->load( filename );
767 else if( !modelName.compare(CV_KNEAREST) )
768 knearest->load( filename );
769 else if( !modelName.compare(CV_SVM) )
770 svm->load( filename );
771 else if( !modelName.compare(CV_EM) )
772 em->load( filename );
773 else if( !modelName.compare(CV_ANN) )
774 ann->load( filename );
775 else if( !modelName.compare(CV_DTREE) )
776 dtree->load( filename );
777 else if( !modelName.compare(CV_BOOST) )
778 boost->load( filename );
779 else if( !modelName.compare(CV_RTREES) )
780 rtrees->load( filename );
781 else if( !modelName.compare(CV_ERTREES) )
782 ertrees->load( filename );
\r