?? svm.java
字號:
{
if(G[i] > Gmax2)
{
Gmax2 = G[i];
Gmax2_idx = i;
}
}
}
else // y = -1
{
if(!is_upper_bound(i)) // d = +1
{
if(-G[i] > Gmax2)
{
Gmax2 = -G[i];
Gmax2_idx = i;
}
}
if(!is_lower_bound(i)) // d = -1
{
if(G[i] > Gmax1)
{
Gmax1 = G[i];
Gmax1_idx = i;
}
}
}
}
if(Gmax1+Gmax2 < eps)
return 1;
working_set[0] = Gmax1_idx;
working_set[1] = Gmax2_idx;
return 0;
}
void do_shrinking()
{
int i,j,k;
int[] working_set = new int[2];
if(select_working_set(working_set)!=0) return;
i = working_set[0];
j = working_set[1];
double Gm1 = -y[j]*G[j];
double Gm2 = y[i]*G[i];
// shrink
for(k=0;k<active_size;k++)
{
if(is_lower_bound(k))
{
if(y[k]==+1)
{
if(-G[k] >= Gm1) continue;
}
else if(-G[k] >= Gm2) continue;
}
else if(is_upper_bound(k))
{
if(y[k]==+1)
{
if(G[k] >= Gm2) continue;
}
else if(G[k] >= Gm1) continue;
}
else continue;
--active_size;
swap_index(k,active_size);
--k; // look at the newcomer
}
// unshrink, check all variables again before final iterations
if(unshrinked || -(Gm1 + Gm2) > eps*10) return;
unshrinked = true;
reconstruct_gradient();
for(k=l-1;k>=active_size;k--)
{
if(is_lower_bound(k))
{
if(y[k]==+1)
{
if(-G[k] < Gm1) continue;
}
else if(-G[k] < Gm2) continue;
}
else if(is_upper_bound(k))
{
if(y[k]==+1)
{
if(G[k] < Gm2) continue;
}
else if(G[k] < Gm1) continue;
}
else continue;
swap_index(k,active_size);
active_size++;
++k; // look at the newcomer
}
}
double calculate_rho()
{
double r;
int nr_free = 0;
double ub = INF, lb = -INF, sum_free = 0;
for(int i=0;i<active_size;i++)
{
double yG = y[i]*G[i];
if(is_lower_bound(i))
{
if(y[i] > 0)
ub = Math.min(ub,yG);
else
lb = Math.max(lb,yG);
}
else if(is_upper_bound(i))
{
if(y[i] < 0)
ub = Math.min(ub,yG);
else
lb = Math.max(lb,yG);
}
else
{
++nr_free;
sum_free += yG;
}
}
if(nr_free>0)
r = sum_free/nr_free;
else
r = (ub+lb)/2;
return r;
}
}
//
// Solver for nu-svm classification and regression
//
// additional constraint: e^T \alpha = constant
//
final class Solver_NU extends Solver
{
private SolutionInfo si;
void Solve(int l, Kernel Q, double[] b, byte[] y,
double[] alpha, double Cp, double Cn, double eps,
SolutionInfo si, int shrinking)
{
this.si = si;
super.Solve(l,Q,b,y,alpha,Cp,Cn,eps,si,shrinking);
}
int select_working_set(int[] working_set)
{
// return i,j which maximize -grad(f)^T d , under constraint
// if alpha_i == C, d != +1
// if alpha_i == 0, d != -1
double Gmax1 = -INF; // max { -grad(f)_i * d | y_i = +1, d = +1 }
int Gmax1_idx = -1;
double Gmax2 = -INF; // max { -grad(f)_i * d | y_i = +1, d = -1 }
int Gmax2_idx = -1;
double Gmax3 = -INF; // max { -grad(f)_i * d | y_i = -1, d = +1 }
int Gmax3_idx = -1;
double Gmax4 = -INF; // max { -grad(f)_i * d | y_i = -1, d = -1 }
int Gmax4_idx = -1;
for(int i=0;i<active_size;i++)
{
if(y[i]==+1) // y == +1
{
if(!is_upper_bound(i)) // d = +1
{
if(-G[i] > Gmax1)
{
Gmax1 = -G[i];
Gmax1_idx = i;
}
}
if(!is_lower_bound(i)) // d = -1
{
if(G[i] > Gmax2)
{
Gmax2 = G[i];
Gmax2_idx = i;
}
}
}
else // y == -1
{
if(!is_upper_bound(i)) // d = +1
{
if(-G[i] > Gmax3)
{
Gmax3 = -G[i];
Gmax3_idx = i;
}
}
if(!is_lower_bound(i)) // d = -1
{
if(G[i] > Gmax4)
{
Gmax4 = G[i];
Gmax4_idx = i;
}
}
}
}
if(Math.max(Gmax1+Gmax2,Gmax3+Gmax4) < eps)
return 1;
if(Gmax1+Gmax2 > Gmax3+Gmax4)
{
working_set[0] = Gmax1_idx;
working_set[1] = Gmax2_idx;
}
else
{
working_set[0] = Gmax3_idx;
working_set[1] = Gmax4_idx;
}
return 0;
}
void do_shrinking()
{
double Gmax1 = -INF; // max { -grad(f)_i * d | y_i = +1, d = +1 }
double Gmax2 = -INF; // max { -grad(f)_i * d | y_i = +1, d = -1 }
double Gmax3 = -INF; // max { -grad(f)_i * d | y_i = -1, d = +1 }
double Gmax4 = -INF; // max { -grad(f)_i * d | y_i = -1, d = -1 }
int k;
for(k=0;k<active_size;k++)
{
if(!is_upper_bound(k))
{
if(y[k]==+1)
{
if(-G[k] > Gmax1) Gmax1 = -G[k];
}
else if(-G[k] > Gmax3) Gmax3 = -G[k];
}
if(!is_lower_bound(k))
{
if(y[k]==+1)
{
if(G[k] > Gmax2) Gmax2 = G[k];
}
else if(G[k] > Gmax4) Gmax4 = G[k];
}
}
double Gm1 = -Gmax2;
double Gm2 = -Gmax1;
double Gm3 = -Gmax4;
double Gm4 = -Gmax3;
for(k=0;k<active_size;k++)
{
if(is_lower_bound(k))
{
if(y[k]==+1)
{
if(-G[k] >= Gm1) continue;
}
else if(-G[k] >= Gm3) continue;
}
else if(is_upper_bound(k))
{
if(y[k]==+1)
{
if(G[k] >= Gm2) continue;
}
else if(G[k] >= Gm4) continue;
}
else continue;
--active_size;
swap_index(k,active_size);
--k; // look at the newcomer
}
// unshrink, check all variables again before final iterations
if(unshrinked || Math.max(-(Gm1+Gm2),-(Gm3+Gm4)) > eps*10) return;
unshrinked = true;
reconstruct_gradient();
for(k=l-1;k>=active_size;k--)
{
if(is_lower_bound(k))
{
if(y[k]==+1)
{
if(-G[k] < Gm1) continue;
}
else if(-G[k] < Gm3) continue;
}
else if(is_upper_bound(k))
{
if(y[k]==+1)
{
if(G[k] < Gm2) continue;
}
else if(G[k] < Gm4) continue;
}
else continue;
swap_index(k,active_size);
active_size++;
++k; // look at the newcomer
}
}
double calculate_rho()
{
int nr_free1 = 0,nr_free2 = 0;
double ub1 = INF, ub2 = INF;
double lb1 = -INF, lb2 = -INF;
double sum_free1 = 0, sum_free2 = 0;
for(int i=0;i<active_size;i++)
{
if(y[i]==+1)
{
if(is_lower_bound(i))
ub1 = Math.min(ub1,G[i]);
else if(is_upper_bound(i))
lb1 = Math.max(lb1,G[i]);
else
{
++nr_free1;
sum_free1 += G[i];
}
}
else
{
if(is_lower_bound(i))
ub2 = Math.min(ub2,G[i]);
else if(is_upper_bound(i))
lb2 = Math.max(lb2,G[i]);
else
{
++nr_free2;
sum_free2 += G[i];
}
}
}
double r1,r2;
if(nr_free1 > 0)
r1 = sum_free1/nr_free1;
else
r1 = (ub1+lb1)/2;
if(nr_free2 > 0)
r2 = sum_free2/nr_free2;
else
r2 = (ub2+lb2)/2;
si.r = (r1+r2)/2;
return (r1-r2)/2;
}
}
//
// Q matrices for various formulations
//
class SVC_Q extends Kernel
{
private final byte[] y;
private final Cache cache;
SVC_Q(svm_problem prob, svm_parameter param, byte[] y_)
{
super(prob.l, prob.x, param);
y = (byte[])y_.clone();
cache = new Cache(prob.l,(int)(param.cache_size*(1<<20)));
}
float[] get_Q(int i, int len)
{
float[][] data = new float[1][];
int start;
if((start = cache.get_data(i,data,len)) < len)
{
for(int j=start;j<len;j++)
data[0][j] = (float)(y[i]*y[j]*kernel_function(i,j));
}
return data[0];
}
void swap_index(int i, int j)
{
cache.swap_index(i,j);
super.swap_index(i,j);
do {byte _=y[i]; y[i]=y[j]; y[j]=_;} while(false);
}
}
class ONE_CLASS_Q extends Kernel
{
private final Cache cache;
ONE_CLASS_Q(svm_problem prob, svm_parameter param)
{
super(prob.l, prob.x, param);
cache = new Cache(prob.l,(int)(param.cache_size*(1<<20)));
}
float[] get_Q(int i, int len)
{
float[][] data = new float[1][];
int start;
if((start = cache.get_data(i,data,len)) < len)
{
for(int j=start;j<len;j++)
data[0][j] = (float)kernel_function(i,j);
}
return data[0];
}
void swap_index(int i, int j)
{
cache.swap_index(i,j);
super.swap_index(i,j);
}
}
class SVR_Q extends Kernel
{
private final int l;
private final Cache cache;
private final byte[] sign;
private final int[] index;
private int next_buffer;
private float[][] buffer;
SVR_Q(svm_problem prob, svm_parameter param)
{
super(prob.l, prob.x, param);
l = prob.l;
cache = new Cache(l,(int)(param.cache_size*(1<<20)));
sign = new byte[2*l];
index = new int[2*l];
for(int k=0;k<l;k++)
{
sign[k] = 1;
sign[k+l] = -1;
index[k] = k;
index[k+l] = k;
}
buffer = new float[2][2*l];
next_buffer = 0;
}
void swap_index(int i, int j)
{
do {byte _=sign[i]; sign[i]=sign[j]; sign[j]=_;} while(false);
do {int _=index[i]; index[i]=index[j]; index[j]=_;} while(false);
}
float[] get_Q(int i, int len)
{
float[][] data = new float[1][];
int real_i = index[i];
if(cache.get_data(real_i,data,l) < l)
{
for(int j=0;j<l;j++)
data[0][j] = (float)kernel_function(real_i,j);
}
// reorder and copy
float buf[] = buffer[next_buffer];
next_buffer = 1 - next_buffer;
byte si = sign[i];
for(int j=0;j<len;j++)
buf[j] = si * sign[j] * data[0][index[j]];
return buf;
}
}
public class svm {
//
// construct and solve various formulations
//
private static void solve_c_svc(svm_problem prob, svm_parameter param,
double[] alpha, Solver.SolutionInfo si,
double Cp, double Cn)
{
int l = prob.l;
double[] minus_ones = new double[l];
byte[] y = new byte[l];
int i;
for(i=0;i<l;i++)
{
alpha[i] = 0;
minus_ones[i] = -1;
if(prob.y[i] > 0) y[i] = +1; else y[i]=-1;
}
Solver s = new Solver();
s.Solve(l, new SVC_Q(prob,param,y), minus_ones, y,
alpha, Cp, Cn, param.eps, si, param.shrinking);
double sum_alpha=0;
for(i=0;i<l;i++)
sum_alpha += alpha[i];
System.out.print("nu = "+sum_alpha/(param.C*prob.l)+"\n");
for(i=0;i<l;i++)
alpha[i] *= y[i];
}
private static void solve_nu_svc(svm_problem prob, svm_parameter param,
double[] alpha, Solver.SolutionInfo si)
{
int i;
int l = prob.l;
double nu = param.nu;
int y_pos = 0;
int y_neg = 0;
byte[] y = new byte[l];
for(i=0;i<l;i++)
if(prob.y[i]>0)
{
y[i] = +1;
++y_pos;
}
else
{
y[i] = -1;
++y_neg;
}
if(nu < 0 || nu*l/2 > Math.min(y_pos,y_neg))
{
System.err.print("specified nu is infeasible\n");
System.exit(1);
}
double sum_pos = nu*l/2;
double sum_neg = nu*l/2;
for(i=0;i<l;i++)
if(y[i] == +1)
{
alpha[i] = Math.min(1.0,sum_pos);
sum_pos -= alpha[i];
}
else
{
alpha[i] = Math.min(1.0,sum_neg);
sum_neg -= alpha[i];
}
double[] zeros = new double[l];
for(i=0;i<l;i++)
zeros[i] = 0;
Solver_NU s = new Solver_NU();
s.Solve(l, new SVC_Q(prob,param,y), zeros, y,
alpha, 1.0, 1.0, param.eps, si, param.shrinking);
double r = si.r;
System.out.print("C = "+1/r+"\n");
for(i=0;i<l;i++)
alpha[i] *= y[i]/r;
si.rho /= r;
si.obj /= (r*r);
si.upper_bound_p = 1/r;
si.upper_bound_n = 1/r;
}
private static void solve_one_class(svm_problem prob, svm_parameter param,
double[] alpha, Solver.SolutionInfo si)
{
int l = prob.l;
double[] zeros = new double[l];
byte[] ones = new byte[l];
int i;
int n = (int)(param.nu*prob.l); // # of alpha's at upper bound
if(n>=prob.l)
{
System.err.print("nu must be in (0,1)\n");
System.exit(1);
}
for(i=0;i<n;i++)
alpha[i] = 1;
alpha[n] = param.nu * prob.l - n;
for(i=n+1;i<l;i++)
alpha[i] = 0;
for(i=0;i<l;i++)
{
zeros[i] = 0;
ones[i] = 1;
}
Solver s = new Solver();
s.Solve(l, new ONE_CLASS_Q(prob,param), zeros, ones,
alpha, 1.0, 1.0, param.eps, si, param.shrinking);
}
private static void solve_epsilon_svr(svm_problem prob, svm_parameter param,
double[] alpha, Solver.SolutionInfo si)
{
int l = prob.l;
double[] alpha2 = new double[2*l];
double[] linear_term = new double[2*l];
byte[] y = new byte[2*l];
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -