1 /*M///////////////////////////////////////////////////////////////////////////////////////
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
10 // Intel License Agreement
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
18 // * Redistribution's of source code must retain the above copyright notice,
19 // this list of conditions and the following disclaimer.
21 // * Redistribution's in binary form must reproduce the above copyright notice,
22 // this list of conditions and the following disclaimer in the documentation
23 // and/or other materials provided with the distribution.
25 // * The name of Intel Corporation may not be used to endorse or promote products
26 // derived from this software without specific prior written permission.
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
44 // disable deprecation warning which appears in VisualStudio 8.0
46 #pragma warning( disable : 4996 )
54 #if defined WIN32 || defined _WIN32 || defined WIN64 || defined _WIN64
58 #else // SKIP_INCLUDES
60 #if defined WIN32 || defined _WIN32 || defined WIN64 || defined _WIN64
61 #define CV_CDECL __cdecl
62 #define CV_STDCALL __stdcall
70 #define CV_EXTERN_C extern "C"
71 #define CV_DEFAULT(val) = val
74 #define CV_DEFAULT(val)
78 #ifndef CV_EXTERN_C_FUNCPTR
80 #define CV_EXTERN_C_FUNCPTR(x) extern "C" { typedef x; }
82 #define CV_EXTERN_C_FUNCPTR(x) typedef x
87 #if defined __cplusplus
88 #define CV_INLINE inline
89 #elif (defined WIN32 || defined _WIN32 || defined WIN64 || defined _WIN64) && !defined __GNUC__
90 #define CV_INLINE __inline
92 #define CV_INLINE static
94 #endif /* CV_INLINE */
96 #if (defined WIN32 || defined _WIN32 || defined WIN64 || defined _WIN64) && defined CVAPI_EXPORTS
97 #define CV_EXPORTS __declspec(dllexport)
103 #define CVAPI(rettype) CV_EXTERN_C CV_EXPORTS rettype CV_CDECL
106 #endif // SKIP_INCLUDES
111 // Apple defines a check() macro somewhere in the debug headers
112 // that interferes with a method definiton in this header
117 /****************************************************************************************\
118 * Main struct definitions *
119 \****************************************************************************************/
122 #define CV_LOG2PI (1.8378770664093454835606594728112)
124 /* columns of <trainData> matrix are training samples */
125 #define CV_COL_SAMPLE 0
127 /* rows of <trainData> matrix are training samples */
128 #define CV_ROW_SAMPLE 1
130 #define CV_IS_ROW_SAMPLE(flags) ((flags) & CV_ROW_SAMPLE)
146 /* A structure, representing the lattice range of statmodel parameters.
147 It is used for optimizing statmodel parameters by cross-validation method.
148 The lattice is logarithmic, so <step> must be greater then 1. */
149 typedef struct CvParamLattice
157 CV_INLINE CvParamLattice cvParamLattice( double min_val, double max_val,
161 pl.min_val = MIN( min_val, max_val );
162 pl.max_val = MAX( min_val, max_val );
163 pl.step = MAX( log_step, 1. );
167 CV_INLINE CvParamLattice cvDefaultParamLattice( void )
169 CvParamLattice pl = {0,0,0};
175 #define CV_VAR_NUMERICAL 0
176 #define CV_VAR_ORDERED 0
177 #define CV_VAR_CATEGORICAL 1
179 #define CV_TYPE_NAME_ML_SVM "opencv-ml-svm"
180 #define CV_TYPE_NAME_ML_KNN "opencv-ml-knn"
181 #define CV_TYPE_NAME_ML_NBAYES "opencv-ml-bayesian"
182 #define CV_TYPE_NAME_ML_EM "opencv-ml-em"
183 #define CV_TYPE_NAME_ML_BOOSTING "opencv-ml-boost-tree"
184 #define CV_TYPE_NAME_ML_TREE "opencv-ml-tree"
185 #define CV_TYPE_NAME_ML_ANN_MLP "opencv-ml-ann-mlp"
186 #define CV_TYPE_NAME_ML_CNN "opencv-ml-cnn"
187 #define CV_TYPE_NAME_ML_RTREES "opencv-ml-random-trees"
189 #define CV_TRAIN_ERROR 0
190 #define CV_TEST_ERROR 1
192 class CV_EXPORTS CvStatModel
196 virtual ~CvStatModel();
198 virtual void clear();
200 virtual void save( const char* filename, const char* name=0 ) const;
201 virtual void load( const char* filename, const char* name=0 );
203 virtual void write( CvFileStorage* storage, const char* name ) const;
204 virtual void read( CvFileStorage* storage, CvFileNode* node );
207 const char* default_model_name;
210 /****************************************************************************************\
211 * Normal Bayes Classifier *
212 \****************************************************************************************/
214 /* The structure, representing the grid range of statmodel parameters.
215 It is used for optimizing statmodel accuracy by varying model parameters,
216 the accuracy estimate being computed by cross-validation.
217 The grid is logarithmic, so <step> must be greater then 1. */
221 struct CV_EXPORTS CvParamGrid
224 enum { SVM_C=0, SVM_GAMMA=1, SVM_P=2, SVM_NU=3, SVM_COEF=4, SVM_DEGREE=5 };
228 min_val = max_val = step = 0;
231 CvParamGrid( double _min_val, double _max_val, double log_step )
237 //CvParamGrid( int param_id );
245 class CV_EXPORTS CvNormalBayesClassifier : public CvStatModel
248 CvNormalBayesClassifier();
249 virtual ~CvNormalBayesClassifier();
251 CvNormalBayesClassifier( const CvMat* _train_data, const CvMat* _responses,
252 const CvMat* _var_idx=0, const CvMat* _sample_idx=0 );
254 virtual bool train( const CvMat* _train_data, const CvMat* _responses,
255 const CvMat* _var_idx = 0, const CvMat* _sample_idx=0, bool update=false );
257 virtual float predict( const CvMat* _samples, CvMat* results=0 ) const;
258 virtual void clear();
261 CvNormalBayesClassifier( const cv::Mat& _train_data, const cv::Mat& _responses,
262 const cv::Mat& _var_idx=cv::Mat(), const cv::Mat& _sample_idx=cv::Mat() );
263 virtual bool train( const cv::Mat& _train_data, const cv::Mat& _responses,
264 const cv::Mat& _var_idx = cv::Mat(), const cv::Mat& _sample_idx=cv::Mat(),
266 virtual float predict( const cv::Mat& _samples, cv::Mat* results=0 ) const;
269 virtual void write( CvFileStorage* storage, const char* name ) const;
270 virtual void read( CvFileStorage* storage, CvFileNode* node );
273 int var_count, var_all;
280 CvMat** inv_eigen_values;
281 CvMat** cov_rotate_mats;
286 /****************************************************************************************\
287 * K-Nearest Neighbour Classifier *
288 \****************************************************************************************/
290 // k Nearest Neighbors
291 class CV_EXPORTS CvKNearest : public CvStatModel
296 virtual ~CvKNearest();
298 CvKNearest( const CvMat* _train_data, const CvMat* _responses,
299 const CvMat* _sample_idx=0, bool _is_regression=false, int max_k=32 );
301 virtual bool train( const CvMat* _train_data, const CvMat* _responses,
302 const CvMat* _sample_idx=0, bool is_regression=false,
303 int _max_k=32, bool _update_base=false );
305 virtual float find_nearest( const CvMat* _samples, int k, CvMat* results=0,
306 const float** neighbors=0, CvMat* neighbor_responses=0, CvMat* dist=0 ) const;
309 CvKNearest( const cv::Mat& _train_data, const cv::Mat& _responses,
310 const cv::Mat& _sample_idx=cv::Mat(), bool _is_regression=false, int max_k=32 );
312 virtual bool train( const cv::Mat& _train_data, const cv::Mat& _responses,
313 const cv::Mat& _sample_idx=cv::Mat(), bool is_regression=false,
314 int _max_k=32, bool _update_base=false );
316 virtual float find_nearest( const cv::Mat& _samples, int k, cv::Mat* results=0,
317 const float** neighbors=0,
318 cv::Mat* neighbor_responses=0,
319 cv::Mat* dist=0 ) const;
322 virtual void clear();
323 int get_max_k() const;
324 int get_var_count() const;
325 int get_sample_count() const;
326 bool is_regression() const;
330 virtual float write_results( int k, int k1, int start, int end,
331 const float* neighbor_responses, const float* dist, CvMat* _results,
332 CvMat* _neighbor_responses, CvMat* _dist, Cv32suf* sort_buf ) const;
334 virtual void find_neighbors_direct( const CvMat* _samples, int k, int start, int end,
335 float* neighbor_responses, const float** neighbors, float* dist ) const;
338 int max_k, var_count;
344 /****************************************************************************************\
345 * Support Vector Machines *
346 \****************************************************************************************/
348 // SVM training parameters
349 struct CV_EXPORTS CvSVMParams
352 CvSVMParams( int _svm_type, int _kernel_type,
353 double _degree, double _gamma, double _coef0,
354 double _C, double _nu, double _p,
355 CvMat* _class_weights, CvTermCriteria _term_crit );
359 double degree; // for poly
360 double gamma; // for poly/rbf/sigmoid
361 double coef0; // for poly/sigmoid
363 double C; // for CV_SVM_C_SVC, CV_SVM_EPS_SVR and CV_SVM_NU_SVR
364 double nu; // for CV_SVM_NU_SVC, CV_SVM_ONE_CLASS, and CV_SVM_NU_SVR
365 double p; // for CV_SVM_EPS_SVR
366 CvMat* class_weights; // for CV_SVM_C_SVC
367 CvTermCriteria term_crit; // termination criteria
371 struct CV_EXPORTS CvSVMKernel
373 typedef void (CvSVMKernel::*Calc)( int vec_count, int vec_size, const float** vecs,
374 const float* another, float* results );
376 CvSVMKernel( const CvSVMParams* _params, Calc _calc_func );
377 virtual bool create( const CvSVMParams* _params, Calc _calc_func );
378 virtual ~CvSVMKernel();
380 virtual void clear();
381 virtual void calc( int vcount, int n, const float** vecs, const float* another, float* results );
383 const CvSVMParams* params;
386 virtual void calc_non_rbf_base( int vec_count, int vec_size, const float** vecs,
387 const float* another, float* results,
388 double alpha, double beta );
390 virtual void calc_linear( int vec_count, int vec_size, const float** vecs,
391 const float* another, float* results );
392 virtual void calc_rbf( int vec_count, int vec_size, const float** vecs,
393 const float* another, float* results );
394 virtual void calc_poly( int vec_count, int vec_size, const float** vecs,
395 const float* another, float* results );
396 virtual void calc_sigmoid( int vec_count, int vec_size, const float** vecs,
397 const float* another, float* results );
401 struct CvSVMKernelRow
403 CvSVMKernelRow* prev;
404 CvSVMKernelRow* next;
409 struct CvSVMSolutionInfo
413 double upper_bound_p;
414 double upper_bound_n;
415 double r; // for Solver_NU
418 class CV_EXPORTS CvSVMSolver
421 typedef bool (CvSVMSolver::*SelectWorkingSet)( int& i, int& j );
422 typedef float* (CvSVMSolver::*GetRow)( int i, float* row, float* dst, bool existed );
423 typedef void (CvSVMSolver::*CalcRho)( double& rho, double& r );
427 CvSVMSolver( int count, int var_count, const float** samples, schar* y,
428 int alpha_count, double* alpha, double Cp, double Cn,
429 CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
430 SelectWorkingSet select_working_set, CalcRho calc_rho );
431 virtual bool create( int count, int var_count, const float** samples, schar* y,
432 int alpha_count, double* alpha, double Cp, double Cn,
433 CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
434 SelectWorkingSet select_working_set, CalcRho calc_rho );
435 virtual ~CvSVMSolver();
437 virtual void clear();
438 virtual bool solve_generic( CvSVMSolutionInfo& si );
440 virtual bool solve_c_svc( int count, int var_count, const float** samples, schar* y,
441 double Cp, double Cn, CvMemStorage* storage,
442 CvSVMKernel* kernel, double* alpha, CvSVMSolutionInfo& si );
443 virtual bool solve_nu_svc( int count, int var_count, const float** samples, schar* y,
444 CvMemStorage* storage, CvSVMKernel* kernel,
445 double* alpha, CvSVMSolutionInfo& si );
446 virtual bool solve_one_class( int count, int var_count, const float** samples,
447 CvMemStorage* storage, CvSVMKernel* kernel,
448 double* alpha, CvSVMSolutionInfo& si );
450 virtual bool solve_eps_svr( int count, int var_count, const float** samples, const float* y,
451 CvMemStorage* storage, CvSVMKernel* kernel,
452 double* alpha, CvSVMSolutionInfo& si );
454 virtual bool solve_nu_svr( int count, int var_count, const float** samples, const float* y,
455 CvMemStorage* storage, CvSVMKernel* kernel,
456 double* alpha, CvSVMSolutionInfo& si );
458 virtual float* get_row_base( int i, bool* _existed );
459 virtual float* get_row( int i, float* dst );
465 const float** samples;
466 const CvSVMParams* params;
467 CvMemStorage* storage;
468 CvSVMKernelRow lru_list;
469 CvSVMKernelRow* rows;
476 // -1 - lower bound, 0 - free, 1 - upper bound
484 double C[2]; // C[0] == Cn, C[1] == Cp
487 SelectWorkingSet select_working_set_func;
488 CalcRho calc_rho_func;
491 virtual bool select_working_set( int& i, int& j );
492 virtual bool select_working_set_nu_svm( int& i, int& j );
493 virtual void calc_rho( double& rho, double& r );
494 virtual void calc_rho_nu_svm( double& rho, double& r );
496 virtual float* get_row_svc( int i, float* row, float* dst, bool existed );
497 virtual float* get_row_one_class( int i, float* row, float* dst, bool existed );
498 virtual float* get_row_svr( int i, float* row, float* dst, bool existed );
502 struct CvSVMDecisionFunc
512 class CV_EXPORTS CvSVM : public CvStatModel
516 enum { C_SVC=100, NU_SVC=101, ONE_CLASS=102, EPS_SVR=103, NU_SVR=104 };
519 enum { LINEAR=0, POLY=1, RBF=2, SIGMOID=3 };
522 enum { C=0, GAMMA=1, P=2, NU=3, COEF=4, DEGREE=5 };
527 CvSVM( const CvMat* _train_data, const CvMat* _responses,
528 const CvMat* _var_idx=0, const CvMat* _sample_idx=0,
529 CvSVMParams _params=CvSVMParams() );
531 virtual bool train( const CvMat* _train_data, const CvMat* _responses,
532 const CvMat* _var_idx=0, const CvMat* _sample_idx=0,
533 CvSVMParams _params=CvSVMParams() );
535 virtual bool train_auto( const CvMat* _train_data, const CvMat* _responses,
536 const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params,
538 CvParamGrid C_grid = get_default_grid(CvSVM::C),
539 CvParamGrid gamma_grid = get_default_grid(CvSVM::GAMMA),
540 CvParamGrid p_grid = get_default_grid(CvSVM::P),
541 CvParamGrid nu_grid = get_default_grid(CvSVM::NU),
542 CvParamGrid coef_grid = get_default_grid(CvSVM::COEF),
543 CvParamGrid degree_grid = get_default_grid(CvSVM::DEGREE) );
545 virtual float predict( const CvMat* _sample, bool returnDFVal=false ) const;
548 CvSVM( const cv::Mat& _train_data, const cv::Mat& _responses,
549 const cv::Mat& _var_idx=cv::Mat(), const cv::Mat& _sample_idx=cv::Mat(),
550 CvSVMParams _params=CvSVMParams() );
552 virtual bool train( const cv::Mat& _train_data, const cv::Mat& _responses,
553 const cv::Mat& _var_idx=cv::Mat(), const cv::Mat& _sample_idx=cv::Mat(),
554 CvSVMParams _params=CvSVMParams() );
556 virtual bool train_auto( const cv::Mat& _train_data, const cv::Mat& _responses,
557 const cv::Mat& _var_idx, const cv::Mat& _sample_idx, CvSVMParams _params,
559 CvParamGrid C_grid = get_default_grid(CvSVM::C),
560 CvParamGrid gamma_grid = get_default_grid(CvSVM::GAMMA),
561 CvParamGrid p_grid = get_default_grid(CvSVM::P),
562 CvParamGrid nu_grid = get_default_grid(CvSVM::NU),
563 CvParamGrid coef_grid = get_default_grid(CvSVM::COEF),
564 CvParamGrid degree_grid = get_default_grid(CvSVM::DEGREE) );
565 virtual float predict( const cv::Mat& _sample, bool returnDFVal=false ) const;
568 virtual int get_support_vector_count() const;
569 virtual const float* get_support_vector(int i) const;
570 virtual CvSVMParams get_params() const { return params; };
571 virtual void clear();
573 static CvParamGrid get_default_grid( int param_id );
575 virtual void write( CvFileStorage* storage, const char* name ) const;
576 virtual void read( CvFileStorage* storage, CvFileNode* node );
577 int get_var_count() const { return var_idx ? var_idx->cols : var_all; }
581 virtual bool set_params( const CvSVMParams& _params );
582 virtual bool train1( int sample_count, int var_count, const float** samples,
583 const void* _responses, double Cp, double Cn,
584 CvMemStorage* _storage, double* alpha, double& rho );
585 virtual bool do_train( int svm_type, int sample_count, int var_count, const float** samples,
586 const CvMat* _responses, CvMemStorage* _storage, double* alpha );
587 virtual void create_kernel();
588 virtual void create_solver();
590 virtual void write_params( CvFileStorage* fs ) const;
591 virtual void read_params( CvFileStorage* fs, CvFileNode* node );
599 CvMat* class_weights;
600 CvSVMDecisionFunc* decision_func;
601 CvMemStorage* storage;
607 /****************************************************************************************\
608 * Expectation - Maximization *
609 \****************************************************************************************/
611 struct CV_EXPORTS CvEMParams
613 CvEMParams() : nclusters(10), cov_mat_type(1/*CvEM::COV_MAT_DIAGONAL*/),
614 start_step(0/*CvEM::START_AUTO_STEP*/), probs(0), weights(0), means(0), covs(0)
616 term_crit=cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON );
619 CvEMParams( int _nclusters, int _cov_mat_type=1/*CvEM::COV_MAT_DIAGONAL*/,
620 int _start_step=0/*CvEM::START_AUTO_STEP*/,
621 CvTermCriteria _term_crit=cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON),
622 const CvMat* _probs=0, const CvMat* _weights=0, const CvMat* _means=0, const CvMat** _covs=0 ) :
623 nclusters(_nclusters), cov_mat_type(_cov_mat_type), start_step(_start_step),
624 probs(_probs), weights(_weights), means(_means), covs(_covs), term_crit(_term_crit)
631 const CvMat* weights;
634 CvTermCriteria term_crit;
638 class CV_EXPORTS CvEM : public CvStatModel
641 // Type of covariation matrices
642 enum { COV_MAT_SPHERICAL=0, COV_MAT_DIAGONAL=1, COV_MAT_GENERIC=2 };
645 enum { START_E_STEP=1, START_M_STEP=2, START_AUTO_STEP=0 };
648 CvEM( const CvMat* samples, const CvMat* sample_idx=0,
649 CvEMParams params=CvEMParams(), CvMat* labels=0 );
650 //CvEM (CvEMParams params, CvMat * means, CvMat ** covs, CvMat * weights, CvMat * probs, CvMat * log_weight_div_det, CvMat * inv_eigen_values, CvMat** cov_rotate_mats);
654 virtual bool train( const CvMat* samples, const CvMat* sample_idx=0,
655 CvEMParams params=CvEMParams(), CvMat* labels=0 );
657 virtual float predict( const CvMat* sample, CvMat* probs ) const;
660 CvEM( const cv::Mat& samples, const cv::Mat& sample_idx=cv::Mat(),
661 CvEMParams params=CvEMParams(), cv::Mat* labels=0 );
663 virtual bool train( const cv::Mat& samples, const cv::Mat& sample_idx=cv::Mat(),
664 CvEMParams params=CvEMParams(), cv::Mat* labels=0 );
666 virtual float predict( const cv::Mat& sample, cv::Mat* probs ) const;
669 virtual void clear();
671 int get_nclusters() const;
672 const CvMat* get_means() const;
673 const CvMat** get_covs() const;
674 const CvMat* get_weights() const;
675 const CvMat* get_probs() const;
677 inline double get_log_likelihood () const { return log_likelihood; };
679 // inline const CvMat * get_log_weight_div_det () const { return log_weight_div_det; };
680 // inline const CvMat * get_inv_eigen_values () const { return inv_eigen_values; };
681 // inline const CvMat ** get_cov_rotate_mats () const { return cov_rotate_mats; };
685 virtual void set_params( const CvEMParams& params,
686 const CvVectors& train_data );
687 virtual void init_em( const CvVectors& train_data );
688 virtual double run_em( const CvVectors& train_data );
689 virtual void init_auto( const CvVectors& samples );
690 virtual void kmeans( const CvVectors& train_data, int nclusters,
691 CvMat* labels, CvTermCriteria criteria,
692 const CvMat* means );
694 double log_likelihood;
701 CvMat* log_weight_div_det;
702 CvMat* inv_eigen_values;
703 CvMat** cov_rotate_mats;
706 /****************************************************************************************\
708 \****************************************************************************************/\
716 #define CV_DTREE_CAT_DIR(idx,subset) \
717 (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
758 // global pruning data
761 double node_risk, tree_risk, tree_error;
763 // cross-validation pruning data
765 double* cv_node_risk;
766 double* cv_node_error;
768 int get_num_valid(int vi) { return num_valid ? num_valid[vi] : sample_count; }
769 void set_num_valid(int vi, int n) { if( num_valid ) num_valid[vi] = n; }
773 struct CV_EXPORTS CvDTreeParams
777 int min_sample_count;
781 bool truncate_pruned_tree;
782 float regression_accuracy;
785 CvDTreeParams() : max_categories(10), max_depth(INT_MAX), min_sample_count(10),
786 cv_folds(10), use_surrogates(true), use_1se_rule(true),
787 truncate_pruned_tree(true), regression_accuracy(0.01f), priors(0)
790 CvDTreeParams( int _max_depth, int _min_sample_count,
791 float _regression_accuracy, bool _use_surrogates,
792 int _max_categories, int _cv_folds,
793 bool _use_1se_rule, bool _truncate_pruned_tree,
794 const float* _priors ) :
795 max_categories(_max_categories), max_depth(_max_depth),
796 min_sample_count(_min_sample_count), cv_folds (_cv_folds),
797 use_surrogates(_use_surrogates), use_1se_rule(_use_1se_rule),
798 truncate_pruned_tree(_truncate_pruned_tree),
799 regression_accuracy(_regression_accuracy),
805 struct CV_EXPORTS CvDTreeTrainData
808 CvDTreeTrainData( const CvMat* _train_data, int _tflag,
809 const CvMat* _responses, const CvMat* _var_idx=0,
810 const CvMat* _sample_idx=0, const CvMat* _var_type=0,
811 const CvMat* _missing_mask=0,
812 const CvDTreeParams& _params=CvDTreeParams(),
813 bool _shared=false, bool _add_labels=false );
814 virtual ~CvDTreeTrainData();
816 virtual void set_data( const CvMat* _train_data, int _tflag,
817 const CvMat* _responses, const CvMat* _var_idx=0,
818 const CvMat* _sample_idx=0, const CvMat* _var_type=0,
819 const CvMat* _missing_mask=0,
820 const CvDTreeParams& _params=CvDTreeParams(),
821 bool _shared=false, bool _add_labels=false,
822 bool _update_data=false );
823 virtual void do_responses_copy();
825 virtual void get_vectors( const CvMat* _subsample_idx,
826 float* values, uchar* missing, float* responses, bool get_class_idx=false );
828 virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
830 virtual void write_params( CvFileStorage* fs ) const;
831 virtual void read_params( CvFileStorage* fs, CvFileNode* node );
833 // release all the data
834 virtual void clear();
836 int get_num_classes() const;
837 int get_var_type(int vi) const;
838 int get_work_var_count() const {return work_var_count;}
840 virtual void get_ord_responses( CvDTreeNode* n, float* values_buf, const float** values );
841 virtual void get_class_labels( CvDTreeNode* n, int* labels_buf, const int** labels );
842 virtual void get_cv_labels( CvDTreeNode* n, int* labels_buf, const int** labels );
843 virtual void get_sample_indices( CvDTreeNode* n, int* indices_buf, const int** labels );
844 virtual int get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf, const int** cat_values );
845 virtual int get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* indices_buf,
846 const float** ord_values, const int** indices );
847 virtual int get_child_buf_idx( CvDTreeNode* n );
849 ////////////////////////////////////
851 virtual bool set_params( const CvDTreeParams& params );
852 virtual CvDTreeNode* new_node( CvDTreeNode* parent, int count,
853 int storage_idx, int offset );
855 virtual CvDTreeSplit* new_split_ord( int vi, float cmp_val,
856 int split_point, int inversed, float quality );
857 virtual CvDTreeSplit* new_split_cat( int vi, float quality );
858 virtual void free_node_data( CvDTreeNode* node );
859 virtual void free_train_data();
860 virtual void free_node( CvDTreeNode* node );
862 // inner arrays for getting predictors and responses
863 float* get_pred_float_buf();
864 int* get_pred_int_buf();
865 float* get_resp_float_buf();
866 int* get_resp_int_buf();
867 int* get_cv_lables_buf();
868 int* get_sample_idx_buf();
870 vector<vector<float> > pred_float_buf;
871 vector<vector<int> > pred_int_buf;
872 vector<vector<float> > resp_float_buf;
873 vector<vector<int> > resp_int_buf;
874 vector<vector<int> > cv_lables_buf;
875 vector<vector<int> > sample_idx_buf;
877 int sample_count, var_all, var_count, max_c_count;
878 int ord_var_count, cat_var_count, work_var_count;
879 bool have_labels, have_priors;
883 const CvMat* train_data;
884 const CvMat* responses;
885 CvMat* responses_copy; // used in Boosting
887 int buf_count, buf_size;
901 CvMat* var_type; // i-th element =
903 // k>=0 - categorical, see k-th element of cat_* arrays
907 CvDTreeParams params;
909 CvMemStorage* tree_storage;
910 CvMemStorage* temp_storage;
912 CvDTreeNode* data_root;
923 class CV_EXPORTS CvDTree : public CvStatModel
929 virtual bool train( const CvMat* _train_data, int _tflag,
930 const CvMat* _responses, const CvMat* _var_idx=0,
931 const CvMat* _sample_idx=0, const CvMat* _var_type=0,
932 const CvMat* _missing_mask=0,
933 CvDTreeParams params=CvDTreeParams() );
935 virtual bool train( CvMLData* _data, CvDTreeParams _params=CvDTreeParams() );
937 virtual float calc_error( CvMLData* _data, int type , vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
939 virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx );
941 virtual CvDTreeNode* predict( const CvMat* _sample, const CvMat* _missing_data_mask=0,
942 bool preprocessed_input=false ) const;
945 virtual bool train( const cv::Mat& _train_data, int _tflag,
946 const cv::Mat& _responses, const cv::Mat& _var_idx=cv::Mat(),
947 const cv::Mat& _sample_idx=cv::Mat(), const cv::Mat& _var_type=cv::Mat(),
948 const cv::Mat& _missing_mask=cv::Mat(),
949 CvDTreeParams params=CvDTreeParams() );
951 virtual CvDTreeNode* predict( const cv::Mat& _sample, const cv::Mat& _missing_data_mask=cv::Mat(),
952 bool preprocessed_input=false ) const;
955 virtual const CvMat* get_var_importance();
956 virtual void clear();
958 virtual void read( CvFileStorage* fs, CvFileNode* node );
959 virtual void write( CvFileStorage* fs, const char* name ) const;
961 // special read & write methods for trees in the tree ensembles
962 virtual void read( CvFileStorage* fs, CvFileNode* node,
963 CvDTreeTrainData* data );
964 virtual void write( CvFileStorage* fs ) const;
966 const CvDTreeNode* get_root() const;
967 int get_pruned_tree_idx() const;
968 CvDTreeTrainData* get_data();
972 virtual bool do_train( const CvMat* _subsample_idx );
974 virtual void try_split_node( CvDTreeNode* n );
975 virtual void split_node_data( CvDTreeNode* n );
976 virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
977 virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
978 float init_quality = 0, CvDTreeSplit* _split = 0 );
979 virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
980 float init_quality = 0, CvDTreeSplit* _split = 0 );
981 virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
982 float init_quality = 0, CvDTreeSplit* _split = 0 );
983 virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
984 float init_quality = 0, CvDTreeSplit* _split = 0 );
985 virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi );
986 virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi );
987 virtual double calc_node_dir( CvDTreeNode* node );
988 virtual void complete_node_dir( CvDTreeNode* node );
989 virtual void cluster_categories( const int* vectors, int vector_count,
990 int var_count, int* sums, int k, int* cluster_labels );
992 virtual void calc_node_value( CvDTreeNode* node );
994 virtual void prune_cv();
995 virtual double update_tree_rnc( int T, int fold );
996 virtual int cut_tree( int T, int fold, double min_alpha );
997 virtual void free_prune_data(bool cut_tree);
998 virtual void free_tree();
1000 virtual void write_node( CvFileStorage* fs, CvDTreeNode* node ) const;
1001 virtual void write_split( CvFileStorage* fs, CvDTreeSplit* split ) const;
1002 virtual CvDTreeNode* read_node( CvFileStorage* fs, CvFileNode* node, CvDTreeNode* parent );
1003 virtual CvDTreeSplit* read_split( CvFileStorage* fs, CvFileNode* node );
1004 virtual void write_tree_nodes( CvFileStorage* fs ) const;
1005 virtual void read_tree_nodes( CvFileStorage* fs, CvFileNode* node );
1008 CvMat* var_importance;
1009 CvDTreeTrainData* data;
1012 int pruned_tree_idx;
1016 /****************************************************************************************\
1017 * Random Trees Classifier *
1018 \****************************************************************************************/
1022 class CV_EXPORTS CvForestTree: public CvDTree
1026 virtual ~CvForestTree();
1028 virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx, CvRTrees* forest );
1030 virtual int get_var_count() const {return data ? data->var_count : 0;}
1031 virtual void read( CvFileStorage* fs, CvFileNode* node, CvRTrees* forest, CvDTreeTrainData* _data );
1033 /* dummy methods to avoid warnings: BEGIN */
1034 virtual bool train( const CvMat* _train_data, int _tflag,
1035 const CvMat* _responses, const CvMat* _var_idx=0,
1036 const CvMat* _sample_idx=0, const CvMat* _var_type=0,
1037 const CvMat* _missing_mask=0,
1038 CvDTreeParams params=CvDTreeParams() );
1040 virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx );
1041 virtual void read( CvFileStorage* fs, CvFileNode* node );
1042 virtual void read( CvFileStorage* fs, CvFileNode* node,
1043 CvDTreeTrainData* data );
1044 /* dummy methods to avoid warnings: END */
1047 virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
1052 struct CV_EXPORTS CvRTParams : public CvDTreeParams
1054 //Parameters for the forest
1055 bool calc_var_importance; // true <=> RF processes variable importance
1057 CvTermCriteria term_crit;
1059 CvRTParams() : CvDTreeParams( 5, 10, 0, false, 10, 0, false, false, 0 ),
1060 calc_var_importance(false), nactive_vars(0)
1062 term_crit = cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 50, 0.1 );
1065 CvRTParams( int _max_depth, int _min_sample_count,
1066 float _regression_accuracy, bool _use_surrogates,
1067 int _max_categories, const float* _priors, bool _calc_var_importance,
1068 int _nactive_vars, int max_num_of_trees_in_the_forest,
1069 float forest_accuracy, int termcrit_type ) :
1070 CvDTreeParams( _max_depth, _min_sample_count, _regression_accuracy,
1071 _use_surrogates, _max_categories, 0,
1072 false, false, _priors ),
1073 calc_var_importance(_calc_var_importance),
1074 nactive_vars(_nactive_vars)
1076 term_crit = cvTermCriteria(termcrit_type,
1077 max_num_of_trees_in_the_forest, forest_accuracy);
1082 class CV_EXPORTS CvRTrees : public CvStatModel
1086 virtual ~CvRTrees();
1087 virtual bool train( const CvMat* _train_data, int _tflag,
1088 const CvMat* _responses, const CvMat* _var_idx=0,
1089 const CvMat* _sample_idx=0, const CvMat* _var_type=0,
1090 const CvMat* _missing_mask=0,
1091 CvRTParams params=CvRTParams() );
1093 virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
1094 virtual float predict( const CvMat* sample, const CvMat* missing = 0 ) const;
1095 virtual float predict_prob( const CvMat* sample, const CvMat* missing = 0 ) const;
1098 virtual bool train( const cv::Mat& _train_data, int _tflag,
1099 const cv::Mat& _responses, const cv::Mat& _var_idx=cv::Mat(),
1100 const cv::Mat& _sample_idx=cv::Mat(), const cv::Mat& _var_type=cv::Mat(),
1101 const cv::Mat& _missing_mask=cv::Mat(),
1102 CvRTParams params=CvRTParams() );
1103 virtual float predict( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
1104 virtual float predict_prob( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
1107 virtual void clear();
1109 virtual const CvMat* get_var_importance();
1110 virtual float get_proximity( const CvMat* sample1, const CvMat* sample2,
1111 const CvMat* missing1 = 0, const CvMat* missing2 = 0 ) const;
1113 virtual float calc_error( CvMLData* _data, int type , vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
1115 virtual float get_train_error();
1117 virtual void read( CvFileStorage* fs, CvFileNode* node );
1118 virtual void write( CvFileStorage* fs, const char* name ) const;
1120 CvMat* get_active_var_mask();
1123 int get_tree_count() const;
1124 CvForestTree* get_tree(int i) const;
1128 virtual bool grow_forest( const CvTermCriteria term_crit );
1130 // array of the trees of the forest
1131 CvForestTree** trees;
1132 CvDTreeTrainData* data;
1136 CvMat* var_importance;
1140 CvMat* active_var_mask;
1143 /****************************************************************************************\
1144 * Extremely randomized trees Classifier *
1145 \****************************************************************************************/
1146 struct CV_EXPORTS CvERTreeTrainData : public CvDTreeTrainData
1148 virtual void set_data( const CvMat* _train_data, int _tflag,
1149 const CvMat* _responses, const CvMat* _var_idx=0,
1150 const CvMat* _sample_idx=0, const CvMat* _var_type=0,
1151 const CvMat* _missing_mask=0,
1152 const CvDTreeParams& _params=CvDTreeParams(),
1153 bool _shared=false, bool _add_labels=false,
1154 bool _update_data=false );
1155 virtual int get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* missing_buf,
1156 const float** ord_values, const int** missing );
1157 virtual void get_sample_indices( CvDTreeNode* n, int* indices_buf, const int** indices );
1158 virtual void get_cv_labels( CvDTreeNode* n, int* labels_buf, const int** labels );
1159 virtual int get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf, const int** cat_values );
1160 virtual void get_vectors( const CvMat* _subsample_idx,
1161 float* values, uchar* missing, float* responses, bool get_class_idx=false );
1162 virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
1163 const CvMat* missing_mask;
1166 class CV_EXPORTS CvForestERTree : public CvForestTree
1169 virtual double calc_node_dir( CvDTreeNode* node );
1170 virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
1171 float init_quality = 0, CvDTreeSplit* _split = 0 );
1172 virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
1173 float init_quality = 0, CvDTreeSplit* _split = 0 );
1174 virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
1175 float init_quality = 0, CvDTreeSplit* _split = 0 );
1176 virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
1177 float init_quality = 0, CvDTreeSplit* _split = 0 );
1178 //virtual void complete_node_dir( CvDTreeNode* node );
1179 virtual void split_node_data( CvDTreeNode* n );
1182 class CV_EXPORTS CvERTrees : public CvRTrees
1186 virtual ~CvERTrees();
1187 virtual bool train( const CvMat* _train_data, int _tflag,
1188 const CvMat* _responses, const CvMat* _var_idx=0,
1189 const CvMat* _sample_idx=0, const CvMat* _var_type=0,
1190 const CvMat* _missing_mask=0,
1191 CvRTParams params=CvRTParams());
1193 virtual bool train( const cv::Mat& _train_data, int _tflag,
1194 const cv::Mat& _responses, const cv::Mat& _var_idx=cv::Mat(),
1195 const cv::Mat& _sample_idx=cv::Mat(), const cv::Mat& _var_type=cv::Mat(),
1196 const cv::Mat& _missing_mask=cv::Mat(),
1197 CvRTParams params=CvRTParams());
1199 virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
1201 virtual bool grow_forest( const CvTermCriteria term_crit );
1205 /****************************************************************************************\
1206 * Boosted tree classifier *
1207 \****************************************************************************************/
1209 struct CV_EXPORTS CvBoostParams : public CvDTreeParams
1214 double weight_trim_rate;
1217 CvBoostParams( int boost_type, int weak_count, double weight_trim_rate,
1218 int max_depth, bool use_surrogates, const float* priors );
1224 class CV_EXPORTS CvBoostTree: public CvDTree
1228 virtual ~CvBoostTree();
1230 virtual bool train( CvDTreeTrainData* _train_data,
1231 const CvMat* subsample_idx, CvBoost* ensemble );
1233 virtual void scale( double s );
1234 virtual void read( CvFileStorage* fs, CvFileNode* node,
1235 CvBoost* ensemble, CvDTreeTrainData* _data );
1236 virtual void clear();
1238 /* dummy methods to avoid warnings: BEGIN */
1239 virtual bool train( const CvMat* _train_data, int _tflag,
1240 const CvMat* _responses, const CvMat* _var_idx=0,
1241 const CvMat* _sample_idx=0, const CvMat* _var_type=0,
1242 const CvMat* _missing_mask=0,
1243 CvDTreeParams params=CvDTreeParams() );
1244 virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx );
1246 virtual void read( CvFileStorage* fs, CvFileNode* node );
1247 virtual void read( CvFileStorage* fs, CvFileNode* node,
1248 CvDTreeTrainData* data );
1249 /* dummy methods to avoid warnings: END */
1253 virtual void try_split_node( CvDTreeNode* n );
1254 virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi );
1255 virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi );
1256 virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
1257 float init_quality = 0, CvDTreeSplit* _split = 0 );
1258 virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
1259 float init_quality = 0, CvDTreeSplit* _split = 0 );
1260 virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
1261 float init_quality = 0, CvDTreeSplit* _split = 0 );
1262 virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
1263 float init_quality = 0, CvDTreeSplit* _split = 0 );
1264 virtual void calc_node_value( CvDTreeNode* n );
1265 virtual double calc_node_dir( CvDTreeNode* n );
1271 class CV_EXPORTS CvBoost : public CvStatModel
1275 enum { DISCRETE=0, REAL=1, LOGIT=2, GENTLE=3 };
1277 // Splitting criteria
1278 enum { DEFAULT=0, GINI=1, MISCLASS=3, SQERR=4 };
1283 CvBoost( const CvMat* _train_data, int _tflag,
1284 const CvMat* _responses, const CvMat* _var_idx=0,
1285 const CvMat* _sample_idx=0, const CvMat* _var_type=0,
1286 const CvMat* _missing_mask=0,
1287 CvBoostParams params=CvBoostParams() );
1289 virtual bool train( const CvMat* _train_data, int _tflag,
1290 const CvMat* _responses, const CvMat* _var_idx=0,
1291 const CvMat* _sample_idx=0, const CvMat* _var_type=0,
1292 const CvMat* _missing_mask=0,
1293 CvBoostParams params=CvBoostParams(),
1294 bool update=false );
1296 virtual bool train( CvMLData* data,
1297 CvBoostParams params=CvBoostParams(),
1298 bool update=false );
1300 virtual float predict( const CvMat* _sample, const CvMat* _missing=0,
1301 CvMat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ,
1302 bool raw_mode=false, bool return_sum=false ) const;
1305 CvBoost( const cv::Mat& _train_data, int _tflag,
1306 const cv::Mat& _responses, const cv::Mat& _var_idx=cv::Mat(),
1307 const cv::Mat& _sample_idx=cv::Mat(), const cv::Mat& _var_type=cv::Mat(),
1308 const cv::Mat& _missing_mask=cv::Mat(),
1309 CvBoostParams params=CvBoostParams() );
1311 virtual bool train( const cv::Mat& _train_data, int _tflag,
1312 const cv::Mat& _responses, const cv::Mat& _var_idx=cv::Mat(),
1313 const cv::Mat& _sample_idx=cv::Mat(), const cv::Mat& _var_type=cv::Mat(),
1314 const cv::Mat& _missing_mask=cv::Mat(),
1315 CvBoostParams params=CvBoostParams(),
1316 bool update=false );
1318 virtual float predict( const cv::Mat& _sample, const cv::Mat& _missing=cv::Mat(),
1319 cv::Mat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ,
1320 bool raw_mode=false, bool return_sum=false ) const;
1323 virtual float calc_error( CvMLData* _data, int type , vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
1325 virtual void prune( CvSlice slice );
1327 virtual void clear();
1329 virtual void write( CvFileStorage* storage, const char* name ) const;
1330 virtual void read( CvFileStorage* storage, CvFileNode* node );
1331 virtual const CvMat* get_active_vars(bool absolute_idx=true);
1333 CvSeq* get_weak_predictors();
1335 CvMat* get_weights();
1336 CvMat* get_subtree_weights();
1337 CvMat* get_weak_response();
1338 const CvBoostParams& get_params() const;
1339 const CvDTreeTrainData* get_data() const;
1343 virtual bool set_params( const CvBoostParams& _params );
1344 virtual void update_weights( CvBoostTree* tree );
1345 virtual void trim_weights();
1346 virtual void write_params( CvFileStorage* fs ) const;
1347 virtual void read_params( CvFileStorage* fs, CvFileNode* node );
1349 CvDTreeTrainData* data;
1350 CvBoostParams params;
1354 CvMat* active_vars_abs;
1355 bool have_active_cat_vars;
1357 CvMat* orig_response;
1358 CvMat* sum_response;
1360 CvMat* subsample_mask;
1362 CvMat* subtree_weights;
1363 bool have_subsample;
1367 /****************************************************************************************\
1368 * Artificial Neural Networks (ANN) *
1369 \****************************************************************************************/
1371 /////////////////////////////////// Multi-Layer Perceptrons //////////////////////////////
1373 struct CV_EXPORTS CvANN_MLP_TrainParams
1375 CvANN_MLP_TrainParams();
1376 CvANN_MLP_TrainParams( CvTermCriteria term_crit, int train_method,
1377 double param1, double param2=0 );
1378 ~CvANN_MLP_TrainParams();
1380 enum { BACKPROP=0, RPROP=1 };
1382 CvTermCriteria term_crit;
1385 // backpropagation parameters
1386 double bp_dw_scale, bp_moment_scale;
1389 double rp_dw0, rp_dw_plus, rp_dw_minus, rp_dw_min, rp_dw_max;
1393 class CV_EXPORTS CvANN_MLP : public CvStatModel
1397 CvANN_MLP( const CvMat* _layer_sizes,
1398 int _activ_func=SIGMOID_SYM,
1399 double _f_param1=0, double _f_param2=0 );
1401 virtual ~CvANN_MLP();
1403 virtual void create( const CvMat* _layer_sizes,
1404 int _activ_func=SIGMOID_SYM,
1405 double _f_param1=0, double _f_param2=0 );
1407 virtual int train( const CvMat* _inputs, const CvMat* _outputs,
1408 const CvMat* _sample_weights, const CvMat* _sample_idx=0,
1409 CvANN_MLP_TrainParams _params = CvANN_MLP_TrainParams(),
1411 virtual float predict( const CvMat* _inputs, CvMat* _outputs ) const;
1414 CvANN_MLP( const cv::Mat& _layer_sizes,
1415 int _activ_func=SIGMOID_SYM,
1416 double _f_param1=0, double _f_param2=0 );
1418 virtual void create( const cv::Mat& _layer_sizes,
1419 int _activ_func=SIGMOID_SYM,
1420 double _f_param1=0, double _f_param2=0 );
1422 virtual int train( const cv::Mat& _inputs, const cv::Mat& _outputs,
1423 const cv::Mat& _sample_weights, const cv::Mat& _sample_idx=cv::Mat(),
1424 CvANN_MLP_TrainParams _params = CvANN_MLP_TrainParams(),
1427 virtual float predict( const cv::Mat& _inputs, cv::Mat& _outputs ) const;
1430 virtual void clear();
1432 // possible activation functions
1433 enum { IDENTITY = 0, SIGMOID_SYM = 1, GAUSSIAN = 2 };
1435 // available training flags
1436 enum { UPDATE_WEIGHTS = 1, NO_INPUT_SCALE = 2, NO_OUTPUT_SCALE = 4 };
1438 virtual void read( CvFileStorage* fs, CvFileNode* node );
1439 virtual void write( CvFileStorage* storage, const char* name ) const;
1441 int get_layer_count() { return layer_sizes ? layer_sizes->cols : 0; }
1442 const CvMat* get_layer_sizes() { return layer_sizes; }
1443 double* get_weights(int layer)
1445 return layer_sizes && weights &&
1446 (unsigned)layer <= (unsigned)layer_sizes->cols ? weights[layer] : 0;
1451 virtual bool prepare_to_train( const CvMat* _inputs, const CvMat* _outputs,
1452 const CvMat* _sample_weights, const CvMat* _sample_idx,
1453 CvVectors* _ivecs, CvVectors* _ovecs, double** _sw, int _flags );
1455 // sequential random backpropagation
1456 virtual int train_backprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
1459 virtual int train_rprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
1461 virtual void calc_activ_func( CvMat* xf, const double* bias ) const;
1462 virtual void calc_activ_func_deriv( CvMat* xf, CvMat* deriv, const double* bias ) const;
1463 virtual void set_activ_func( int _activ_func=SIGMOID_SYM,
1464 double _f_param1=0, double _f_param2=0 );
1465 virtual void init_weights();
1466 virtual void scale_input( const CvMat* _src, CvMat* _dst ) const;
1467 virtual void scale_output( const CvMat* _src, CvMat* _dst ) const;
1468 virtual void calc_input_scale( const CvVectors* vecs, int flags );
1469 virtual void calc_output_scale( const CvVectors* vecs, int flags );
1471 virtual void write_params( CvFileStorage* fs ) const;
1472 virtual void read_params( CvFileStorage* fs, CvFileNode* node );
1476 CvMat* sample_weights;
1478 double f_param1, f_param2;
1479 double min_val, max_val, min_val1, max_val1;
1481 int max_count, max_buf_sz;
1482 CvANN_MLP_TrainParams params;
1487 /****************************************************************************************\
1488 * Convolutional Neural Network *
1489 \****************************************************************************************/
1490 typedef struct CvCNNLayer CvCNNLayer;
1491 typedef struct CvCNNetwork CvCNNetwork;
1493 #define CV_CNN_LEARN_RATE_DECREASE_HYPERBOLICALLY 1
1494 #define CV_CNN_LEARN_RATE_DECREASE_SQRT_INV 2
1495 #define CV_CNN_LEARN_RATE_DECREASE_LOG_INV 3
1497 #define CV_CNN_GRAD_ESTIM_RANDOM 0
1498 #define CV_CNN_GRAD_ESTIM_BY_WORST_IMG 1
1500 #define ICV_CNN_LAYER 0x55550000
1501 #define ICV_CNN_CONVOLUTION_LAYER 0x00001111
1502 #define ICV_CNN_SUBSAMPLING_LAYER 0x00002222
1503 #define ICV_CNN_FULLCONNECT_LAYER 0x00003333
1505 #define ICV_IS_CNN_LAYER( layer ) \
1506 ( ((layer) != NULL) && ((((CvCNNLayer*)(layer))->flags & CV_MAGIC_MASK)\
1509 #define ICV_IS_CNN_CONVOLUTION_LAYER( layer ) \
1510 ( (ICV_IS_CNN_LAYER( layer )) && (((CvCNNLayer*) (layer))->flags \
1511 & ~CV_MAGIC_MASK) == ICV_CNN_CONVOLUTION_LAYER )
1513 #define ICV_IS_CNN_SUBSAMPLING_LAYER( layer ) \
1514 ( (ICV_IS_CNN_LAYER( layer )) && (((CvCNNLayer*) (layer))->flags \
1515 & ~CV_MAGIC_MASK) == ICV_CNN_SUBSAMPLING_LAYER )
1517 #define ICV_IS_CNN_FULLCONNECT_LAYER( layer ) \
1518 ( (ICV_IS_CNN_LAYER( layer )) && (((CvCNNLayer*) (layer))->flags \
1519 & ~CV_MAGIC_MASK) == ICV_CNN_FULLCONNECT_LAYER )
1521 typedef void (CV_CDECL *CvCNNLayerForward)
1522 ( CvCNNLayer* layer, const CvMat* input, CvMat* output );
1524 typedef void (CV_CDECL *CvCNNLayerBackward)
1525 ( CvCNNLayer* layer, int t, const CvMat* X, const CvMat* dE_dY, CvMat* dE_dX );
1527 typedef void (CV_CDECL *CvCNNLayerRelease)
1528 (CvCNNLayer** layer);
1530 typedef void (CV_CDECL *CvCNNetworkAddLayer)
1531 (CvCNNetwork* network, CvCNNLayer* layer);
1533 typedef void (CV_CDECL *CvCNNetworkRelease)
1534 (CvCNNetwork** network);
1536 #define CV_CNN_LAYER_FIELDS() \
1537 /* Indicator of the layer's type */ \
1540 /* Number of input images */ \
1541 int n_input_planes; \
1542 /* Height of each input image */ \
1544 /* Width of each input image */ \
1547 /* Number of output images */ \
1548 int n_output_planes; \
1549 /* Height of each output image */ \
1550 int output_height; \
1551 /* Width of each output image */ \
1554 /* Learning rate at the first iteration */ \
1555 float init_learn_rate; \
1556 /* Dynamics of learning rate decreasing */ \
1557 int learn_rate_decrease_type; \
1558 /* Trainable weights of the layer (including bias) */ \
1559 /* i-th row is a set of weights of the i-th output plane */ \
1562 CvCNNLayerForward forward; \
1563 CvCNNLayerBackward backward; \
1564 CvCNNLayerRelease release; \
1565 /* Pointers to the previous and next layers in the network */ \
1566 CvCNNLayer* prev_layer; \
1567 CvCNNLayer* next_layer
1569 typedef struct CvCNNLayer
1571 CV_CNN_LAYER_FIELDS();
1574 typedef struct CvCNNConvolutionLayer
1576 CV_CNN_LAYER_FIELDS();
1577 // Kernel size (height and width) for convolution.
1579 // connections matrix, (i,j)-th element is 1 iff there is a connection between
1580 // i-th plane of the current layer and j-th plane of the previous layer;
1581 // (i,j)-th element is equal to 0 otherwise
1582 CvMat *connect_mask;
1583 // value of the learning rate for updating weights at the first iteration
1584 }CvCNNConvolutionLayer;
1586 typedef struct CvCNNSubSamplingLayer
1588 CV_CNN_LAYER_FIELDS();
1589 // ratio between the heights (or widths - ratios are supposed to be equal)
1590 // of the input and output planes
1592 // amplitude of sigmoid activation function
1594 // scale parameter of sigmoid activation function
1596 // exp2ssumWX = exp(2<s>*(bias+w*(x1+...+x4))), where x1,...x4 are some elements of X
1597 // - is the vector used in computing of the activation function in backward
1599 // (x1+x2+x3+x4), where x1,...x4 are some elements of X
1600 // - is the vector used in computing of the activation function in backward
1602 }CvCNNSubSamplingLayer;
1604 // Structure of the last layer.
1605 typedef struct CvCNNFullConnectLayer
1607 CV_CNN_LAYER_FIELDS();
1608 // amplitude of sigmoid activation function
1610 // scale parameter of sigmoid activation function
1612 // exp2ssumWX = exp(2*<s>*(W*X)) - is the vector used in computing of the
1613 // activation function and it's derivative by the formulae
1614 // activ.func. = <a>(exp(2<s>WX)-1)/(exp(2<s>WX)+1) == <a> - 2<a>/(<exp2ssumWX> + 1)
1615 // (activ.func.)' = 4<a><s>exp(2<s>WX)/(exp(2<s>WX)+1)^2
1617 }CvCNNFullConnectLayer;
1619 typedef struct CvCNNetwork
1623 CvCNNetworkAddLayer add_layer;
1624 CvCNNetworkRelease release;
1627 typedef struct CvCNNStatModel
1629 CV_STAT_MODEL_FIELDS();
1630 CvCNNetwork* network;
1631 // etalons are allocated as rows, the i-th etalon has label cls_labeles[i]
1637 typedef struct CvCNNStatModelParams
1639 CV_STAT_MODEL_PARAM_FIELDS();
1640 // network must be created by the functions cvCreateCNNetwork and <add_layer>
1641 CvCNNetwork* network;
1643 // termination criteria
1646 int grad_estim_type;
1647 }CvCNNStatModelParams;
1649 CVAPI(CvCNNLayer*) cvCreateCNNConvolutionLayer(
1650 int n_input_planes, int input_height, int input_width,
1651 int n_output_planes, int K,
1652 float init_learn_rate, int learn_rate_decrease_type,
1653 CvMat* connect_mask CV_DEFAULT(0), CvMat* weights CV_DEFAULT(0) );
1655 CVAPI(CvCNNLayer*) cvCreateCNNSubSamplingLayer(
1656 int n_input_planes, int input_height, int input_width,
1657 int sub_samp_scale, float a, float s,
1658 float init_learn_rate, int learn_rate_decrease_type, CvMat* weights CV_DEFAULT(0) );
1660 CVAPI(CvCNNLayer*) cvCreateCNNFullConnectLayer(
1661 int n_inputs, int n_outputs, float a, float s,
1662 float init_learn_rate, int learning_type, CvMat* weights CV_DEFAULT(0) );
1664 CVAPI(CvCNNetwork*) cvCreateCNNetwork( CvCNNLayer* first_layer );
1666 CVAPI(CvStatModel*) cvTrainCNNClassifier(
1667 const CvMat* train_data, int tflag,
1668 const CvMat* responses,
1669 const CvStatModelParams* params,
1670 const CvMat* CV_DEFAULT(0),
1671 const CvMat* sample_idx CV_DEFAULT(0),
1672 const CvMat* CV_DEFAULT(0), const CvMat* CV_DEFAULT(0) );
1674 /****************************************************************************************\
1675 * Estimate classifiers algorithms *
1676 \****************************************************************************************/
1677 typedef const CvMat* (CV_CDECL *CvStatModelEstimateGetMat)
1678 ( const CvStatModel* estimateModel );
1680 typedef int (CV_CDECL *CvStatModelEstimateNextStep)
1681 ( CvStatModel* estimateModel );
1683 typedef void (CV_CDECL *CvStatModelEstimateCheckClassifier)
1684 ( CvStatModel* estimateModel,
1685 const CvStatModel* model,
1686 const CvMat* features,
1688 const CvMat* responses );
1690 typedef void (CV_CDECL *CvStatModelEstimateCheckClassifierEasy)
1691 ( CvStatModel* estimateModel,
1692 const CvStatModel* model );
1694 typedef float (CV_CDECL *CvStatModelEstimateGetCurrentResult)
1695 ( const CvStatModel* estimateModel,
1696 float* correlation );
1698 typedef void (CV_CDECL *CvStatModelEstimateReset)
1699 ( CvStatModel* estimateModel );
1701 //-------------------------------- Cross-validation --------------------------------------
1702 #define CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_PARAM_FIELDS() \
1703 CV_STAT_MODEL_PARAM_FIELDS(); \
1705 int is_regression; \
1708 typedef struct CvCrossValidationParams
1710 CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_PARAM_FIELDS();
1711 } CvCrossValidationParams;
1713 #define CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_FIELDS() \
1714 CvStatModelEstimateGetMat getTrainIdxMat; \
1715 CvStatModelEstimateGetMat getCheckIdxMat; \
1716 CvStatModelEstimateNextStep nextStep; \
1717 CvStatModelEstimateCheckClassifier check; \
1718 CvStatModelEstimateGetCurrentResult getResult; \
1719 CvStatModelEstimateReset reset; \
1720 int is_regression; \
1723 int* sampleIdxAll; \
1725 int max_fold_size; \
1728 CvMat* sampleIdxTrain; \
1729 CvMat* sampleIdxEval; \
1730 CvMat* predict_results; \
1731 int correct_results; \
1734 double sum_correct; \
1735 double sum_predict; \
1740 typedef struct CvCrossValidationModel
1742 CV_STAT_MODEL_FIELDS();
1743 CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_FIELDS();
1744 } CvCrossValidationModel;
1747 cvCreateCrossValidationEstimateModel
1749 const CvStatModelParams* estimateParams CV_DEFAULT(0),
1750 const CvMat* sampleIdx CV_DEFAULT(0) );
1753 cvCrossValidation( const CvMat* trueData,
1755 const CvMat* trueClasses,
1756 CvStatModel* (*createClassifier)( const CvMat*,
1759 const CvStatModelParams*,
1764 const CvStatModelParams* estimateParams CV_DEFAULT(0),
1765 const CvStatModelParams* trainParams CV_DEFAULT(0),
1766 const CvMat* compIdx CV_DEFAULT(0),
1767 const CvMat* sampleIdx CV_DEFAULT(0),
1768 CvStatModel** pCrValModel CV_DEFAULT(0),
1769 const CvMat* typeMask CV_DEFAULT(0),
1770 const CvMat* missedMeasurementMask CV_DEFAULT(0) );
1773 /****************************************************************************************\
1774 * Auxilary functions declarations *
1775 \****************************************************************************************/
1777 /* Generates <sample> from multivariate normal distribution, where <mean> - is an
1778 average row vector, <cov> - symmetric covariation matrix */
1779 CVAPI(void) cvRandMVNormal( CvMat* mean, CvMat* cov, CvMat* sample,
1780 CvRNG* rng CV_DEFAULT(0) );
1782 /* Generates sample from gaussian mixture distribution */
1783 CVAPI(void) cvRandGaussMixture( CvMat* means[],
1788 CvMat* sampClasses CV_DEFAULT(0) );
1790 #define CV_TS_CONCENTRIC_SPHERES 0
1792 /* creates test set */
1793 CVAPI(void) cvCreateTestSet( int type, CvMat** samples,
1797 int num_classes, ... );
1802 /****************************************************************************************\
1804 \****************************************************************************************/
1809 using namespace std;
1812 #define CV_PORTION 1
1814 struct CV_EXPORTS CvTrainTestSplit
1818 CvTrainTestSplit( int _train_sample_count, bool _mix = true);
1819 CvTrainTestSplit( float _train_sample_portion, bool _mix = true);
1825 } train_sample_part;
1826 int train_sample_part_mode;
1833 int class_part_mode;
1838 class CV_EXPORTS CvMLData
1842 virtual ~CvMLData();
1846 // 1 - file can not be opened or is not correct
1847 int read_csv(const char* filename);
1849 const CvMat* get_values(){ return values; };
1851 const CvMat* get_responses();
1853 const CvMat* get_missing(){ return missing; };
1855 void set_response_idx( int idx ); // idx < 0 to set all vars as predictors
1856 int get_response_idx() { return response_idx; }
1858 const CvMat* get_train_sample_idx() { return train_sample_idx; };
1859 const CvMat* get_test_sample_idx() { return test_sample_idx; };
1860 void mix_train_and_test_idx();
1861 void set_train_test_split( const CvTrainTestSplit * spl);
1863 const CvMat* get_var_idx();
1864 void chahge_var_idx( int vi, bool state );
1866 const CvMat* get_var_types();
1867 int get_var_type( int var_idx ) { return var_types->data.ptr[var_idx]; };
1868 // following 2 methods enable to change vars type
1869 // use these methods to assign CV_VAR_CATEGORICAL type for categorical variable
1870 // with numerical labels; in the other cases var types are correctly determined automatically
1871 void set_var_types( const char* str ); // str examples:
1872 // "ord[0-17],cat[18]", "ord[0,2,4,10-12], cat[1,3,5-9,13,14]",
1873 // "cat", "ord" (all vars are categorical/ordered)
1874 void change_var_type( int var_idx, int type); // type in { CV_VAR_ORDERED, CV_VAR_CATEGORICAL }
1876 void set_delimiter( char ch );
1877 char get_delimiter() { return delimiter; };
1879 void set_miss_ch( char ch );
1880 char get_miss_ch() { return miss_ch; };
1883 virtual void clear();
1885 void str_to_flt_elem( const char* token, float& flt_elem, int& type);
1886 void free_train_test_idx();
1890 //char flt_separator;
1895 CvMat* var_idx_mask;
1897 CvMat* response_out; // header
1898 CvMat* var_idx_out; // mat
1899 CvMat* var_types_out; // mat
1903 int train_sample_count;
1906 int total_class_count;
1907 map<string, int> *class_map;
1909 CvMat* train_sample_idx;
1910 CvMat* test_sample_idx;
1911 int* sample_idx; // data of train_sample_idx and test_sample_idx
1920 typedef CvStatModel StatModel;
1921 typedef CvParamGrid ParamGrid;
1922 typedef CvNormalBayesClassifier NormalBayesClassifier;
1923 typedef CvKNearest KNearest;
1924 typedef CvSVMParams SVMParams;
1925 typedef CvSVMKernel SVMKernel;
1926 typedef CvSVMSolver SVMSolver;
1928 typedef CvEMParams EMParams;
1929 typedef CvEM ExpectationMaximization;
1930 typedef CvDTreeParams DTreeParams;
1931 typedef CvMLData TrainData;
1932 typedef CvDTree DecisionTree;
1933 typedef CvForestTree ForestTree;
1934 typedef CvRTParams RandomTreeParams;
1935 typedef CvRTrees RandomTrees;
1936 typedef CvERTreeTrainData ERTreeTRainData;
1937 typedef CvForestERTree ERTree;
1938 typedef CvERTrees ERTrees;
1939 typedef CvBoostParams BoostParams;
1940 typedef CvBoostTree BoostTree;
1941 typedef CvBoost Boost;
1942 typedef CvANN_MLP_TrainParams ANN_MLP_TrainParams;
1943 typedef CvANN_MLP NeuralNet_MLP;