?? bsvm2_mex.c
字號:
long i,j ; /* common use loop variables */
long inx1, inx2;
long NA;
long tmax; /* input arg - max number of iteration */
long t; /* output arg - number of iterations */
long verb; /* input argument */
double C; /* input arg - regularization const */
double tolrel; /* input arg */
double tolabs; /* input arg */
double trnerr; /* output arg */
double *tmp_ptr;
double *tmp_ptr1;
double *tmp_ptr2;
double *vector_c; /* auxiliary */
double *Alpha; /* solution vector */
double *History; /* output arg */
double *diagK; /* cache for diagonal of virtual K matrix */
/*------------------------------------------------------------------- */
/* Take input arguments */
/*------------------------------------------------------------------- */
if( nrhs != 11) mexErrMsgTxt("Incorrect number of input arguments.");
dataA = mxGetPr(prhs[0]); /* pointers at data */
dataB = dataA;
dim = mxGetM(prhs[0]); /* data dimension */
num_data = mxGetN(prhs[0]); /* number of data */
labels = mxGetPr(prhs[1]); /* pointer at data labels */
/* take kernel identifier and its argument */
ker = kernel_id( prhs[2] );
if( ker == -1 ) mexErrMsgTxt("Improper kernel identifier.");
arg1 = mxGetPr(prhs[3]);
C = mxGetScalar(prhs[4]); /* regularization constant */
/* take string identifier QP solver to be used */
if( mxIsChar( prhs[5] ) != 1) mexErrMsgTxt("solver must be string.");
buf_len = (mxGetM(prhs[5]) * mxGetN(prhs[5])) + 1;
buf_len = (buf_len > 20) ? 20 : buf_len;
mxGetString( prhs[5], solver, buf_len );
/* maximal allowed number of iterations */
tmax = mxIsInf( mxGetScalar(prhs[6])) ? INT_MAX : (long)mxGetScalar(prhs[6]);
tolabs = mxGetScalar(prhs[7]); /* abs. precision defining stopping cond*/
tolrel = mxGetScalar(prhs[8]); /* rel. precision defining stopping cond*/
Cache_Size = (long)mxGetScalar(prhs[9]); /* cache size */
if( Cache_Size < 1 ) mexErrMsgTxt("Cache must be greater than 1.");
if( Cache_Size > num_data ) Cache_Size = num_data;
verb = (long)mxGetScalar(prhs[10]); /* verbosity on/off */
/*------------------------------------------------------------------- */
/* Inicialization (caches, etc.) */
/*------------------------------------------------------------------- */
/* constant added to diagonal of separable problem */
if( C!=0 ) kernel_diag = 1/(2*C); else kernel_diag = 0;
/* num_classes = max( labels ) */
num_classes = MINUS_INF;
for( i = 0; i < num_data; i++ ) {
if( labels[i] > num_classes ) num_classes = (long)labels[i];
}
/* computes number of virtual "single-class" examples */
num_virt_data = (num_classes-1)*num_data;
ker_cnt = 0; /* counter of kernel evaluations */
access_cnt = 0; /* counter for access to the kernel matrix */
/* allocattes and precomputes diagonal of virtual K matrix */
diagK = mxCalloc(num_virt_data, sizeof(double));
if( diagK == NULL ) mexErrMsgTxt("Not enough memory.");
for(i = 0; i < num_virt_data; i++ ) {
diagK[i] = kernel_fce(i,i);
}
/* allocates memory for kernel cache */
kernel_columns = mxCalloc(Cache_Size, sizeof(double*));
if( kernel_columns == NULL ) mexErrMsgTxt("Not enough memory.");
cache_index = mxCalloc(Cache_Size, sizeof(double));
if( cache_index == NULL ) mexErrMsgTxt("Not enough memory.");
for(i = 0; i < Cache_Size; i++ ) {
kernel_columns[i] = mxCalloc(num_data, sizeof(double));
if(kernel_columns[i] == NULL) mexErrMsgTxt("Not enough memory.");
cache_index[i] = -2;
}
first_kernel_inx = 0;
/* allocates memory for three virtual kernel matrix columns */
for(i = 0; i < 3; i++ ) {
virt_columns[i] = mxCalloc(num_virt_data, sizeof(double));
if(virt_columns[i] == NULL) mexErrMsgTxt("Not enough memory.");
}
first_virt_inx = 0;
/* Solution vector */
Alpha = mxCalloc(num_virt_data, sizeof(double));
if( Alpha == NULL ) mexErrMsgTxt("Not enough memory.");
/* Vector c; for this problem set to zero */
vector_c = mxCalloc(num_virt_data, sizeof(double));
if( vector_c == NULL ) mexErrMsgTxt("Not enough memory.");
for(i = 0; i < num_virt_data; i++ ) vector_c[i] = 0;
/*------------------------------------------------------------------- */
/* Call QP solver */
/*------------------------------------------------------------------- */
if ( strcmp( solver, "mdm" )==0 ) {
exitflag = qpc_mdm( &get_col, diagK, vector_c, num_virt_data, tmax,
tolabs, tolrel, Alpha, &t, &History, verb );
} else if ( strcmp( solver, "imdm" )==0 ) {
exitflag = qpc_imdm( &get_col, diagK, vector_c, num_virt_data, tmax,
tolabs, tolrel, Alpha, &t, &History, verb );
} else if ( strcmp( solver, "iimdm" )==0 ) {
exitflag = qpc_iimdm( &get_col, diagK, vector_c, num_virt_data, tmax,
tolabs, tolrel, Alpha, &t, &History, verb );
} else if ( strcmp( solver, "keerthi" )==0 ) {
exitflag = qpc_keerthi( &get_col, diagK, vector_c, num_virt_data, tmax,
tolabs, tolrel, Alpha, &t, &History, verb );
} else if ( strcmp( solver, "kowalczyk" )==0 ) {
exitflag = qpc_kowalczyk( &get_col, diagK, vector_c, num_virt_data, tmax,
tolabs, tolrel, Alpha, &t, &History, verb );
} else if ( strcmp( solver, "kozinec" )==0 ) {
exitflag = qpc_kozinec( &get_col, diagK, vector_c, num_virt_data, tmax,
tolabs, tolrel, Alpha, &t, &History, verb );
} else {
mexErrMsgTxt("Unknown solver identifier.");
}
/*------------------------------------------------------------------- */
/* Generate outputs */
/*------------------------------------------------------------------- */
/* matrix Alpha [num_classes x num_data] */
plhs[0] = mxCreateDoubleMatrix(num_classes,num_data,mxREAL);
tmp_ptr1 = mxGetPr(plhs[0]);
/* bias vector b [num_classes x 1] */
plhs[1] = mxCreateDoubleMatrix(num_classes,1,mxREAL);
tmp_ptr2 = mxGetPr(plhs[1]);
for( i=0; i < num_classes; i++ ) {
for( j=0; j < num_virt_data; j++ ) {
get_indices2( &inx1, &inx2, j );
tmp_ptr1[(inx1*num_classes)+i] +=
Alpha[j]*(KDELTA(labels[inx1],i+1)+KDELTA(i+1,inx2));
tmp_ptr2[i] += Alpha[j]*(KDELTA(labels[inx1],i+1)-KDELTA(i+1,inx2));
}
}
/* exit_flag [1x1] */
plhs[2] = mxCreateDoubleMatrix(1,1,mxREAL);
*(mxGetPr(plhs[2])) = (double)exitflag;
/* kercnt [1x1] */
plhs[3] = mxCreateDoubleMatrix(1,1,mxREAL);
*(mxGetPr(plhs[3])) = (double)ker_cnt;
/* access [1x1] */
plhs[4] = mxCreateDoubleMatrix(1,1,mxREAL);
*(mxGetPr(plhs[4])) = (double)access_cnt;
/* trnerr [1x1] */
err_bit = mxCalloc(num_data, sizeof(int));
if( err_bit == NULL ) mexErrMsgTxt("Not enough memory.");
for( i=0; i < num_classes; i++ ) {
for( j=0; j < num_virt_data; j++ ) {
get_indices2( &inx1, &inx2, j );
if( Alpha[j] > 2*C ) err_bit[inx1] = 1;
}
}
for( trnerr = 0, i = 0; i < num_data; i++ ) trnerr += err_bit[i];
trnerr = trnerr/num_data;
plhs[5] = mxCreateDoubleMatrix(1,1,mxREAL);
*(mxGetPr(plhs[5])) = trnerr;
/* t [1x1] */
plhs[6] = mxCreateDoubleMatrix(1,1,mxREAL);
*(mxGetPr(plhs[6])) = (double)t;
/* NA [1x1] */
for( NA = 0, j=0; j < num_virt_data; j++ ) {
if( Alpha[j] > 0 ) NA++;
}
plhs[7] = mxCreateDoubleMatrix(1,1,mxREAL);
*(mxGetPr(plhs[7])) = (double)NA;
/* UB [1x1] */
plhs[8] = mxCreateDoubleMatrix(1,1,mxREAL);
*(mxGetPr(plhs[8])) = History[INDEX(1,t,2)];
/* LB [1x1] */
plhs[9] = mxCreateDoubleMatrix(1,1,mxREAL);
*(mxGetPr(plhs[9])) = History[INDEX(0,t,2)];
/* History [2 x (t+1)] */
plhs[10] = mxCreateDoubleMatrix(2,t+1,mxREAL);
tmp_ptr = mxGetPr( plhs[10] );
for( i = 0; i <= t; i++ ) {
tmp_ptr[INDEX(0,i,2)] = History[INDEX(0,i,2)];
tmp_ptr[INDEX(1,i,2)] = History[INDEX(1,i,2)];
}
/*------------------------------------------------------------------- */
/* Free used memory */
/*------------------------------------------------------------------- */
mxFree( vector_c );
mxFree( Alpha );
mxFree( History );
mxFree( diagK );
for(i = 0; i < Cache_Size; i++ ) mxFree(kernel_columns[i]);
for(i = 0; i < 3; i++ ) mxFree(virt_columns[i]);
mxFree( kernel_columns );
mxFree( cache_index );
}
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -