?? ssvm2_mex.c~
字號:
/*-----------------------------------------------------------------------
ssvm2_mex.c: MEX-file for single-class SVM with L2-soft margin.
Compile:
mex -O -I../../kernels -outdir ../ ssvm2_mex.c ../../kernels/kernel_fun.c npa.c kozinec.c mdm.c
Synopsis:
[Alpha,exitflag,UB,LB,t,kercnt,margin,trnerr,History] =
ssvm2_mex(data,ker,arg,C,solver,tmax,tolabs,tolrel)
Input:
data [dim x num_data] Input vectors.
ker [string] Kernel identifier.
arg [1xnarg] Kernel argument(s).
C [1x1] Regularization constant k'(i,j) = k(i,j) + delta(i,j)/(2*C).
solver [string] Solver; options are 'mdm','kozinec' or 'npa'.
tmax [1x1] Maximal number of iterations.
tolabs [1x1] Absolute tolerance stopping condition.
tolrel [1x1] Relative tolerance stopping condition.
Output:
Alpha [num_data x 1] Weights.
exitflag [1x1] Indicates which stopping condition was used:
UB <= tolabs -> exit_flag = 1 Abs. tolerance.
(UB-LB)/(LB+1) <= tolrel -> exit_flag = 2 Relative tolerance.
t >= tmax -> exit_flag = 0 Number of iterations.
UB [1x1] Upper bound on the optimal solution.
LB [1x1] Lower bound on the optimal solution.
t [1x1] Number of iterations.
kercnt [1x1] Number of kernel evaluations.
margin [1x1] Achieved margin.
trnerr [1x1] Training error.
History [2x(t+1)] UB and LB with respect to number of iterations.
Modifications:
15-jun-2004, VF
23-Jan-2004, VF
22-Jan-2004, VF
14-Oct-2003, VF
-------------------------------------------------------------------- */
#include "mex.h"
#include "matrix.h"
#include <math.h>
#include <stdlib.h>
#include <string.h>
#include <limits.h>
#include "kernel_fun.h"
#include "kozinec.h"
#include "mdm.h"
#include "npa.h"
#define INDEX(ROW,COL,DIM) ((COL*DIM)+ROW)
/* Diagonal addend of kernel matrix */
double kernel_diag;
/* ==============================================================
Kernel function.
============================================================== */
double kernel_fce( long i, long j)
{
return( kernel( i,j ) + (i==j)*kernel_diag );
}
/* ==============================================================
Main MEX function - interface to Matlab.
============================================================== */
void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[] )
{
double *Alpha;
double LB;
double UB;
double margin;
double trnerr;
double tolabs;
double tolrel;
double C;
double *History;
double *tmp_ptr;
double tmp;
long num_data;
long i;
long t;
long tmax;
int exitflag;
char solver[20];
int buf_len;
/* == Processing of inputs == */
if(nrhs != 8) mexErrMsgTxt("Incorrect number of input arguments.");
/* get input data and parameters */
dataA = mxGetPr(prhs[0]); /* pointer at patterns */
dataB = mxGetPr(prhs[0]); /* pointer at patterns */
dim = mxGetM(prhs[0]); /* data dimension */
num_data = mxGetN(prhs[0]); /* number of data */
ker_cnt = 0; /* counter of kernel evaluations */
ker = kernel_id( prhs[1] );
if( ker == -1 )
mexErrMsgTxt("Improper kernel identifier.");
arg1 = mxGetPr(prhs[2]); /* kernel arg*/
C = mxGetScalar(prhs[3]); /* regularization constant */
/* take solver string */
if( mxIsChar( prhs[4] ) != 1)
mexErrMsgTxt("solver must be string.");
buf_len = (mxGetM(prhs[4]) * mxGetN(prhs[4])) + 1;
buf_len = (buf_len > 20) ? 20 : buf_len;
mxGetString( prhs[4], solver, buf_len );
tmax = mxIsInf( mxGetScalar(prhs[5])) ? INT_MAX : (long)mxGetScalar(prhs[5]);
tolabs = mxGetScalar(prhs[6]);
tolrel = mxGetScalar(prhs[7]);
if( C!=0 ) kernel_diag = 1/(2*C); else kernel_diag = 0;
/* == call SVM solver == */
if( strcmp( solver, "kozinec" )==0 ) {
exitflag = single_kozinec( &kernel_fce, num_data, tmax, tolabs, tolrel,
&Alpha, &UB, &LB, &t, &History );
} else if ( strcmp( solver, "npa" )==0 ) {
exitflag = single_npa( &kernel_fce, num_data, tmax, tolabs, tolrel,
&Alpha, &UB, &LB, &t, &History );
} else if ( strcmp( solver, "mdm" )==0 ) {
exitflag = single_mdm( &kernel_fce, num_data, tmax, tolabs, tolrel,
&Alpha, &UB, &LB, &t, &History );
} else {
mexErrMsgTxt("Unknown solver identifier.");
}
/* == Ouputs == */
/* allocate memory for Alphas */
plhs[0] = mxCreateDoubleMatrix(num_data,1,mxREAL);
tmp_ptr = mxGetPr(plhs[0]);
for( i = 0; i < num_data; i++ )
{
tmp_ptr[i] = Alpha[i];
}
/* compute margin */
tmp = 0;
for( i = 0; i < num_data; i++ ) {
tmp += Alpha[i]*(1-Alpha[i]/(2*C));
}
if( tmp ) margin = 1/sqrt(tmp); else margin = -1;
/*compute training error */
tmp = 0;
for( i = 0; i < num_data; i++ ) {
if( Alpha[i] >= 2*C) tmp++;
}
trnerr = tmp/((double) num_data);
/* outputs */
plhs[1] = mxCreateDoubleMatrix(1,1,mxREAL);
plhs[2] = mxCreateDoubleMatrix(1,1,mxREAL);
plhs[3] = mxCreateDoubleMatrix(1,1,mxREAL);
plhs[4] = mxCreateDoubleMatrix(1,1,mxREAL);
plhs[5] = mxCreateDoubleMatrix(1,1,mxREAL);
plhs[6] = mxCreateDoubleMatrix(1,1,mxREAL);
plhs[7] = mxCreateDoubleMatrix(1,1,mxREAL);
(*mxGetPr(plhs[1])) = (double)exitflag;
(*mxGetPr(plhs[2])) = UB;
(*mxGetPr(plhs[3])) = LB;
(*mxGetPr(plhs[4])) = (double)t;
(*mxGetPr(plhs[5])) = (double)ker_cnt;
(*mxGetPr(plhs[6])) = margin;
(*mxGetPr(plhs[7])) = trnerr;
plhs[8] = mxCreateDoubleMatrix(2,t+1,mxREAL);
tmp_ptr = mxGetPr( plhs[8] );
for( i = 0; i <= t; i++ ) {
tmp_ptr[INDEX(0,i,2)] = History[INDEX(0,i,2)];
tmp_ptr[INDEX(1,i,2)] = History[INDEX(1,i,2)];
}
/**/
mxFree( Alpha );
mxFree( History );
return;
}
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -