?? svm.java
字號:
package edu.udo.cs.mySVMdb.SVM;import edu.udo.cs.mySVMdb.Optimizer.*;import edu.udo.cs.mySVMdb.Container.*;import edu.udo.cs.mySVMdb.Kernel.*;import edu.udo.cs.mySVMdb.Util.*;// import java.util.BitSet;import java.lang.Integer;import java.lang.Double;import java.util.Random;import java.lang.Math;public abstract class SVM{ /** * Abstract base class for all SVMs * @author Stefan R黳ing * @version 1.0 */ protected Kernel the_kernel; protected JDBCDatabaseContainer the_container; protected int examples_total; protected int verbosity; protected int working_set_size; protected int parameters_working_set_size; // wss set in parameters protected int target_count; protected double convergence_epsilon; protected double is_zero; protected int shrink_const; protected double lambda_factor; protected int[] at_bound; protected double[] sum; protected boolean[] which_alpha; protected int[] working_set; protected double[] primal; protected double Cpos; protected double Cneg; protected double sum_alpha; protected double lambda_eq; protected double epsilon_pos; protected double epsilon_neg; protected int to_shrink; protected double feasible_epsilon; protected double lambda_WS; protected boolean quadraticLossPos; protected boolean quadraticLossNeg; protected double descend; boolean shrinked; MinHeap heap_min; MaxHeap heap_max; protected quadraticProblem qp; /** * class constructor */ public SVM() { }; /** * Init the SVM * @param Kernel new kernel function. * @param JDBCDatabaseContainer the data container * @exception Exception on any error */ public void init(Kernel new_kernel, JDBCDatabaseContainer new_container) { String dummy; the_kernel = new_kernel; the_container = new_container; examples_total = the_container.count_examples(); try{ verbosity = (new Integer(the_container.get_param("verbosity"))).intValue(); } catch(Exception e){ verbosity = 3; }; try{ working_set_size = (new Integer(the_container.get_param("working_set_size"))).intValue(); if(working_set_size < 2){ working_set_size = 2; }; } catch(Exception e){ working_set_size = 10; // !!! has to be identical to JDBCDatabaseContainer::prepareKisStatement }; parameters_working_set_size = working_set_size; try{ is_zero = (new Double(the_container.get_param("is_zero"))).doubleValue(); if(is_zero <= 0){ is_zero = 1e-10; }; } catch(Exception e){ is_zero = 1e-10; }; try{ convergence_epsilon = (new Double(the_container.get_param("convergence_epsilon"))).doubleValue(); if(convergence_epsilon <= 0){ convergence_epsilon = 1e-3; }; } catch(Exception e){ convergence_epsilon = 1e-3; }; try{ shrink_const = (new Integer(the_container.get_param("shrink_const"))).intValue(); if(shrink_const <= 0){ shrink_const = 50; }; } catch(Exception e){ shrink_const = 50; }; try{ descend = (new Double(the_container.get_param("descend"))).doubleValue(); if(descend < 0){ descend = 1e-15; }; } catch(Exception e){ descend = 1e-15; }; quadraticLossPos = false; try{ dummy = the_container.get_param("quadraticLossPos"); if(dummy.equals("true")){ quadraticLossPos = true; }; } catch(Exception e){ }; quadraticLossNeg = false; try{ dummy = the_container.get_param("quadraticLossNeg"); if(dummy.equals("true")){ quadraticLossNeg = true; }; } catch(Exception e){ }; try{ Cpos = (new Double(the_container.get_param("C"))).doubleValue(); Cneg = Cpos; } catch(Exception e){ Cpos = 1.0; Cneg = 1.0; }; try{ Cpos = (new Double(the_container.get_param("Cpos"))).doubleValue(); } catch(Exception e){ }; try{ Cneg = (new Double(the_container.get_param("Cneg"))).doubleValue(); } catch(Exception e){ }; // better in subclass: try{ epsilon_pos = (new Double(the_container.get_param("epsilon"))).doubleValue(); epsilon_neg = epsilon_pos; } catch(Exception e){ epsilon_pos = 0.0; epsilon_neg = 0.0; }; try{ epsilon_pos = (new Double(the_container.get_param("epsilon_pos"))).doubleValue(); } catch(Exception e){ }; try{ epsilon_neg = (new Double(the_container.get_param("epsilon_neg"))).doubleValue(); } catch(Exception e){ }; try{ dummy = the_container.get_param("balance_cost"); if(dummy.equals("true")){ Cpos *= ((double)the_container.count_pos_examples())/((double)the_container.count_examples()); Cneg *= ((double)(the_container.count_examples()-the_container.count_pos_examples()))/((double)the_container.count_examples()); }; } catch(Exception e){ }; // System.out.println("Cpos "+Cpos); // System.out.println("Cneg "+Cneg); // System.out.println("epsilon_pos "+epsilon_pos); // System.out.println("epsilon_neg "+epsilon_neg); lambda_factor = 1.0; lambda_eq=0; target_count=0; sum_alpha = 0; feasible_epsilon = convergence_epsilon; at_bound = new int[examples_total]; sum = new double[examples_total]; which_alpha = new boolean[examples_total]; primal = new double[working_set_size]; }; /** * Train the SVM * @exception Exception on any error */ public void train() throws Exception { target_count = 0; shrinked = false; init_optimizer(); init_working_set(); int iteration = 0; int max_iterations; try{ max_iterations = (new Integer(the_container.get_param("max_iterations"))).intValue(); } catch(Exception e){ max_iterations=30000; }; boolean converged=false; //long time_train_loop = System.currentTimeMillis(); //long time_dummy = 0; //long time_resetshrink = 0; M:while(iteration < max_iterations){ iteration++; logln(4,"optimizer iteration "+iteration); log(4,"."); optimize(); put_optimizer_values(); converged = convergence(); if(converged){ logln(4,""); // dots project_to_constraint(); if(shrinked){ // check convergence for all alphas logln(2,"***** Checking convergence for all variables"); // time_resetshrink -= System.currentTimeMillis(); reset_shrinked(); // time_resetshrink += System.currentTimeMillis(); converged = convergence(); }; if(converged){ logln(1,"*** Convergence"); break M; }; // set variables free again shrink_const += 10; target_count = 0; for(int i=0;i<examples_total;i++){ at_bound[i]=0; }; }; shrink(); calculate_working_set(); update_working_set(); }; //time_train_loop = (System.currentTimeMillis() - time_train_loop)/1000; int i; if((iteration >= max_iterations) && (! converged)){ logln(1,"*** No convergence: Time up."); if(shrinked){ // set sums for all variables for statistics //time_resetshrink -= System.currentTimeMillis(); reset_shrinked(); //time_resetshrink += System.currentTimeMillis(); }; }; // calculate b double new_b=0; int new_b_count=0; double[] my_sum = sum; double[] my_y = the_container.get_ys(); double[] my_alphas = the_container.get_alphas(); for(i=0;i<examples_total;i++){ if((my_alphas[i]-Cneg < -is_zero) && (my_alphas[i] > is_zero)){ new_b += my_y[i] - my_sum[i]-epsilon_neg; new_b_count++; } else if((my_alphas[i]+Cpos > is_zero) && (my_alphas[i] < -is_zero)){ new_b += my_y[i] - my_sum[i]+epsilon_pos; new_b_count++; }; }; if(new_b_count>0){ the_container.set_b(new_b/((double)new_b_count)); } else{ // unlikely for(i=0;i<examples_total;i++){ if((my_alphas[i]<is_zero) && (my_alphas[i]>-is_zero)) { new_b += my_y[i] - my_sum[i]; new_b_count++; }; }; if(new_b_count>0){ the_container.set_b(new_b/((double)new_b_count)); } else{ // even unlikelier for(i=0;i<examples_total;i++){ new_b += my_y[i] - my_sum[i]; new_b_count++; }; the_container.set_b(new_b/((double)new_b_count)); }; }; if(verbosity>= 2){ logln(2,"Done training: "+iteration+" iterations."); if(verbosity>= 3){ double now_target=0; double now_target_dummy=0; for(i=0;i<examples_total;i++){ now_target_dummy=sum[i]/2-the_container.get_y(i); if(is_alpha_neg(i)){ now_target_dummy+= epsilon_pos; } else{ now_target_dummy-= epsilon_neg; }; now_target+=the_container.get_alpha(i)*now_target_dummy; }; logln(3,"Target function: "+now_target); }; }; print_statistics(); exit_optimizer(); // System.out.println("Time in resetshrink: "+(time_resetshrink/1000)+"s"); // System.out.println("Time in train loop: "+time_train_loop+"s"); }; /** * print statistics about result */ protected void print_statistics() throws Exception { int dim = the_container.get_dim(); int i,j; double alpha; double[] x; int svs=0; int bsv = 0; double mae=0; double mse = 0; int countpos = 0; int countneg = 0; double y; double prediction; double min_lambda = Double.MAX_VALUE; double b = the_container.get_b(); for(i=0;i<examples_total;i++){ if(lambda(i) < min_lambda){ min_lambda = lambda(i); }; y = the_container.get_y(i); prediction = sum[i]+b; mae += Math.abs(y-prediction); mse += (y-prediction)*(y-prediction); alpha = the_container.get_alpha(i); if(y < prediction-epsilon_pos){ countpos++; } else if(y > prediction+epsilon_neg){ countneg++; }; if(alpha != 0){ svs++; if((alpha == Cpos) || (alpha == -Cneg)){ bsv++; }; }; }; mae /= (double)examples_total; mse /= (double)examples_total; min_lambda = -min_lambda; logln(1,"Error on KKT is "+min_lambda); logln(1,svs+" SVs"); logln(1,bsv+" BSVs"); logln(1,"MAE "+mae); logln(1,"MSE "+mse); logln(1,countpos+" pos loss"); logln(1,countneg+" neg loss"); if(verbosity >= 2){ // print hyperplane double[] w = new double[dim]; for(j=0;j<dim;j++) w[j] = 0; for(i=0;i<examples_total;i++){ x = the_container.get_example(i); alpha = the_container.get_alpha(i); for(j=0;j<dim;j++){ w[j] += alpha*x[j]; }; }; double[] Exp = the_container.Exp; double[] Dev = the_container.Dev; if(Exp != null){ for(j=0;j<dim;j++){ if(Dev[j] != 0){ w[j] /= Dev[j]; }; if(0 != Dev[dim]){ w[j] *= Dev[dim]; }; b -= w[j]*Exp[j]; }; b += Exp[dim]; }; logln(2," "); for(j=0;j<dim;j++){ logln(2,"w["+j+"] = "+w[j]); }; logln(2,"b = "+b);
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -