Update to 2.0.0 tree from current Fremantle build
[opencv] / samples / c / tree_engine.cpp
diff --git a/samples/c/tree_engine.cpp b/samples/c/tree_engine.cpp
new file mode 100644 (file)
index 0000000..fc93b1a
--- /dev/null
@@ -0,0 +1,79 @@
+#include "ml.h"
+#include <stdio.h>
+/*
+The sample demonstrates how to use different decision trees.
+*/
+void print_result(float train_err, float test_err, const CvMat* var_imp)
+{
+    printf( "train error    %f\n", train_err );
+    printf( "test error    %f\n\n", test_err );
+       
+    if (var_imp)
+    {
+        bool is_flt = false;
+        if ( CV_MAT_TYPE( var_imp->type ) == CV_32FC1)
+            is_flt = true;
+        printf( "variable impotance\n" );
+        for( int i = 0; i < var_imp->cols; i++)
+        {
+            printf( "%d     %f\n", i, is_flt ? var_imp->data.fl[i] : var_imp->data.db[i] );
+        }
+    }
+    printf("\n");
+}
+
+int main()
+{
+    const int train_sample_count = 300;
+
+//#define LEPIOTA
+#ifdef LEPIOTA
+    const char* filename = "../../../OpenCV/samples/c/agaricus-lepiota.data";
+#else
+    const char* filename = "../../../OpenCV/samples/c/waveform.data";
+#endif
+
+    CvDTree dtree;
+    CvBoost boost;
+    CvRTrees rtrees;
+    CvERTrees ertrees;
+
+    CvMLData data;
+
+    CvTrainTestSplit spl( train_sample_count );
+    
+    if ( data.read_csv( filename ) == 0)
+    {
+
+#ifdef LEPIOTA
+        data.set_response_idx( 0 );     
+#else
+        data.set_response_idx( 21 );     
+        data.change_var_type( 21, CV_VAR_CATEGORICAL );
+#endif
+
+        data.set_train_test_split( &spl );
+        
+        printf("======DTREE=====\n");
+        dtree.train( &data, CvDTreeParams( 10, 2, 0, false, 16, 0, false, false, 0 ));
+        print_result( dtree.calc_error( &data, CV_TRAIN_ERROR), dtree.calc_error( &data, CV_TEST_ERROR ), dtree.get_var_importance() );
+
+#ifdef LEPIOTA
+        printf("======BOOST=====\n");
+        boost.train( &data, CvBoostParams(CvBoost::DISCRETE, 100, 0.95, 2, false, 0));
+        print_result( boost.calc_error( &data, CV_TRAIN_ERROR ), boost.calc_error( &data ), 0 );
+#endif
+
+        printf("======RTREES=====\n");
+        rtrees.train( &data, CvRTParams( 10, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER ));
+        print_result( rtrees.calc_error( &data, CV_TRAIN_ERROR), rtrees.calc_error( &data, CV_TEST_ERROR ), rtrees.get_var_importance() );
+
+        printf("======ERTREES=====\n");
+        ertrees.train( &data, CvRTParams( 10, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER ));
+        print_result( ertrees.calc_error( &data, CV_TRAIN_ERROR), ertrees.calc_error( &data, CV_TEST_ERROR ), ertrees.get_var_importance() );
+    }
+    else
+        printf("File can not be read");
+
+    return 0;
+}
\ No newline at end of file