Update to 2.0.0 tree from current Fremantle build
[opencv] / src / ml / mlsvm.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
14 //
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
17 //
18 //   * Redistribution's of source code must retain the above copyright notice,
19 //     this list of conditions and the following disclaimer.
20 //
21 //   * Redistribution's in binary form must reproduce the above copyright notice,
22 //     this list of conditions and the following disclaimer in the documentation
23 //     and/or other materials provided with the distribution.
24 //
25 //   * The name of Intel Corporation may not be used to endorse or promote products
26 //     derived from this software without specific prior written permission.
27 //
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
38 //
39 //M*/
40
41 #include "_ml.h"
42
43 /****************************************************************************************\
44                                 COPYRIGHT NOTICE
45                                 ----------------
46
47   The code has been derived from libsvm library (version 2.6)
48   (http://www.csie.ntu.edu.tw/~cjlin/libsvm).
49
50   Here is the orignal copyright:
51 ------------------------------------------------------------------------------------------
52     Copyright (c) 2000-2003 Chih-Chung Chang and Chih-Jen Lin
53     All rights reserved.
54
55     Redistribution and use in source and binary forms, with or without
56     modification, are permitted provided that the following conditions
57     are met:
58
59     1. Redistributions of source code must retain the above copyright
60     notice, this list of conditions and the following disclaimer.
61
62     2. Redistributions in binary form must reproduce the above copyright
63     notice, this list of conditions and the following disclaimer in the
64     documentation and/or other materials provided with the distribution.
65
66     3. Neither name of copyright holders nor the names of its contributors
67     may be used to endorse or promote products derived from this software
68     without specific prior written permission.
69
70
71     THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
72     ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
73     LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
74     A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE REGENTS OR
75     CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
76     EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
77     PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
78     PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
79     LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
80     NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
81     SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
82 \****************************************************************************************/
83
84 using namespace cv;
85
86 #define CV_SVM_MIN_CACHE_SIZE  (40 << 20)  /* 40Mb */
87
88 #include <stdarg.h>
89 #include <ctype.h>
90
91 #if _MSC_VER >= 1200
92 #pragma warning( disable: 4514 ) /* unreferenced inline functions */
93 #endif
94
95 #if 1
96 typedef float Qfloat;
97 #define QFLOAT_TYPE CV_32F
98 #else
99 typedef double Qfloat;
100 #define QFLOAT_TYPE CV_64F
101 #endif
102
103 // Param Grid
104 bool CvParamGrid::check() const
105 {
106     bool ok = false;
107
108     CV_FUNCNAME( "CvParamGrid::check" );
109     __BEGIN__;
110
111     if( min_val > max_val )
112         CV_ERROR( CV_StsBadArg, "Lower bound of the grid must be less then the upper one" );
113     if( min_val < DBL_EPSILON )
114         CV_ERROR( CV_StsBadArg, "Lower bound of the grid must be positive" );
115     if( step < 1. + FLT_EPSILON )
116         CV_ERROR( CV_StsBadArg, "Grid step must greater then 1" );
117
118     ok = true;
119
120     __END__;
121
122     return ok;
123 }
124
125 CvParamGrid CvSVM::get_default_grid( int param_id )
126 {
127     CvParamGrid grid;
128     if( param_id == CvSVM::C )
129     {
130         grid.min_val = 0.1;
131         grid.max_val = 500;
132         grid.step = 5; // total iterations = 5
133     }
134     else if( param_id == CvSVM::GAMMA )
135     {
136         grid.min_val = 1e-5;
137         grid.max_val = 0.6;
138         grid.step = 15; // total iterations = 4
139     }
140     else if( param_id == CvSVM::P )
141     {
142         grid.min_val = 0.01;
143         grid.max_val = 100;
144         grid.step = 7; // total iterations = 4
145     }
146     else if( param_id == CvSVM::NU )
147     {
148         grid.min_val = 0.01;
149         grid.max_val = 0.2;
150         grid.step = 3; // total iterations = 3
151     }
152     else if( param_id == CvSVM::COEF )
153     {
154         grid.min_val = 0.1;
155         grid.max_val = 300;
156         grid.step = 14; // total iterations = 3
157     }
158     else if( param_id == CvSVM::DEGREE )
159     {
160         grid.min_val = 0.01;
161         grid.max_val = 4;
162         grid.step = 7; // total iterations = 3
163     }
164     else
165         cvError( CV_StsBadArg, "CvSVM::get_default_grid", "Invalid type of parameter "
166             "(use one of CvSVM::C, CvSVM::GAMMA et al.)", __FILE__, __LINE__ );
167     return grid;
168 }
169
170 // SVM training parameters
171 CvSVMParams::CvSVMParams() :
172     svm_type(CvSVM::C_SVC), kernel_type(CvSVM::RBF), degree(0),
173     gamma(1), coef0(0), C(1), nu(0), p(0), class_weights(0)
174 {
175     term_crit = cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 1000, FLT_EPSILON );
176 }
177
178
179 CvSVMParams::CvSVMParams( int _svm_type, int _kernel_type,
180     double _degree, double _gamma, double _coef0,
181     double _Con, double _nu, double _p,
182     CvMat* _class_weights, CvTermCriteria _term_crit ) :
183     svm_type(_svm_type), kernel_type(_kernel_type),
184     degree(_degree), gamma(_gamma), coef0(_coef0),
185     C(_Con), nu(_nu), p(_p), class_weights(_class_weights), term_crit(_term_crit)
186 {
187 }
188
189
190 /////////////////////////////////////// SVM kernel ///////////////////////////////////////
191
192 CvSVMKernel::CvSVMKernel()
193 {
194     clear();
195 }
196
197
198 void CvSVMKernel::clear()
199 {
200     params = 0;
201     calc_func = 0;
202 }
203
204
205 CvSVMKernel::~CvSVMKernel()
206 {
207 }
208
209
210 CvSVMKernel::CvSVMKernel( const CvSVMParams* _params, Calc _calc_func )
211 {
212     clear();
213     create( _params, _calc_func );
214 }
215
216
217 bool CvSVMKernel::create( const CvSVMParams* _params, Calc _calc_func )
218 {
219     clear();
220     params = _params;
221     calc_func = _calc_func;
222
223     if( !calc_func )
224         calc_func = params->kernel_type == CvSVM::RBF ? &CvSVMKernel::calc_rbf :
225                     params->kernel_type == CvSVM::POLY ? &CvSVMKernel::calc_poly :
226                     params->kernel_type == CvSVM::SIGMOID ? &CvSVMKernel::calc_sigmoid :
227                     &CvSVMKernel::calc_linear;
228
229     return true;
230 }
231
232
233 void CvSVMKernel::calc_non_rbf_base( int vcount, int var_count, const float** vecs,
234                                      const float* another, Qfloat* results,
235                                      double alpha, double beta )
236 {
237     int j, k;
238     for( j = 0; j < vcount; j++ )
239     {
240         const float* sample = vecs[j];
241         double s = 0;
242         for( k = 0; k <= var_count - 4; k += 4 )
243             s += sample[k]*another[k] + sample[k+1]*another[k+1] +
244                  sample[k+2]*another[k+2] + sample[k+3]*another[k+3];
245         for( ; k < var_count; k++ )
246             s += sample[k]*another[k];
247         results[j] = (Qfloat)(s*alpha + beta);
248     }
249 }
250
251
252 void CvSVMKernel::calc_linear( int vcount, int var_count, const float** vecs,
253                                const float* another, Qfloat* results )
254 {
255     calc_non_rbf_base( vcount, var_count, vecs, another, results, 1, 0 );
256 }
257
258
259 void CvSVMKernel::calc_poly( int vcount, int var_count, const float** vecs,
260                              const float* another, Qfloat* results )
261 {
262     CvMat R = cvMat( 1, vcount, QFLOAT_TYPE, results );
263     calc_non_rbf_base( vcount, var_count, vecs, another, results, params->gamma, params->coef0 );
264     cvPow( &R, &R, params->degree );
265 }
266
267
268 void CvSVMKernel::calc_sigmoid( int vcount, int var_count, const float** vecs,
269                                 const float* another, Qfloat* results )
270 {
271     int j;
272     calc_non_rbf_base( vcount, var_count, vecs, another, results,
273                        -2*params->gamma, -2*params->coef0 );
274     // TODO: speedup this
275     for( j = 0; j < vcount; j++ )
276     {
277         Qfloat t = results[j];
278         double e = exp(-fabs(t));
279         if( t > 0 )
280             results[j] = (Qfloat)((1. - e)/(1. + e));
281         else
282             results[j] = (Qfloat)((e - 1.)/(e + 1.));
283     }
284 }
285
286
287 void CvSVMKernel::calc_rbf( int vcount, int var_count, const float** vecs,
288                             const float* another, Qfloat* results )
289 {
290     CvMat R = cvMat( 1, vcount, QFLOAT_TYPE, results );
291     double gamma = -params->gamma;
292     int j, k;
293
294     for( j = 0; j < vcount; j++ )
295     {
296         const float* sample = vecs[j];
297         double s = 0;
298
299         for( k = 0; k <= var_count - 4; k += 4 )
300         {
301             double t0 = sample[k] - another[k];
302             double t1 = sample[k+1] - another[k+1];
303
304             s += t0*t0 + t1*t1;
305
306             t0 = sample[k+2] - another[k+2];
307             t1 = sample[k+3] - another[k+3];
308
309             s += t0*t0 + t1*t1;
310         }
311
312         for( ; k < var_count; k++ )
313         {
314             double t0 = sample[k] - another[k];
315             s += t0*t0;
316         }
317         results[j] = (Qfloat)(s*gamma);
318     }
319
320     cvExp( &R, &R );
321 }
322
323
324 void CvSVMKernel::calc( int vcount, int var_count, const float** vecs,
325                         const float* another, Qfloat* results )
326 {
327     const Qfloat max_val = (Qfloat)(FLT_MAX*1e-3);
328     int j;
329     (this->*calc_func)( vcount, var_count, vecs, another, results );
330     for( j = 0; j < vcount; j++ )
331     {
332         if( results[j] > max_val )
333             results[j] = max_val;
334     }
335 }
336
337
338 // Generalized SMO+SVMlight algorithm
339 // Solves:
340 //
341 //  min [0.5(\alpha^T Q \alpha) + b^T \alpha]
342 //
343 //      y^T \alpha = \delta
344 //      y_i = +1 or -1
345 //      0 <= alpha_i <= Cp for y_i = 1
346 //      0 <= alpha_i <= Cn for y_i = -1
347 //
348 // Given:
349 //
350 //  Q, b, y, Cp, Cn, and an initial feasible point \alpha
351 //  l is the size of vectors and matrices
352 //  eps is the stopping criterion
353 //
354 // solution will be put in \alpha, objective value will be put in obj
355 //
356
357 void CvSVMSolver::clear()
358 {
359     G = 0;
360     alpha = 0;
361     y = 0;
362     b = 0;
363     buf[0] = buf[1] = 0;
364     cvReleaseMemStorage( &storage );
365     kernel = 0;
366     select_working_set_func = 0;
367     calc_rho_func = 0;
368
369     rows = 0;
370     samples = 0;
371     get_row_func = 0;
372 }
373
374
375 CvSVMSolver::CvSVMSolver()
376 {
377     storage = 0;
378     clear();
379 }
380
381
382 CvSVMSolver::~CvSVMSolver()
383 {
384     clear();
385 }
386
387
388 CvSVMSolver::CvSVMSolver( int _sample_count, int _var_count, const float** _samples, schar* _y,
389                 int _alpha_count, double* _alpha, double _Cp, double _Cn,
390                 CvMemStorage* _storage, CvSVMKernel* _kernel, GetRow _get_row,
391                 SelectWorkingSet _select_working_set, CalcRho _calc_rho )
392 {
393     storage = 0;
394     create( _sample_count, _var_count, _samples, _y, _alpha_count, _alpha, _Cp, _Cn,
395             _storage, _kernel, _get_row, _select_working_set, _calc_rho );
396 }
397
398
399 bool CvSVMSolver::create( int _sample_count, int _var_count, const float** _samples, schar* _y,
400                 int _alpha_count, double* _alpha, double _Cp, double _Cn,
401                 CvMemStorage* _storage, CvSVMKernel* _kernel, GetRow _get_row,
402                 SelectWorkingSet _select_working_set, CalcRho _calc_rho )
403 {
404     bool ok = false;
405     int i, svm_type;
406
407     CV_FUNCNAME( "CvSVMSolver::create" );
408
409     __BEGIN__;
410
411     int rows_hdr_size;
412
413     clear();
414
415     sample_count = _sample_count;
416     var_count = _var_count;
417     samples = _samples;
418     y = _y;
419     alpha_count = _alpha_count;
420     alpha = _alpha;
421     kernel = _kernel;
422
423     C[0] = _Cn;
424     C[1] = _Cp;
425     eps = kernel->params->term_crit.epsilon;
426     max_iter = kernel->params->term_crit.max_iter;
427     storage = cvCreateChildMemStorage( _storage );
428
429     b = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(b[0]));
430     alpha_status = (schar*)cvMemStorageAlloc( storage, alpha_count*sizeof(alpha_status[0]));
431     G = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(G[0]));
432     for( i = 0; i < 2; i++ )
433         buf[i] = (Qfloat*)cvMemStorageAlloc( storage, sample_count*2*sizeof(buf[i][0]) );
434     svm_type = kernel->params->svm_type;
435
436     select_working_set_func = _select_working_set;
437     if( !select_working_set_func )
438         select_working_set_func = svm_type == CvSVM::NU_SVC || svm_type == CvSVM::NU_SVR ?
439         &CvSVMSolver::select_working_set_nu_svm : &CvSVMSolver::select_working_set;
440
441     calc_rho_func = _calc_rho;
442     if( !calc_rho_func )
443         calc_rho_func = svm_type == CvSVM::NU_SVC || svm_type == CvSVM::NU_SVR ?
444             &CvSVMSolver::calc_rho_nu_svm : &CvSVMSolver::calc_rho;
445
446     get_row_func = _get_row;
447     if( !get_row_func )
448         get_row_func = params->svm_type == CvSVM::EPS_SVR ||
449                        params->svm_type == CvSVM::NU_SVR ? &CvSVMSolver::get_row_svr :
450                        params->svm_type == CvSVM::C_SVC ||
451                        params->svm_type == CvSVM::NU_SVC ? &CvSVMSolver::get_row_svc :
452                        &CvSVMSolver::get_row_one_class;
453
454     cache_line_size = sample_count*sizeof(Qfloat);
455     // cache size = max(num_of_samples^2*sizeof(Qfloat)*0.25, 64Kb)
456     // (assuming that for large training sets ~25% of Q matrix is used)
457     cache_size = MAX( cache_line_size*sample_count/4, CV_SVM_MIN_CACHE_SIZE );
458
459     // the size of Q matrix row headers
460     rows_hdr_size = sample_count*sizeof(rows[0]);
461     if( rows_hdr_size > storage->block_size )
462         CV_ERROR( CV_StsOutOfRange, "Too small storage block size" );
463
464     lru_list.prev = lru_list.next = &lru_list;
465     rows = (CvSVMKernelRow*)cvMemStorageAlloc( storage, rows_hdr_size );
466     memset( rows, 0, rows_hdr_size );
467
468     ok = true;
469
470     __END__;
471
472     return ok;
473 }
474
475
476 float* CvSVMSolver::get_row_base( int i, bool* _existed )
477 {
478     int i1 = i < sample_count ? i : i - sample_count;
479     CvSVMKernelRow* row = rows + i1;
480     bool existed = row->data != 0;
481     Qfloat* data;
482
483     if( existed || cache_size <= 0 )
484     {
485         CvSVMKernelRow* del_row = existed ? row : lru_list.prev;
486         data = del_row->data;
487         assert( data != 0 );
488
489         // delete row from the LRU list
490         del_row->data = 0;
491         del_row->prev->next = del_row->next;
492         del_row->next->prev = del_row->prev;
493     }
494     else
495     {
496         data = (Qfloat*)cvMemStorageAlloc( storage, cache_line_size );
497         cache_size -= cache_line_size;
498     }
499
500     // insert row into the LRU list
501     row->data = data;
502     row->prev = &lru_list;
503     row->next = lru_list.next;
504     row->prev->next = row->next->prev = row;
505
506     if( !existed )
507     {
508         kernel->calc( sample_count, var_count, samples, samples[i1], row->data );
509     }
510
511     if( _existed )
512         *_existed = existed;
513
514     return row->data;
515 }
516
517
518 float* CvSVMSolver::get_row_svc( int i, float* row, float*, bool existed )
519 {
520     if( !existed )
521     {
522         const schar* _y = y;
523         int j, len = sample_count;
524         assert( _y && i < sample_count );
525
526         if( _y[i] > 0 )
527         {
528             for( j = 0; j < len; j++ )
529                 row[j] = _y[j]*row[j];
530         }
531         else
532         {
533             for( j = 0; j < len; j++ )
534                 row[j] = -_y[j]*row[j];
535         }
536     }
537     return row;
538 }
539
540
541 float* CvSVMSolver::get_row_one_class( int, float* row, float*, bool )
542 {
543     return row;
544 }
545
546
547 float* CvSVMSolver::get_row_svr( int i, float* row, float* dst, bool )
548 {
549     int j, len = sample_count;
550     Qfloat* dst_pos = dst;
551     Qfloat* dst_neg = dst + len;
552     if( i >= len )
553     {
554         Qfloat* temp;
555         CV_SWAP( dst_pos, dst_neg, temp );
556     }
557
558     for( j = 0; j < len; j++ )
559     {
560         Qfloat t = row[j];
561         dst_pos[j] = t;
562         dst_neg[j] = -t;
563     }
564     return dst;
565 }
566
567
568
569 float* CvSVMSolver::get_row( int i, float* dst )
570 {
571     bool existed = false;
572     float* row = get_row_base( i, &existed );
573     return (this->*get_row_func)( i, row, dst, existed );
574 }
575
576
577 #undef is_upper_bound
578 #define is_upper_bound(i) (alpha_status[i] > 0)
579
580 #undef is_lower_bound
581 #define is_lower_bound(i) (alpha_status[i] < 0)
582
583 #undef is_free
584 #define is_free(i) (alpha_status[i] == 0)
585
586 #undef get_C
587 #define get_C(i) (C[y[i]>0])
588
589 #undef update_alpha_status
590 #define update_alpha_status(i) \
591     alpha_status[i] = (schar)(alpha[i] >= get_C(i) ? 1 : alpha[i] <= 0 ? -1 : 0)
592
593 #undef reconstruct_gradient
594 #define reconstruct_gradient() /* empty for now */
595
596
597 bool CvSVMSolver::solve_generic( CvSVMSolutionInfo& si )
598 {
599     int iter = 0;
600     int i, j, k;
601
602     // 1. initialize gradient and alpha status
603     for( i = 0; i < alpha_count; i++ )
604     {
605         update_alpha_status(i);
606         G[i] = b[i];
607         if( fabs(G[i]) > 1e200 )
608             return false;
609     }
610
611     for( i = 0; i < alpha_count; i++ )
612     {
613         if( !is_lower_bound(i) )
614         {
615             const Qfloat *Q_i = get_row( i, buf[0] );
616             double alpha_i = alpha[i];
617
618             for( j = 0; j < alpha_count; j++ )
619                 G[j] += alpha_i*Q_i[j];
620         }
621     }
622
623     // 2. optimization loop
624     for(;;)
625     {
626         const Qfloat *Q_i, *Q_j;
627         double C_i, C_j;
628         double old_alpha_i, old_alpha_j, alpha_i, alpha_j;
629         double delta_alpha_i, delta_alpha_j;
630
631 #ifdef _DEBUG
632         for( i = 0; i < alpha_count; i++ )
633         {
634             if( fabs(G[i]) > 1e+300 )
635                 return false;
636
637             if( fabs(alpha[i]) > 1e16 )
638                 return false;
639         }
640 #endif
641
642         if( (this->*select_working_set_func)( i, j ) != 0 || iter++ >= max_iter )
643             break;
644
645         Q_i = get_row( i, buf[0] );
646         Q_j = get_row( j, buf[1] );
647
648         C_i = get_C(i);
649         C_j = get_C(j);
650
651         alpha_i = old_alpha_i = alpha[i];
652         alpha_j = old_alpha_j = alpha[j];
653
654         if( y[i] != y[j] )
655         {
656             double denom = Q_i[i]+Q_j[j]+2*Q_i[j];
657             double delta = (-G[i]-G[j])/MAX(fabs(denom),FLT_EPSILON);
658             double diff = alpha_i - alpha_j;
659             alpha_i += delta;
660             alpha_j += delta;
661
662             if( diff > 0 && alpha_j < 0 )
663             {
664                 alpha_j = 0;
665                 alpha_i = diff;
666             }
667             else if( diff <= 0 && alpha_i < 0 )
668             {
669                 alpha_i = 0;
670                 alpha_j = -diff;
671             }
672
673             if( diff > C_i - C_j && alpha_i > C_i )
674             {
675                 alpha_i = C_i;
676                 alpha_j = C_i - diff;
677             }
678             else if( diff <= C_i - C_j && alpha_j > C_j )
679             {
680                 alpha_j = C_j;
681                 alpha_i = C_j + diff;
682             }
683         }
684         else
685         {
686             double denom = Q_i[i]+Q_j[j]-2*Q_i[j];
687             double delta = (G[i]-G[j])/MAX(fabs(denom),FLT_EPSILON);
688             double sum = alpha_i + alpha_j;
689             alpha_i -= delta;
690             alpha_j += delta;
691
692             if( sum > C_i && alpha_i > C_i )
693             {
694                 alpha_i = C_i;
695                 alpha_j = sum - C_i;
696             }
697             else if( sum <= C_i && alpha_j < 0)
698             {
699                 alpha_j = 0;
700                 alpha_i = sum;
701             }
702
703             if( sum > C_j && alpha_j > C_j )
704             {
705                 alpha_j = C_j;
706                 alpha_i = sum - C_j;
707             }
708             else if( sum <= C_j && alpha_i < 0 )
709             {
710                 alpha_i = 0;
711                 alpha_j = sum;
712             }
713         }
714
715         // update alpha
716         alpha[i] = alpha_i;
717         alpha[j] = alpha_j;
718         update_alpha_status(i);
719         update_alpha_status(j);
720
721         // update G
722         delta_alpha_i = alpha_i - old_alpha_i;
723         delta_alpha_j = alpha_j - old_alpha_j;
724
725         for( k = 0; k < alpha_count; k++ )
726             G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
727     }
728
729     // calculate rho
730     (this->*calc_rho_func)( si.rho, si.r );
731
732     // calculate objective value
733     for( i = 0, si.obj = 0; i < alpha_count; i++ )
734         si.obj += alpha[i] * (G[i] + b[i]);
735
736     si.obj *= 0.5;
737
738     si.upper_bound_p = C[1];
739     si.upper_bound_n = C[0];
740
741     return true;
742 }
743
744
745 // return 1 if already optimal, return 0 otherwise
746 bool
747 CvSVMSolver::select_working_set( int& out_i, int& out_j )
748 {
749     // return i,j which maximize -grad(f)^T d , under constraint
750     // if alpha_i == C, d != +1
751     // if alpha_i == 0, d != -1
752     double Gmax1 = -DBL_MAX;        // max { -grad(f)_i * d | y_i*d = +1 }
753     int Gmax1_idx = -1;
754
755     double Gmax2 = -DBL_MAX;        // max { -grad(f)_i * d | y_i*d = -1 }
756     int Gmax2_idx = -1;
757
758     int i;
759
760     for( i = 0; i < alpha_count; i++ )
761     {
762         double t;
763
764         if( y[i] > 0 )    // y = +1
765         {
766             if( !is_upper_bound(i) && (t = -G[i]) > Gmax1 )  // d = +1
767             {
768                 Gmax1 = t;
769                 Gmax1_idx = i;
770             }
771             if( !is_lower_bound(i) && (t = G[i]) > Gmax2 )  // d = -1
772             {
773                 Gmax2 = t;
774                 Gmax2_idx = i;
775             }
776         }
777         else        // y = -1
778         {
779             if( !is_upper_bound(i) && (t = -G[i]) > Gmax2 )  // d = +1
780             {
781                 Gmax2 = t;
782                 Gmax2_idx = i;
783             }
784             if( !is_lower_bound(i) && (t = G[i]) > Gmax1 )  // d = -1
785             {
786                 Gmax1 = t;
787                 Gmax1_idx = i;
788             }
789         }
790     }
791
792     out_i = Gmax1_idx;
793     out_j = Gmax2_idx;
794
795     return Gmax1 + Gmax2 < eps;
796 }
797
798
799 void
800 CvSVMSolver::calc_rho( double& rho, double& r )
801 {
802     int i, nr_free = 0;
803     double ub = DBL_MAX, lb = -DBL_MAX, sum_free = 0;
804
805     for( i = 0; i < alpha_count; i++ )
806     {
807         double yG = y[i]*G[i];
808
809         if( is_lower_bound(i) )
810         {
811             if( y[i] > 0 )
812                 ub = MIN(ub,yG);
813             else
814                 lb = MAX(lb,yG);
815         }
816         else if( is_upper_bound(i) )
817         {
818             if( y[i] < 0)
819                 ub = MIN(ub,yG);
820             else
821                 lb = MAX(lb,yG);
822         }
823         else
824         {
825             ++nr_free;
826             sum_free += yG;
827         }
828     }
829
830     rho = nr_free > 0 ? sum_free/nr_free : (ub + lb)*0.5;
831     r = 0;
832 }
833
834
835 bool
836 CvSVMSolver::select_working_set_nu_svm( int& out_i, int& out_j )
837 {
838     // return i,j which maximize -grad(f)^T d , under constraint
839     // if alpha_i == C, d != +1
840     // if alpha_i == 0, d != -1
841     double Gmax1 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = +1, d = +1 }
842     int Gmax1_idx = -1;
843
844     double Gmax2 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = +1, d = -1 }
845     int Gmax2_idx = -1;
846
847     double Gmax3 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = -1, d = +1 }
848     int Gmax3_idx = -1;
849
850     double Gmax4 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = -1, d = -1 }
851     int Gmax4_idx = -1;
852
853     int i;
854
855     for( i = 0; i < alpha_count; i++ )
856     {
857         double t;
858
859         if( y[i] > 0 )    // y == +1
860         {
861             if( !is_upper_bound(i) && (t = -G[i]) > Gmax1 )  // d = +1
862             {
863                 Gmax1 = t;
864                 Gmax1_idx = i;
865             }
866             if( !is_lower_bound(i) && (t = G[i]) > Gmax2 )  // d = -1
867             {
868                 Gmax2 = t;
869                 Gmax2_idx = i;
870             }
871         }
872         else        // y == -1
873         {
874             if( !is_upper_bound(i) && (t = -G[i]) > Gmax3 )  // d = +1
875             {
876                 Gmax3 = t;
877                 Gmax3_idx = i;
878             }
879             if( !is_lower_bound(i) && (t = G[i]) > Gmax4 )  // d = -1
880             {
881                 Gmax4 = t;
882                 Gmax4_idx = i;
883             }
884         }
885     }
886
887     if( MAX(Gmax1 + Gmax2, Gmax3 + Gmax4) < eps )
888         return 1;
889
890     if( Gmax1 + Gmax2 > Gmax3 + Gmax4 )
891     {
892         out_i = Gmax1_idx;
893         out_j = Gmax2_idx;
894     }
895     else
896     {
897         out_i = Gmax3_idx;
898         out_j = Gmax4_idx;
899     }
900     return 0;
901 }
902
903
904 void
905 CvSVMSolver::calc_rho_nu_svm( double& rho, double& r )
906 {
907     int nr_free1 = 0, nr_free2 = 0;
908     double ub1 = DBL_MAX, ub2 = DBL_MAX;
909     double lb1 = -DBL_MAX, lb2 = -DBL_MAX;
910     double sum_free1 = 0, sum_free2 = 0;
911     double r1, r2;
912
913     int i;
914
915     for( i = 0; i < alpha_count; i++ )
916     {
917         double G_i = G[i];
918         if( y[i] > 0 )
919         {
920             if( is_lower_bound(i) )
921                 ub1 = MIN( ub1, G_i );
922             else if( is_upper_bound(i) )
923                 lb1 = MAX( lb1, G_i );
924             else
925             {
926                 ++nr_free1;
927                 sum_free1 += G_i;
928             }
929         }
930         else
931         {
932             if( is_lower_bound(i) )
933                 ub2 = MIN( ub2, G_i );
934             else if( is_upper_bound(i) )
935                 lb2 = MAX( lb2, G_i );
936             else
937             {
938                 ++nr_free2;
939                 sum_free2 += G_i;
940             }
941         }
942     }
943
944     r1 = nr_free1 > 0 ? sum_free1/nr_free1 : (ub1 + lb1)*0.5;
945     r2 = nr_free2 > 0 ? sum_free2/nr_free2 : (ub2 + lb2)*0.5;
946
947     rho = (r1 - r2)*0.5;
948     r = (r1 + r2)*0.5;
949 }
950
951
952 /*
953 ///////////////////////// construct and solve various formulations ///////////////////////
954 */
955
956 bool CvSVMSolver::solve_c_svc( int _sample_count, int _var_count, const float** _samples, schar* _y,
957                                double _Cp, double _Cn, CvMemStorage* _storage,
958                                CvSVMKernel* _kernel, double* _alpha, CvSVMSolutionInfo& _si )
959 {
960     int i;
961
962     if( !create( _sample_count, _var_count, _samples, _y, _sample_count,
963                  _alpha, _Cp, _Cn, _storage, _kernel, &CvSVMSolver::get_row_svc,
964                  &CvSVMSolver::select_working_set, &CvSVMSolver::calc_rho ))
965         return false;
966
967     for( i = 0; i < sample_count; i++ )
968     {
969         alpha[i] = 0;
970         b[i] = -1;
971     }
972
973     if( !solve_generic( _si ))
974         return false;
975
976     for( i = 0; i < sample_count; i++ )
977         alpha[i] *= y[i];
978
979     return true;
980 }
981
982
983 bool CvSVMSolver::solve_nu_svc( int _sample_count, int _var_count, const float** _samples, schar* _y,
984                                 CvMemStorage* _storage, CvSVMKernel* _kernel,
985                                 double* _alpha, CvSVMSolutionInfo& _si )
986 {
987     int i;
988     double sum_pos, sum_neg, inv_r;
989
990     if( !create( _sample_count, _var_count, _samples, _y, _sample_count,
991                  _alpha, 1., 1., _storage, _kernel, &CvSVMSolver::get_row_svc,
992                  &CvSVMSolver::select_working_set_nu_svm, &CvSVMSolver::calc_rho_nu_svm ))
993         return false;
994
995     sum_pos = kernel->params->nu * sample_count * 0.5;
996     sum_neg = kernel->params->nu * sample_count * 0.5;
997
998     for( i = 0; i < sample_count; i++ )
999     {
1000         if( y[i] > 0 )
1001         {
1002             alpha[i] = MIN(1.0, sum_pos);
1003             sum_pos -= alpha[i];
1004         }
1005         else
1006         {
1007             alpha[i] = MIN(1.0, sum_neg);
1008             sum_neg -= alpha[i];
1009         }
1010         b[i] = 0;
1011     }
1012
1013     if( !solve_generic( _si ))
1014         return false;
1015
1016     inv_r = 1./_si.r;
1017
1018     for( i = 0; i < sample_count; i++ )
1019         alpha[i] *= y[i]*inv_r;
1020
1021     _si.rho *= inv_r;
1022     _si.obj *= (inv_r*inv_r);
1023     _si.upper_bound_p = inv_r;
1024     _si.upper_bound_n = inv_r;
1025
1026     return true;
1027 }
1028
1029
1030 bool CvSVMSolver::solve_one_class( int _sample_count, int _var_count, const float** _samples,
1031                                    CvMemStorage* _storage, CvSVMKernel* _kernel,
1032                                    double* _alpha, CvSVMSolutionInfo& _si )
1033 {
1034     int i, n;
1035     double nu = _kernel->params->nu;
1036
1037     if( !create( _sample_count, _var_count, _samples, 0, _sample_count,
1038                  _alpha, 1., 1., _storage, _kernel, &CvSVMSolver::get_row_one_class,
1039                  &CvSVMSolver::select_working_set, &CvSVMSolver::calc_rho ))
1040         return false;
1041
1042     y = (schar*)cvMemStorageAlloc( storage, sample_count*sizeof(y[0]) );
1043     n = cvRound( nu*sample_count );
1044
1045     for( i = 0; i < sample_count; i++ )
1046     {
1047         y[i] = 1;
1048         b[i] = 0;
1049         alpha[i] = i < n ? 1 : 0;
1050     }
1051
1052     if( n < sample_count )
1053         alpha[n] = nu * sample_count - n;
1054     else
1055         alpha[n-1] = nu * sample_count - (n-1);
1056
1057     return solve_generic(_si);
1058 }
1059
1060
1061 bool CvSVMSolver::solve_eps_svr( int _sample_count, int _var_count, const float** _samples,
1062                                  const float* _y, CvMemStorage* _storage,
1063                                  CvSVMKernel* _kernel, double* _alpha, CvSVMSolutionInfo& _si )
1064 {
1065     int i;
1066     double p = _kernel->params->p, C = _kernel->params->C;
1067
1068     if( !create( _sample_count, _var_count, _samples, 0,
1069                  _sample_count*2, 0, C, C, _storage, _kernel, &CvSVMSolver::get_row_svr,
1070                  &CvSVMSolver::select_working_set, &CvSVMSolver::calc_rho ))
1071         return false;
1072
1073     y = (schar*)cvMemStorageAlloc( storage, sample_count*2*sizeof(y[0]) );
1074     alpha = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(alpha[0]) );
1075
1076     for( i = 0; i < sample_count; i++ )
1077     {
1078         alpha[i] = 0;
1079         b[i] = p - _y[i];
1080         y[i] = 1;
1081
1082         alpha[i+sample_count] = 0;
1083         b[i+sample_count] = p + _y[i];
1084         y[i+sample_count] = -1;
1085     }
1086
1087     if( !solve_generic( _si ))
1088         return false;
1089
1090     for( i = 0; i < sample_count; i++ )
1091         _alpha[i] = alpha[i] - alpha[i+sample_count];
1092
1093     return true;
1094 }
1095
1096
1097 bool CvSVMSolver::solve_nu_svr( int _sample_count, int _var_count, const float** _samples,
1098                                 const float* _y, CvMemStorage* _storage,
1099                                 CvSVMKernel* _kernel, double* _alpha, CvSVMSolutionInfo& _si )
1100 {
1101     int i;
1102     double C = _kernel->params->C, sum;
1103
1104     if( !create( _sample_count, _var_count, _samples, 0,
1105                  _sample_count*2, 0, 1., 1., _storage, _kernel, &CvSVMSolver::get_row_svr,
1106                  &CvSVMSolver::select_working_set_nu_svm, &CvSVMSolver::calc_rho_nu_svm ))
1107         return false;
1108
1109     y = (schar*)cvMemStorageAlloc( storage, sample_count*2*sizeof(y[0]) );
1110     alpha = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(alpha[0]) );
1111     sum = C * _kernel->params->nu * sample_count * 0.5;
1112
1113     for( i = 0; i < sample_count; i++ )
1114     {
1115         alpha[i] = alpha[i + sample_count] = MIN(sum, C);
1116         sum -= alpha[i];
1117
1118         b[i] = -_y[i];
1119         y[i] = 1;
1120
1121         b[i + sample_count] = _y[i];
1122         y[i + sample_count] = -1;
1123     }
1124
1125     if( !solve_generic( _si ))
1126         return false;
1127
1128     for( i = 0; i < sample_count; i++ )
1129         _alpha[i] = alpha[i] - alpha[i+sample_count];
1130
1131     return true;
1132 }
1133
1134
1135 //////////////////////////////////////////////////////////////////////////////////////////
1136
1137 CvSVM::CvSVM()
1138 {
1139     decision_func = 0;
1140     class_labels = 0;
1141     class_weights = 0;
1142     storage = 0;
1143     var_idx = 0;
1144     kernel = 0;
1145     solver = 0;
1146     default_model_name = "my_svm";
1147
1148     clear();
1149 }
1150
1151
1152 CvSVM::~CvSVM()
1153 {
1154     clear();
1155 }
1156
1157
1158 void CvSVM::clear()
1159 {
1160     cvFree( &decision_func );
1161     cvReleaseMat( &class_labels );
1162     cvReleaseMat( &class_weights );
1163     cvReleaseMemStorage( &storage );
1164     cvReleaseMat( &var_idx );
1165     delete kernel;
1166     delete solver;
1167     kernel = 0;
1168     solver = 0;
1169     var_all = 0;
1170     sv = 0;
1171     sv_total = 0;
1172 }
1173
1174
1175 CvSVM::CvSVM( const CvMat* _train_data, const CvMat* _responses,
1176     const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params )
1177 {
1178     decision_func = 0;
1179     class_labels = 0;
1180     class_weights = 0;
1181     storage = 0;
1182     var_idx = 0;
1183     kernel = 0;
1184     solver = 0;
1185     default_model_name = "my_svm";
1186
1187     train( _train_data, _responses, _var_idx, _sample_idx, _params );
1188 }
1189
1190
1191 int CvSVM::get_support_vector_count() const
1192 {
1193     return sv_total;
1194 }
1195
1196
1197 const float* CvSVM::get_support_vector(int i) const
1198 {
1199     return sv && (unsigned)i < (unsigned)sv_total ? sv[i] : 0;
1200 }
1201
1202
1203 bool CvSVM::set_params( const CvSVMParams& _params )
1204 {
1205     bool ok = false;
1206
1207     CV_FUNCNAME( "CvSVM::set_params" );
1208
1209     __BEGIN__;
1210
1211     int kernel_type, svm_type;
1212
1213     params = _params;
1214
1215     kernel_type = params.kernel_type;
1216     svm_type = params.svm_type;
1217
1218     if( kernel_type != LINEAR && kernel_type != POLY &&
1219         kernel_type != SIGMOID && kernel_type != RBF )
1220         CV_ERROR( CV_StsBadArg, "Unknown/unsupported kernel type" );
1221
1222     if( kernel_type == LINEAR )
1223         params.gamma = 1;
1224     else if( params.gamma <= 0 )
1225         CV_ERROR( CV_StsOutOfRange, "gamma parameter of the kernel must be positive" );
1226
1227     if( kernel_type != SIGMOID && kernel_type != POLY )
1228         params.coef0 = 0;
1229     else if( params.coef0 < 0 )
1230         CV_ERROR( CV_StsOutOfRange, "The kernel parameter <coef0> must be positive or zero" );
1231
1232     if( kernel_type != POLY )
1233         params.degree = 0;
1234     else if( params.degree <= 0 )
1235         CV_ERROR( CV_StsOutOfRange, "The kernel parameter <degree> must be positive" );
1236
1237     if( svm_type != C_SVC && svm_type != NU_SVC &&
1238         svm_type != ONE_CLASS && svm_type != EPS_SVR &&
1239         svm_type != NU_SVR )
1240         CV_ERROR( CV_StsBadArg, "Unknown/unsupported SVM type" );
1241
1242     if( svm_type == ONE_CLASS || svm_type == NU_SVC )
1243         params.C = 0;
1244     else if( params.C <= 0 )
1245         CV_ERROR( CV_StsOutOfRange, "The parameter C must be positive" );
1246
1247     if( svm_type == C_SVC || svm_type == EPS_SVR )
1248         params.nu = 0;
1249     else if( params.nu <= 0 || params.nu >= 1 )
1250         CV_ERROR( CV_StsOutOfRange, "The parameter nu must be between 0 and 1" );
1251
1252     if( svm_type != EPS_SVR )
1253         params.p = 0;
1254     else if( params.p <= 0 )
1255         CV_ERROR( CV_StsOutOfRange, "The parameter p must be positive" );
1256
1257     if( svm_type != C_SVC )
1258         params.class_weights = 0;
1259
1260     params.term_crit = cvCheckTermCriteria( params.term_crit, DBL_EPSILON, INT_MAX );
1261     params.term_crit.epsilon = MAX( params.term_crit.epsilon, DBL_EPSILON );
1262     ok = true;
1263
1264     __END__;
1265
1266     return ok;
1267 }
1268
1269
1270
1271 void CvSVM::create_kernel()
1272 {
1273     kernel = new CvSVMKernel(&params,0);
1274 }
1275
1276
1277 void CvSVM::create_solver( )
1278 {
1279     solver = new CvSVMSolver;
1280 }
1281
1282
1283 // switching function
1284 bool CvSVM::train1( int sample_count, int var_count, const float** samples,
1285                     const void* _responses, double Cp, double Cn,
1286                     CvMemStorage* _storage, double* alpha, double& rho )
1287 {
1288     bool ok = false;
1289
1290     //CV_FUNCNAME( "CvSVM::train1" );
1291
1292     __BEGIN__;
1293
1294     CvSVMSolutionInfo si;
1295     int svm_type = params.svm_type;
1296
1297     si.rho = 0;
1298
1299     ok = svm_type == C_SVC ? solver->solve_c_svc( sample_count, var_count, samples, (schar*)_responses,
1300                                                   Cp, Cn, _storage, kernel, alpha, si ) :
1301          svm_type == NU_SVC ? solver->solve_nu_svc( sample_count, var_count, samples, (schar*)_responses,
1302                                                     _storage, kernel, alpha, si ) :
1303          svm_type == ONE_CLASS ? solver->solve_one_class( sample_count, var_count, samples,
1304                                                           _storage, kernel, alpha, si ) :
1305          svm_type == EPS_SVR ? solver->solve_eps_svr( sample_count, var_count, samples, (float*)_responses,
1306                                                       _storage, kernel, alpha, si ) :
1307          svm_type == NU_SVR ? solver->solve_nu_svr( sample_count, var_count, samples, (float*)_responses,
1308                                                     _storage, kernel, alpha, si ) : false;
1309
1310     rho = si.rho;
1311
1312     __END__;
1313
1314     return ok;
1315 }
1316
1317
1318 bool CvSVM::do_train( int svm_type, int sample_count, int var_count, const float** samples,
1319                     const CvMat* responses, CvMemStorage* temp_storage, double* alpha )
1320 {
1321     bool ok = false;
1322
1323     CV_FUNCNAME( "CvSVM::do_train" );
1324
1325     __BEGIN__;
1326
1327     CvSVMDecisionFunc* df = 0;
1328     const int sample_size = var_count*sizeof(samples[0][0]);
1329     int i, j, k;
1330
1331     if( svm_type == ONE_CLASS || svm_type == EPS_SVR || svm_type == NU_SVR )
1332     {
1333         int sv_count = 0;
1334
1335         CV_CALL( decision_func = df =
1336             (CvSVMDecisionFunc*)cvAlloc( sizeof(df[0]) ));
1337
1338         df->rho = 0;
1339         if( !train1( sample_count, var_count, samples, svm_type == ONE_CLASS ? 0 :
1340             responses->data.i, 0, 0, temp_storage, alpha, df->rho ))
1341             EXIT;
1342
1343         for( i = 0; i < sample_count; i++ )
1344             sv_count += fabs(alpha[i]) > 0;
1345
1346         sv_total = df->sv_count = sv_count;
1347         CV_CALL( df->alpha = (double*)cvMemStorageAlloc( storage, sv_count*sizeof(df->alpha[0])) );
1348         CV_CALL( sv = (float**)cvMemStorageAlloc( storage, sv_count*sizeof(sv[0])));
1349
1350         for( i = k = 0; i < sample_count; i++ )
1351         {
1352             if( fabs(alpha[i]) > 0 )
1353             {
1354                 CV_CALL( sv[k] = (float*)cvMemStorageAlloc( storage, sample_size ));
1355                 memcpy( sv[k], samples[i], sample_size );
1356                 df->alpha[k++] = alpha[i];
1357             }
1358         }
1359     }
1360     else
1361     {
1362         int class_count = class_labels->cols;
1363         int* sv_tab = 0;
1364         const float** temp_samples = 0;
1365         int* class_ranges = 0;
1366         schar* temp_y = 0;
1367         assert( svm_type == CvSVM::C_SVC || svm_type == CvSVM::NU_SVC );
1368
1369         if( svm_type == CvSVM::C_SVC && params.class_weights )
1370         {
1371             const CvMat* cw = params.class_weights;
1372
1373             if( !CV_IS_MAT(cw) || (cw->cols != 1 && cw->rows != 1) ||
1374                 cw->rows + cw->cols - 1 != class_count ||
1375                 (CV_MAT_TYPE(cw->type) != CV_32FC1 && CV_MAT_TYPE(cw->type) != CV_64FC1) )
1376                 CV_ERROR( CV_StsBadArg, "params.class_weights must be 1d floating-point vector "
1377                     "containing as many elements as the number of classes" );
1378
1379             CV_CALL( class_weights = cvCreateMat( cw->rows, cw->cols, CV_64F ));
1380             CV_CALL( cvConvert( cw, class_weights ));
1381             CV_CALL( cvScale( class_weights, class_weights, params.C ));
1382         }
1383
1384         CV_CALL( decision_func = df = (CvSVMDecisionFunc*)cvAlloc(
1385             (class_count*(class_count-1)/2)*sizeof(df[0])));
1386
1387         CV_CALL( sv_tab = (int*)cvMemStorageAlloc( temp_storage, sample_count*sizeof(sv_tab[0]) ));
1388         memset( sv_tab, 0, sample_count*sizeof(sv_tab[0]) );
1389         CV_CALL( class_ranges = (int*)cvMemStorageAlloc( temp_storage,
1390                             (class_count + 1)*sizeof(class_ranges[0])));
1391         CV_CALL( temp_samples = (const float**)cvMemStorageAlloc( temp_storage,
1392                             sample_count*sizeof(temp_samples[0])));
1393         CV_CALL( temp_y = (schar*)cvMemStorageAlloc( temp_storage, sample_count));
1394
1395         class_ranges[class_count] = 0;
1396         cvSortSamplesByClasses( samples, responses, class_ranges, 0 );
1397         //check that while cross-validation there were the samples from all the classes
1398         if( class_ranges[class_count] <= 0 )
1399             CV_ERROR( CV_StsBadArg, "While cross-validation one or more of the classes have "
1400             "been fell out of the sample. Try to enlarge <CvSVMParams::k_fold>" );
1401
1402         if( svm_type == NU_SVC )
1403         {
1404             // check if nu is feasible
1405             for(i = 0; i < class_count; i++ )
1406             {
1407                 int ci = class_ranges[i+1] - class_ranges[i];
1408                 for( j = i+1; j< class_count; j++ )
1409                 {
1410                     int cj = class_ranges[j+1] - class_ranges[j];
1411                     if( params.nu*(ci + cj)*0.5 > MIN( ci, cj ) )
1412                     {
1413                         // !!!TODO!!! add some diagnostic
1414                         EXIT; // exit immediately; will release the model and return NULL pointer
1415                     }
1416                 }
1417             }
1418         }
1419
1420         // train n*(n-1)/2 classifiers
1421         for( i = 0; i < class_count; i++ )
1422         {
1423             for( j = i+1; j < class_count; j++, df++ )
1424             {
1425                 int si = class_ranges[i], ci = class_ranges[i+1] - si;
1426                 int sj = class_ranges[j], cj = class_ranges[j+1] - sj;
1427                 double Cp = params.C, Cn = Cp;
1428                 int k1 = 0, sv_count = 0;
1429
1430                 for( k = 0; k < ci; k++ )
1431                 {
1432                     temp_samples[k] = samples[si + k];
1433                     temp_y[k] = 1;
1434                 }
1435
1436                 for( k = 0; k < cj; k++ )
1437                 {
1438                     temp_samples[ci + k] = samples[sj + k];
1439                     temp_y[ci + k] = -1;
1440                 }
1441
1442                 if( class_weights )
1443                 {
1444                     Cp = class_weights->data.db[i];
1445                     Cn = class_weights->data.db[j];
1446                 }
1447
1448                 if( !train1( ci + cj, var_count, temp_samples, temp_y,
1449                              Cp, Cn, temp_storage, alpha, df->rho ))
1450                     EXIT;
1451
1452                 for( k = 0; k < ci + cj; k++ )
1453                     sv_count += fabs(alpha[k]) > 0;
1454
1455                 df->sv_count = sv_count;
1456
1457                 CV_CALL( df->alpha = (double*)cvMemStorageAlloc( temp_storage,
1458                                                 sv_count*sizeof(df->alpha[0])));
1459                 CV_CALL( df->sv_index = (int*)cvMemStorageAlloc( temp_storage,
1460                                                 sv_count*sizeof(df->sv_index[0])));
1461
1462                 for( k = 0; k < ci; k++ )
1463                 {
1464                     if( fabs(alpha[k]) > 0 )
1465                     {
1466                         sv_tab[si + k] = 1;
1467                         df->sv_index[k1] = si + k;
1468                         df->alpha[k1++] = alpha[k];
1469                     }
1470                 }
1471
1472                 for( k = 0; k < cj; k++ )
1473                 {
1474                     if( fabs(alpha[ci + k]) > 0 )
1475                     {
1476                         sv_tab[sj + k] = 1;
1477                         df->sv_index[k1] = sj + k;
1478                         df->alpha[k1++] = alpha[ci + k];
1479                     }
1480                 }
1481             }
1482         }
1483
1484         // allocate support vectors and initialize sv_tab
1485         for( i = 0, k = 0; i < sample_count; i++ )
1486         {
1487             if( sv_tab[i] )
1488                 sv_tab[i] = ++k;
1489         }
1490
1491         sv_total = k;
1492         CV_CALL( sv = (float**)cvMemStorageAlloc( storage, sv_total*sizeof(sv[0])));
1493
1494         for( i = 0, k = 0; i < sample_count; i++ )
1495         {
1496             if( sv_tab[i] )
1497             {
1498                 CV_CALL( sv[k] = (float*)cvMemStorageAlloc( storage, sample_size ));
1499                 memcpy( sv[k], samples[i], sample_size );
1500                 k++;
1501             }
1502         }
1503
1504         df = (CvSVMDecisionFunc*)decision_func;
1505
1506         // set sv pointers
1507         for( i = 0; i < class_count; i++ )
1508         {
1509             for( j = i+1; j < class_count; j++, df++ )
1510             {
1511                 for( k = 0; k < df->sv_count; k++ )
1512                 {
1513                     df->sv_index[k] = sv_tab[df->sv_index[k]]-1;
1514                     assert( (unsigned)df->sv_index[k] < (unsigned)sv_total );
1515                 }
1516             }
1517         }
1518     }
1519
1520     ok = true;
1521
1522     __END__;
1523
1524     return ok;
1525 }
1526
1527 bool CvSVM::train( const CvMat* _train_data, const CvMat* _responses,
1528     const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params )
1529 {
1530     bool ok = false;
1531     CvMat* responses = 0;
1532     CvMemStorage* temp_storage = 0;
1533     const float** samples = 0;
1534
1535     CV_FUNCNAME( "CvSVM::train" );
1536
1537     __BEGIN__;
1538
1539     int svm_type, sample_count, var_count, sample_size;
1540     int block_size = 1 << 16;
1541     double* alpha;
1542
1543     clear();
1544     CV_CALL( set_params( _params ));
1545
1546     svm_type = _params.svm_type;
1547
1548     /* Prepare training data and related parameters */
1549     CV_CALL( cvPrepareTrainData( "CvSVM::train", _train_data, CV_ROW_SAMPLE,
1550                                  svm_type != CvSVM::ONE_CLASS ? _responses : 0,
1551                                  svm_type == CvSVM::C_SVC ||
1552                                  svm_type == CvSVM::NU_SVC ? CV_VAR_CATEGORICAL :
1553                                  CV_VAR_ORDERED, _var_idx, _sample_idx,
1554                                  false, &samples, &sample_count, &var_count, &var_all,
1555                                  &responses, &class_labels, &var_idx ));
1556
1557
1558     sample_size = var_count*sizeof(samples[0][0]);
1559
1560     // make the storage block size large enough to fit all
1561     // the temporary vectors and output support vectors.
1562     block_size = MAX( block_size, sample_count*(int)sizeof(CvSVMKernelRow));
1563     block_size = MAX( block_size, sample_count*2*(int)sizeof(double) + 1024 );
1564     block_size = MAX( block_size, sample_size*2 + 1024 );
1565
1566     CV_CALL( storage = cvCreateMemStorage(block_size));
1567     CV_CALL( temp_storage = cvCreateChildMemStorage(storage));
1568     CV_CALL( alpha = (double*)cvMemStorageAlloc(temp_storage, sample_count*sizeof(double)));
1569
1570     create_kernel();
1571     create_solver();
1572
1573     if( !do_train( svm_type, sample_count, var_count, samples, responses, temp_storage, alpha ))
1574         EXIT;
1575
1576     ok = true; // model has been trained succesfully
1577
1578     __END__;
1579
1580     delete solver;
1581     solver = 0;
1582     cvReleaseMemStorage( &temp_storage );
1583     cvReleaseMat( &responses );
1584     cvFree( &samples );
1585
1586     if( cvGetErrStatus() < 0 || !ok )
1587         clear();
1588
1589     return ok;
1590 }
1591
1592 bool CvSVM::train_auto( const CvMat* _train_data, const CvMat* _responses,
1593     const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params, int k_fold,
1594     CvParamGrid C_grid, CvParamGrid gamma_grid, CvParamGrid p_grid,
1595     CvParamGrid nu_grid, CvParamGrid coef_grid, CvParamGrid degree_grid )
1596 {
1597     bool ok = false;
1598     CvMat* responses = 0;
1599     CvMat* responses_local = 0;
1600     CvMemStorage* temp_storage = 0;
1601     const float** samples = 0;
1602     const float** samples_local = 0;
1603
1604     CV_FUNCNAME( "CvSVM::train_auto" );
1605     __BEGIN__;
1606
1607     int svm_type, sample_count, var_count, sample_size;
1608     int block_size = 1 << 16;
1609     double* alpha;
1610     int i, k;
1611     CvRNG rng = cvRNG(-1);
1612
1613     // all steps are logarithmic and must be > 1
1614     double degree_step = 10, g_step = 10, coef_step = 10, C_step = 10, nu_step = 10, p_step = 10;
1615     double gamma = 0, C = 0, degree = 0, coef = 0, p = 0, nu = 0;
1616     double best_degree = 0, best_gamma = 0, best_coef = 0, best_C = 0, best_nu = 0, best_p = 0;
1617     float min_error = FLT_MAX, error;
1618
1619     if( _params.svm_type == CvSVM::ONE_CLASS )
1620     {
1621         if(!train( _train_data, _responses, _var_idx, _sample_idx, _params ))
1622             EXIT;
1623         return true;
1624     }
1625
1626     clear();
1627
1628     if( k_fold < 2 )
1629         CV_ERROR( CV_StsBadArg, "Parameter <k_fold> must be > 1" );
1630
1631     CV_CALL(set_params( _params ));
1632     svm_type = _params.svm_type;
1633
1634     // All the parameters except, possibly, <coef0> are positive.
1635     // <coef0> is nonnegative
1636     if( C_grid.step <= 1 )
1637     {
1638         C_grid.min_val = C_grid.max_val = params.C;
1639         C_grid.step = 10;
1640     }
1641     else
1642         CV_CALL(C_grid.check());
1643
1644     if( gamma_grid.step <= 1 )
1645     {
1646         gamma_grid.min_val = gamma_grid.max_val = params.gamma;
1647         gamma_grid.step = 10;
1648     }
1649     else
1650         CV_CALL(gamma_grid.check());
1651
1652     if( p_grid.step <= 1 )
1653     {
1654         p_grid.min_val = p_grid.max_val = params.p;
1655         p_grid.step = 10;
1656     }
1657     else
1658         CV_CALL(p_grid.check());
1659
1660     if( nu_grid.step <= 1 )
1661     {
1662         nu_grid.min_val = nu_grid.max_val = params.nu;
1663         nu_grid.step = 10;
1664     }
1665     else
1666         CV_CALL(nu_grid.check());
1667
1668     if( coef_grid.step <= 1 )
1669     {
1670         coef_grid.min_val = coef_grid.max_val = params.coef0;
1671         coef_grid.step = 10;
1672     }
1673     else
1674         CV_CALL(coef_grid.check());
1675
1676     if( degree_grid.step <= 1 )
1677     {
1678         degree_grid.min_val = degree_grid.max_val = params.degree;
1679         degree_grid.step = 10;
1680     }
1681     else
1682         CV_CALL(degree_grid.check());
1683
1684     // these parameters are not used:
1685     if( params.kernel_type != CvSVM::POLY )
1686         degree_grid.min_val = degree_grid.max_val = params.degree;
1687     if( params.kernel_type == CvSVM::LINEAR )
1688         gamma_grid.min_val = gamma_grid.max_val = params.gamma;
1689     if( params.kernel_type != CvSVM::POLY && params.kernel_type != CvSVM::SIGMOID )
1690         coef_grid.min_val = coef_grid.max_val = params.coef0;
1691     if( svm_type == CvSVM::NU_SVC || svm_type == CvSVM::ONE_CLASS )
1692         C_grid.min_val = C_grid.max_val = params.C;
1693     if( svm_type == CvSVM::C_SVC || svm_type == CvSVM::EPS_SVR )
1694         nu_grid.min_val = nu_grid.max_val = params.nu;
1695     if( svm_type != CvSVM::EPS_SVR )
1696         p_grid.min_val = p_grid.max_val = params.p;
1697
1698     CV_ASSERT( g_step > 1 && degree_step > 1 && coef_step > 1);
1699     CV_ASSERT( p_step > 1 && C_step > 1 && nu_step > 1 );
1700
1701     /* Prepare training data and related parameters */
1702     CV_CALL(cvPrepareTrainData( "CvSVM::train_auto", _train_data, CV_ROW_SAMPLE,
1703                                  svm_type != CvSVM::ONE_CLASS ? _responses : 0,
1704                                  svm_type == CvSVM::C_SVC ||
1705                                  svm_type == CvSVM::NU_SVC ? CV_VAR_CATEGORICAL :
1706                                  CV_VAR_ORDERED, _var_idx, _sample_idx,
1707                                  false, &samples, &sample_count, &var_count, &var_all,
1708                                  &responses, &class_labels, &var_idx ));
1709
1710     sample_size = var_count*sizeof(samples[0][0]);
1711
1712     // make the storage block size large enough to fit all
1713     // the temporary vectors and output support vectors.
1714     block_size = MAX( block_size, sample_count*(int)sizeof(CvSVMKernelRow));
1715     block_size = MAX( block_size, sample_count*2*(int)sizeof(double) + 1024 );
1716     block_size = MAX( block_size, sample_size*2 + 1024 );
1717
1718     CV_CALL(storage = cvCreateMemStorage(block_size));
1719     CV_CALL(temp_storage = cvCreateChildMemStorage(storage));
1720     CV_CALL(alpha = (double*)cvMemStorageAlloc(temp_storage, sample_count*sizeof(double)));
1721
1722     create_kernel();
1723     create_solver();
1724
1725     {
1726     const int testset_size = sample_count/k_fold;
1727     const int trainset_size = sample_count - testset_size;
1728     const int last_testset_size = sample_count - testset_size*(k_fold-1);
1729     const int last_trainset_size = sample_count - last_testset_size;
1730     const bool is_regression = (svm_type == EPS_SVR) || (svm_type == NU_SVR);
1731
1732     size_t resp_elem_size = CV_ELEM_SIZE(responses->type);
1733     size_t size = 2*last_trainset_size*sizeof(samples[0]);
1734
1735     samples_local = (const float**) cvAlloc( size );
1736     memset( samples_local, 0, size );
1737
1738     responses_local = cvCreateMat( 1, trainset_size, CV_MAT_TYPE(responses->type) );
1739     cvZero( responses_local );
1740
1741     // randomly permute samples and responses
1742     for( i = 0; i < sample_count; i++ )
1743     {
1744         int i1 = cvRandInt( &rng ) % sample_count;
1745         int i2 = cvRandInt( &rng ) % sample_count;
1746         const float* temp;
1747         float t;
1748         int y;
1749
1750         CV_SWAP( samples[i1], samples[i2], temp );
1751         if( is_regression )
1752             CV_SWAP( responses->data.fl[i1], responses->data.fl[i2], t );
1753         else
1754             CV_SWAP( responses->data.i[i1], responses->data.i[i2], y );
1755     }
1756
1757     int* cls_lbls = class_labels ? class_labels->data.i : 0;
1758     C = C_grid.min_val;
1759     do
1760     {
1761       params.C = C;
1762       gamma = gamma_grid.min_val;
1763       do
1764       {
1765         params.gamma = gamma;
1766         p = p_grid.min_val;
1767         do
1768         {
1769           params.p = p;
1770           nu = nu_grid.min_val;
1771           do
1772           {
1773             params.nu = nu;
1774             coef = coef_grid.min_val;
1775             do
1776             {
1777               params.coef0 = coef;
1778               degree = degree_grid.min_val;
1779               do
1780               {
1781                 params.degree = degree;
1782
1783                 float** test_samples_ptr = (float**)samples;
1784                 uchar* true_resp = responses->data.ptr;
1785                 int test_size = testset_size;
1786                 int train_size = trainset_size;
1787
1788                 error = 0;
1789                 for( k = 0; k < k_fold; k++ )
1790                 {
1791                     memcpy( samples_local, samples, sizeof(samples[0])*test_size*k );
1792                     memcpy( samples_local + test_size*k, test_samples_ptr + test_size,
1793                         sizeof(samples[0])*(sample_count - testset_size*(k+1)) );
1794
1795                     memcpy( responses_local->data.ptr, responses->data.ptr, resp_elem_size*test_size*k );
1796                     memcpy( responses_local->data.ptr + resp_elem_size*test_size*k,
1797                         true_resp + resp_elem_size*test_size,
1798                         sizeof(samples[0])*(sample_count - testset_size*(k+1)) );
1799
1800                     if( k == k_fold - 1 )
1801                     {
1802                         test_size = last_testset_size;
1803                         train_size = last_trainset_size;
1804                         responses_local->cols = last_trainset_size;
1805                     }
1806
1807                     // Train SVM on <train_size> samples
1808                     if( !do_train( svm_type, train_size, var_count,
1809                         (const float**)samples_local, responses_local, temp_storage, alpha ) )
1810                         EXIT;
1811
1812                     // Compute test set error on <test_size> samples
1813                     CvMat s = cvMat( 1, var_count, CV_32FC1 );
1814                     for( i = 0; i < test_size; i++, true_resp += resp_elem_size, test_samples_ptr++ )
1815                     {
1816                         float resp;
1817                         s.data.fl = *test_samples_ptr;
1818                         resp = predict( &s );
1819                         error += is_regression ? powf( resp - *(float*)true_resp, 2 )
1820                             : ((int)resp != cls_lbls[*(int*)true_resp]);
1821                     }
1822                 }
1823                 if( min_error > error )
1824                 {
1825                     min_error   = error;
1826                     best_degree = degree;
1827                     best_gamma  = gamma;
1828                     best_coef   = coef;
1829                     best_C      = C;
1830                     best_nu     = nu;
1831                     best_p      = p;
1832                 }
1833                 degree *= degree_grid.step;
1834               }
1835               while( degree < degree_grid.max_val );
1836               coef *= coef_grid.step;
1837             }
1838             while( coef < coef_grid.max_val );
1839             nu *= nu_grid.step;
1840           }
1841           while( nu < nu_grid.max_val );
1842           p *= p_grid.step;
1843         }
1844         while( p < p_grid.max_val );
1845         gamma *= gamma_grid.step;
1846       }
1847       while( gamma < gamma_grid.max_val );
1848       C *= C_grid.step;
1849     }
1850     while( C < C_grid.max_val );
1851     }
1852
1853     min_error /= (float) sample_count;
1854
1855     params.C      = best_C;
1856     params.nu     = best_nu;
1857     params.p      = best_p;
1858     params.gamma  = best_gamma;
1859     params.degree = best_degree;
1860     params.coef0  = best_coef;
1861
1862     CV_CALL(ok = do_train( svm_type, sample_count, var_count, samples, responses, temp_storage, alpha ));
1863
1864     __END__;
1865
1866     delete solver;
1867     solver = 0;
1868     cvReleaseMemStorage( &temp_storage );
1869     cvReleaseMat( &responses );
1870     cvReleaseMat( &responses_local );
1871     cvFree( &samples );
1872     cvFree( &samples_local );
1873
1874     if( cvGetErrStatus() < 0 || !ok )
1875         clear();
1876
1877     return ok;
1878 }
1879
1880 float CvSVM::predict( const CvMat* sample, bool returnDFVal ) const
1881 {
1882     bool local_alloc = 0;
1883     float result = 0;
1884     float* row_sample = 0;
1885     Qfloat* buffer = 0;
1886
1887     CV_FUNCNAME( "CvSVM::predict" );
1888
1889     __BEGIN__;
1890
1891     int class_count;
1892     int var_count, buf_sz;
1893
1894     if( !kernel )
1895         CV_ERROR( CV_StsBadArg, "The SVM should be trained first" );
1896
1897     class_count = class_labels ? class_labels->cols :
1898                   params.svm_type == ONE_CLASS ? 1 : 0;
1899
1900     CV_CALL( cvPreparePredictData( sample, var_all, var_idx,
1901                                    class_count, 0, &row_sample ));
1902
1903     var_count = get_var_count();
1904
1905     buf_sz = sv_total*sizeof(buffer[0]) + (class_count+1)*sizeof(int);
1906     if( buf_sz <= CV_MAX_LOCAL_SIZE )
1907     {
1908         CV_CALL( buffer = (Qfloat*)cvStackAlloc( buf_sz ));
1909         local_alloc = 1;
1910     }
1911     else
1912         CV_CALL( buffer = (Qfloat*)cvAlloc( buf_sz ));
1913
1914     if( params.svm_type == EPS_SVR ||
1915         params.svm_type == NU_SVR ||
1916         params.svm_type == ONE_CLASS )
1917     {
1918         CvSVMDecisionFunc* df = (CvSVMDecisionFunc*)decision_func;
1919         int i, sv_count = df->sv_count;
1920         double sum = -df->rho;
1921
1922         kernel->calc( sv_count, var_count, (const float**)sv, row_sample, buffer );
1923         for( i = 0; i < sv_count; i++ )
1924             sum += buffer[i]*df->alpha[i];
1925
1926         result = params.svm_type == ONE_CLASS ? (float)(sum > 0) : (float)sum;
1927     }
1928     else if( params.svm_type == C_SVC ||
1929              params.svm_type == NU_SVC )
1930     {
1931         CvSVMDecisionFunc* df = (CvSVMDecisionFunc*)decision_func;
1932         int* vote = (int*)(buffer + sv_total);
1933         int i, j, k;
1934
1935         memset( vote, 0, class_count*sizeof(vote[0]));
1936         kernel->calc( sv_total, var_count, (const float**)sv, row_sample, buffer );
1937         double sum = 0.;
1938
1939         for( i = 0; i < class_count; i++ )
1940         {
1941             for( j = i+1; j < class_count; j++, df++ )
1942             {
1943                 sum = -df->rho;
1944                 int sv_count = df->sv_count;
1945                 for( k = 0; k < sv_count; k++ )
1946                     sum += df->alpha[k]*buffer[df->sv_index[k]];
1947
1948                 vote[sum > 0 ? i : j]++;
1949             }
1950         }
1951
1952         for( i = 1, k = 0; i < class_count; i++ )
1953         {
1954             if( vote[i] > vote[k] )
1955                 k = i;
1956         }
1957         result = returnDFVal && class_count == 2 ? (float)sum : (float)(class_labels->data.i[k]);
1958     }
1959     else
1960         CV_ERROR( CV_StsBadArg, "INTERNAL ERROR: Unknown SVM type, "
1961                                 "the SVM structure is probably corrupted" );
1962
1963     __END__;
1964
1965     if( sample && (!CV_IS_MAT(sample) || sample->data.fl != row_sample) )
1966         cvFree( &row_sample );
1967
1968     if( !local_alloc )
1969         cvFree( &buffer );
1970
1971     return result;
1972 }
1973
1974
1975 bool CvSVM::train( const Mat& _train_data, const Mat& _responses,
1976                   const Mat& _var_idx, const Mat& _sample_idx, CvSVMParams _params )
1977 {
1978     CvMat tdata = _train_data, responses = _responses, vidx = _var_idx, sidx = _sample_idx;
1979     return train(&tdata, &responses, vidx.data.ptr ? &vidx : 0, sidx.data.ptr ? &sidx : 0, _params);
1980 }
1981
1982
1983 bool CvSVM::train_auto( const Mat& _train_data, const Mat& _responses,
1984                        const Mat& _var_idx, const Mat& _sample_idx, CvSVMParams _params, int k_fold,
1985                        CvParamGrid C_grid, CvParamGrid gamma_grid, CvParamGrid p_grid,
1986                        CvParamGrid nu_grid, CvParamGrid coef_grid, CvParamGrid degree_grid )
1987 {
1988     CvMat tdata = _train_data, responses = _responses, vidx = _var_idx, sidx = _sample_idx;
1989     return train_auto(&tdata, &responses, vidx.data.ptr ? &vidx : 0,
1990                       sidx.data.ptr ? &sidx : 0, _params, k_fold, C_grid, gamma_grid, p_grid,
1991                       nu_grid, coef_grid, degree_grid);
1992 }
1993
1994 float CvSVM::predict( const Mat& _sample, bool returnDFVal ) const
1995 {
1996     CvMat sample = _sample; 
1997     return predict(&sample, returnDFVal);
1998 }
1999
2000
2001 void CvSVM::write_params( CvFileStorage* fs ) const
2002 {
2003     //CV_FUNCNAME( "CvSVM::write_params" );
2004
2005     __BEGIN__;
2006
2007     int svm_type = params.svm_type;
2008     int kernel_type = params.kernel_type;
2009
2010     const char* svm_type_str =
2011         svm_type == CvSVM::C_SVC ? "C_SVC" :
2012         svm_type == CvSVM::NU_SVC ? "NU_SVC" :
2013         svm_type == CvSVM::ONE_CLASS ? "ONE_CLASS" :
2014         svm_type == CvSVM::EPS_SVR ? "EPS_SVR" :
2015         svm_type == CvSVM::NU_SVR ? "NU_SVR" : 0;
2016     const char* kernel_type_str =
2017         kernel_type == CvSVM::LINEAR ? "LINEAR" :
2018         kernel_type == CvSVM::POLY ? "POLY" :
2019         kernel_type == CvSVM::RBF ? "RBF" :
2020         kernel_type == CvSVM::SIGMOID ? "SIGMOID" : 0;
2021
2022     if( svm_type_str )
2023         cvWriteString( fs, "svm_type", svm_type_str );
2024     else
2025         cvWriteInt( fs, "svm_type", svm_type );
2026
2027     // save kernel
2028     cvStartWriteStruct( fs, "kernel", CV_NODE_MAP + CV_NODE_FLOW );
2029
2030     if( kernel_type_str )
2031         cvWriteString( fs, "type", kernel_type_str );
2032     else
2033         cvWriteInt( fs, "type", kernel_type );
2034
2035     if( kernel_type == CvSVM::POLY || !kernel_type_str )
2036         cvWriteReal( fs, "degree", params.degree );
2037
2038     if( kernel_type != CvSVM::LINEAR || !kernel_type_str )
2039         cvWriteReal( fs, "gamma", params.gamma );
2040
2041     if( kernel_type == CvSVM::POLY || kernel_type == CvSVM::SIGMOID || !kernel_type_str )
2042         cvWriteReal( fs, "coef0", params.coef0 );
2043
2044     cvEndWriteStruct(fs);
2045
2046     if( svm_type == CvSVM::C_SVC || svm_type == CvSVM::EPS_SVR ||
2047         svm_type == CvSVM::NU_SVR || !svm_type_str )
2048         cvWriteReal( fs, "C", params.C );
2049
2050     if( svm_type == CvSVM::NU_SVC || svm_type == CvSVM::ONE_CLASS ||
2051         svm_type == CvSVM::NU_SVR || !svm_type_str )
2052         cvWriteReal( fs, "nu", params.nu );
2053
2054     if( svm_type == CvSVM::EPS_SVR || !svm_type_str )
2055         cvWriteReal( fs, "p", params.p );
2056
2057     cvStartWriteStruct( fs, "term_criteria", CV_NODE_MAP + CV_NODE_FLOW );
2058     if( params.term_crit.type & CV_TERMCRIT_EPS )
2059         cvWriteReal( fs, "epsilon", params.term_crit.epsilon );
2060     if( params.term_crit.type & CV_TERMCRIT_ITER )
2061         cvWriteInt( fs, "iterations", params.term_crit.max_iter );
2062     cvEndWriteStruct( fs );
2063
2064     __END__;
2065 }
2066
2067
2068 void CvSVM::write( CvFileStorage* fs, const char* name ) const
2069 {
2070     CV_FUNCNAME( "CvSVM::write" );
2071
2072     __BEGIN__;
2073
2074     int i, var_count = get_var_count(), df_count, class_count;
2075     const CvSVMDecisionFunc* df = decision_func;
2076
2077     cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_SVM );
2078
2079     write_params( fs );
2080
2081     cvWriteInt( fs, "var_all", var_all );
2082     cvWriteInt( fs, "var_count", var_count );
2083
2084     class_count = class_labels ? class_labels->cols :
2085                   params.svm_type == CvSVM::ONE_CLASS ? 1 : 0;
2086
2087     if( class_count )
2088     {
2089         cvWriteInt( fs, "class_count", class_count );
2090
2091         if( class_labels )
2092             cvWrite( fs, "class_labels", class_labels );
2093
2094         if( class_weights )
2095             cvWrite( fs, "class_weights", class_weights );
2096     }
2097
2098     if( var_idx )
2099         cvWrite( fs, "var_idx", var_idx );
2100
2101     // write the joint collection of support vectors
2102     cvWriteInt( fs, "sv_total", sv_total );
2103     cvStartWriteStruct( fs, "support_vectors", CV_NODE_SEQ );
2104     for( i = 0; i < sv_total; i++ )
2105     {
2106         cvStartWriteStruct( fs, 0, CV_NODE_SEQ + CV_NODE_FLOW );
2107         cvWriteRawData( fs, sv[i], var_count, "f" );
2108         cvEndWriteStruct( fs );
2109     }
2110
2111     cvEndWriteStruct( fs );
2112
2113     // write decision functions
2114     df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
2115     df = decision_func;
2116
2117     cvStartWriteStruct( fs, "decision_functions", CV_NODE_SEQ );
2118     for( i = 0; i < df_count; i++ )
2119     {
2120         int sv_count = df[i].sv_count;
2121         cvStartWriteStruct( fs, 0, CV_NODE_MAP );
2122         cvWriteInt( fs, "sv_count", sv_count );
2123         cvWriteReal( fs, "rho", df[i].rho );
2124         cvStartWriteStruct( fs, "alpha", CV_NODE_SEQ+CV_NODE_FLOW );
2125         cvWriteRawData( fs, df[i].alpha, df[i].sv_count, "d" );
2126         cvEndWriteStruct( fs );
2127         if( class_count > 1 )
2128         {
2129             cvStartWriteStruct( fs, "index", CV_NODE_SEQ+CV_NODE_FLOW );
2130             cvWriteRawData( fs, df[i].sv_index, df[i].sv_count, "i" );
2131             cvEndWriteStruct( fs );
2132         }
2133         else
2134             CV_ASSERT( sv_count == sv_total );
2135         cvEndWriteStruct( fs );
2136     }
2137     cvEndWriteStruct( fs );
2138     cvEndWriteStruct( fs );
2139
2140     __END__;
2141 }
2142
2143
2144 void CvSVM::read_params( CvFileStorage* fs, CvFileNode* svm_node )
2145 {
2146     CV_FUNCNAME( "CvSVM::read_params" );
2147
2148     __BEGIN__;
2149
2150     int svm_type, kernel_type;
2151     CvSVMParams _params;
2152
2153     CvFileNode* tmp_node = cvGetFileNodeByName( fs, svm_node, "svm_type" );
2154     CvFileNode* kernel_node;
2155     if( !tmp_node )
2156         CV_ERROR( CV_StsBadArg, "svm_type tag is not found" );
2157
2158     if( CV_NODE_TYPE(tmp_node->tag) == CV_NODE_INT )
2159         svm_type = cvReadInt( tmp_node, -1 );
2160     else
2161     {
2162         const char* svm_type_str = cvReadString( tmp_node, "" );
2163         svm_type =
2164             strcmp( svm_type_str, "C_SVC" ) == 0 ? CvSVM::C_SVC :
2165             strcmp( svm_type_str, "NU_SVC" ) == 0 ? CvSVM::NU_SVC :
2166             strcmp( svm_type_str, "ONE_CLASS" ) == 0 ? CvSVM::ONE_CLASS :
2167             strcmp( svm_type_str, "EPS_SVR" ) == 0 ? CvSVM::EPS_SVR :
2168             strcmp( svm_type_str, "NU_SVR" ) == 0 ? CvSVM::NU_SVR : -1;
2169
2170         if( svm_type < 0 )
2171             CV_ERROR( CV_StsParseError, "Missing of invalid SVM type" );
2172     }
2173
2174     kernel_node = cvGetFileNodeByName( fs, svm_node, "kernel" );
2175     if( !kernel_node )
2176         CV_ERROR( CV_StsParseError, "SVM kernel tag is not found" );
2177
2178     tmp_node = cvGetFileNodeByName( fs, kernel_node, "type" );
2179     if( !tmp_node )
2180         CV_ERROR( CV_StsParseError, "SVM kernel type tag is not found" );
2181
2182     if( CV_NODE_TYPE(tmp_node->tag) == CV_NODE_INT )
2183         kernel_type = cvReadInt( tmp_node, -1 );
2184     else
2185     {
2186         const char* kernel_type_str = cvReadString( tmp_node, "" );
2187         kernel_type =
2188             strcmp( kernel_type_str, "LINEAR" ) == 0 ? CvSVM::LINEAR :
2189             strcmp( kernel_type_str, "POLY" ) == 0 ? CvSVM::POLY :
2190             strcmp( kernel_type_str, "RBF" ) == 0 ? CvSVM::RBF :
2191             strcmp( kernel_type_str, "SIGMOID" ) == 0 ? CvSVM::SIGMOID : -1;
2192
2193         if( kernel_type < 0 )
2194             CV_ERROR( CV_StsParseError, "Missing of invalid SVM kernel type" );
2195     }
2196
2197     _params.svm_type = svm_type;
2198     _params.kernel_type = kernel_type;
2199     _params.degree = cvReadRealByName( fs, kernel_node, "degree", 0 );
2200     _params.gamma = cvReadRealByName( fs, kernel_node, "gamma", 0 );
2201     _params.coef0 = cvReadRealByName( fs, kernel_node, "coef0", 0 );
2202
2203     _params.C = cvReadRealByName( fs, svm_node, "C", 0 );
2204     _params.nu = cvReadRealByName( fs, svm_node, "nu", 0 );
2205     _params.p = cvReadRealByName( fs, svm_node, "p", 0 );
2206     _params.class_weights = 0;
2207
2208     tmp_node = cvGetFileNodeByName( fs, svm_node, "term_criteria" );
2209     if( tmp_node )
2210     {
2211         _params.term_crit.epsilon = cvReadRealByName( fs, tmp_node, "epsilon", -1. );
2212         _params.term_crit.max_iter = cvReadIntByName( fs, tmp_node, "iterations", -1 );
2213         _params.term_crit.type = (_params.term_crit.epsilon >= 0 ? CV_TERMCRIT_EPS : 0) +
2214                                (_params.term_crit.max_iter >= 0 ? CV_TERMCRIT_ITER : 0);
2215     }
2216     else
2217         _params.term_crit = cvTermCriteria( CV_TERMCRIT_EPS + CV_TERMCRIT_ITER, 1000, FLT_EPSILON );
2218
2219     set_params( _params );
2220
2221     __END__;
2222 }
2223
2224
2225 void CvSVM::read( CvFileStorage* fs, CvFileNode* svm_node )
2226 {
2227     const double not_found_dbl = DBL_MAX;
2228
2229     CV_FUNCNAME( "CvSVM::read" );
2230
2231     __BEGIN__;
2232
2233     int i, var_count, df_count, class_count;
2234     int block_size = 1 << 16, sv_size;
2235     CvFileNode *sv_node, *df_node;
2236     CvSVMDecisionFunc* df;
2237     CvSeqReader reader;
2238
2239     if( !svm_node )
2240         CV_ERROR( CV_StsParseError, "The requested element is not found" );
2241
2242     clear();
2243
2244     // read SVM parameters
2245     read_params( fs, svm_node );
2246
2247     // and top-level data
2248     sv_total = cvReadIntByName( fs, svm_node, "sv_total", -1 );
2249     var_all = cvReadIntByName( fs, svm_node, "var_all", -1 );
2250     var_count = cvReadIntByName( fs, svm_node, "var_count", var_all );
2251     class_count = cvReadIntByName( fs, svm_node, "class_count", 0 );
2252
2253     if( sv_total <= 0 || var_all <= 0 || var_count <= 0 || var_count > var_all || class_count < 0 )
2254         CV_ERROR( CV_StsParseError, "SVM model data is invalid, check sv_count, var_* and class_count tags" );
2255
2256     CV_CALL( class_labels = (CvMat*)cvReadByName( fs, svm_node, "class_labels" ));
2257     CV_CALL( class_weights = (CvMat*)cvReadByName( fs, svm_node, "class_weights" ));
2258     CV_CALL( var_idx = (CvMat*)cvReadByName( fs, svm_node, "var_idx" ));
2259
2260     if( class_count > 1 && (!class_labels ||
2261         !CV_IS_MAT(class_labels) || class_labels->cols != class_count))
2262         CV_ERROR( CV_StsParseError, "Array of class labels is missing or invalid" );
2263
2264     if( var_count < var_all && (!var_idx || !CV_IS_MAT(var_idx) || var_idx->cols != var_count) )
2265         CV_ERROR( CV_StsParseError, "var_idx array is missing or invalid" );
2266
2267     // read support vectors
2268     sv_node = cvGetFileNodeByName( fs, svm_node, "support_vectors" );
2269     if( !sv_node || !CV_NODE_IS_SEQ(sv_node->tag))
2270         CV_ERROR( CV_StsParseError, "Missing or invalid sequence of support vectors" );
2271
2272     block_size = MAX( block_size, sv_total*(int)sizeof(CvSVMKernelRow));
2273     block_size = MAX( block_size, sv_total*2*(int)sizeof(double));
2274     block_size = MAX( block_size, var_all*(int)sizeof(double));
2275     CV_CALL( storage = cvCreateMemStorage( block_size ));
2276     CV_CALL( sv = (float**)cvMemStorageAlloc( storage,
2277                                 sv_total*sizeof(sv[0]) ));
2278
2279     CV_CALL( cvStartReadSeq( sv_node->data.seq, &reader, 0 ));
2280     sv_size = var_count*sizeof(sv[0][0]);
2281
2282     for( i = 0; i < sv_total; i++ )
2283     {
2284         CvFileNode* sv_elem = (CvFileNode*)reader.ptr;
2285         CV_ASSERT( var_count == 1 || (CV_NODE_IS_SEQ(sv_elem->tag) &&
2286                    sv_elem->data.seq->total == var_count) );
2287
2288         CV_CALL( sv[i] = (float*)cvMemStorageAlloc( storage, sv_size ));
2289         CV_CALL( cvReadRawData( fs, sv_elem, sv[i], "f" ));
2290         CV_NEXT_SEQ_ELEM( sv_node->data.seq->elem_size, reader );
2291     }
2292
2293     // read decision functions
2294     df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
2295     df_node = cvGetFileNodeByName( fs, svm_node, "decision_functions" );
2296     if( !df_node || !CV_NODE_IS_SEQ(df_node->tag) ||
2297         df_node->data.seq->total != df_count )
2298         CV_ERROR( CV_StsParseError, "decision_functions is missing or is not a collection "
2299                   "or has a wrong number of elements" );
2300
2301     CV_CALL( df = decision_func = (CvSVMDecisionFunc*)cvAlloc( df_count*sizeof(df[0]) ));
2302     cvStartReadSeq( df_node->data.seq, &reader, 0 );
2303
2304     for( i = 0; i < df_count; i++ )
2305     {
2306         CvFileNode* df_elem = (CvFileNode*)reader.ptr;
2307         CvFileNode* alpha_node = cvGetFileNodeByName( fs, df_elem, "alpha" );
2308
2309         int sv_count = cvReadIntByName( fs, df_elem, "sv_count", -1 );
2310         if( sv_count <= 0 )
2311             CV_ERROR( CV_StsParseError, "sv_count is missing or non-positive" );
2312         df[i].sv_count = sv_count;
2313
2314         df[i].rho = cvReadRealByName( fs, df_elem, "rho", not_found_dbl );
2315         if( fabs(df[i].rho - not_found_dbl) < DBL_EPSILON )
2316             CV_ERROR( CV_StsParseError, "rho is missing" );
2317
2318         if( !alpha_node )
2319             CV_ERROR( CV_StsParseError, "alpha is missing in the decision function" );
2320
2321         CV_CALL( df[i].alpha = (double*)cvMemStorageAlloc( storage,
2322                                         sv_count*sizeof(df[i].alpha[0])));
2323         CV_ASSERT( sv_count == 1 || (CV_NODE_IS_SEQ(alpha_node->tag) &&
2324                    alpha_node->data.seq->total == sv_count) );
2325         CV_CALL( cvReadRawData( fs, alpha_node, df[i].alpha, "d" ));
2326
2327         if( class_count > 1 )
2328         {
2329             CvFileNode* index_node = cvGetFileNodeByName( fs, df_elem, "index" );
2330             if( !index_node )
2331                 CV_ERROR( CV_StsParseError, "index is missing in the decision function" );
2332             CV_CALL( df[i].sv_index = (int*)cvMemStorageAlloc( storage,
2333                                             sv_count*sizeof(df[i].sv_index[0])));
2334             CV_ASSERT( sv_count == 1 || (CV_NODE_IS_SEQ(index_node->tag) &&
2335                    index_node->data.seq->total == sv_count) );
2336             CV_CALL( cvReadRawData( fs, index_node, df[i].sv_index, "i" ));
2337         }
2338         else
2339             df[i].sv_index = 0;
2340
2341         CV_NEXT_SEQ_ELEM( df_node->data.seq->elem_size, reader );
2342     }
2343
2344     create_kernel();
2345
2346     __END__;
2347 }
2348
2349 #if 0
2350
2351 static void*
2352 icvCloneSVM( const void* _src )
2353 {
2354     CvSVMModel* dst = 0;
2355
2356     CV_FUNCNAME( "icvCloneSVM" );
2357
2358     __BEGIN__;
2359
2360     const CvSVMModel* src = (const CvSVMModel*)_src;
2361     int var_count, class_count;
2362     int i, sv_total, df_count;
2363     int sv_size;
2364
2365     if( !CV_IS_SVM(src) )
2366         CV_ERROR( !src ? CV_StsNullPtr : CV_StsBadArg, "Input pointer is NULL or invalid" );
2367
2368     // 0. create initial CvSVMModel structure
2369     CV_CALL( dst = icvCreateSVM() );
2370     dst->params = src->params;
2371     dst->params.weight_labels = 0;
2372     dst->params.weights = 0;
2373
2374     dst->var_all = src->var_all;
2375     if( src->class_labels )
2376         dst->class_labels = cvCloneMat( src->class_labels );
2377     if( src->class_weights )
2378         dst->class_weights = cvCloneMat( src->class_weights );
2379     if( src->comp_idx )
2380         dst->comp_idx = cvCloneMat( src->comp_idx );
2381
2382     var_count = src->comp_idx ? src->comp_idx->cols : src->var_all;
2383     class_count = src->class_labels ? src->class_labels->cols :
2384                   src->params.svm_type == CvSVM::ONE_CLASS ? 1 : 0;
2385     sv_total = dst->sv_total = src->sv_total;
2386     CV_CALL( dst->storage = cvCreateMemStorage( src->storage->block_size ));
2387     CV_CALL( dst->sv = (float**)cvMemStorageAlloc( dst->storage,
2388                                     sv_total*sizeof(dst->sv[0]) ));
2389
2390     sv_size = var_count*sizeof(dst->sv[0][0]);
2391
2392     for( i = 0; i < sv_total; i++ )
2393     {
2394         CV_CALL( dst->sv[i] = (float*)cvMemStorageAlloc( dst->storage, sv_size ));
2395         memcpy( dst->sv[i], src->sv[i], sv_size );
2396     }
2397
2398     df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
2399
2400     CV_CALL( dst->decision_func = cvAlloc( df_count*sizeof(CvSVMDecisionFunc) ));
2401
2402     for( i = 0; i < df_count; i++ )
2403     {
2404         const CvSVMDecisionFunc *sdf =
2405             (const CvSVMDecisionFunc*)src->decision_func+i;
2406         CvSVMDecisionFunc *ddf =
2407             (CvSVMDecisionFunc*)dst->decision_func+i;
2408         int sv_count = sdf->sv_count;
2409         ddf->sv_count = sv_count;
2410         ddf->rho = sdf->rho;
2411         CV_CALL( ddf->alpha = (double*)cvMemStorageAlloc( dst->storage,
2412                                         sv_count*sizeof(ddf->alpha[0])));
2413         memcpy( ddf->alpha, sdf->alpha, sv_count*sizeof(ddf->alpha[0]));
2414
2415         if( class_count > 1 )
2416         {
2417             CV_CALL( ddf->sv_index = (int*)cvMemStorageAlloc( dst->storage,
2418                                                 sv_count*sizeof(ddf->sv_index[0])));
2419             memcpy( ddf->sv_index, sdf->sv_index, sv_count*sizeof(ddf->sv_index[0]));
2420         }
2421         else
2422             ddf->sv_index = 0;
2423     }
2424
2425     __END__;
2426
2427     if( cvGetErrStatus() < 0 && dst )
2428         icvReleaseSVM( &dst );
2429
2430     return dst;
2431 }
2432
2433 static int icvRegisterSVMType()
2434 {
2435     CvTypeInfo info;
2436     memset( &info, 0, sizeof(info) );
2437
2438     info.flags = 0;
2439     info.header_size = sizeof( info );
2440     info.is_instance = icvIsSVM;
2441     info.release = (CvReleaseFunc)icvReleaseSVM;
2442     info.read = icvReadSVM;
2443     info.write = icvWriteSVM;
2444     info.clone = icvCloneSVM;
2445     info.type_name = CV_TYPE_NAME_ML_SVM;
2446     cvRegisterType( &info );
2447
2448     return 1;
2449 }
2450
2451
2452 static int svm = icvRegisterSVMType();
2453
2454 /* The function trains SVM model with optimal parameters, obtained by using cross-validation.
2455 The parameters to be estimated should be indicated by setting theirs values to FLT_MAX.
2456 The optimal parameters are saved in <model_params> */
2457 CV_IMPL CvStatModel*
2458 cvTrainSVM_CrossValidation( const CvMat* train_data, int tflag,
2459             const CvMat* responses,
2460             CvStatModelParams* model_params,
2461             const CvStatModelParams* cross_valid_params,
2462             const CvMat* comp_idx,
2463             const CvMat* sample_idx,
2464             const CvParamGrid* degree_grid,
2465             const CvParamGrid* gamma_grid,
2466             const CvParamGrid* coef_grid,
2467             const CvParamGrid* C_grid,
2468             const CvParamGrid* nu_grid,
2469             const CvParamGrid* p_grid )
2470 {
2471     CvStatModel* svm = 0;
2472
2473     CV_FUNCNAME("cvTainSVMCrossValidation");
2474     __BEGIN__;
2475
2476     double degree_step = 7,
2477                g_step      = 15,
2478                    coef_step   = 14,
2479                    C_step      = 20,
2480                    nu_step     = 5,
2481                    p_step      = 7; // all steps must be > 1
2482     double degree_begin = 0.01, degree_end = 2;
2483     double g_begin      = 1e-5, g_end      = 0.5;
2484     double coef_begin   = 0.1,  coef_end   = 300;
2485     double C_begin      = 0.1,  C_end      = 6000;
2486     double nu_begin     = 0.01,  nu_end    = 0.4;
2487     double p_begin      = 0.01, p_end      = 100;
2488
2489     double rate = 0, gamma = 0, C = 0, degree = 0, coef = 0, p = 0, nu = 0;
2490
2491         double best_rate    = 0;
2492     double best_degree  = degree_begin;
2493     double best_gamma   = g_begin;
2494     double best_coef    = coef_begin;
2495         double best_C       = C_begin;
2496         double best_nu      = nu_begin;
2497     double best_p       = p_begin;
2498
2499     CvSVMModelParams svm_params, *psvm_params;
2500     CvCrossValidationParams* cv_params = (CvCrossValidationParams*)cross_valid_params;
2501     int svm_type, kernel;
2502     int is_regression;
2503
2504     if( !model_params )
2505         CV_ERROR( CV_StsBadArg, "" );
2506     if( !cv_params )
2507         CV_ERROR( CV_StsBadArg, "" );
2508
2509     svm_params = *(CvSVMModelParams*)model_params;
2510     psvm_params = (CvSVMModelParams*)model_params;
2511     svm_type = svm_params.svm_type;
2512     kernel = svm_params.kernel_type;
2513
2514     svm_params.degree = svm_params.degree > 0 ? svm_params.degree : 1;
2515     svm_params.gamma = svm_params.gamma > 0 ? svm_params.gamma : 1;
2516     svm_params.coef0 = svm_params.coef0 > 0 ? svm_params.coef0 : 1e-6;
2517     svm_params.C = svm_params.C > 0 ? svm_params.C : 1;
2518     svm_params.nu = svm_params.nu > 0 ? svm_params.nu : 1;
2519     svm_params.p = svm_params.p > 0 ? svm_params.p : 1;
2520
2521     if( degree_grid )
2522     {
2523         if( !(degree_grid->max_val == 0 && degree_grid->min_val == 0 &&
2524               degree_grid->step == 0) )
2525         {
2526             if( degree_grid->min_val > degree_grid->max_val )
2527                 CV_ERROR( CV_StsBadArg,
2528                 "low bound of grid should be less then the upper one");
2529             if( degree_grid->step <= 1 )
2530                 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2531             degree_begin = degree_grid->min_val;
2532             degree_end   = degree_grid->max_val;
2533             degree_step  = degree_grid->step;
2534         }
2535     }
2536     else
2537         degree_begin = degree_end = svm_params.degree;
2538
2539     if( gamma_grid )
2540     {
2541         if( !(gamma_grid->max_val == 0 && gamma_grid->min_val == 0 &&
2542               gamma_grid->step == 0) )
2543         {
2544             if( gamma_grid->min_val > gamma_grid->max_val )
2545                 CV_ERROR( CV_StsBadArg,
2546                 "low bound of grid should be less then the upper one");
2547             if( gamma_grid->step <= 1 )
2548                 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2549             g_begin = gamma_grid->min_val;
2550             g_end   = gamma_grid->max_val;
2551             g_step  = gamma_grid->step;
2552         }
2553     }
2554     else
2555         g_begin = g_end = svm_params.gamma;
2556
2557     if( coef_grid )
2558     {
2559         if( !(coef_grid->max_val == 0 && coef_grid->min_val == 0 &&
2560               coef_grid->step == 0) )
2561         {
2562             if( coef_grid->min_val > coef_grid->max_val )
2563                 CV_ERROR( CV_StsBadArg,
2564                 "low bound of grid should be less then the upper one");
2565             if( coef_grid->step <= 1 )
2566                 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2567             coef_begin = coef_grid->min_val;
2568             coef_end   = coef_grid->max_val;
2569             coef_step  = coef_grid->step;
2570         }
2571     }
2572     else
2573         coef_begin = coef_end = svm_params.coef0;
2574
2575     if( C_grid )
2576     {
2577         if( !(C_grid->max_val == 0 && C_grid->min_val == 0 && C_grid->step == 0))
2578         {
2579             if( C_grid->min_val > C_grid->max_val )
2580                 CV_ERROR( CV_StsBadArg,
2581                 "low bound of grid should be less then the upper one");
2582             if( C_grid->step <= 1 )
2583                 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2584             C_begin = C_grid->min_val;
2585             C_end   = C_grid->max_val;
2586             C_step  = C_grid->step;
2587         }
2588     }
2589     else
2590         C_begin = C_end = svm_params.C;
2591
2592     if( nu_grid )
2593     {
2594         if(!(nu_grid->max_val == 0 && nu_grid->min_val == 0 && nu_grid->step==0))
2595         {
2596             if( nu_grid->min_val > nu_grid->max_val )
2597                 CV_ERROR( CV_StsBadArg,
2598                 "low bound of grid should be less then the upper one");
2599             if( nu_grid->step <= 1 )
2600                 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2601             nu_begin = nu_grid->min_val;
2602             nu_end   = nu_grid->max_val;
2603             nu_step  = nu_grid->step;
2604         }
2605     }
2606     else
2607         nu_begin = nu_end = svm_params.nu;
2608
2609     if( p_grid )
2610     {
2611         if( !(p_grid->max_val == 0 && p_grid->min_val == 0 && p_grid->step == 0))
2612         {
2613             if( p_grid->min_val > p_grid->max_val )
2614                 CV_ERROR( CV_StsBadArg,
2615                 "low bound of grid should be less then the upper one");
2616             if( p_grid->step <= 1 )
2617                 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2618             p_begin = p_grid->min_val;
2619             p_end   = p_grid->max_val;
2620             p_step  = p_grid->step;
2621         }
2622     }
2623     else
2624         p_begin = p_end = svm_params.p;
2625
2626     // these parameters are not used:
2627     if( kernel != CvSVM::POLY )
2628         degree_begin = degree_end = svm_params.degree;
2629
2630    if( kernel == CvSVM::LINEAR )
2631         g_begin = g_end = svm_params.gamma;
2632
2633     if( kernel != CvSVM::POLY && kernel != CvSVM::SIGMOID )
2634         coef_begin = coef_end = svm_params.coef0;
2635
2636     if( svm_type == CvSVM::NU_SVC || svm_type == CvSVM::ONE_CLASS )
2637         C_begin = C_end = svm_params.C;
2638
2639     if( svm_type == CvSVM::C_SVC || svm_type == CvSVM::EPS_SVR )
2640         nu_begin = nu_end = svm_params.nu;
2641
2642     if( svm_type != CvSVM::EPS_SVR )
2643         p_begin = p_end = svm_params.p;
2644
2645     is_regression = cv_params->is_regression;
2646     best_rate = is_regression ? FLT_MAX : 0;
2647
2648     assert( g_step > 1 && degree_step > 1 && coef_step > 1);
2649     assert( p_step > 1 && C_step > 1 && nu_step > 1 );
2650
2651     for( degree = degree_begin; degree <= degree_end; degree *= degree_step )
2652     {
2653       svm_params.degree = degree;
2654       //printf("degree = %.3f\n", degree );
2655       for( gamma= g_begin; gamma <= g_end; gamma *= g_step )
2656       {
2657         svm_params.gamma = gamma;
2658         //printf("   gamma = %.3f\n", gamma );
2659         for( coef = coef_begin; coef <= coef_end; coef *= coef_step )
2660         {
2661           svm_params.coef0 = coef;
2662           //printf("      coef = %.3f\n", coef );
2663           for( C = C_begin; C <= C_end; C *= C_step )
2664           {
2665             svm_params.C = C;
2666             //printf("         C = %.3f\n", C );
2667             for( nu = nu_begin; nu <= nu_end; nu *= nu_step )
2668             {
2669               svm_params.nu = nu;
2670               //printf("            nu = %.3f\n", nu );
2671               for( p = p_begin; p <= p_end; p *= p_step )
2672               {
2673                 int well;
2674                 svm_params.p = p;
2675                 //printf("               p = %.3f\n", p );
2676
2677                 CV_CALL(rate = cvCrossValidation( train_data, tflag, responses, &cvTrainSVM,
2678                     cross_valid_params, (CvStatModelParams*)&svm_params, comp_idx, sample_idx ));
2679
2680                 well =  rate > best_rate && !is_regression || rate < best_rate && is_regression;
2681                 if( well || (rate == best_rate && C < best_C) )
2682                 {
2683                     best_rate   = rate;
2684                     best_degree = degree;
2685                     best_gamma  = gamma;
2686                     best_coef   = coef;
2687                     best_C      = C;
2688                     best_nu     = nu;
2689                     best_p      = p;
2690                 }
2691                 //printf("                  rate = %.2f\n", rate );
2692               }
2693             }
2694           }
2695         }
2696       }
2697     }
2698     //printf("The best:\nrate = %.2f%% degree = %f gamma = %f coef = %f c = %f nu = %f p = %f\n",
2699       //  best_rate, best_degree, best_gamma, best_coef, best_C, best_nu, best_p );
2700
2701     psvm_params->C      = best_C;
2702     psvm_params->nu     = best_nu;
2703     psvm_params->p      = best_p;
2704     psvm_params->gamma  = best_gamma;
2705     psvm_params->degree = best_degree;
2706     psvm_params->coef0  = best_coef;
2707
2708     CV_CALL(svm = cvTrainSVM( train_data, tflag, responses, model_params, comp_idx, sample_idx ));
2709
2710     __END__;
2711
2712     return svm;
2713 }
2714
2715 #endif
2716
2717 /* End of file. */
2718