?? smo.java
字號:
{
tar = Float.valueOf(v.lastElement().toString()).intValue();
target.add(new Integer (tar));
v.remove(v.size()-1);
n = v.size();
}
if (is_sparse_data && is_binary )
{
sparse_binary_vector x = new sparse_binary_vector();
for (int i=0; i<n; i++)
{
if (object2float(v.elementAt(i)) < 1 || object2float(v.elementAt(i)) > d)
{
int line2 = n_lines +1;
System.out.println("error: line " + line2 + ": attribute index "+ (int)object2float(v.elementAt(i))+ " out of range.\n");
System.exit(1);
}
x.id.add(new Integer((int)object2float(v.elementAt(i)) -1));
}
sparse_binary_points.add(x);
}
else if (is_sparse_data && !is_binary)
{
sparse_vector x = new sparse_vector();
if (this.is_libsvm_file)
{
for (int i=0; i<n; i+=2)
{
if (object2float(v.elementAt(i)) < 1 || object2float(v.elementAt(i)) > d)
{
int line3 = n_lines +1;
System.out.println("data file error: line " + line3 + ": attribute index " + (int)object2float(v.elementAt(i)) + " out of range.\n");
System.exit(1);
}
int id = (int)object2float(v.elementAt(i)) -1;
float value = (float)object2float(v.elementAt(i+1));
x.id.add(new Integer(id));
x.val.add(new Float(value));
}
sparse_points.add(x);
}
else
{
for (int i=0; i<n; i+=2)
{
if (object2float(v.elementAt(i)) < 1 || object2float(v.elementAt(i)) > d)
{
int line3 = n_lines +1;
System.out.println("data file error: line " + line3 + ": attribute index " + (int)object2float(v.elementAt(i)) + " out of range.\n");
System.exit(1);
}
int id = (int)object2float(v.elementAt(i)) -1;
float value = (float)object2float(v.elementAt(i+1));
x.id.add(new Integer(id));
x.val.add(new Float(value));
}
sparse_points.add(x);
}
}
else
{
if (v.size() != d)
{
int line4 = n_lines +1;
System.out.println("Data file error: line "+line4+ " has "+ v.size() +" attributes; should be d=" + d);
System.exit(1);
}
for ( int i=0; i<d; i++)
{
dense_points[N][i] = object2float(v.elementAt(i));
}
N= N+1;
}
}
}
catch(Exception e)
{
e.printStackTrace();
}
return n_lines;
}
void write_svm(PrintStream os)
{
os.println(d);
os.println(is_sparse_data);
os.println(is_binary);
os.println(is_linear_kernel);
os.println(b);
if ( is_linear_kernel)
{
for( int i=0; i<d; i++)
os.println(object2float(w.elementAt(i)));
}
else
{
os.println(two_sigma_squared);
int n_support_vectors =0;
for ( int i=0; i< end_support_i; i++)
if ( object2float(alph.elementAt(i)) > 0)
n_support_vectors++;
os.println(n_support_vectors);
for ( int i=0; i< end_support_i; i++)
if ( object2float(alph.elementAt(i)) >0)
os.println(object2float(alph.elementAt(i)));
for (int i=0; i<end_support_i; i++)
if (object2float(alph.elementAt(i)) > 0)
{
if (is_sparse_data && is_binary)
{
os.print(object2int( target.elementAt(i)));
os.print(" ");
for (int j=0; j<((sparse_binary_vector)sparse_binary_points.elementAt(i)).id.size(); j++)
{
os.print(object2int(((sparse_binary_vector)sparse_binary_points.elementAt(i)).id.elementAt(j)) +1);
os.print(" ");
}
}
else if (is_sparse_data && !is_binary)
{
os.print(object2int( target.elementAt(i)));
os.print(" ");
for (int j=0; j<((sparse_vector)sparse_points.elementAt(i)).id.size(); j++)
{
int id = object2int(((sparse_vector)sparse_points.elementAt(i)).id.elementAt(j)) +1;
float value = object2float(((sparse_vector)sparse_points.elementAt(i)).val.elementAt(j));
os.print(id+ " "+value+ " ");
}
}
else
{
for (int j=0; j<d; j++)
{
os.print(dense_points[i][j]);
os.print(" ");
}
os.print(object2int( target.elementAt(i)));
}
os.print("\n");
}
}
}
int read_svm(DataInputStream is)
{
try
{
d = Integer.valueOf(is.readLine().toString()).intValue();
is_sparse_data = Boolean.valueOf(is.readLine().toString()).booleanValue();
is_binary = Boolean.valueOf(is.readLine().toString()).booleanValue();
is_linear_kernel = Boolean.valueOf(is.readLine().toString()).booleanValue();
b = Float.valueOf(is.readLine().toString()).floatValue();
/*System.out.println("Finnished reading first few flags ...");
System.out.println("d = " + this.d);
System.out.println("is_sparse_data = " + this.is_sparse_data);
System.out.println("is_binary = " + this.is_binary);
System.out.println("is_linear_kernel = " + this.is_linear_kernel);*/
}
catch (Exception e)
{
e.printStackTrace();
}
if (is_linear_kernel)
{
resize(w,d,2);
for ( int i=0; i<d; i++)
{
try{
float weight = Float.valueOf(is.readLine().toString()).floatValue();
w.set(i,new Float(weight));
}
catch (Exception e)
{
e.printStackTrace();
}
}
}
else
{
try
{
two_sigma_squared = Float.valueOf(is.readLine().toString()).floatValue();
int n_support_vectors =0;
n_support_vectors = Integer.valueOf(is.readLine().toString()).intValue();
resize(alph,n_support_vectors,2);
for (int i =0; i< n_support_vectors;i++)
{
float value = Float.valueOf(is.readLine().toString()).floatValue();
alph.set(i,new Float(value));
}
}
catch (Exception e)
{
e.printStackTrace();
}
return read_data(is);
}
return 0;
}
float
error_rate()
{
int n_total = 0;
int n_error = 0;
for (int i=first_test_i; i<N; i++)
{
if ((learned_func(i,learned_func_flag) > 0) != (object2int(target.elementAt(i)) > 0))
n_error++;
n_total++;
}
return (float)n_error/(float)n_total;
}
float dot_product_func(int i,int j,int flag)
{
float result=0;
if (flag == 1)
result = dot_product_sparse_binary(i,j);
else if (flag == 2)
result = dot_product_sparse_nonbinary(i,j);
else if (flag ==3)
result = dot_product_dense(i,j);
return result;
}
float learned_func(int i, int flag)
{
float result =0;
if (flag == 1)
result =learned_func_linear_sparse_binary(i);
else if (flag == 2)
result = learned_func_linear_sparse_nonbinary(i);
else if (flag ==3)
result =learned_func_linear_dense(i);
else if (flag == 4)
result =learned_func_nonlinear(i);
return result;
}
float kernel_func(int i, int j, int flag)
{
float result =0;
if (flag == 1)
result = dot_product_func(i,j,this.dot_product_flag);
else if (flag == 2)
result = rbf_kernel(i,j);
return result;
}
void resize(Vector v, int newSize, int type)
{
int original = v.size();
if ( original > newSize)
{
v.setSize(newSize);
return;
}
for ( int i = original; i< newSize; i++)
{
if ( type == 1)
v.add(new Integer(0));
else if ( type ==2)
v.add(new Float(0));
}
}
void reserve (Vector v, int size, int type)
{
for ( int i=0; i<size; i++)
{
if ( type ==1)
v.add(i,new Integer(0));
else if ( type ==2)
v.add(i,new Float(0));
}
}
void reserveSparse(Vector v, int size)
{
for ( int i=0; i<size; i++)
{
v.add(i,new sparse_vector());
}
}
void reserveSparseBinary(Vector v, int size)
{
for ( int i=0; i<size; i++)
v.add(i,new sparse_binary_vector());
}
void reserve (float[][] array, int size)
{
for ( int i=0; i< size; i++)
for ( int j=0; j< d;j++)
array[i][j] = 0;
}
public static void main(String[] args)
{
long time, newTime;
time = System.currentTimeMillis();
try
{
String data_file_name = "java-svm.data";
String svm_file_name = "java-svm.model";
String output_file_name = "java-svm.output";
Smo my = new Smo();
int numChanged =0;
int examineAll =0;
{
GetOpt go = new GetOpt(args,"n:d:c:t:e:p:f:m:o:r:lsbai");
go.optErr= true;
int ch = -1;
int errflg = 0;
while ((ch = go.getopt()) != go.optEOF)
switch (ch)
{
case 'n':
my.N = go.processArg(go.optArgGet(),my.N);
break;
case 'd':
my.d = go.processArg(go.optArgGet(),my.d);
break;
case 'c':
my.C = go.processArg(go.optArgGet(),my.C);
break;
case 't':
my.tolerance = go.processArg(go.optArgGet(),my.tolerance);
break;
case 'e':
my.eps = go.processArg(go.optArgGet(),my.eps);
break;
case 'p':
my.two_sigma_squared = go.processArg(go.optArgGet(),my.two_sigma_squared);
break;
case 'f':
data_file_name = go.optArgGet();
break;
case 'm':
svm_file_name = go.optArgGet();
break;
case 'o':
output_file_name = go.optArgGet();
break;
case 'r':
System.out.println("Random");
break;
case 'l':
my.is_linear_kernel = true;
break;
case 's':
my.is_sparse_data = true;
break;
case 'b':
my.is_binary = true;
my.is_sparse_data =true;
break;
case 'a':
my.is_test_only = true;
break;
case 'i':
my.is_libsvm_file = true;
break;
case '?':
errflg++;
}
if (errflg >0 )
{
System.out.println("usage: " + args[0] + " " +
"\n-f data_file_name\n" +
"-m svm_file_name\n" +
"-o output_file_name\n" +
"-n N\n" +
"-d d\n" +
"-c C\n" +
"-t tolerance\n" +
"-e epsilon\n" +
"-p two_sigma_squared\n" +
//"-r random_seed\n" +
"-l (is_linear_kernel)\n"+
"-s (is_sparse_data)\n" +
"-b (is_binary)\n" +
"-a (is_test_only)\n" );
// "-i (is_libsvm_file)\n");
System.exit(2);
}
}
{
int n =0;
if (my.is_test_only)
{
try
{
FileInputStream svm = new FileInputStream(svm_file_name);
DataInputStream svm_file = new DataInputStream(svm);
my.end_support_i = my.first_test_i = n = my.read_svm(svm_file);
// my.N += n;
}
catch (Exception e)
{
e.printStackTrace();
}
}
if (my.N > 0)
{
my.reserve(my.target,my.N,1);
if (my.is_sparse_data && my.is_binary)
my.reserveSparseBinary(my.sparse_binary_points,my.N);
else if (my.is_sparse_data && !my.is_binary)
{
my.reserveSparse(my.sparse_points,my.N);
}
else
my.reserve(my.dense_points,my.N);
}
System.out.println(data_file_name);
FileInputStream data = new FileInputStream(data_file_name);
DataInputStream data_file = new DataInputStream (data);
n = my.read_data(data_file);
if (my.is_test_only)
{
my.N = my.first_test_i + n;
}
else
{
my.N = n;
my.first_test_i = 0;
my.end_support_i = my.N;
}
}
if (!my.is_test_only)
{
my.resize(my.alph,my.end_support_i,2);
my.b = 0;
my.resize(my.error_cache,my.N,2);
if (my.is_linear_kernel)
my.resize(my.w,my.d,2);
}
if (my.is_linear_kernel && my.is_sparse_data && my.is_binary)
my.learned_func_flag = 1;
if (my.is_linear_kernel && my.is_sparse_data && !my.is_binary)
my.learned_func_flag = 2;
if (my.is_linear_kernel && !my.is_sparse_data)
my.learned_func_flag = 3;
if (!my.is_linear_kernel)
my.learned_func_flag = 4;
if (my.is_sparse_data && my.is_binary)
my.dot_product_flag = 1;
if (my.is_sparse_data && !my.is_binary)
my.dot_product_flag = 2;
if (!my.is_sparse_data)
my.dot_product_flag = 3;
if (my.is_linear_kernel)
my.kernel_flag = 1;
if (!my.is_linear_kernel)
my.kernel_flag = 2;
/***************************************************************************/
// System.out.println("All flags " + "dot flag "+ my.dot_product_flag + ",kernel flag " +my.kernel_flag+ ",learn flag " + my.learned_func_flag);
/***************************************************************************/
if (!my.is_linear_kernel)
{
my.resize(my.precomputed_self_dot_product,my.N,2);
for (int i=0; i<my.N; i++)
for (int j=0; j<my.N; j++)
{
if (i != j)
my.precomputed_dot_product[i][j] = my.dot_product_func(i,j,my.dot_product_flag);
else
{
float temp = my.dot_product_func(i,i,my.dot_product_flag);
my.precomputed_self_dot_product.set(i,new Float(temp));
my.precomputed_dot_product[i][i] =temp;
}
}
}
if (!my.is_test_only)
{
numChanged = 0;
examineAll = 1;
while (numChanged > 0 || examineAll >0)
{
numChanged = 0;
if (examineAll>0)
{
for (int k = 0; k < my.N; k++)
numChanged += my.examineExample (k);
}
else
{
for (int k = 0; k < my.N; k++)
{
if (my.object2float(my.alph.elementAt(k)) != 0 && my.object2float(my.alph.elementAt(k)) != my.C)
numChanged += my.examineExample (k);
}
}
if (examineAll == 1)
examineAll = 0;
else if (numChanged == 0)
examineAll = 1;
{
int non_bound_support =0;
int bound_support =0;
for (int i=0; i<my.N; i++)
if (my.object2float(my.alph.elementAt(i)) > 0)
{
if (my.object2float(my.alph.elementAt(i)) < my.C)
{non_bound_support++;}
else
bound_support++;
}
System.out.println("non_bound= " +non_bound_support+"\t"+"bound_support= "+bound_support);
}
}
{
if (!my.is_test_only && svm_file_name != null)
{
try
{
PrintStream svm_file = new PrintStream(new FileOutputStream(svm_file_name));
my.write_svm(svm_file);
}
catch(Exception e)
{
e.printStackTrace();
}
}
}
System.out.println("Threshold=" + my.b);
}
System.out.println("Error_rate="+my.error_rate());
newTime = System.currentTimeMillis();
{
try
{
PrintStream svm_file = new PrintStream(new FileOutputStream(output_file_name));
for (int i=my.first_test_i; i<my.N; i++)
svm_file.println(my.learned_func(i,my.learned_func_flag));
}
catch(Exception e)
{
e.printStackTrace();
}
}
System.out.println("Time cost = "+(newTime - time)*1.0/1000);
}
catch(Exception e)
{
e.printStackTrace();
}
}
}
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -