Update to 2.0.0 tree from current Fremantle build
[opencv] / tests / ml / src / mltests.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////\r
2 //\r
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.\r
4 //\r
5 //  By downloading, copying, installing or using the software you agree to this license.\r
6 //  If you do not agree to this license, do not download, install,\r
7 //  copy or use the software.\r
8 //\r
9 //\r
10 //                        Intel License Agreement\r
11 //                For Open Source Computer Vision Library\r
12 //\r
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.\r
14 // Third party copyrights are property of their respective owners.\r
15 //\r
16 // Redistribution and use in source and binary forms, with or without modification,\r
17 // are permitted provided that the following conditions are met:\r
18 //\r
19 //   * Redistribution's of source code must retain the above copyright notice,\r
20 //     this list of conditions and the following disclaimer.\r
21 //\r
22 //   * Redistribution's in binary form must reproduce the above copyright notice,\r
23 //     this list of conditions and the following disclaimer in the documentation\r
24 //     and/or other materials provided with the distribution.\r
25 //\r
26 //   * The name of Intel Corporation may not be used to endorse or promote products\r
27 //     derived from this software without specific prior written permission.\r
28 //\r
29 // This software is provided by the copyright holders and contributors "as is" and\r
30 // any express or implied warranties, including, but not limited to, the implied\r
31 // warranties of merchantability and fitness for a particular purpose are disclaimed.\r
32 // In no event shall the Intel Corporation or contributors be liable for any direct,\r
33 // indirect, incidental, special, exemplary, or consequential damages\r
34 // (including, but not limited to, procurement of substitute goods or services;\r
35 // loss of use, data, or profits; or business interruption) however caused\r
36 // and on any theory of liability, whether in contract, strict liability,\r
37 // or tort (including negligence or otherwise) arising in any way out of\r
38 // the use of this software, even if advised of the possibility of such damage.\r
39 //\r
40 //M*/\r
41 \r
42 #include "mltest.h"\r
43 \r
44 // auxiliary functions\r
45 // 1. nbayes\r
46 void nbayes_check_data( CvMLData* _data )
47 {
48     if( _data->get_missing() )
49         CV_Error( CV_StsBadArg, "missing values are not supported" );
50     const CvMat* var_types = _data->get_var_types();
51     bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL;
52     if( ( fabs( cvNorm( var_types, 0, CV_L1 ) - \r
53         (var_types->rows + var_types->cols - 2)*CV_VAR_ORDERED - CV_VAR_CATEGORICAL ) > FLT_EPSILON ) ||\r
54         !is_classifier )\r
55         CV_Error( CV_StsBadArg, "incorrect types of predictors or responses" );
56 }
57 bool nbayes_train( CvNormalBayesClassifier* nbayes, CvMLData* _data )
58 {
59     nbayes_check_data( _data );\r
60     const CvMat* values = _data->get_values();\r
61     const CvMat* responses = _data->get_responses();\r
62     const CvMat* train_sidx = _data->get_train_sample_idx();\r
63     const CvMat* var_idx = _data->get_var_idx();\r
64     return nbayes->train( values, responses, var_idx, train_sidx );\r
65 }
66 float nbayes_calc_error( CvNormalBayesClassifier* nbayes, CvMLData* _data, int type, vector<float> *resp )\r
67 {\r
68     float err = 0;\r
69     nbayes_check_data( _data );\r
70     const CvMat* values = _data->get_values();
71     const CvMat* response = _data->get_responses();
72     const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();\r
73     int* sidx = sample_idx ? sample_idx->data.i : 0;\r
74     int r_step = CV_IS_MAT_CONT(response->type) ?\r
75         1 : response->step / CV_ELEM_SIZE(response->type);\r
76     int sample_count = sample_idx ? sample_idx->cols : 0;\r
77     sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? values->rows : sample_count;\r
78     float* pred_resp = 0;\r
79     if( resp && (sample_count > 0) )\r
80     {\r
81         resp->resize( sample_count );\r
82         pred_resp = &((*resp)[0]);\r
83     }\r
84 \r
85     for( int i = 0; i < sample_count; i++ )\r
86     {\r
87         CvMat sample;\r
88         int si = sidx ? sidx[i] : i;\r
89         cvGetRow( values, &sample, si ); \r
90         float r = (float)nbayes->predict( &sample, 0 );\r
91         if( pred_resp )\r
92             pred_resp[i] = r;\r
93         int d = fabs((double)r - response->data.fl[si*r_step]) <= FLT_EPSILON ? 0 : 1;\r
94         err += d;\r
95     }\r
96     err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;\r
97     return err;\r
98 }\r
99 \r
100 // 2. knearest\r
101 void knearest_check_data_and_get_predictors( CvMLData* _data, CvMat* _predictors )
102 {
103     const CvMat* values = _data->get_values();
104     const CvMat* var_idx = _data->get_var_idx();
105     if( var_idx->cols + var_idx->rows != values->cols )\r
106         CV_Error( CV_StsBadArg, "var_idx is not supported" );\r
107     if( _data->get_missing() )\r
108         CV_Error( CV_StsBadArg, "missing values are not supported" );\r
109     int resp_idx = _data->get_response_idx();\r
110     if( resp_idx == 0)\r
111         cvGetCols( values, _predictors, 1, values->cols );\r
112     else if( resp_idx == values->cols - 1 )\r
113         cvGetCols( values, _predictors, 0, values->cols - 1 );\r
114     else\r
115         CV_Error( CV_StsBadArg, "responses must be in the first or last column; other cases are not supported" );\r
116 }
117 bool knearest_train( CvKNearest* knearest, CvMLData* _data )
118 {
119     const CvMat* responses = _data->get_responses();\r
120     const CvMat* train_sidx = _data->get_train_sample_idx();\r
121     bool is_regression = _data->get_var_type( _data->get_response_idx() ) == CV_VAR_ORDERED;\r
122     CvMat predictors;\r
123     knearest_check_data_and_get_predictors( _data, &predictors );\r
124     return knearest->train( &predictors, responses, train_sidx, is_regression );
125 }
126 float knearest_calc_error( CvKNearest* knearest, CvMLData* _data, int k, int type, vector<float> *resp )
127 {
128     float err = 0;\r
129     const CvMat* response = _data->get_responses();\r
130     const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();\r
131     int* sidx = sample_idx ? sample_idx->data.i : 0;\r
132     int r_step = CV_IS_MAT_CONT(response->type) ?\r
133         1 : response->step / CV_ELEM_SIZE(response->type);\r
134     bool is_regression = _data->get_var_type( _data->get_response_idx() ) == CV_VAR_ORDERED;\r
135     CvMat predictors;\r
136     knearest_check_data_and_get_predictors( _data, &predictors );\r
137     int sample_count = sample_idx ? sample_idx->cols : 0;\r
138     sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? predictors.rows : sample_count;\r
139     float* pred_resp = 0;\r
140     if( resp && (sample_count > 0) )\r
141     {\r
142         resp->resize( sample_count );\r
143         pred_resp = &((*resp)[0]);\r
144     }\r
145     if ( !is_regression )\r
146     {\r
147         for( int i = 0; i < sample_count; i++ )\r
148         {\r
149             CvMat sample;\r
150             int si = sidx ? sidx[i] : i;\r
151             cvGetRow( &predictors, &sample, si ); \r
152             float r = knearest->find_nearest( &sample, k );\r
153             if( pred_resp )\r
154                 pred_resp[i] = r;\r
155             int d = fabs((double)r - response->data.fl[si*r_step]) <= FLT_EPSILON ? 0 : 1;\r
156             err += d;\r
157         }\r
158         err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;\r
159     }\r
160     else\r
161     {\r
162         for( int i = 0; i < sample_count; i++ )\r
163         {\r
164             CvMat sample;\r
165             int si = sidx ? sidx[i] : i;\r
166             cvGetRow( &predictors, &sample, si ); \r
167             float r = knearest->find_nearest( &sample, k );\r
168             if( pred_resp )\r
169                 pred_resp[i] = r;\r
170             float d = r - response->data.fl[si*r_step];\r
171             err += d*d;\r
172         }\r
173         err = sample_count ? err / (float)sample_count : -FLT_MAX;    \r
174     }\r
175     return err;
176 }\r
177 \r
178 // 3. svm\r
179 int str_to_svm_type(string& str)\r
180 {\r
181     if( !str.compare("C_SVC") )\r
182         return CvSVM::C_SVC;\r
183     if( !str.compare("NU_SVC") )\r
184         return CvSVM::NU_SVC;\r
185     if( !str.compare("ONE_CLASS") )\r
186         return CvSVM::ONE_CLASS;\r
187     if( !str.compare("EPS_SVR") )\r
188         return CvSVM::EPS_SVR;\r
189     if( !str.compare("NU_SVR") )\r
190         return CvSVM::NU_SVR;\r
191     CV_Error( CV_StsBadArg, "incorrect svm type string" );\r
192     return -1;\r
193 }\r
194 int str_to_svm_kernel_type( string& str )\r
195 {\r
196     if( !str.compare("LINEAR") )\r
197         return CvSVM::LINEAR;\r
198     if( !str.compare("POLY") )\r
199         return CvSVM::POLY;\r
200     if( !str.compare("RBF") )\r
201         return CvSVM::RBF;\r
202     if( !str.compare("SIGMOID") )\r
203         return CvSVM::SIGMOID;\r
204     CV_Error( CV_StsBadArg, "incorrect svm type string" );\r
205     return -1;\r
206 }\r
207 void svm_check_data( CvMLData* _data )
208 {
209     if( _data->get_missing() )
210         CV_Error( CV_StsBadArg, "missing values are not supported" );
211     const CvMat* var_types = _data->get_var_types();
212     for( int i = 0; i < var_types->cols-1; i++ )
213         if (var_types->data.ptr[i] == CV_VAR_CATEGORICAL)
214         {
215             char msg[50];
216             sprintf( msg, "incorrect type of %d-predictor", i );
217             CV_Error( CV_StsBadArg, msg );
218         }
219 }
220 bool svm_train( CvSVM* svm, CvMLData* _data, CvSVMParams _params )
221 {
222     svm_check_data(_data);
223     const CvMat* _train_data = _data->get_values();
224     const CvMat* _responses = _data->get_responses();
225     const CvMat* _var_idx = _data->get_var_idx();
226     const CvMat* _sample_idx = _data->get_train_sample_idx();
227     return svm->train( _train_data, _responses, _var_idx, _sample_idx, _params );
228 }
229 bool svm_train_auto( CvSVM* svm, CvMLData* _data, CvSVMParams _params,
230                     int k_fold, CvParamGrid C_grid, CvParamGrid gamma_grid,
231                     CvParamGrid p_grid, CvParamGrid nu_grid, CvParamGrid coef_grid,
232                     CvParamGrid degree_grid )
233 {
234     svm_check_data(_data);
235     const CvMat* _train_data = _data->get_values();
236     const CvMat* _responses = _data->get_responses();
237     const CvMat* _var_idx = _data->get_var_idx();
238     const CvMat* _sample_idx = _data->get_train_sample_idx();
239     return svm->train_auto( _train_data, _responses, _var_idx, 
240         _sample_idx, _params, k_fold, C_grid, gamma_grid, p_grid, nu_grid, coef_grid, degree_grid );
241 }
242 float svm_calc_error( CvSVM* svm, CvMLData* _data, int type, vector<float> *resp )
243 {
244     svm_check_data(_data);
245     float err = 0;\r
246     const CvMat* values = _data->get_values();\r
247     const CvMat* response = _data->get_responses();\r
248     const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();\r
249     const CvMat* var_types = _data->get_var_types();\r
250     int* sidx = sample_idx ? sample_idx->data.i : 0;\r
251     int r_step = CV_IS_MAT_CONT(response->type) ?\r
252         1 : response->step / CV_ELEM_SIZE(response->type);\r
253     bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL;\r
254     int sample_count = sample_idx ? sample_idx->cols : 0;\r
255     sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? values->rows : sample_count;\r
256     float* pred_resp = 0;\r
257     if( resp && (sample_count > 0) )\r
258     {\r
259         resp->resize( sample_count );\r
260         pred_resp = &((*resp)[0]);\r
261     }\r
262     if ( is_classifier )\r
263     {\r
264         for( int i = 0; i < sample_count; i++ )\r
265         {\r
266             CvMat sample;\r
267             int si = sidx ? sidx[i] : i;\r
268             cvGetRow( values, &sample, si ); \r
269             float r = svm->predict( &sample );\r
270             if( pred_resp )\r
271                 pred_resp[i] = r;\r
272             int d = fabs((double)r - response->data.fl[si*r_step]) <= FLT_EPSILON ? 0 : 1;\r
273             err += d;\r
274         }\r
275         err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;\r
276     }\r
277     else\r
278     {\r
279         for( int i = 0; i < sample_count; i++ )\r
280         {\r
281             CvMat sample;\r
282             int si = sidx ? sidx[i] : i;\r
283             cvGetRow( values, &sample, si );\r
284             float r = svm->predict( &sample );\r
285             if( pred_resp )\r
286                 pred_resp[i] = r;\r
287             float d = r - response->data.fl[si*r_step];\r
288             err += d*d;\r
289         }\r
290         err = sample_count ? err / (float)sample_count : -FLT_MAX;    \r
291     }\r
292     return err;
293 }\r
294 \r
295 // 4. em\r
296 // 5. ann\r
297 int str_to_ann_train_method( string& str )\r
298 {\r
299     if( !str.compare("BACKPROP") )\r
300         return CvANN_MLP_TrainParams::BACKPROP;\r
301     if( !str.compare("RPROP") )\r
302         return CvANN_MLP_TrainParams::RPROP;\r
303     CV_Error( CV_StsBadArg, "incorrect ann train method string" );\r
304     return -1;\r
305 }\r
306 void ann_check_data_and_get_predictors( CvMLData* _data, CvMat* _inputs )
307 {
308     const CvMat* values = _data->get_values();
309     const CvMat* var_idx = _data->get_var_idx();
310     if( var_idx->cols + var_idx->rows != values->cols )\r
311         CV_Error( CV_StsBadArg, "var_idx is not supported" );\r
312     if( _data->get_missing() )\r
313         CV_Error( CV_StsBadArg, "missing values are not supported" );\r
314     int resp_idx = _data->get_response_idx();\r
315     if( resp_idx == 0)\r
316         cvGetCols( values, _inputs, 1, values->cols );\r
317     else if( resp_idx == values->cols - 1 )\r
318         cvGetCols( values, _inputs, 0, values->cols - 1 );\r
319     else\r
320         CV_Error( CV_StsBadArg, "outputs must be in the first or last column; other cases are not supported" );\r
321 }
322 void ann_get_new_responses( CvMLData* _data, Mat& new_responses, map<int, int>& cls_map )
323 {
324     const CvMat* train_sidx = _data->get_train_sample_idx();
325     int* train_sidx_ptr = train_sidx->data.i;\r
326     const CvMat* responses = _data->get_responses();\r
327     float* responses_ptr = responses->data.fl;\r
328     int r_step = CV_IS_MAT_CONT(responses->type) ?\r
329         1 : responses->step / CV_ELEM_SIZE(responses->type);\r
330     int cls_count = 0;\r
331     // construct cls_map\r
332     cls_map.clear();\r
333     for( int si = 0; si < train_sidx->cols; si++ )\r
334     {\r
335         int sidx = train_sidx_ptr[si];\r
336         int r = cvRound(responses_ptr[sidx*r_step]);\r
337         CV_DbgAssert( fabs(responses_ptr[sidx*r_step]-r) < FLT_EPSILON );\r
338         int cls_map_size = (int)cls_map.size();\r
339         cls_map[r];\r
340         if ( (int)cls_map.size() > cls_map_size )\r
341             cls_map[r] = cls_count++;\r
342     }\r
343     new_responses.create( responses->rows, cls_count, CV_32F );\r
344     new_responses.setTo( 0 );\r
345     for( int si = 0; si < train_sidx->cols; si++ )\r
346     {\r
347         int sidx = train_sidx_ptr[si];\r
348         int r = cvRound(responses_ptr[sidx*r_step]);\r
349         int cidx = cls_map[r];\r
350         new_responses.ptr<float>(sidx)[cidx] = 1;\r
351     }
352 }
353 int ann_train( CvANN_MLP* ann, CvMLData* _data, Mat& new_responses, CvANN_MLP_TrainParams _params, int flags = 0 )
354 {
355     const CvMat* train_sidx = _data->get_train_sample_idx();\r
356     CvMat predictors;\r
357     ann_check_data_and_get_predictors( _data, &predictors );\r
358     CvMat _new_responses = CvMat( new_responses );\r
359     return ann->train( &predictors, &_new_responses, 0, train_sidx, _params, flags );
360 }
361 float ann_calc_error( CvANN_MLP* ann, CvMLData* _data, map<int, int>& cls_map, int type , vector<float> *resp_labels )
362 {
363     float err = 0;\r
364     const CvMat* responses = _data->get_responses();\r
365     const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();\r
366     int* sidx = sample_idx ? sample_idx->data.i : 0;\r
367     int r_step = CV_IS_MAT_CONT(responses->type) ?\r
368         1 : responses->step / CV_ELEM_SIZE(responses->type);\r
369     CvMat predictors;\r
370     ann_check_data_and_get_predictors( _data, &predictors );\r
371     int sample_count = sample_idx ? sample_idx->cols : 0;\r
372     sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? predictors.rows : sample_count;\r
373     float* pred_resp = 0;\r
374     vector<float> innresp;\r
375     if( sample_count > 0 )\r
376     {\r
377         if( resp_labels )\r
378         {\r
379             resp_labels->resize( sample_count );\r
380             pred_resp = &((*resp_labels)[0]);\r
381         }\r
382         else\r
383         {\r
384             innresp.resize( sample_count );\r
385             pred_resp = &(innresp[0]);\r
386         }\r
387     }\r
388     int cls_count = (int)cls_map.size();\r
389     Mat output( 1, cls_count, CV_32FC1 );\r
390     CvMat _output = CvMat(output);\r
391     map<int, int>::iterator b_it = cls_map.begin();\r
392     for( int i = 0; i < sample_count; i++ )\r
393     {\r
394         CvMat sample;\r
395         int si = sidx ? sidx[i] : i;\r
396         cvGetRow( &predictors, &sample, si ); \r
397         ann->predict( &sample, &_output );\r
398         CvPoint best_cls = {0,0};\r
399         cvMinMaxLoc( &_output, 0, 0, 0, &best_cls, 0 );
400         int r = cvRound(responses->data.fl[si*r_step]);
401         CV_DbgAssert( fabs(responses->data.fl[si*r_step]-r) < FLT_EPSILON );
402         r = cls_map[r];
403         int d = best_cls.x == r ? 0 : 1;
404         err += d;\r
405         pred_resp[i] = (float)best_cls.x;\r
406     }\r
407     err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;\r
408     return err;
409 }\r
410 \r
411 // 6. dtree\r
412 // 7. boost\r
413 int str_to_boost_type( string& str )\r
414 {\r
415     if ( !str.compare("DISCRETE") )\r
416         return CvBoost::DISCRETE;\r
417     if ( !str.compare("REAL") )\r
418         return CvBoost::REAL;    \r
419     if ( !str.compare("LOGIT") )\r
420         return CvBoost::LOGIT;\r
421     if ( !str.compare("GENTLE") )\r
422         return CvBoost::GENTLE;\r
423     CV_Error( CV_StsBadArg, "incorrect boost type string" );\r
424     return -1;\r
425 }\r
426 \r
427 // 8. rtrees\r
428 // 9. ertrees\r
429 \r
430 // ---------------------------------- MLBaseTest ---------------------------------------------------\r
431 \r
432 CV_MLBaseTest::CV_MLBaseTest( const char* _modelName, const char* _testName, const char* _testFuncs ) :
433 CvTest( _testName, _testFuncs )
434 {
435     modelName = _modelName;
436     nbayes = 0;\r
437     knearest = 0;\r
438     svm = 0;\r
439     em = 0;\r
440     ann = 0;\r
441     dtree = 0;\r
442     boost = 0;\r
443     rtrees = 0;\r
444     ertrees = 0;
445     if( !modelName.compare(CV_NBAYES) )
446         nbayes = new CvNormalBayesClassifier;
447     else if( !modelName.compare(CV_KNEAREST) )
448         knearest = new CvKNearest;
449     else if( !modelName.compare(CV_SVM) )
450         svm = new CvSVM;
451     else if( !modelName.compare(CV_EM) )
452         em = new CvEM;
453     else if( !modelName.compare(CV_ANN) )
454         ann = new CvANN_MLP;
455     else if( !modelName.compare(CV_DTREE) )
456         dtree = new CvDTree;
457     else if( !modelName.compare(CV_BOOST) )
458         boost = new CvBoost;
459     else if( !modelName.compare(CV_RTREES) )
460         rtrees = new CvRTrees;
461     else if( !modelName.compare(CV_ERTREES) )
462         ertrees = new CvERTrees;
463 }
464 \r
465 int CV_MLBaseTest::init( CvTS* system )\r
466 {\r
467     clear();
468     ts = system;
469
470     string filename = ts->get_data_path();
471     filename += get_validation_filename();
472     validationFS.open( filename, FileStorage::READ );
473     return read_params( *validationFS );\r
474 }\r
475 \r
476 CV_MLBaseTest::~CV_MLBaseTest()\r
477 {\r
478     if( validationFS.isOpened() )\r
479         validationFS.release();\r
480     if( nbayes )\r
481         delete nbayes;\r
482     if( knearest ) \r
483         delete knearest;\r
484     if( svm )\r
485         delete svm;\r
486     if( em )\r
487         delete em;\r
488     if( ann )\r
489         delete ann;\r
490     if( dtree )\r
491         delete dtree;\r
492     if( boost )\r
493         delete boost;\r
494     if( rtrees )\r
495         delete rtrees;\r
496     if( ertrees )\r
497         delete ertrees;\r
498 }\r
499 \r
500 int CV_MLBaseTest::read_params( CvFileStorage* _fs )\r
501 {\r
502     if( !_fs )\r
503         test_case_count = -1;\r
504     else\r
505     {\r
506         CvFileNode* fn = cvGetRootFileNode( _fs, 0 );\r
507         fn = (CvFileNode*)cvGetSeqElem( fn->data.seq, 0 );\r
508         fn = cvGetFileNodeByName( _fs, fn, "run_params" );\r
509         CvSeq* dataSetNamesSeq = cvGetFileNodeByName( _fs, fn, modelName.c_str() )->data.seq;\r
510         test_case_count = dataSetNamesSeq ? dataSetNamesSeq->total : -1;\r
511         if( test_case_count > 0 )\r
512         {\r
513             dataSetNames.resize( test_case_count );\r
514             vector<string>::iterator it = dataSetNames.begin();\r
515             for( int i = 0; i < test_case_count; i++, it++ )\r
516                 *it = ((CvFileNode*)cvGetSeqElem( dataSetNamesSeq, i ))->data.str.ptr;\r
517         }\r
518     }\r
519     return CvTS::OK;;\r
520 }\r
521 \r
522 void CV_MLBaseTest::run( int start_from )
523 {
524     int code = CvTS::OK;
525     start_from = 0;
526     for (int i = 0; i < test_case_count; i++)
527     {
528         int temp_code = run_test_case( i );
529         if (temp_code == CvTS::OK)
530             temp_code = validate_test_results( i );
531         if (temp_code != CvTS::OK)
532             code = temp_code;
533     }
534     if ( test_case_count <= 0)
535     {\r
536         ts->printf( CvTS::LOG, "validation file is not determined or not correct" );
537         code = CvTS::FAIL_INVALID_TEST_DATA;\r
538     }
539     ts->set_failed_test_info( code );
540 }
541
542 int CV_MLBaseTest::prepare_test_case( int test_case_idx )
543 {
544     int trainSampleCount, respIdx;
545     string varTypes;
546     clear();
547 \r
548     string dataPath = ts->get_data_path();
549     if ( dataPath.empty() )\r
550     {\r
551         ts->printf( CvTS::LOG, "data path is empty" );
552         return CvTS::FAIL_INVALID_TEST_DATA;\r
553     }
554
555     string dataName = dataSetNames[test_case_idx],
556         filename = dataPath + dataName + ".data";
557     if ( data.read_csv( filename.c_str() ) != 0)
558     {\r
559         char msg[100];\r
560         sprintf( msg, "file %s can not be read", filename.c_str() );\r
561         ts->printf( CvTS::LOG, msg );
562         return CvTS::FAIL_INVALID_TEST_DATA;\r
563     }
564
565     FileNode dataParamsNode = validationFS.getFirstTopLevelNode()["validation"][modelName][dataName]["data_params"];
566     CV_DbgAssert( !dataParamsNode.empty() );
567
568     CV_DbgAssert( !dataParamsNode["LS"].empty() );
569     dataParamsNode["LS"] >> trainSampleCount;
570     CvTrainTestSplit spl( trainSampleCount );
571     data.set_train_test_split( &spl );
572
573     CV_DbgAssert( !dataParamsNode["resp_idx"].empty() );
574     dataParamsNode["resp_idx"] >> respIdx;
575     data.set_response_idx( respIdx );
576
577     CV_DbgAssert( !dataParamsNode["types"].empty() );
578     dataParamsNode["types"] >> varTypes;
579     data.set_var_types( varTypes.c_str() );
580
581     return CvTS::OK;
582 }
583
584 string& CV_MLBaseTest::get_validation_filename()
585 {
586     return validationFN;
587 }
588
589 int CV_MLBaseTest::train( int testCaseIdx )\r
590 {\r
591     bool is_trained = false;\r
592     FileNode modelParamsNode = 
593         validationFS.getFirstTopLevelNode()["validation"][modelName][dataSetNames[testCaseIdx]]["model_params"];
594 \r
595     if( !modelName.compare(CV_NBAYES) )
596         is_trained = nbayes_train( nbayes, &data );
597     else if( !modelName.compare(CV_KNEAREST) )
598     {
599         assert( 0 );
600         //is_trained = knearest->train( &data );
601     }
602     else if( !modelName.compare(CV_SVM) )
603     {
604         string svm_type_str, kernel_type_str;
605         modelParamsNode["svm_type"] >> svm_type_str;
606         modelParamsNode["kernel_type"] >> kernel_type_str;
607         CvSVMParams params;
608         params.svm_type = str_to_svm_type( svm_type_str );
609         params.kernel_type = str_to_svm_kernel_type( kernel_type_str );
610         modelParamsNode["degree"] >> params.degree;
611         modelParamsNode["gamma"] >> params.gamma;
612         modelParamsNode["coef0"] >> params.coef0;
613         modelParamsNode["C"] >> params.C;
614         modelParamsNode["nu"] >> params.nu;
615         modelParamsNode["p"] >> params.p;
616         is_trained = svm_train( svm, &data, params );
617     }
618     else if( !modelName.compare(CV_EM) )
619     {
620         assert( 0 );
621     }
622     else if( !modelName.compare(CV_ANN) )
623     {
624         string train_method_str;
625         double param1, param2;
626         modelParamsNode["train_method"] >> train_method_str;
627         modelParamsNode["param1"] >> param1;
628         modelParamsNode["param2"] >> param2;
629         Mat new_responses;
630         ann_get_new_responses( &data, new_responses, cls_map );
631         int layer_sz[] = { data.get_values()->cols - 1, 100, 100, (int)cls_map.size() };
632         CvMat layer_sizes =
633             cvMat( 1, (int)(sizeof(layer_sz)/sizeof(layer_sz[0])), CV_32S, layer_sz );
634         ann->create( &layer_sizes );
635         is_trained = ann_train( ann, &data, new_responses, CvANN_MLP_TrainParams(cvTermCriteria(CV_TERMCRIT_ITER,300,0.01),
636             str_to_ann_train_method(train_method_str), param1, param2) ) >= 0;
637     }
638     else if( !modelName.compare(CV_DTREE) )
639     {
640         int MAX_DEPTH, MIN_SAMPLE_COUNT, MAX_CATEGORIES, CV_FOLDS;\r
641         float REG_ACCURACY = 0;\r
642         bool USE_SURROGATE, IS_PRUNED;
643         modelParamsNode["max_depth"] >> MAX_DEPTH;
644         modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT;
645         modelParamsNode["use_surrogate"] >> USE_SURROGATE;
646         modelParamsNode["max_categories"] >> MAX_CATEGORIES;
647         modelParamsNode["cv_folds"] >> CV_FOLDS;
648         modelParamsNode["is_pruned"] >> IS_PRUNED;
649         is_trained = dtree->train( &data, \r
650             CvDTreeParams(MAX_DEPTH, MIN_SAMPLE_COUNT, REG_ACCURACY, USE_SURROGATE,\r
651             MAX_CATEGORIES, CV_FOLDS, false, IS_PRUNED, 0 )) != 0;\r
652     }
653     else if( !modelName.compare(CV_BOOST) )
654     {
655         int BOOST_TYPE, WEAK_COUNT, MAX_DEPTH;\r
656         float WEIGHT_TRIM_RATE;\r
657         bool USE_SURROGATE;
658         string typeStr;
659         modelParamsNode["type"] >> typeStr;
660         BOOST_TYPE = str_to_boost_type( typeStr );
661         modelParamsNode["weak_count"] >> WEAK_COUNT;\r
662         modelParamsNode["weight_trim_rate"] >> WEIGHT_TRIM_RATE;\r
663         modelParamsNode["max_depth"] >> MAX_DEPTH;\r
664         modelParamsNode["use_surrogate"] >> USE_SURROGATE;\r
665         is_trained = boost->train( &data,\r
666             CvBoostParams(BOOST_TYPE, WEAK_COUNT, WEIGHT_TRIM_RATE, MAX_DEPTH, USE_SURROGATE, 0) ) != 0;
667     }
668     else if( !modelName.compare(CV_RTREES) )
669     {
670         int MAX_DEPTH, MIN_SAMPLE_COUNT, MAX_CATEGORIES, CV_FOLDS, NACTIVE_VARS, MAX_TREES_NUM;\r
671         float REG_ACCURACY = 0, OOB_EPS = 0.0;\r
672         bool USE_SURROGATE, IS_PRUNED;
673         modelParamsNode["max_depth"] >> MAX_DEPTH;
674         modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT;
675         modelParamsNode["use_surrogate"] >> USE_SURROGATE;
676         modelParamsNode["max_categories"] >> MAX_CATEGORIES;
677         modelParamsNode["cv_folds"] >> CV_FOLDS;
678         modelParamsNode["is_pruned"] >> IS_PRUNED;
679         modelParamsNode["nactive_vars"] >> NACTIVE_VARS;
680         modelParamsNode["max_trees_num"] >> MAX_TREES_NUM;
681         is_trained = rtrees->train( &data, CvRTParams(  MAX_DEPTH, MIN_SAMPLE_COUNT, REG_ACCURACY,\r
682             USE_SURROGATE, MAX_CATEGORIES, 0, true, // (calc_var_importance == true) <=> RF processes variable importance\r
683             NACTIVE_VARS, MAX_TREES_NUM, OOB_EPS, CV_TERMCRIT_ITER)) != 0;
684     }
685     else if( !modelName.compare(CV_ERTREES) )
686     {\r
687         int MAX_DEPTH, MIN_SAMPLE_COUNT, MAX_CATEGORIES, CV_FOLDS, NACTIVE_VARS, MAX_TREES_NUM;\r
688         float REG_ACCURACY = 0, OOB_EPS = 0.0;\r
689         bool USE_SURROGATE, IS_PRUNED;\r
690         modelParamsNode["max_depth"] >> MAX_DEPTH;
691         modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT;
692         modelParamsNode["use_surrogate"] >> USE_SURROGATE;
693         modelParamsNode["max_categories"] >> MAX_CATEGORIES;
694         modelParamsNode["cv_folds"] >> CV_FOLDS;
695         modelParamsNode["is_pruned"] >> IS_PRUNED;
696         modelParamsNode["nactive_vars"] >> NACTIVE_VARS;
697         modelParamsNode["max_trees_num"] >> MAX_TREES_NUM;
698         is_trained = ertrees->train( &data, CvRTParams( MAX_DEPTH, MIN_SAMPLE_COUNT, REG_ACCURACY,\r
699             USE_SURROGATE, MAX_CATEGORIES, 0, false, // (calc_var_importance == true) <=> RF processes variable importance\r
700             NACTIVE_VARS, MAX_TREES_NUM, OOB_EPS, CV_TERMCRIT_ITER)) != 0;\r
701     }\r
702 \r
703     if( !is_trained )\r
704     {\r
705         ts->printf( CvTS::LOG, "in test case %d model training was failed", testCaseIdx );\r
706         return CvTS::FAIL_INVALID_OUTPUT;\r
707     }\r
708     return CvTS::OK;\r
709 }\r
710 \r
711 float CV_MLBaseTest::get_error( int testCaseIdx, int type, vector<float> *resp )\r
712 {\r
713     float err = 0;\r
714     if( !modelName.compare(CV_NBAYES) )
715         err = nbayes_calc_error( nbayes, &data, type, resp );
716     else if( !modelName.compare(CV_KNEAREST) )
717     {
718         assert( 0 );
719         testCaseIdx = 0;
720         /*int k = 2;
721         validationFS.getFirstTopLevelNode()["validation"][modelName][dataSetNames[testCaseIdx]]["model_params"]["k"] >> k;
722         err = knearest->calc_error( &data, k, type, resp );*/
723     }
724     else if( !modelName.compare(CV_SVM) )
725         err = svm_calc_error( svm, &data, type, resp );
726     else if( !modelName.compare(CV_EM) )
727         assert( 0 );
728     else if( !modelName.compare(CV_ANN) )
729         err = ann_calc_error( ann, &data, cls_map, type, resp );
730     else if( !modelName.compare(CV_DTREE) )
731         err = dtree->calc_error( &data, type, resp );
732     else if( !modelName.compare(CV_BOOST) )
733         err = boost->calc_error( &data, type, resp );
734     else if( !modelName.compare(CV_RTREES) )
735         err = rtrees->calc_error( &data, type, resp );
736     else if( !modelName.compare(CV_ERTREES) )
737         err = ertrees->calc_error( &data, type, resp );\r
738     return err;\r
739 }\r
740 \r
741 void CV_MLBaseTest::save( const char* filename )\r
742 {\r
743     if( !modelName.compare(CV_NBAYES) )
744         nbayes->save( filename );
745     else if( !modelName.compare(CV_KNEAREST) )
746         knearest->save( filename );
747     else if( !modelName.compare(CV_SVM) )
748         svm->save( filename );
749     else if( !modelName.compare(CV_EM) )
750         em->save( filename );
751     else if( !modelName.compare(CV_ANN) )
752         ann->save( filename );
753     else if( !modelName.compare(CV_DTREE) )
754         dtree->save( filename );
755     else if( !modelName.compare(CV_BOOST) )
756         boost->save( filename );
757     else if( !modelName.compare(CV_RTREES) )
758         rtrees->save( filename );
759     else if( !modelName.compare(CV_ERTREES) )
760         ertrees->save( filename );\r
761 }\r
762 \r
763 void CV_MLBaseTest::load( const char* filename )\r
764 {\r
765     if( !modelName.compare(CV_NBAYES) )
766         nbayes->load( filename );
767     else if( !modelName.compare(CV_KNEAREST) )
768         knearest->load( filename );
769     else if( !modelName.compare(CV_SVM) )
770         svm->load( filename );
771     else if( !modelName.compare(CV_EM) )
772         em->load( filename );
773     else if( !modelName.compare(CV_ANN) )
774         ann->load( filename );
775     else if( !modelName.compare(CV_DTREE) )
776         dtree->load( filename );
777     else if( !modelName.compare(CV_BOOST) )
778         boost->load( filename );
779     else if( !modelName.compare(CV_RTREES) )
780         rtrees->load( filename );
781     else if( !modelName.compare(CV_ERTREES) )
782         ertrees->load( filename );\r
783 }\r
784 \r
785 /* End of file. */\r