Update to 2.0.0 tree from current Fremantle build
[opencv] / samples / c / tree_engine.cpp
1 #include "ml.h"
2 #include <stdio.h>
3 /*
4 The sample demonstrates how to use different decision trees.
5 */
6 void print_result(float train_err, float test_err, const CvMat* var_imp)
7 {
8     printf( "train error    %f\n", train_err );
9     printf( "test error    %f\n\n", test_err );
10        
11     if (var_imp)
12     {
13         bool is_flt = false;
14         if ( CV_MAT_TYPE( var_imp->type ) == CV_32FC1)
15             is_flt = true;
16         printf( "variable impotance\n" );
17         for( int i = 0; i < var_imp->cols; i++)
18         {
19             printf( "%d     %f\n", i, is_flt ? var_imp->data.fl[i] : var_imp->data.db[i] );
20         }
21     }
22     printf("\n");
23 }
24
25 int main()
26 {
27     const int train_sample_count = 300;
28
29 //#define LEPIOTA
30 #ifdef LEPIOTA
31     const char* filename = "../../../OpenCV/samples/c/agaricus-lepiota.data";
32 #else
33     const char* filename = "../../../OpenCV/samples/c/waveform.data";
34 #endif
35
36     CvDTree dtree;
37     CvBoost boost;
38     CvRTrees rtrees;
39     CvERTrees ertrees;
40
41     CvMLData data;
42
43     CvTrainTestSplit spl( train_sample_count );
44     
45     if ( data.read_csv( filename ) == 0)
46     {
47
48 #ifdef LEPIOTA
49         data.set_response_idx( 0 );     
50 #else
51         data.set_response_idx( 21 );     
52         data.change_var_type( 21, CV_VAR_CATEGORICAL );
53 #endif
54
55         data.set_train_test_split( &spl );
56         
57         printf("======DTREE=====\n");
58         dtree.train( &data, CvDTreeParams( 10, 2, 0, false, 16, 0, false, false, 0 ));
59         print_result( dtree.calc_error( &data, CV_TRAIN_ERROR), dtree.calc_error( &data, CV_TEST_ERROR ), dtree.get_var_importance() );
60
61 #ifdef LEPIOTA
62         printf("======BOOST=====\n");
63         boost.train( &data, CvBoostParams(CvBoost::DISCRETE, 100, 0.95, 2, false, 0));
64         print_result( boost.calc_error( &data, CV_TRAIN_ERROR ), boost.calc_error( &data ), 0 );
65 #endif
66
67         printf("======RTREES=====\n");
68         rtrees.train( &data, CvRTParams( 10, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER ));
69         print_result( rtrees.calc_error( &data, CV_TRAIN_ERROR), rtrees.calc_error( &data, CV_TEST_ERROR ), rtrees.get_var_importance() );
70
71         printf("======ERTREES=====\n");
72         ertrees.train( &data, CvRTParams( 10, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER ));
73         print_result( ertrees.calc_error( &data, CV_TRAIN_ERROR), ertrees.calc_error( &data, CV_TEST_ERROR ), ertrees.get_var_importance() );
74     }
75     else
76         printf("File can not be read");
77
78     return 0;
79 }