1 /*M///////////////////////////////////////////////////////////////////////////////////////
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
10 // Intel License Agreement
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
18 // * Redistribution's of source code must retain the above copyright notice,
19 // this list of conditions and the following disclaimer.
21 // * Redistribution's in binary form must reproduce the above copyright notice,
22 // this list of conditions and the following disclaimer in the documentation
23 // and/or other materials provided with the distribution.
25 // * The name of Intel Corporation may not be used to endorse or promote products
26 // derived from this software without specific prior written permission.
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
43 /****************************************************************************************\
47 The code has been derived from libsvm library (version 2.6)
48 (http://www.csie.ntu.edu.tw/~cjlin/libsvm).
50 Here is the orignal copyright:
51 ------------------------------------------------------------------------------------------
52 Copyright (c) 2000-2003 Chih-Chung Chang and Chih-Jen Lin
55 Redistribution and use in source and binary forms, with or without
56 modification, are permitted provided that the following conditions
59 1. Redistributions of source code must retain the above copyright
60 notice, this list of conditions and the following disclaimer.
62 2. Redistributions in binary form must reproduce the above copyright
63 notice, this list of conditions and the following disclaimer in the
64 documentation and/or other materials provided with the distribution.
66 3. Neither name of copyright holders nor the names of its contributors
67 may be used to endorse or promote products derived from this software
68 without specific prior written permission.
71 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
72 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
73 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
74 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR
75 CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
76 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
77 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
78 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
79 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
80 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
81 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
82 \****************************************************************************************/
86 #define CV_SVM_MIN_CACHE_SIZE (40 << 20) /* 40Mb */
92 #pragma warning( disable: 4514 ) /* unreferenced inline functions */
97 #define QFLOAT_TYPE CV_32F
99 typedef double Qfloat;
100 #define QFLOAT_TYPE CV_64F
104 bool CvParamGrid::check() const
108 CV_FUNCNAME( "CvParamGrid::check" );
111 if( min_val > max_val )
112 CV_ERROR( CV_StsBadArg, "Lower bound of the grid must be less then the upper one" );
113 if( min_val < DBL_EPSILON )
114 CV_ERROR( CV_StsBadArg, "Lower bound of the grid must be positive" );
115 if( step < 1. + FLT_EPSILON )
116 CV_ERROR( CV_StsBadArg, "Grid step must greater then 1" );
125 CvParamGrid CvSVM::get_default_grid( int param_id )
128 if( param_id == CvSVM::C )
132 grid.step = 5; // total iterations = 5
134 else if( param_id == CvSVM::GAMMA )
138 grid.step = 15; // total iterations = 4
140 else if( param_id == CvSVM::P )
144 grid.step = 7; // total iterations = 4
146 else if( param_id == CvSVM::NU )
150 grid.step = 3; // total iterations = 3
152 else if( param_id == CvSVM::COEF )
156 grid.step = 14; // total iterations = 3
158 else if( param_id == CvSVM::DEGREE )
162 grid.step = 7; // total iterations = 3
165 cvError( CV_StsBadArg, "CvSVM::get_default_grid", "Invalid type of parameter "
166 "(use one of CvSVM::C, CvSVM::GAMMA et al.)", __FILE__, __LINE__ );
170 // SVM training parameters
171 CvSVMParams::CvSVMParams() :
172 svm_type(CvSVM::C_SVC), kernel_type(CvSVM::RBF), degree(0),
173 gamma(1), coef0(0), C(1), nu(0), p(0), class_weights(0)
175 term_crit = cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 1000, FLT_EPSILON );
179 CvSVMParams::CvSVMParams( int _svm_type, int _kernel_type,
180 double _degree, double _gamma, double _coef0,
181 double _Con, double _nu, double _p,
182 CvMat* _class_weights, CvTermCriteria _term_crit ) :
183 svm_type(_svm_type), kernel_type(_kernel_type),
184 degree(_degree), gamma(_gamma), coef0(_coef0),
185 C(_Con), nu(_nu), p(_p), class_weights(_class_weights), term_crit(_term_crit)
190 /////////////////////////////////////// SVM kernel ///////////////////////////////////////
192 CvSVMKernel::CvSVMKernel()
198 void CvSVMKernel::clear()
205 CvSVMKernel::~CvSVMKernel()
210 CvSVMKernel::CvSVMKernel( const CvSVMParams* _params, Calc _calc_func )
213 create( _params, _calc_func );
217 bool CvSVMKernel::create( const CvSVMParams* _params, Calc _calc_func )
221 calc_func = _calc_func;
224 calc_func = params->kernel_type == CvSVM::RBF ? &CvSVMKernel::calc_rbf :
225 params->kernel_type == CvSVM::POLY ? &CvSVMKernel::calc_poly :
226 params->kernel_type == CvSVM::SIGMOID ? &CvSVMKernel::calc_sigmoid :
227 &CvSVMKernel::calc_linear;
233 void CvSVMKernel::calc_non_rbf_base( int vcount, int var_count, const float** vecs,
234 const float* another, Qfloat* results,
235 double alpha, double beta )
238 for( j = 0; j < vcount; j++ )
240 const float* sample = vecs[j];
242 for( k = 0; k <= var_count - 4; k += 4 )
243 s += sample[k]*another[k] + sample[k+1]*another[k+1] +
244 sample[k+2]*another[k+2] + sample[k+3]*another[k+3];
245 for( ; k < var_count; k++ )
246 s += sample[k]*another[k];
247 results[j] = (Qfloat)(s*alpha + beta);
252 void CvSVMKernel::calc_linear( int vcount, int var_count, const float** vecs,
253 const float* another, Qfloat* results )
255 calc_non_rbf_base( vcount, var_count, vecs, another, results, 1, 0 );
259 void CvSVMKernel::calc_poly( int vcount, int var_count, const float** vecs,
260 const float* another, Qfloat* results )
262 CvMat R = cvMat( 1, vcount, QFLOAT_TYPE, results );
263 calc_non_rbf_base( vcount, var_count, vecs, another, results, params->gamma, params->coef0 );
264 cvPow( &R, &R, params->degree );
268 void CvSVMKernel::calc_sigmoid( int vcount, int var_count, const float** vecs,
269 const float* another, Qfloat* results )
272 calc_non_rbf_base( vcount, var_count, vecs, another, results,
273 -2*params->gamma, -2*params->coef0 );
274 // TODO: speedup this
275 for( j = 0; j < vcount; j++ )
277 Qfloat t = results[j];
278 double e = exp(-fabs(t));
280 results[j] = (Qfloat)((1. - e)/(1. + e));
282 results[j] = (Qfloat)((e - 1.)/(e + 1.));
287 void CvSVMKernel::calc_rbf( int vcount, int var_count, const float** vecs,
288 const float* another, Qfloat* results )
290 CvMat R = cvMat( 1, vcount, QFLOAT_TYPE, results );
291 double gamma = -params->gamma;
294 for( j = 0; j < vcount; j++ )
296 const float* sample = vecs[j];
299 for( k = 0; k <= var_count - 4; k += 4 )
301 double t0 = sample[k] - another[k];
302 double t1 = sample[k+1] - another[k+1];
306 t0 = sample[k+2] - another[k+2];
307 t1 = sample[k+3] - another[k+3];
312 for( ; k < var_count; k++ )
314 double t0 = sample[k] - another[k];
317 results[j] = (Qfloat)(s*gamma);
324 void CvSVMKernel::calc( int vcount, int var_count, const float** vecs,
325 const float* another, Qfloat* results )
327 const Qfloat max_val = (Qfloat)(FLT_MAX*1e-3);
329 (this->*calc_func)( vcount, var_count, vecs, another, results );
330 for( j = 0; j < vcount; j++ )
332 if( results[j] > max_val )
333 results[j] = max_val;
338 // Generalized SMO+SVMlight algorithm
341 // min [0.5(\alpha^T Q \alpha) + b^T \alpha]
343 // y^T \alpha = \delta
345 // 0 <= alpha_i <= Cp for y_i = 1
346 // 0 <= alpha_i <= Cn for y_i = -1
350 // Q, b, y, Cp, Cn, and an initial feasible point \alpha
351 // l is the size of vectors and matrices
352 // eps is the stopping criterion
354 // solution will be put in \alpha, objective value will be put in obj
357 void CvSVMSolver::clear()
364 cvReleaseMemStorage( &storage );
366 select_working_set_func = 0;
375 CvSVMSolver::CvSVMSolver()
382 CvSVMSolver::~CvSVMSolver()
388 CvSVMSolver::CvSVMSolver( int _sample_count, int _var_count, const float** _samples, schar* _y,
389 int _alpha_count, double* _alpha, double _Cp, double _Cn,
390 CvMemStorage* _storage, CvSVMKernel* _kernel, GetRow _get_row,
391 SelectWorkingSet _select_working_set, CalcRho _calc_rho )
394 create( _sample_count, _var_count, _samples, _y, _alpha_count, _alpha, _Cp, _Cn,
395 _storage, _kernel, _get_row, _select_working_set, _calc_rho );
399 bool CvSVMSolver::create( int _sample_count, int _var_count, const float** _samples, schar* _y,
400 int _alpha_count, double* _alpha, double _Cp, double _Cn,
401 CvMemStorage* _storage, CvSVMKernel* _kernel, GetRow _get_row,
402 SelectWorkingSet _select_working_set, CalcRho _calc_rho )
407 CV_FUNCNAME( "CvSVMSolver::create" );
415 sample_count = _sample_count;
416 var_count = _var_count;
419 alpha_count = _alpha_count;
425 eps = kernel->params->term_crit.epsilon;
426 max_iter = kernel->params->term_crit.max_iter;
427 storage = cvCreateChildMemStorage( _storage );
429 b = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(b[0]));
430 alpha_status = (schar*)cvMemStorageAlloc( storage, alpha_count*sizeof(alpha_status[0]));
431 G = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(G[0]));
432 for( i = 0; i < 2; i++ )
433 buf[i] = (Qfloat*)cvMemStorageAlloc( storage, sample_count*2*sizeof(buf[i][0]) );
434 svm_type = kernel->params->svm_type;
436 select_working_set_func = _select_working_set;
437 if( !select_working_set_func )
438 select_working_set_func = svm_type == CvSVM::NU_SVC || svm_type == CvSVM::NU_SVR ?
439 &CvSVMSolver::select_working_set_nu_svm : &CvSVMSolver::select_working_set;
441 calc_rho_func = _calc_rho;
443 calc_rho_func = svm_type == CvSVM::NU_SVC || svm_type == CvSVM::NU_SVR ?
444 &CvSVMSolver::calc_rho_nu_svm : &CvSVMSolver::calc_rho;
446 get_row_func = _get_row;
448 get_row_func = params->svm_type == CvSVM::EPS_SVR ||
449 params->svm_type == CvSVM::NU_SVR ? &CvSVMSolver::get_row_svr :
450 params->svm_type == CvSVM::C_SVC ||
451 params->svm_type == CvSVM::NU_SVC ? &CvSVMSolver::get_row_svc :
452 &CvSVMSolver::get_row_one_class;
454 cache_line_size = sample_count*sizeof(Qfloat);
455 // cache size = max(num_of_samples^2*sizeof(Qfloat)*0.25, 64Kb)
456 // (assuming that for large training sets ~25% of Q matrix is used)
457 cache_size = MAX( cache_line_size*sample_count/4, CV_SVM_MIN_CACHE_SIZE );
459 // the size of Q matrix row headers
460 rows_hdr_size = sample_count*sizeof(rows[0]);
461 if( rows_hdr_size > storage->block_size )
462 CV_ERROR( CV_StsOutOfRange, "Too small storage block size" );
464 lru_list.prev = lru_list.next = &lru_list;
465 rows = (CvSVMKernelRow*)cvMemStorageAlloc( storage, rows_hdr_size );
466 memset( rows, 0, rows_hdr_size );
476 float* CvSVMSolver::get_row_base( int i, bool* _existed )
478 int i1 = i < sample_count ? i : i - sample_count;
479 CvSVMKernelRow* row = rows + i1;
480 bool existed = row->data != 0;
483 if( existed || cache_size <= 0 )
485 CvSVMKernelRow* del_row = existed ? row : lru_list.prev;
486 data = del_row->data;
489 // delete row from the LRU list
491 del_row->prev->next = del_row->next;
492 del_row->next->prev = del_row->prev;
496 data = (Qfloat*)cvMemStorageAlloc( storage, cache_line_size );
497 cache_size -= cache_line_size;
500 // insert row into the LRU list
502 row->prev = &lru_list;
503 row->next = lru_list.next;
504 row->prev->next = row->next->prev = row;
508 kernel->calc( sample_count, var_count, samples, samples[i1], row->data );
518 float* CvSVMSolver::get_row_svc( int i, float* row, float*, bool existed )
523 int j, len = sample_count;
524 assert( _y && i < sample_count );
528 for( j = 0; j < len; j++ )
529 row[j] = _y[j]*row[j];
533 for( j = 0; j < len; j++ )
534 row[j] = -_y[j]*row[j];
541 float* CvSVMSolver::get_row_one_class( int, float* row, float*, bool )
547 float* CvSVMSolver::get_row_svr( int i, float* row, float* dst, bool )
549 int j, len = sample_count;
550 Qfloat* dst_pos = dst;
551 Qfloat* dst_neg = dst + len;
555 CV_SWAP( dst_pos, dst_neg, temp );
558 for( j = 0; j < len; j++ )
569 float* CvSVMSolver::get_row( int i, float* dst )
571 bool existed = false;
572 float* row = get_row_base( i, &existed );
573 return (this->*get_row_func)( i, row, dst, existed );
577 #undef is_upper_bound
578 #define is_upper_bound(i) (alpha_status[i] > 0)
580 #undef is_lower_bound
581 #define is_lower_bound(i) (alpha_status[i] < 0)
584 #define is_free(i) (alpha_status[i] == 0)
587 #define get_C(i) (C[y[i]>0])
589 #undef update_alpha_status
590 #define update_alpha_status(i) \
591 alpha_status[i] = (schar)(alpha[i] >= get_C(i) ? 1 : alpha[i] <= 0 ? -1 : 0)
593 #undef reconstruct_gradient
594 #define reconstruct_gradient() /* empty for now */
597 bool CvSVMSolver::solve_generic( CvSVMSolutionInfo& si )
602 // 1. initialize gradient and alpha status
603 for( i = 0; i < alpha_count; i++ )
605 update_alpha_status(i);
607 if( fabs(G[i]) > 1e200 )
611 for( i = 0; i < alpha_count; i++ )
613 if( !is_lower_bound(i) )
615 const Qfloat *Q_i = get_row( i, buf[0] );
616 double alpha_i = alpha[i];
618 for( j = 0; j < alpha_count; j++ )
619 G[j] += alpha_i*Q_i[j];
623 // 2. optimization loop
626 const Qfloat *Q_i, *Q_j;
628 double old_alpha_i, old_alpha_j, alpha_i, alpha_j;
629 double delta_alpha_i, delta_alpha_j;
632 for( i = 0; i < alpha_count; i++ )
634 if( fabs(G[i]) > 1e+300 )
637 if( fabs(alpha[i]) > 1e16 )
642 if( (this->*select_working_set_func)( i, j ) != 0 || iter++ >= max_iter )
645 Q_i = get_row( i, buf[0] );
646 Q_j = get_row( j, buf[1] );
651 alpha_i = old_alpha_i = alpha[i];
652 alpha_j = old_alpha_j = alpha[j];
656 double denom = Q_i[i]+Q_j[j]+2*Q_i[j];
657 double delta = (-G[i]-G[j])/MAX(fabs(denom),FLT_EPSILON);
658 double diff = alpha_i - alpha_j;
662 if( diff > 0 && alpha_j < 0 )
667 else if( diff <= 0 && alpha_i < 0 )
673 if( diff > C_i - C_j && alpha_i > C_i )
676 alpha_j = C_i - diff;
678 else if( diff <= C_i - C_j && alpha_j > C_j )
681 alpha_i = C_j + diff;
686 double denom = Q_i[i]+Q_j[j]-2*Q_i[j];
687 double delta = (G[i]-G[j])/MAX(fabs(denom),FLT_EPSILON);
688 double sum = alpha_i + alpha_j;
692 if( sum > C_i && alpha_i > C_i )
697 else if( sum <= C_i && alpha_j < 0)
703 if( sum > C_j && alpha_j > C_j )
708 else if( sum <= C_j && alpha_i < 0 )
718 update_alpha_status(i);
719 update_alpha_status(j);
722 delta_alpha_i = alpha_i - old_alpha_i;
723 delta_alpha_j = alpha_j - old_alpha_j;
725 for( k = 0; k < alpha_count; k++ )
726 G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
730 (this->*calc_rho_func)( si.rho, si.r );
732 // calculate objective value
733 for( i = 0, si.obj = 0; i < alpha_count; i++ )
734 si.obj += alpha[i] * (G[i] + b[i]);
738 si.upper_bound_p = C[1];
739 si.upper_bound_n = C[0];
745 // return 1 if already optimal, return 0 otherwise
747 CvSVMSolver::select_working_set( int& out_i, int& out_j )
749 // return i,j which maximize -grad(f)^T d , under constraint
750 // if alpha_i == C, d != +1
751 // if alpha_i == 0, d != -1
752 double Gmax1 = -DBL_MAX; // max { -grad(f)_i * d | y_i*d = +1 }
755 double Gmax2 = -DBL_MAX; // max { -grad(f)_i * d | y_i*d = -1 }
760 for( i = 0; i < alpha_count; i++ )
764 if( y[i] > 0 ) // y = +1
766 if( !is_upper_bound(i) && (t = -G[i]) > Gmax1 ) // d = +1
771 if( !is_lower_bound(i) && (t = G[i]) > Gmax2 ) // d = -1
779 if( !is_upper_bound(i) && (t = -G[i]) > Gmax2 ) // d = +1
784 if( !is_lower_bound(i) && (t = G[i]) > Gmax1 ) // d = -1
795 return Gmax1 + Gmax2 < eps;
800 CvSVMSolver::calc_rho( double& rho, double& r )
803 double ub = DBL_MAX, lb = -DBL_MAX, sum_free = 0;
805 for( i = 0; i < alpha_count; i++ )
807 double yG = y[i]*G[i];
809 if( is_lower_bound(i) )
816 else if( is_upper_bound(i) )
830 rho = nr_free > 0 ? sum_free/nr_free : (ub + lb)*0.5;
836 CvSVMSolver::select_working_set_nu_svm( int& out_i, int& out_j )
838 // return i,j which maximize -grad(f)^T d , under constraint
839 // if alpha_i == C, d != +1
840 // if alpha_i == 0, d != -1
841 double Gmax1 = -DBL_MAX; // max { -grad(f)_i * d | y_i = +1, d = +1 }
844 double Gmax2 = -DBL_MAX; // max { -grad(f)_i * d | y_i = +1, d = -1 }
847 double Gmax3 = -DBL_MAX; // max { -grad(f)_i * d | y_i = -1, d = +1 }
850 double Gmax4 = -DBL_MAX; // max { -grad(f)_i * d | y_i = -1, d = -1 }
855 for( i = 0; i < alpha_count; i++ )
859 if( y[i] > 0 ) // y == +1
861 if( !is_upper_bound(i) && (t = -G[i]) > Gmax1 ) // d = +1
866 if( !is_lower_bound(i) && (t = G[i]) > Gmax2 ) // d = -1
874 if( !is_upper_bound(i) && (t = -G[i]) > Gmax3 ) // d = +1
879 if( !is_lower_bound(i) && (t = G[i]) > Gmax4 ) // d = -1
887 if( MAX(Gmax1 + Gmax2, Gmax3 + Gmax4) < eps )
890 if( Gmax1 + Gmax2 > Gmax3 + Gmax4 )
905 CvSVMSolver::calc_rho_nu_svm( double& rho, double& r )
907 int nr_free1 = 0, nr_free2 = 0;
908 double ub1 = DBL_MAX, ub2 = DBL_MAX;
909 double lb1 = -DBL_MAX, lb2 = -DBL_MAX;
910 double sum_free1 = 0, sum_free2 = 0;
915 for( i = 0; i < alpha_count; i++ )
920 if( is_lower_bound(i) )
921 ub1 = MIN( ub1, G_i );
922 else if( is_upper_bound(i) )
923 lb1 = MAX( lb1, G_i );
932 if( is_lower_bound(i) )
933 ub2 = MIN( ub2, G_i );
934 else if( is_upper_bound(i) )
935 lb2 = MAX( lb2, G_i );
944 r1 = nr_free1 > 0 ? sum_free1/nr_free1 : (ub1 + lb1)*0.5;
945 r2 = nr_free2 > 0 ? sum_free2/nr_free2 : (ub2 + lb2)*0.5;
953 ///////////////////////// construct and solve various formulations ///////////////////////
956 bool CvSVMSolver::solve_c_svc( int _sample_count, int _var_count, const float** _samples, schar* _y,
957 double _Cp, double _Cn, CvMemStorage* _storage,
958 CvSVMKernel* _kernel, double* _alpha, CvSVMSolutionInfo& _si )
962 if( !create( _sample_count, _var_count, _samples, _y, _sample_count,
963 _alpha, _Cp, _Cn, _storage, _kernel, &CvSVMSolver::get_row_svc,
964 &CvSVMSolver::select_working_set, &CvSVMSolver::calc_rho ))
967 for( i = 0; i < sample_count; i++ )
973 if( !solve_generic( _si ))
976 for( i = 0; i < sample_count; i++ )
983 bool CvSVMSolver::solve_nu_svc( int _sample_count, int _var_count, const float** _samples, schar* _y,
984 CvMemStorage* _storage, CvSVMKernel* _kernel,
985 double* _alpha, CvSVMSolutionInfo& _si )
988 double sum_pos, sum_neg, inv_r;
990 if( !create( _sample_count, _var_count, _samples, _y, _sample_count,
991 _alpha, 1., 1., _storage, _kernel, &CvSVMSolver::get_row_svc,
992 &CvSVMSolver::select_working_set_nu_svm, &CvSVMSolver::calc_rho_nu_svm ))
995 sum_pos = kernel->params->nu * sample_count * 0.5;
996 sum_neg = kernel->params->nu * sample_count * 0.5;
998 for( i = 0; i < sample_count; i++ )
1002 alpha[i] = MIN(1.0, sum_pos);
1003 sum_pos -= alpha[i];
1007 alpha[i] = MIN(1.0, sum_neg);
1008 sum_neg -= alpha[i];
1013 if( !solve_generic( _si ))
1018 for( i = 0; i < sample_count; i++ )
1019 alpha[i] *= y[i]*inv_r;
1022 _si.obj *= (inv_r*inv_r);
1023 _si.upper_bound_p = inv_r;
1024 _si.upper_bound_n = inv_r;
1030 bool CvSVMSolver::solve_one_class( int _sample_count, int _var_count, const float** _samples,
1031 CvMemStorage* _storage, CvSVMKernel* _kernel,
1032 double* _alpha, CvSVMSolutionInfo& _si )
1035 double nu = _kernel->params->nu;
1037 if( !create( _sample_count, _var_count, _samples, 0, _sample_count,
1038 _alpha, 1., 1., _storage, _kernel, &CvSVMSolver::get_row_one_class,
1039 &CvSVMSolver::select_working_set, &CvSVMSolver::calc_rho ))
1042 y = (schar*)cvMemStorageAlloc( storage, sample_count*sizeof(y[0]) );
1043 n = cvRound( nu*sample_count );
1045 for( i = 0; i < sample_count; i++ )
1049 alpha[i] = i < n ? 1 : 0;
1052 if( n < sample_count )
1053 alpha[n] = nu * sample_count - n;
1055 alpha[n-1] = nu * sample_count - (n-1);
1057 return solve_generic(_si);
1061 bool CvSVMSolver::solve_eps_svr( int _sample_count, int _var_count, const float** _samples,
1062 const float* _y, CvMemStorage* _storage,
1063 CvSVMKernel* _kernel, double* _alpha, CvSVMSolutionInfo& _si )
1066 double p = _kernel->params->p, C = _kernel->params->C;
1068 if( !create( _sample_count, _var_count, _samples, 0,
1069 _sample_count*2, 0, C, C, _storage, _kernel, &CvSVMSolver::get_row_svr,
1070 &CvSVMSolver::select_working_set, &CvSVMSolver::calc_rho ))
1073 y = (schar*)cvMemStorageAlloc( storage, sample_count*2*sizeof(y[0]) );
1074 alpha = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(alpha[0]) );
1076 for( i = 0; i < sample_count; i++ )
1082 alpha[i+sample_count] = 0;
1083 b[i+sample_count] = p + _y[i];
1084 y[i+sample_count] = -1;
1087 if( !solve_generic( _si ))
1090 for( i = 0; i < sample_count; i++ )
1091 _alpha[i] = alpha[i] - alpha[i+sample_count];
1097 bool CvSVMSolver::solve_nu_svr( int _sample_count, int _var_count, const float** _samples,
1098 const float* _y, CvMemStorage* _storage,
1099 CvSVMKernel* _kernel, double* _alpha, CvSVMSolutionInfo& _si )
1102 double C = _kernel->params->C, sum;
1104 if( !create( _sample_count, _var_count, _samples, 0,
1105 _sample_count*2, 0, 1., 1., _storage, _kernel, &CvSVMSolver::get_row_svr,
1106 &CvSVMSolver::select_working_set_nu_svm, &CvSVMSolver::calc_rho_nu_svm ))
1109 y = (schar*)cvMemStorageAlloc( storage, sample_count*2*sizeof(y[0]) );
1110 alpha = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(alpha[0]) );
1111 sum = C * _kernel->params->nu * sample_count * 0.5;
1113 for( i = 0; i < sample_count; i++ )
1115 alpha[i] = alpha[i + sample_count] = MIN(sum, C);
1121 b[i + sample_count] = _y[i];
1122 y[i + sample_count] = -1;
1125 if( !solve_generic( _si ))
1128 for( i = 0; i < sample_count; i++ )
1129 _alpha[i] = alpha[i] - alpha[i+sample_count];
1135 //////////////////////////////////////////////////////////////////////////////////////////
1146 default_model_name = "my_svm";
1160 cvFree( &decision_func );
1161 cvReleaseMat( &class_labels );
1162 cvReleaseMat( &class_weights );
1163 cvReleaseMemStorage( &storage );
1164 cvReleaseMat( &var_idx );
1175 CvSVM::CvSVM( const CvMat* _train_data, const CvMat* _responses,
1176 const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params )
1185 default_model_name = "my_svm";
1187 train( _train_data, _responses, _var_idx, _sample_idx, _params );
1191 int CvSVM::get_support_vector_count() const
1197 const float* CvSVM::get_support_vector(int i) const
1199 return sv && (unsigned)i < (unsigned)sv_total ? sv[i] : 0;
1203 bool CvSVM::set_params( const CvSVMParams& _params )
1207 CV_FUNCNAME( "CvSVM::set_params" );
1211 int kernel_type, svm_type;
1215 kernel_type = params.kernel_type;
1216 svm_type = params.svm_type;
1218 if( kernel_type != LINEAR && kernel_type != POLY &&
1219 kernel_type != SIGMOID && kernel_type != RBF )
1220 CV_ERROR( CV_StsBadArg, "Unknown/unsupported kernel type" );
1222 if( kernel_type == LINEAR )
1224 else if( params.gamma <= 0 )
1225 CV_ERROR( CV_StsOutOfRange, "gamma parameter of the kernel must be positive" );
1227 if( kernel_type != SIGMOID && kernel_type != POLY )
1229 else if( params.coef0 < 0 )
1230 CV_ERROR( CV_StsOutOfRange, "The kernel parameter <coef0> must be positive or zero" );
1232 if( kernel_type != POLY )
1234 else if( params.degree <= 0 )
1235 CV_ERROR( CV_StsOutOfRange, "The kernel parameter <degree> must be positive" );
1237 if( svm_type != C_SVC && svm_type != NU_SVC &&
1238 svm_type != ONE_CLASS && svm_type != EPS_SVR &&
1239 svm_type != NU_SVR )
1240 CV_ERROR( CV_StsBadArg, "Unknown/unsupported SVM type" );
1242 if( svm_type == ONE_CLASS || svm_type == NU_SVC )
1244 else if( params.C <= 0 )
1245 CV_ERROR( CV_StsOutOfRange, "The parameter C must be positive" );
1247 if( svm_type == C_SVC || svm_type == EPS_SVR )
1249 else if( params.nu <= 0 || params.nu >= 1 )
1250 CV_ERROR( CV_StsOutOfRange, "The parameter nu must be between 0 and 1" );
1252 if( svm_type != EPS_SVR )
1254 else if( params.p <= 0 )
1255 CV_ERROR( CV_StsOutOfRange, "The parameter p must be positive" );
1257 if( svm_type != C_SVC )
1258 params.class_weights = 0;
1260 params.term_crit = cvCheckTermCriteria( params.term_crit, DBL_EPSILON, INT_MAX );
1261 params.term_crit.epsilon = MAX( params.term_crit.epsilon, DBL_EPSILON );
1271 void CvSVM::create_kernel()
1273 kernel = new CvSVMKernel(¶ms,0);
1277 void CvSVM::create_solver( )
1279 solver = new CvSVMSolver;
1283 // switching function
1284 bool CvSVM::train1( int sample_count, int var_count, const float** samples,
1285 const void* _responses, double Cp, double Cn,
1286 CvMemStorage* _storage, double* alpha, double& rho )
1290 //CV_FUNCNAME( "CvSVM::train1" );
1294 CvSVMSolutionInfo si;
1295 int svm_type = params.svm_type;
1299 ok = svm_type == C_SVC ? solver->solve_c_svc( sample_count, var_count, samples, (schar*)_responses,
1300 Cp, Cn, _storage, kernel, alpha, si ) :
1301 svm_type == NU_SVC ? solver->solve_nu_svc( sample_count, var_count, samples, (schar*)_responses,
1302 _storage, kernel, alpha, si ) :
1303 svm_type == ONE_CLASS ? solver->solve_one_class( sample_count, var_count, samples,
1304 _storage, kernel, alpha, si ) :
1305 svm_type == EPS_SVR ? solver->solve_eps_svr( sample_count, var_count, samples, (float*)_responses,
1306 _storage, kernel, alpha, si ) :
1307 svm_type == NU_SVR ? solver->solve_nu_svr( sample_count, var_count, samples, (float*)_responses,
1308 _storage, kernel, alpha, si ) : false;
1318 bool CvSVM::do_train( int svm_type, int sample_count, int var_count, const float** samples,
1319 const CvMat* responses, CvMemStorage* temp_storage, double* alpha )
1323 CV_FUNCNAME( "CvSVM::do_train" );
1327 CvSVMDecisionFunc* df = 0;
1328 const int sample_size = var_count*sizeof(samples[0][0]);
1331 if( svm_type == ONE_CLASS || svm_type == EPS_SVR || svm_type == NU_SVR )
1335 CV_CALL( decision_func = df =
1336 (CvSVMDecisionFunc*)cvAlloc( sizeof(df[0]) ));
1339 if( !train1( sample_count, var_count, samples, svm_type == ONE_CLASS ? 0 :
1340 responses->data.i, 0, 0, temp_storage, alpha, df->rho ))
1343 for( i = 0; i < sample_count; i++ )
1344 sv_count += fabs(alpha[i]) > 0;
1346 sv_total = df->sv_count = sv_count;
1347 CV_CALL( df->alpha = (double*)cvMemStorageAlloc( storage, sv_count*sizeof(df->alpha[0])) );
1348 CV_CALL( sv = (float**)cvMemStorageAlloc( storage, sv_count*sizeof(sv[0])));
1350 for( i = k = 0; i < sample_count; i++ )
1352 if( fabs(alpha[i]) > 0 )
1354 CV_CALL( sv[k] = (float*)cvMemStorageAlloc( storage, sample_size ));
1355 memcpy( sv[k], samples[i], sample_size );
1356 df->alpha[k++] = alpha[i];
1362 int class_count = class_labels->cols;
1364 const float** temp_samples = 0;
1365 int* class_ranges = 0;
1367 assert( svm_type == CvSVM::C_SVC || svm_type == CvSVM::NU_SVC );
1369 if( svm_type == CvSVM::C_SVC && params.class_weights )
1371 const CvMat* cw = params.class_weights;
1373 if( !CV_IS_MAT(cw) || (cw->cols != 1 && cw->rows != 1) ||
1374 cw->rows + cw->cols - 1 != class_count ||
1375 (CV_MAT_TYPE(cw->type) != CV_32FC1 && CV_MAT_TYPE(cw->type) != CV_64FC1) )
1376 CV_ERROR( CV_StsBadArg, "params.class_weights must be 1d floating-point vector "
1377 "containing as many elements as the number of classes" );
1379 CV_CALL( class_weights = cvCreateMat( cw->rows, cw->cols, CV_64F ));
1380 CV_CALL( cvConvert( cw, class_weights ));
1381 CV_CALL( cvScale( class_weights, class_weights, params.C ));
1384 CV_CALL( decision_func = df = (CvSVMDecisionFunc*)cvAlloc(
1385 (class_count*(class_count-1)/2)*sizeof(df[0])));
1387 CV_CALL( sv_tab = (int*)cvMemStorageAlloc( temp_storage, sample_count*sizeof(sv_tab[0]) ));
1388 memset( sv_tab, 0, sample_count*sizeof(sv_tab[0]) );
1389 CV_CALL( class_ranges = (int*)cvMemStorageAlloc( temp_storage,
1390 (class_count + 1)*sizeof(class_ranges[0])));
1391 CV_CALL( temp_samples = (const float**)cvMemStorageAlloc( temp_storage,
1392 sample_count*sizeof(temp_samples[0])));
1393 CV_CALL( temp_y = (schar*)cvMemStorageAlloc( temp_storage, sample_count));
1395 class_ranges[class_count] = 0;
1396 cvSortSamplesByClasses( samples, responses, class_ranges, 0 );
1397 //check that while cross-validation there were the samples from all the classes
1398 if( class_ranges[class_count] <= 0 )
1399 CV_ERROR( CV_StsBadArg, "While cross-validation one or more of the classes have "
1400 "been fell out of the sample. Try to enlarge <CvSVMParams::k_fold>" );
1402 if( svm_type == NU_SVC )
1404 // check if nu is feasible
1405 for(i = 0; i < class_count; i++ )
1407 int ci = class_ranges[i+1] - class_ranges[i];
1408 for( j = i+1; j< class_count; j++ )
1410 int cj = class_ranges[j+1] - class_ranges[j];
1411 if( params.nu*(ci + cj)*0.5 > MIN( ci, cj ) )
1413 // !!!TODO!!! add some diagnostic
1414 EXIT; // exit immediately; will release the model and return NULL pointer
1420 // train n*(n-1)/2 classifiers
1421 for( i = 0; i < class_count; i++ )
1423 for( j = i+1; j < class_count; j++, df++ )
1425 int si = class_ranges[i], ci = class_ranges[i+1] - si;
1426 int sj = class_ranges[j], cj = class_ranges[j+1] - sj;
1427 double Cp = params.C, Cn = Cp;
1428 int k1 = 0, sv_count = 0;
1430 for( k = 0; k < ci; k++ )
1432 temp_samples[k] = samples[si + k];
1436 for( k = 0; k < cj; k++ )
1438 temp_samples[ci + k] = samples[sj + k];
1439 temp_y[ci + k] = -1;
1444 Cp = class_weights->data.db[i];
1445 Cn = class_weights->data.db[j];
1448 if( !train1( ci + cj, var_count, temp_samples, temp_y,
1449 Cp, Cn, temp_storage, alpha, df->rho ))
1452 for( k = 0; k < ci + cj; k++ )
1453 sv_count += fabs(alpha[k]) > 0;
1455 df->sv_count = sv_count;
1457 CV_CALL( df->alpha = (double*)cvMemStorageAlloc( temp_storage,
1458 sv_count*sizeof(df->alpha[0])));
1459 CV_CALL( df->sv_index = (int*)cvMemStorageAlloc( temp_storage,
1460 sv_count*sizeof(df->sv_index[0])));
1462 for( k = 0; k < ci; k++ )
1464 if( fabs(alpha[k]) > 0 )
1467 df->sv_index[k1] = si + k;
1468 df->alpha[k1++] = alpha[k];
1472 for( k = 0; k < cj; k++ )
1474 if( fabs(alpha[ci + k]) > 0 )
1477 df->sv_index[k1] = sj + k;
1478 df->alpha[k1++] = alpha[ci + k];
1484 // allocate support vectors and initialize sv_tab
1485 for( i = 0, k = 0; i < sample_count; i++ )
1492 CV_CALL( sv = (float**)cvMemStorageAlloc( storage, sv_total*sizeof(sv[0])));
1494 for( i = 0, k = 0; i < sample_count; i++ )
1498 CV_CALL( sv[k] = (float*)cvMemStorageAlloc( storage, sample_size ));
1499 memcpy( sv[k], samples[i], sample_size );
1504 df = (CvSVMDecisionFunc*)decision_func;
1507 for( i = 0; i < class_count; i++ )
1509 for( j = i+1; j < class_count; j++, df++ )
1511 for( k = 0; k < df->sv_count; k++ )
1513 df->sv_index[k] = sv_tab[df->sv_index[k]]-1;
1514 assert( (unsigned)df->sv_index[k] < (unsigned)sv_total );
1527 bool CvSVM::train( const CvMat* _train_data, const CvMat* _responses,
1528 const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params )
1531 CvMat* responses = 0;
1532 CvMemStorage* temp_storage = 0;
1533 const float** samples = 0;
1535 CV_FUNCNAME( "CvSVM::train" );
1539 int svm_type, sample_count, var_count, sample_size;
1540 int block_size = 1 << 16;
1544 CV_CALL( set_params( _params ));
1546 svm_type = _params.svm_type;
1548 /* Prepare training data and related parameters */
1549 CV_CALL( cvPrepareTrainData( "CvSVM::train", _train_data, CV_ROW_SAMPLE,
1550 svm_type != CvSVM::ONE_CLASS ? _responses : 0,
1551 svm_type == CvSVM::C_SVC ||
1552 svm_type == CvSVM::NU_SVC ? CV_VAR_CATEGORICAL :
1553 CV_VAR_ORDERED, _var_idx, _sample_idx,
1554 false, &samples, &sample_count, &var_count, &var_all,
1555 &responses, &class_labels, &var_idx ));
1558 sample_size = var_count*sizeof(samples[0][0]);
1560 // make the storage block size large enough to fit all
1561 // the temporary vectors and output support vectors.
1562 block_size = MAX( block_size, sample_count*(int)sizeof(CvSVMKernelRow));
1563 block_size = MAX( block_size, sample_count*2*(int)sizeof(double) + 1024 );
1564 block_size = MAX( block_size, sample_size*2 + 1024 );
1566 CV_CALL( storage = cvCreateMemStorage(block_size));
1567 CV_CALL( temp_storage = cvCreateChildMemStorage(storage));
1568 CV_CALL( alpha = (double*)cvMemStorageAlloc(temp_storage, sample_count*sizeof(double)));
1573 if( !do_train( svm_type, sample_count, var_count, samples, responses, temp_storage, alpha ))
1576 ok = true; // model has been trained succesfully
1582 cvReleaseMemStorage( &temp_storage );
1583 cvReleaseMat( &responses );
1586 if( cvGetErrStatus() < 0 || !ok )
1592 bool CvSVM::train_auto( const CvMat* _train_data, const CvMat* _responses,
1593 const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params, int k_fold,
1594 CvParamGrid C_grid, CvParamGrid gamma_grid, CvParamGrid p_grid,
1595 CvParamGrid nu_grid, CvParamGrid coef_grid, CvParamGrid degree_grid )
1598 CvMat* responses = 0;
1599 CvMat* responses_local = 0;
1600 CvMemStorage* temp_storage = 0;
1601 const float** samples = 0;
1602 const float** samples_local = 0;
1604 CV_FUNCNAME( "CvSVM::train_auto" );
1607 int svm_type, sample_count, var_count, sample_size;
1608 int block_size = 1 << 16;
1611 CvRNG rng = cvRNG(-1);
1613 // all steps are logarithmic and must be > 1
1614 double degree_step = 10, g_step = 10, coef_step = 10, C_step = 10, nu_step = 10, p_step = 10;
1615 double gamma = 0, C = 0, degree = 0, coef = 0, p = 0, nu = 0;
1616 double best_degree = 0, best_gamma = 0, best_coef = 0, best_C = 0, best_nu = 0, best_p = 0;
1617 float min_error = FLT_MAX, error;
1619 if( _params.svm_type == CvSVM::ONE_CLASS )
1621 if(!train( _train_data, _responses, _var_idx, _sample_idx, _params ))
1629 CV_ERROR( CV_StsBadArg, "Parameter <k_fold> must be > 1" );
1631 CV_CALL(set_params( _params ));
1632 svm_type = _params.svm_type;
1634 // All the parameters except, possibly, <coef0> are positive.
1635 // <coef0> is nonnegative
1636 if( C_grid.step <= 1 )
1638 C_grid.min_val = C_grid.max_val = params.C;
1642 CV_CALL(C_grid.check());
1644 if( gamma_grid.step <= 1 )
1646 gamma_grid.min_val = gamma_grid.max_val = params.gamma;
1647 gamma_grid.step = 10;
1650 CV_CALL(gamma_grid.check());
1652 if( p_grid.step <= 1 )
1654 p_grid.min_val = p_grid.max_val = params.p;
1658 CV_CALL(p_grid.check());
1660 if( nu_grid.step <= 1 )
1662 nu_grid.min_val = nu_grid.max_val = params.nu;
1666 CV_CALL(nu_grid.check());
1668 if( coef_grid.step <= 1 )
1670 coef_grid.min_val = coef_grid.max_val = params.coef0;
1671 coef_grid.step = 10;
1674 CV_CALL(coef_grid.check());
1676 if( degree_grid.step <= 1 )
1678 degree_grid.min_val = degree_grid.max_val = params.degree;
1679 degree_grid.step = 10;
1682 CV_CALL(degree_grid.check());
1684 // these parameters are not used:
1685 if( params.kernel_type != CvSVM::POLY )
1686 degree_grid.min_val = degree_grid.max_val = params.degree;
1687 if( params.kernel_type == CvSVM::LINEAR )
1688 gamma_grid.min_val = gamma_grid.max_val = params.gamma;
1689 if( params.kernel_type != CvSVM::POLY && params.kernel_type != CvSVM::SIGMOID )
1690 coef_grid.min_val = coef_grid.max_val = params.coef0;
1691 if( svm_type == CvSVM::NU_SVC || svm_type == CvSVM::ONE_CLASS )
1692 C_grid.min_val = C_grid.max_val = params.C;
1693 if( svm_type == CvSVM::C_SVC || svm_type == CvSVM::EPS_SVR )
1694 nu_grid.min_val = nu_grid.max_val = params.nu;
1695 if( svm_type != CvSVM::EPS_SVR )
1696 p_grid.min_val = p_grid.max_val = params.p;
1698 CV_ASSERT( g_step > 1 && degree_step > 1 && coef_step > 1);
1699 CV_ASSERT( p_step > 1 && C_step > 1 && nu_step > 1 );
1701 /* Prepare training data and related parameters */
1702 CV_CALL(cvPrepareTrainData( "CvSVM::train_auto", _train_data, CV_ROW_SAMPLE,
1703 svm_type != CvSVM::ONE_CLASS ? _responses : 0,
1704 svm_type == CvSVM::C_SVC ||
1705 svm_type == CvSVM::NU_SVC ? CV_VAR_CATEGORICAL :
1706 CV_VAR_ORDERED, _var_idx, _sample_idx,
1707 false, &samples, &sample_count, &var_count, &var_all,
1708 &responses, &class_labels, &var_idx ));
1710 sample_size = var_count*sizeof(samples[0][0]);
1712 // make the storage block size large enough to fit all
1713 // the temporary vectors and output support vectors.
1714 block_size = MAX( block_size, sample_count*(int)sizeof(CvSVMKernelRow));
1715 block_size = MAX( block_size, sample_count*2*(int)sizeof(double) + 1024 );
1716 block_size = MAX( block_size, sample_size*2 + 1024 );
1718 CV_CALL(storage = cvCreateMemStorage(block_size));
1719 CV_CALL(temp_storage = cvCreateChildMemStorage(storage));
1720 CV_CALL(alpha = (double*)cvMemStorageAlloc(temp_storage, sample_count*sizeof(double)));
1726 const int testset_size = sample_count/k_fold;
1727 const int trainset_size = sample_count - testset_size;
1728 const int last_testset_size = sample_count - testset_size*(k_fold-1);
1729 const int last_trainset_size = sample_count - last_testset_size;
1730 const bool is_regression = (svm_type == EPS_SVR) || (svm_type == NU_SVR);
1732 size_t resp_elem_size = CV_ELEM_SIZE(responses->type);
1733 size_t size = 2*last_trainset_size*sizeof(samples[0]);
1735 samples_local = (const float**) cvAlloc( size );
1736 memset( samples_local, 0, size );
1738 responses_local = cvCreateMat( 1, trainset_size, CV_MAT_TYPE(responses->type) );
1739 cvZero( responses_local );
1741 // randomly permute samples and responses
1742 for( i = 0; i < sample_count; i++ )
1744 int i1 = cvRandInt( &rng ) % sample_count;
1745 int i2 = cvRandInt( &rng ) % sample_count;
1750 CV_SWAP( samples[i1], samples[i2], temp );
1752 CV_SWAP( responses->data.fl[i1], responses->data.fl[i2], t );
1754 CV_SWAP( responses->data.i[i1], responses->data.i[i2], y );
1757 int* cls_lbls = class_labels ? class_labels->data.i : 0;
1762 gamma = gamma_grid.min_val;
1765 params.gamma = gamma;
1770 nu = nu_grid.min_val;
1774 coef = coef_grid.min_val;
1777 params.coef0 = coef;
1778 degree = degree_grid.min_val;
1781 params.degree = degree;
1783 float** test_samples_ptr = (float**)samples;
1784 uchar* true_resp = responses->data.ptr;
1785 int test_size = testset_size;
1786 int train_size = trainset_size;
1789 for( k = 0; k < k_fold; k++ )
1791 memcpy( samples_local, samples, sizeof(samples[0])*test_size*k );
1792 memcpy( samples_local + test_size*k, test_samples_ptr + test_size,
1793 sizeof(samples[0])*(sample_count - testset_size*(k+1)) );
1795 memcpy( responses_local->data.ptr, responses->data.ptr, resp_elem_size*test_size*k );
1796 memcpy( responses_local->data.ptr + resp_elem_size*test_size*k,
1797 true_resp + resp_elem_size*test_size,
1798 sizeof(samples[0])*(sample_count - testset_size*(k+1)) );
1800 if( k == k_fold - 1 )
1802 test_size = last_testset_size;
1803 train_size = last_trainset_size;
1804 responses_local->cols = last_trainset_size;
1807 // Train SVM on <train_size> samples
1808 if( !do_train( svm_type, train_size, var_count,
1809 (const float**)samples_local, responses_local, temp_storage, alpha ) )
1812 // Compute test set error on <test_size> samples
1813 CvMat s = cvMat( 1, var_count, CV_32FC1 );
1814 for( i = 0; i < test_size; i++, true_resp += resp_elem_size, test_samples_ptr++ )
1817 s.data.fl = *test_samples_ptr;
1818 resp = predict( &s );
1819 error += is_regression ? powf( resp - *(float*)true_resp, 2 )
1820 : ((int)resp != cls_lbls[*(int*)true_resp]);
1823 if( min_error > error )
1826 best_degree = degree;
1833 degree *= degree_grid.step;
1835 while( degree < degree_grid.max_val );
1836 coef *= coef_grid.step;
1838 while( coef < coef_grid.max_val );
1841 while( nu < nu_grid.max_val );
1844 while( p < p_grid.max_val );
1845 gamma *= gamma_grid.step;
1847 while( gamma < gamma_grid.max_val );
1850 while( C < C_grid.max_val );
1853 min_error /= (float) sample_count;
1856 params.nu = best_nu;
1858 params.gamma = best_gamma;
1859 params.degree = best_degree;
1860 params.coef0 = best_coef;
1862 CV_CALL(ok = do_train( svm_type, sample_count, var_count, samples, responses, temp_storage, alpha ));
1868 cvReleaseMemStorage( &temp_storage );
1869 cvReleaseMat( &responses );
1870 cvReleaseMat( &responses_local );
1872 cvFree( &samples_local );
1874 if( cvGetErrStatus() < 0 || !ok )
1880 float CvSVM::predict( const CvMat* sample, bool returnDFVal ) const
1882 bool local_alloc = 0;
1884 float* row_sample = 0;
1887 CV_FUNCNAME( "CvSVM::predict" );
1892 int var_count, buf_sz;
1895 CV_ERROR( CV_StsBadArg, "The SVM should be trained first" );
1897 class_count = class_labels ? class_labels->cols :
1898 params.svm_type == ONE_CLASS ? 1 : 0;
1900 CV_CALL( cvPreparePredictData( sample, var_all, var_idx,
1901 class_count, 0, &row_sample ));
1903 var_count = get_var_count();
1905 buf_sz = sv_total*sizeof(buffer[0]) + (class_count+1)*sizeof(int);
1906 if( buf_sz <= CV_MAX_LOCAL_SIZE )
1908 CV_CALL( buffer = (Qfloat*)cvStackAlloc( buf_sz ));
1912 CV_CALL( buffer = (Qfloat*)cvAlloc( buf_sz ));
1914 if( params.svm_type == EPS_SVR ||
1915 params.svm_type == NU_SVR ||
1916 params.svm_type == ONE_CLASS )
1918 CvSVMDecisionFunc* df = (CvSVMDecisionFunc*)decision_func;
1919 int i, sv_count = df->sv_count;
1920 double sum = -df->rho;
1922 kernel->calc( sv_count, var_count, (const float**)sv, row_sample, buffer );
1923 for( i = 0; i < sv_count; i++ )
1924 sum += buffer[i]*df->alpha[i];
1926 result = params.svm_type == ONE_CLASS ? (float)(sum > 0) : (float)sum;
1928 else if( params.svm_type == C_SVC ||
1929 params.svm_type == NU_SVC )
1931 CvSVMDecisionFunc* df = (CvSVMDecisionFunc*)decision_func;
1932 int* vote = (int*)(buffer + sv_total);
1935 memset( vote, 0, class_count*sizeof(vote[0]));
1936 kernel->calc( sv_total, var_count, (const float**)sv, row_sample, buffer );
1939 for( i = 0; i < class_count; i++ )
1941 for( j = i+1; j < class_count; j++, df++ )
1944 int sv_count = df->sv_count;
1945 for( k = 0; k < sv_count; k++ )
1946 sum += df->alpha[k]*buffer[df->sv_index[k]];
1948 vote[sum > 0 ? i : j]++;
1952 for( i = 1, k = 0; i < class_count; i++ )
1954 if( vote[i] > vote[k] )
1957 result = returnDFVal && class_count == 2 ? (float)sum : (float)(class_labels->data.i[k]);
1960 CV_ERROR( CV_StsBadArg, "INTERNAL ERROR: Unknown SVM type, "
1961 "the SVM structure is probably corrupted" );
1965 if( sample && (!CV_IS_MAT(sample) || sample->data.fl != row_sample) )
1966 cvFree( &row_sample );
1975 bool CvSVM::train( const Mat& _train_data, const Mat& _responses,
1976 const Mat& _var_idx, const Mat& _sample_idx, CvSVMParams _params )
1978 CvMat tdata = _train_data, responses = _responses, vidx = _var_idx, sidx = _sample_idx;
1979 return train(&tdata, &responses, vidx.data.ptr ? &vidx : 0, sidx.data.ptr ? &sidx : 0, _params);
1983 bool CvSVM::train_auto( const Mat& _train_data, const Mat& _responses,
1984 const Mat& _var_idx, const Mat& _sample_idx, CvSVMParams _params, int k_fold,
1985 CvParamGrid C_grid, CvParamGrid gamma_grid, CvParamGrid p_grid,
1986 CvParamGrid nu_grid, CvParamGrid coef_grid, CvParamGrid degree_grid )
1988 CvMat tdata = _train_data, responses = _responses, vidx = _var_idx, sidx = _sample_idx;
1989 return train_auto(&tdata, &responses, vidx.data.ptr ? &vidx : 0,
1990 sidx.data.ptr ? &sidx : 0, _params, k_fold, C_grid, gamma_grid, p_grid,
1991 nu_grid, coef_grid, degree_grid);
1994 float CvSVM::predict( const Mat& _sample, bool returnDFVal ) const
1996 CvMat sample = _sample;
1997 return predict(&sample, returnDFVal);
2001 void CvSVM::write_params( CvFileStorage* fs ) const
2003 //CV_FUNCNAME( "CvSVM::write_params" );
2007 int svm_type = params.svm_type;
2008 int kernel_type = params.kernel_type;
2010 const char* svm_type_str =
2011 svm_type == CvSVM::C_SVC ? "C_SVC" :
2012 svm_type == CvSVM::NU_SVC ? "NU_SVC" :
2013 svm_type == CvSVM::ONE_CLASS ? "ONE_CLASS" :
2014 svm_type == CvSVM::EPS_SVR ? "EPS_SVR" :
2015 svm_type == CvSVM::NU_SVR ? "NU_SVR" : 0;
2016 const char* kernel_type_str =
2017 kernel_type == CvSVM::LINEAR ? "LINEAR" :
2018 kernel_type == CvSVM::POLY ? "POLY" :
2019 kernel_type == CvSVM::RBF ? "RBF" :
2020 kernel_type == CvSVM::SIGMOID ? "SIGMOID" : 0;
2023 cvWriteString( fs, "svm_type", svm_type_str );
2025 cvWriteInt( fs, "svm_type", svm_type );
2028 cvStartWriteStruct( fs, "kernel", CV_NODE_MAP + CV_NODE_FLOW );
2030 if( kernel_type_str )
2031 cvWriteString( fs, "type", kernel_type_str );
2033 cvWriteInt( fs, "type", kernel_type );
2035 if( kernel_type == CvSVM::POLY || !kernel_type_str )
2036 cvWriteReal( fs, "degree", params.degree );
2038 if( kernel_type != CvSVM::LINEAR || !kernel_type_str )
2039 cvWriteReal( fs, "gamma", params.gamma );
2041 if( kernel_type == CvSVM::POLY || kernel_type == CvSVM::SIGMOID || !kernel_type_str )
2042 cvWriteReal( fs, "coef0", params.coef0 );
2044 cvEndWriteStruct(fs);
2046 if( svm_type == CvSVM::C_SVC || svm_type == CvSVM::EPS_SVR ||
2047 svm_type == CvSVM::NU_SVR || !svm_type_str )
2048 cvWriteReal( fs, "C", params.C );
2050 if( svm_type == CvSVM::NU_SVC || svm_type == CvSVM::ONE_CLASS ||
2051 svm_type == CvSVM::NU_SVR || !svm_type_str )
2052 cvWriteReal( fs, "nu", params.nu );
2054 if( svm_type == CvSVM::EPS_SVR || !svm_type_str )
2055 cvWriteReal( fs, "p", params.p );
2057 cvStartWriteStruct( fs, "term_criteria", CV_NODE_MAP + CV_NODE_FLOW );
2058 if( params.term_crit.type & CV_TERMCRIT_EPS )
2059 cvWriteReal( fs, "epsilon", params.term_crit.epsilon );
2060 if( params.term_crit.type & CV_TERMCRIT_ITER )
2061 cvWriteInt( fs, "iterations", params.term_crit.max_iter );
2062 cvEndWriteStruct( fs );
2068 void CvSVM::write( CvFileStorage* fs, const char* name ) const
2070 CV_FUNCNAME( "CvSVM::write" );
2074 int i, var_count = get_var_count(), df_count, class_count;
2075 const CvSVMDecisionFunc* df = decision_func;
2077 cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_SVM );
2081 cvWriteInt( fs, "var_all", var_all );
2082 cvWriteInt( fs, "var_count", var_count );
2084 class_count = class_labels ? class_labels->cols :
2085 params.svm_type == CvSVM::ONE_CLASS ? 1 : 0;
2089 cvWriteInt( fs, "class_count", class_count );
2092 cvWrite( fs, "class_labels", class_labels );
2095 cvWrite( fs, "class_weights", class_weights );
2099 cvWrite( fs, "var_idx", var_idx );
2101 // write the joint collection of support vectors
2102 cvWriteInt( fs, "sv_total", sv_total );
2103 cvStartWriteStruct( fs, "support_vectors", CV_NODE_SEQ );
2104 for( i = 0; i < sv_total; i++ )
2106 cvStartWriteStruct( fs, 0, CV_NODE_SEQ + CV_NODE_FLOW );
2107 cvWriteRawData( fs, sv[i], var_count, "f" );
2108 cvEndWriteStruct( fs );
2111 cvEndWriteStruct( fs );
2113 // write decision functions
2114 df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
2117 cvStartWriteStruct( fs, "decision_functions", CV_NODE_SEQ );
2118 for( i = 0; i < df_count; i++ )
2120 int sv_count = df[i].sv_count;
2121 cvStartWriteStruct( fs, 0, CV_NODE_MAP );
2122 cvWriteInt( fs, "sv_count", sv_count );
2123 cvWriteReal( fs, "rho", df[i].rho );
2124 cvStartWriteStruct( fs, "alpha", CV_NODE_SEQ+CV_NODE_FLOW );
2125 cvWriteRawData( fs, df[i].alpha, df[i].sv_count, "d" );
2126 cvEndWriteStruct( fs );
2127 if( class_count > 1 )
2129 cvStartWriteStruct( fs, "index", CV_NODE_SEQ+CV_NODE_FLOW );
2130 cvWriteRawData( fs, df[i].sv_index, df[i].sv_count, "i" );
2131 cvEndWriteStruct( fs );
2134 CV_ASSERT( sv_count == sv_total );
2135 cvEndWriteStruct( fs );
2137 cvEndWriteStruct( fs );
2138 cvEndWriteStruct( fs );
2144 void CvSVM::read_params( CvFileStorage* fs, CvFileNode* svm_node )
2146 CV_FUNCNAME( "CvSVM::read_params" );
2150 int svm_type, kernel_type;
2151 CvSVMParams _params;
2153 CvFileNode* tmp_node = cvGetFileNodeByName( fs, svm_node, "svm_type" );
2154 CvFileNode* kernel_node;
2156 CV_ERROR( CV_StsBadArg, "svm_type tag is not found" );
2158 if( CV_NODE_TYPE(tmp_node->tag) == CV_NODE_INT )
2159 svm_type = cvReadInt( tmp_node, -1 );
2162 const char* svm_type_str = cvReadString( tmp_node, "" );
2164 strcmp( svm_type_str, "C_SVC" ) == 0 ? CvSVM::C_SVC :
2165 strcmp( svm_type_str, "NU_SVC" ) == 0 ? CvSVM::NU_SVC :
2166 strcmp( svm_type_str, "ONE_CLASS" ) == 0 ? CvSVM::ONE_CLASS :
2167 strcmp( svm_type_str, "EPS_SVR" ) == 0 ? CvSVM::EPS_SVR :
2168 strcmp( svm_type_str, "NU_SVR" ) == 0 ? CvSVM::NU_SVR : -1;
2171 CV_ERROR( CV_StsParseError, "Missing of invalid SVM type" );
2174 kernel_node = cvGetFileNodeByName( fs, svm_node, "kernel" );
2176 CV_ERROR( CV_StsParseError, "SVM kernel tag is not found" );
2178 tmp_node = cvGetFileNodeByName( fs, kernel_node, "type" );
2180 CV_ERROR( CV_StsParseError, "SVM kernel type tag is not found" );
2182 if( CV_NODE_TYPE(tmp_node->tag) == CV_NODE_INT )
2183 kernel_type = cvReadInt( tmp_node, -1 );
2186 const char* kernel_type_str = cvReadString( tmp_node, "" );
2188 strcmp( kernel_type_str, "LINEAR" ) == 0 ? CvSVM::LINEAR :
2189 strcmp( kernel_type_str, "POLY" ) == 0 ? CvSVM::POLY :
2190 strcmp( kernel_type_str, "RBF" ) == 0 ? CvSVM::RBF :
2191 strcmp( kernel_type_str, "SIGMOID" ) == 0 ? CvSVM::SIGMOID : -1;
2193 if( kernel_type < 0 )
2194 CV_ERROR( CV_StsParseError, "Missing of invalid SVM kernel type" );
2197 _params.svm_type = svm_type;
2198 _params.kernel_type = kernel_type;
2199 _params.degree = cvReadRealByName( fs, kernel_node, "degree", 0 );
2200 _params.gamma = cvReadRealByName( fs, kernel_node, "gamma", 0 );
2201 _params.coef0 = cvReadRealByName( fs, kernel_node, "coef0", 0 );
2203 _params.C = cvReadRealByName( fs, svm_node, "C", 0 );
2204 _params.nu = cvReadRealByName( fs, svm_node, "nu", 0 );
2205 _params.p = cvReadRealByName( fs, svm_node, "p", 0 );
2206 _params.class_weights = 0;
2208 tmp_node = cvGetFileNodeByName( fs, svm_node, "term_criteria" );
2211 _params.term_crit.epsilon = cvReadRealByName( fs, tmp_node, "epsilon", -1. );
2212 _params.term_crit.max_iter = cvReadIntByName( fs, tmp_node, "iterations", -1 );
2213 _params.term_crit.type = (_params.term_crit.epsilon >= 0 ? CV_TERMCRIT_EPS : 0) +
2214 (_params.term_crit.max_iter >= 0 ? CV_TERMCRIT_ITER : 0);
2217 _params.term_crit = cvTermCriteria( CV_TERMCRIT_EPS + CV_TERMCRIT_ITER, 1000, FLT_EPSILON );
2219 set_params( _params );
2225 void CvSVM::read( CvFileStorage* fs, CvFileNode* svm_node )
2227 const double not_found_dbl = DBL_MAX;
2229 CV_FUNCNAME( "CvSVM::read" );
2233 int i, var_count, df_count, class_count;
2234 int block_size = 1 << 16, sv_size;
2235 CvFileNode *sv_node, *df_node;
2236 CvSVMDecisionFunc* df;
2240 CV_ERROR( CV_StsParseError, "The requested element is not found" );
2244 // read SVM parameters
2245 read_params( fs, svm_node );
2247 // and top-level data
2248 sv_total = cvReadIntByName( fs, svm_node, "sv_total", -1 );
2249 var_all = cvReadIntByName( fs, svm_node, "var_all", -1 );
2250 var_count = cvReadIntByName( fs, svm_node, "var_count", var_all );
2251 class_count = cvReadIntByName( fs, svm_node, "class_count", 0 );
2253 if( sv_total <= 0 || var_all <= 0 || var_count <= 0 || var_count > var_all || class_count < 0 )
2254 CV_ERROR( CV_StsParseError, "SVM model data is invalid, check sv_count, var_* and class_count tags" );
2256 CV_CALL( class_labels = (CvMat*)cvReadByName( fs, svm_node, "class_labels" ));
2257 CV_CALL( class_weights = (CvMat*)cvReadByName( fs, svm_node, "class_weights" ));
2258 CV_CALL( var_idx = (CvMat*)cvReadByName( fs, svm_node, "var_idx" ));
2260 if( class_count > 1 && (!class_labels ||
2261 !CV_IS_MAT(class_labels) || class_labels->cols != class_count))
2262 CV_ERROR( CV_StsParseError, "Array of class labels is missing or invalid" );
2264 if( var_count < var_all && (!var_idx || !CV_IS_MAT(var_idx) || var_idx->cols != var_count) )
2265 CV_ERROR( CV_StsParseError, "var_idx array is missing or invalid" );
2267 // read support vectors
2268 sv_node = cvGetFileNodeByName( fs, svm_node, "support_vectors" );
2269 if( !sv_node || !CV_NODE_IS_SEQ(sv_node->tag))
2270 CV_ERROR( CV_StsParseError, "Missing or invalid sequence of support vectors" );
2272 block_size = MAX( block_size, sv_total*(int)sizeof(CvSVMKernelRow));
2273 block_size = MAX( block_size, sv_total*2*(int)sizeof(double));
2274 block_size = MAX( block_size, var_all*(int)sizeof(double));
2275 CV_CALL( storage = cvCreateMemStorage( block_size ));
2276 CV_CALL( sv = (float**)cvMemStorageAlloc( storage,
2277 sv_total*sizeof(sv[0]) ));
2279 CV_CALL( cvStartReadSeq( sv_node->data.seq, &reader, 0 ));
2280 sv_size = var_count*sizeof(sv[0][0]);
2282 for( i = 0; i < sv_total; i++ )
2284 CvFileNode* sv_elem = (CvFileNode*)reader.ptr;
2285 CV_ASSERT( var_count == 1 || (CV_NODE_IS_SEQ(sv_elem->tag) &&
2286 sv_elem->data.seq->total == var_count) );
2288 CV_CALL( sv[i] = (float*)cvMemStorageAlloc( storage, sv_size ));
2289 CV_CALL( cvReadRawData( fs, sv_elem, sv[i], "f" ));
2290 CV_NEXT_SEQ_ELEM( sv_node->data.seq->elem_size, reader );
2293 // read decision functions
2294 df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
2295 df_node = cvGetFileNodeByName( fs, svm_node, "decision_functions" );
2296 if( !df_node || !CV_NODE_IS_SEQ(df_node->tag) ||
2297 df_node->data.seq->total != df_count )
2298 CV_ERROR( CV_StsParseError, "decision_functions is missing or is not a collection "
2299 "or has a wrong number of elements" );
2301 CV_CALL( df = decision_func = (CvSVMDecisionFunc*)cvAlloc( df_count*sizeof(df[0]) ));
2302 cvStartReadSeq( df_node->data.seq, &reader, 0 );
2304 for( i = 0; i < df_count; i++ )
2306 CvFileNode* df_elem = (CvFileNode*)reader.ptr;
2307 CvFileNode* alpha_node = cvGetFileNodeByName( fs, df_elem, "alpha" );
2309 int sv_count = cvReadIntByName( fs, df_elem, "sv_count", -1 );
2311 CV_ERROR( CV_StsParseError, "sv_count is missing or non-positive" );
2312 df[i].sv_count = sv_count;
2314 df[i].rho = cvReadRealByName( fs, df_elem, "rho", not_found_dbl );
2315 if( fabs(df[i].rho - not_found_dbl) < DBL_EPSILON )
2316 CV_ERROR( CV_StsParseError, "rho is missing" );
2319 CV_ERROR( CV_StsParseError, "alpha is missing in the decision function" );
2321 CV_CALL( df[i].alpha = (double*)cvMemStorageAlloc( storage,
2322 sv_count*sizeof(df[i].alpha[0])));
2323 CV_ASSERT( sv_count == 1 || (CV_NODE_IS_SEQ(alpha_node->tag) &&
2324 alpha_node->data.seq->total == sv_count) );
2325 CV_CALL( cvReadRawData( fs, alpha_node, df[i].alpha, "d" ));
2327 if( class_count > 1 )
2329 CvFileNode* index_node = cvGetFileNodeByName( fs, df_elem, "index" );
2331 CV_ERROR( CV_StsParseError, "index is missing in the decision function" );
2332 CV_CALL( df[i].sv_index = (int*)cvMemStorageAlloc( storage,
2333 sv_count*sizeof(df[i].sv_index[0])));
2334 CV_ASSERT( sv_count == 1 || (CV_NODE_IS_SEQ(index_node->tag) &&
2335 index_node->data.seq->total == sv_count) );
2336 CV_CALL( cvReadRawData( fs, index_node, df[i].sv_index, "i" ));
2341 CV_NEXT_SEQ_ELEM( df_node->data.seq->elem_size, reader );
2352 icvCloneSVM( const void* _src )
2354 CvSVMModel* dst = 0;
2356 CV_FUNCNAME( "icvCloneSVM" );
2360 const CvSVMModel* src = (const CvSVMModel*)_src;
2361 int var_count, class_count;
2362 int i, sv_total, df_count;
2365 if( !CV_IS_SVM(src) )
2366 CV_ERROR( !src ? CV_StsNullPtr : CV_StsBadArg, "Input pointer is NULL or invalid" );
2368 // 0. create initial CvSVMModel structure
2369 CV_CALL( dst = icvCreateSVM() );
2370 dst->params = src->params;
2371 dst->params.weight_labels = 0;
2372 dst->params.weights = 0;
2374 dst->var_all = src->var_all;
2375 if( src->class_labels )
2376 dst->class_labels = cvCloneMat( src->class_labels );
2377 if( src->class_weights )
2378 dst->class_weights = cvCloneMat( src->class_weights );
2380 dst->comp_idx = cvCloneMat( src->comp_idx );
2382 var_count = src->comp_idx ? src->comp_idx->cols : src->var_all;
2383 class_count = src->class_labels ? src->class_labels->cols :
2384 src->params.svm_type == CvSVM::ONE_CLASS ? 1 : 0;
2385 sv_total = dst->sv_total = src->sv_total;
2386 CV_CALL( dst->storage = cvCreateMemStorage( src->storage->block_size ));
2387 CV_CALL( dst->sv = (float**)cvMemStorageAlloc( dst->storage,
2388 sv_total*sizeof(dst->sv[0]) ));
2390 sv_size = var_count*sizeof(dst->sv[0][0]);
2392 for( i = 0; i < sv_total; i++ )
2394 CV_CALL( dst->sv[i] = (float*)cvMemStorageAlloc( dst->storage, sv_size ));
2395 memcpy( dst->sv[i], src->sv[i], sv_size );
2398 df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
2400 CV_CALL( dst->decision_func = cvAlloc( df_count*sizeof(CvSVMDecisionFunc) ));
2402 for( i = 0; i < df_count; i++ )
2404 const CvSVMDecisionFunc *sdf =
2405 (const CvSVMDecisionFunc*)src->decision_func+i;
2406 CvSVMDecisionFunc *ddf =
2407 (CvSVMDecisionFunc*)dst->decision_func+i;
2408 int sv_count = sdf->sv_count;
2409 ddf->sv_count = sv_count;
2410 ddf->rho = sdf->rho;
2411 CV_CALL( ddf->alpha = (double*)cvMemStorageAlloc( dst->storage,
2412 sv_count*sizeof(ddf->alpha[0])));
2413 memcpy( ddf->alpha, sdf->alpha, sv_count*sizeof(ddf->alpha[0]));
2415 if( class_count > 1 )
2417 CV_CALL( ddf->sv_index = (int*)cvMemStorageAlloc( dst->storage,
2418 sv_count*sizeof(ddf->sv_index[0])));
2419 memcpy( ddf->sv_index, sdf->sv_index, sv_count*sizeof(ddf->sv_index[0]));
2427 if( cvGetErrStatus() < 0 && dst )
2428 icvReleaseSVM( &dst );
2433 static int icvRegisterSVMType()
2436 memset( &info, 0, sizeof(info) );
2439 info.header_size = sizeof( info );
2440 info.is_instance = icvIsSVM;
2441 info.release = (CvReleaseFunc)icvReleaseSVM;
2442 info.read = icvReadSVM;
2443 info.write = icvWriteSVM;
2444 info.clone = icvCloneSVM;
2445 info.type_name = CV_TYPE_NAME_ML_SVM;
2446 cvRegisterType( &info );
2452 static int svm = icvRegisterSVMType();
2454 /* The function trains SVM model with optimal parameters, obtained by using cross-validation.
2455 The parameters to be estimated should be indicated by setting theirs values to FLT_MAX.
2456 The optimal parameters are saved in <model_params> */
2457 CV_IMPL CvStatModel*
2458 cvTrainSVM_CrossValidation( const CvMat* train_data, int tflag,
2459 const CvMat* responses,
2460 CvStatModelParams* model_params,
2461 const CvStatModelParams* cross_valid_params,
2462 const CvMat* comp_idx,
2463 const CvMat* sample_idx,
2464 const CvParamGrid* degree_grid,
2465 const CvParamGrid* gamma_grid,
2466 const CvParamGrid* coef_grid,
2467 const CvParamGrid* C_grid,
2468 const CvParamGrid* nu_grid,
2469 const CvParamGrid* p_grid )
2471 CvStatModel* svm = 0;
2473 CV_FUNCNAME("cvTainSVMCrossValidation");
2476 double degree_step = 7,
2481 p_step = 7; // all steps must be > 1
2482 double degree_begin = 0.01, degree_end = 2;
2483 double g_begin = 1e-5, g_end = 0.5;
2484 double coef_begin = 0.1, coef_end = 300;
2485 double C_begin = 0.1, C_end = 6000;
2486 double nu_begin = 0.01, nu_end = 0.4;
2487 double p_begin = 0.01, p_end = 100;
2489 double rate = 0, gamma = 0, C = 0, degree = 0, coef = 0, p = 0, nu = 0;
2491 double best_rate = 0;
2492 double best_degree = degree_begin;
2493 double best_gamma = g_begin;
2494 double best_coef = coef_begin;
2495 double best_C = C_begin;
2496 double best_nu = nu_begin;
2497 double best_p = p_begin;
2499 CvSVMModelParams svm_params, *psvm_params;
2500 CvCrossValidationParams* cv_params = (CvCrossValidationParams*)cross_valid_params;
2501 int svm_type, kernel;
2505 CV_ERROR( CV_StsBadArg, "" );
2507 CV_ERROR( CV_StsBadArg, "" );
2509 svm_params = *(CvSVMModelParams*)model_params;
2510 psvm_params = (CvSVMModelParams*)model_params;
2511 svm_type = svm_params.svm_type;
2512 kernel = svm_params.kernel_type;
2514 svm_params.degree = svm_params.degree > 0 ? svm_params.degree : 1;
2515 svm_params.gamma = svm_params.gamma > 0 ? svm_params.gamma : 1;
2516 svm_params.coef0 = svm_params.coef0 > 0 ? svm_params.coef0 : 1e-6;
2517 svm_params.C = svm_params.C > 0 ? svm_params.C : 1;
2518 svm_params.nu = svm_params.nu > 0 ? svm_params.nu : 1;
2519 svm_params.p = svm_params.p > 0 ? svm_params.p : 1;
2523 if( !(degree_grid->max_val == 0 && degree_grid->min_val == 0 &&
2524 degree_grid->step == 0) )
2526 if( degree_grid->min_val > degree_grid->max_val )
2527 CV_ERROR( CV_StsBadArg,
2528 "low bound of grid should be less then the upper one");
2529 if( degree_grid->step <= 1 )
2530 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2531 degree_begin = degree_grid->min_val;
2532 degree_end = degree_grid->max_val;
2533 degree_step = degree_grid->step;
2537 degree_begin = degree_end = svm_params.degree;
2541 if( !(gamma_grid->max_val == 0 && gamma_grid->min_val == 0 &&
2542 gamma_grid->step == 0) )
2544 if( gamma_grid->min_val > gamma_grid->max_val )
2545 CV_ERROR( CV_StsBadArg,
2546 "low bound of grid should be less then the upper one");
2547 if( gamma_grid->step <= 1 )
2548 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2549 g_begin = gamma_grid->min_val;
2550 g_end = gamma_grid->max_val;
2551 g_step = gamma_grid->step;
2555 g_begin = g_end = svm_params.gamma;
2559 if( !(coef_grid->max_val == 0 && coef_grid->min_val == 0 &&
2560 coef_grid->step == 0) )
2562 if( coef_grid->min_val > coef_grid->max_val )
2563 CV_ERROR( CV_StsBadArg,
2564 "low bound of grid should be less then the upper one");
2565 if( coef_grid->step <= 1 )
2566 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2567 coef_begin = coef_grid->min_val;
2568 coef_end = coef_grid->max_val;
2569 coef_step = coef_grid->step;
2573 coef_begin = coef_end = svm_params.coef0;
2577 if( !(C_grid->max_val == 0 && C_grid->min_val == 0 && C_grid->step == 0))
2579 if( C_grid->min_val > C_grid->max_val )
2580 CV_ERROR( CV_StsBadArg,
2581 "low bound of grid should be less then the upper one");
2582 if( C_grid->step <= 1 )
2583 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2584 C_begin = C_grid->min_val;
2585 C_end = C_grid->max_val;
2586 C_step = C_grid->step;
2590 C_begin = C_end = svm_params.C;
2594 if(!(nu_grid->max_val == 0 && nu_grid->min_val == 0 && nu_grid->step==0))
2596 if( nu_grid->min_val > nu_grid->max_val )
2597 CV_ERROR( CV_StsBadArg,
2598 "low bound of grid should be less then the upper one");
2599 if( nu_grid->step <= 1 )
2600 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2601 nu_begin = nu_grid->min_val;
2602 nu_end = nu_grid->max_val;
2603 nu_step = nu_grid->step;
2607 nu_begin = nu_end = svm_params.nu;
2611 if( !(p_grid->max_val == 0 && p_grid->min_val == 0 && p_grid->step == 0))
2613 if( p_grid->min_val > p_grid->max_val )
2614 CV_ERROR( CV_StsBadArg,
2615 "low bound of grid should be less then the upper one");
2616 if( p_grid->step <= 1 )
2617 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2618 p_begin = p_grid->min_val;
2619 p_end = p_grid->max_val;
2620 p_step = p_grid->step;
2624 p_begin = p_end = svm_params.p;
2626 // these parameters are not used:
2627 if( kernel != CvSVM::POLY )
2628 degree_begin = degree_end = svm_params.degree;
2630 if( kernel == CvSVM::LINEAR )
2631 g_begin = g_end = svm_params.gamma;
2633 if( kernel != CvSVM::POLY && kernel != CvSVM::SIGMOID )
2634 coef_begin = coef_end = svm_params.coef0;
2636 if( svm_type == CvSVM::NU_SVC || svm_type == CvSVM::ONE_CLASS )
2637 C_begin = C_end = svm_params.C;
2639 if( svm_type == CvSVM::C_SVC || svm_type == CvSVM::EPS_SVR )
2640 nu_begin = nu_end = svm_params.nu;
2642 if( svm_type != CvSVM::EPS_SVR )
2643 p_begin = p_end = svm_params.p;
2645 is_regression = cv_params->is_regression;
2646 best_rate = is_regression ? FLT_MAX : 0;
2648 assert( g_step > 1 && degree_step > 1 && coef_step > 1);
2649 assert( p_step > 1 && C_step > 1 && nu_step > 1 );
2651 for( degree = degree_begin; degree <= degree_end; degree *= degree_step )
2653 svm_params.degree = degree;
2654 //printf("degree = %.3f\n", degree );
2655 for( gamma= g_begin; gamma <= g_end; gamma *= g_step )
2657 svm_params.gamma = gamma;
2658 //printf(" gamma = %.3f\n", gamma );
2659 for( coef = coef_begin; coef <= coef_end; coef *= coef_step )
2661 svm_params.coef0 = coef;
2662 //printf(" coef = %.3f\n", coef );
2663 for( C = C_begin; C <= C_end; C *= C_step )
2666 //printf(" C = %.3f\n", C );
2667 for( nu = nu_begin; nu <= nu_end; nu *= nu_step )
2670 //printf(" nu = %.3f\n", nu );
2671 for( p = p_begin; p <= p_end; p *= p_step )
2675 //printf(" p = %.3f\n", p );
2677 CV_CALL(rate = cvCrossValidation( train_data, tflag, responses, &cvTrainSVM,
2678 cross_valid_params, (CvStatModelParams*)&svm_params, comp_idx, sample_idx ));
2680 well = rate > best_rate && !is_regression || rate < best_rate && is_regression;
2681 if( well || (rate == best_rate && C < best_C) )
2684 best_degree = degree;
2691 //printf(" rate = %.2f\n", rate );
2698 //printf("The best:\nrate = %.2f%% degree = %f gamma = %f coef = %f c = %f nu = %f p = %f\n",
2699 // best_rate, best_degree, best_gamma, best_coef, best_C, best_nu, best_p );
2701 psvm_params->C = best_C;
2702 psvm_params->nu = best_nu;
2703 psvm_params->p = best_p;
2704 psvm_params->gamma = best_gamma;
2705 psvm_params->degree = best_degree;
2706 psvm_params->coef0 = best_coef;
2708 CV_CALL(svm = cvTrainSVM( train_data, tflag, responses, model_params, comp_idx, sample_idx ));