X-Git-Url: http://vcs.maemo.org/git/?a=blobdiff_plain;f=samples%2Fc%2Ftree_engine.cpp;fp=samples%2Fc%2Ftree_engine.cpp;h=fc93b1a663c253ed919f0933be654247c11fe7f3;hb=e4c14cdbdf2fe805e79cd96ded236f57e7b89060;hp=0000000000000000000000000000000000000000;hpb=454138ff8a20f6edb9b65a910101403d8b520643;p=opencv diff --git a/samples/c/tree_engine.cpp b/samples/c/tree_engine.cpp new file mode 100644 index 0000000..fc93b1a --- /dev/null +++ b/samples/c/tree_engine.cpp @@ -0,0 +1,79 @@ +#include "ml.h" +#include +/* +The sample demonstrates how to use different decision trees. +*/ +void print_result(float train_err, float test_err, const CvMat* var_imp) +{ + printf( "train error %f\n", train_err ); + printf( "test error %f\n\n", test_err ); + + if (var_imp) + { + bool is_flt = false; + if ( CV_MAT_TYPE( var_imp->type ) == CV_32FC1) + is_flt = true; + printf( "variable impotance\n" ); + for( int i = 0; i < var_imp->cols; i++) + { + printf( "%d %f\n", i, is_flt ? var_imp->data.fl[i] : var_imp->data.db[i] ); + } + } + printf("\n"); +} + +int main() +{ + const int train_sample_count = 300; + +//#define LEPIOTA +#ifdef LEPIOTA + const char* filename = "../../../OpenCV/samples/c/agaricus-lepiota.data"; +#else + const char* filename = "../../../OpenCV/samples/c/waveform.data"; +#endif + + CvDTree dtree; + CvBoost boost; + CvRTrees rtrees; + CvERTrees ertrees; + + CvMLData data; + + CvTrainTestSplit spl( train_sample_count ); + + if ( data.read_csv( filename ) == 0) + { + +#ifdef LEPIOTA + data.set_response_idx( 0 ); +#else + data.set_response_idx( 21 ); + data.change_var_type( 21, CV_VAR_CATEGORICAL ); +#endif + + data.set_train_test_split( &spl ); + + printf("======DTREE=====\n"); + dtree.train( &data, CvDTreeParams( 10, 2, 0, false, 16, 0, false, false, 0 )); + print_result( dtree.calc_error( &data, CV_TRAIN_ERROR), dtree.calc_error( &data, CV_TEST_ERROR ), dtree.get_var_importance() ); + +#ifdef LEPIOTA + printf("======BOOST=====\n"); + boost.train( &data, CvBoostParams(CvBoost::DISCRETE, 100, 0.95, 2, false, 0)); + print_result( boost.calc_error( &data, CV_TRAIN_ERROR ), boost.calc_error( &data ), 0 ); +#endif + + printf("======RTREES=====\n"); + rtrees.train( &data, CvRTParams( 10, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER )); + print_result( rtrees.calc_error( &data, CV_TRAIN_ERROR), rtrees.calc_error( &data, CV_TEST_ERROR ), rtrees.get_var_importance() ); + + printf("======ERTREES=====\n"); + ertrees.train( &data, CvRTParams( 10, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER )); + print_result( ertrees.calc_error( &data, CV_TRAIN_ERROR), ertrees.calc_error( &data, CV_TEST_ERROR ), ertrees.get_var_importance() ); + } + else + printf("File can not be read"); + + return 0; +} \ No newline at end of file