?? svm_learn.c
字號(hào):
/***********************************************************************//* *//* svm_learn.c *//* *//* Learning module of Support Vector Machine. *//* *//* Author: Thorsten Joachims *//* Date: 31.10.00 *//* *//* Copyright (c) 2000 Universitaet Dortmund - All rights reserved *//* *//* This software is available for non-commercial use only. It must *//* not be modified and distributed without prior permission of the *//* author. The author is not responsible for implications from the *//* use of this software. *//* *//***********************************************************************/# include "svm_common.h"# include "svm_learn.h"/* interface to QP-solver */double *optimize_qp(QP *, double *, long, double *, LEARN_PARM *);/*---------------------------------------------------------------------------*//* Learns an SVM classification model based on the training data in docs/label. The resulting model is returned in the structure model. */void svm_learn_classification(DOC *docs, double *class, long int totdoc, long int totwords, LEARN_PARM *learn_parm, KERNEL_PARM *kernel_parm, KERNEL_CACHE *kernel_cache, MODEL *model) /* docs: Training vectors (x-part) */ /* class: Training labels (y-part, zero if test example for transduction) */ /* totdoc: Number of examples in docs/label */ /* totwords: Number of features (i.e. highest feature index) */ /* learn_parm: Learning paramenters */ /* kernel_parm: Kernel paramenters */ /* kernel_cache:Initialized Cache of size 1*totdoc */ /* model: Returns learning result (assumed empty before called) */{ long *inconsistent,i,*label; long inconsistentnum; long misclassified,upsupvecnum; double loss,model_length,example_length; double maxdiff,*lin,*a,*c; long runtime_start,runtime_end; long iterations; long *unlabeled,transduction; long heldout; long loo_count=0,loo_count_pos=0,loo_count_neg=0,trainpos=0,trainneg=0; long loocomputed=0,runtime_start_loo=0,runtime_start_xa=0; double heldout_c=0,r_delta_sq=0,r_delta,r_delta_avg; double *xi_fullset; /* buffer for storing xi on full sample in loo */ double *a_fullset; /* buffer for storing alpha on full sample in loo */ TIMING timing_profile; SHRINK_STATE shrink_state; runtime_start=get_runtime(); timing_profile.time_kernel=0; timing_profile.time_opti=0; timing_profile.time_shrink=0; timing_profile.time_update=0; timing_profile.time_model=0; timing_profile.time_check=0; timing_profile.time_select=0; kernel_cache_statistic=0; learn_parm->totwords=totwords; /* make sure -n value is reasonable */ if((learn_parm->svm_newvarsinqp < 2) || (learn_parm->svm_newvarsinqp > learn_parm->svm_maxqpsize)) { learn_parm->svm_newvarsinqp=learn_parm->svm_maxqpsize; } init_shrink_state(&shrink_state,totdoc,(long)20000); label = (long *)my_malloc(sizeof(long)*totdoc); inconsistent = (long *)my_malloc(sizeof(long)*totdoc); unlabeled = (long *)my_malloc(sizeof(long)*totdoc); c = (double *)my_malloc(sizeof(double)*totdoc); a = (double *)my_malloc(sizeof(double)*totdoc); a_fullset = (double *)my_malloc(sizeof(double)*totdoc); xi_fullset = (double *)my_malloc(sizeof(double)*totdoc); lin = (double *)my_malloc(sizeof(double)*totdoc); learn_parm->svm_cost = (double *)my_malloc(sizeof(double)*totdoc); model->supvec = (DOC **)my_malloc(sizeof(DOC *)*(totdoc+2)); model->alpha = (double *)my_malloc(sizeof(double)*(totdoc+2)); model->index = (long *)my_malloc(sizeof(long)*(totdoc+2)); model->at_upper_bound=0; model->b=0; model->supvec[0]=0; /* element 0 reserved and empty for now */ model->alpha[0]=0; model->lin_weights=NULL; model->totwords=totwords; model->totdoc=totdoc; model->kernel_parm=(*kernel_parm); model->sv_num=1; model->loo_error=-1; model->loo_recall=-1; model->loo_precision=-1; model->xa_error=-1; model->xa_recall=-1; model->xa_precision=-1; inconsistentnum=0; transduction=0; r_delta=estimate_r_delta(docs,totdoc,kernel_parm); r_delta_sq=r_delta*r_delta; r_delta_avg=estimate_r_delta_average(docs,totdoc,kernel_parm); if(learn_parm->svm_c == 0.0) { /* default value for C */ learn_parm->svm_c=1.0/(r_delta_avg*r_delta_avg); if(verbosity>=1) printf("Setting default regularization parameter C=%.4f\n", learn_parm->svm_c); } learn_parm->eps=-1.0; /* equivalent regession epsilon for classification */ for(i=0;i<totdoc;i++) { /* various inits */ docs[i].docnum=i; inconsistent[i]=0; a[i]=0; lin[i]=0; c[i]=0.0; unlabeled[i]=0; if(class[i] == 0) { unlabeled[i]=1; transduction=1; } if(class[i] > 0) { learn_parm->svm_cost[i]=learn_parm->svm_c*learn_parm->svm_costratio* fabs(class[i]); label[i]=1; trainpos++; } else if(class[i] < 0) { learn_parm->svm_cost[i]=learn_parm->svm_c*fabs(class[i]); label[i]=-1; trainneg++; } else { learn_parm->svm_cost[i]=0; } } /* caching makes no sense for linear kernel */ if(kernel_parm->kernel_type == LINEAR) { kernel_cache = NULL; } if(transduction) { learn_parm->svm_iter_to_shrink=99999999; if(verbosity >= 1) printf("\nDeactivating Shrinking due to an incompatibility with the transductive \nlearner in the current version.\n\n"); } if(transduction && learn_parm->compute_loo) { learn_parm->compute_loo=0; if(verbosity >= 1) printf("\nCannot compute leave-one-out estimates for transductive learner.\n\n"); } if(learn_parm->remove_inconsistent && learn_parm->compute_loo) { learn_parm->compute_loo=0; printf("\nCannot compute leave-one-out estimates when removing inconsistent examples.\n\n"); } if(learn_parm->compute_loo && ((trainpos == 1) || (trainneg == 1))) { learn_parm->compute_loo=0; printf("\nCannot compute leave-one-out with only one example in one class.\n\n"); } if(verbosity==1) { printf("Optimizing"); fflush(stdout); } /* train the svm */ iterations=optimize_to_convergence(docs,label,totdoc,totwords,learn_parm, kernel_parm,kernel_cache,&shrink_state,model, inconsistent,unlabeled,a,lin, c,&timing_profile, &maxdiff,(long)-1, (long)1); if(verbosity>=1) { if(verbosity==1) printf("done. (%ld iterations)\n",iterations); misclassified=0; for(i=0;(i<totdoc);i++) { /* get final statistic */ if((lin[i]-model->b)*(double)label[i] <= 0.0) misclassified++; } printf("Optimization finished (%ld misclassified, maxdiff=%.5f).\n", misclassified,maxdiff); runtime_end=get_runtime(); if(verbosity>=2) { printf("Runtime in cpu-seconds: %.2f (%.2f%% for kernel/%.2f%% for optimizer/%.2f%% for final/%.2f%% for update/%.2f%% for model/%.2f%% for check/%.2f%% for select)\n", ((float)runtime_end-(float)runtime_start)/100.0, (100.0*timing_profile.time_kernel)/(float)(runtime_end-runtime_start), (100.0*timing_profile.time_opti)/(float)(runtime_end-runtime_start), (100.0*timing_profile.time_shrink)/(float)(runtime_end-runtime_start), (100.0*timing_profile.time_update)/(float)(runtime_end-runtime_start), (100.0*timing_profile.time_model)/(float)(runtime_end-runtime_start), (100.0*timing_profile.time_check)/(float)(runtime_end-runtime_start), (100.0*timing_profile.time_select)/(float)(runtime_end-runtime_start)); } else { printf("Runtime in cpu-seconds: %.2f\n", (runtime_end-runtime_start)/100.0); } if(learn_parm->remove_inconsistent) { inconsistentnum=0; for(i=0;i<totdoc;i++) if(inconsistent[i]) inconsistentnum++; printf("Number of SV: %ld (plus %ld inconsistent examples)\n", model->sv_num-1,inconsistentnum); } else { upsupvecnum=0; for(i=1;i<model->sv_num;i++) { if(fabs(model->alpha[i]) >= (learn_parm->svm_cost[(model->supvec[i])->docnum]- learn_parm->epsilon_a)) upsupvecnum++; } printf("Number of SV: %ld (including %ld at upper bound)\n", model->sv_num-1,upsupvecnum); } if((verbosity>=1) && (!learn_parm->skip_final_opt_check)) { loss=0; model_length=0; for(i=0;i<totdoc;i++) { if((lin[i]-model->b)*(double)label[i] < 1.0-learn_parm->epsilon_crit) loss+=1.0-(lin[i]-model->b)*(double)label[i]; model_length+=a[i]*label[i]*lin[i]; } model_length=sqrt(model_length); fprintf(stdout,"L1 loss: loss=%.5f\n",loss); fprintf(stdout,"Norm of weight vector: |w|=%.5f\n",model_length); example_length=estimate_sphere(model,kernel_parm); fprintf(stdout,"Norm of longest example vector: |x|=%.5f\n", length_of_longest_document_vector(docs,totdoc,kernel_parm)); fprintf(stdout,"Estimated VCdim of classifier: VCdim<=%.5f\n", estimate_margin_vcdim(model,model_length,example_length, kernel_parm)); if((!learn_parm->remove_inconsistent) && (!transduction)) { runtime_start_xa=get_runtime(); if(verbosity>=1) { printf("Computing XiAlpha-estimates..."); fflush(stdout); } compute_xa_estimates(model,label,unlabeled,totdoc,docs,lin,a, kernel_parm,learn_parm,&(model->xa_error), &(model->xa_recall),&(model->xa_precision)); if(verbosity>=1) { printf("done\n"); } printf("Runtime for XiAlpha-estimates in cpu-seconds: %.2f\n", (get_runtime()-runtime_start_xa)/100.0); fprintf(stdout,"XiAlpha-estimate of the error: error<=%.2f%% (rho=%.2f,depth=%ld)\n", model->xa_error,learn_parm->rho,learn_parm->xa_depth); fprintf(stdout,"XiAlpha-estimate of the recall: recall=>%.2f%% (rho=%.2f,depth=%ld)\n", model->xa_recall,learn_parm->rho,learn_parm->xa_depth); fprintf(stdout,"XiAlpha-estimate of the precision: precision=>%.2f%% (rho=%.2f,depth=%ld)\n", model->xa_precision,learn_parm->rho,learn_parm->xa_depth); } else if(!learn_parm->remove_inconsistent) { estimate_transduction_quality(model,label,unlabeled,totdoc,docs,lin); } } if(verbosity>=1) { printf("Number of kernel evaluations: %ld\n",kernel_cache_statistic); } } /* leave-one-out testing starts now */ if(learn_parm->compute_loo) { /* save results of training on full dataset for leave-one-out */ runtime_start_loo=get_runtime(); for(i=0;i<totdoc;i++) { xi_fullset[i]=1.0-((lin[i]-model->b)*(double)label[i]); if(xi_fullset[i]<0) xi_fullset[i]=0; a_fullset[i]=a[i]; } if(verbosity>=1) { printf("Computing leave-one-out"); } /* repeat this loop for every held-out example */ for(heldout=0;(heldout<totdoc);heldout++) { if(learn_parm->rho*a_fullset[heldout]*r_delta_sq+xi_fullset[heldout] < 1.0) { /* guaranteed to not produce a leave-one-out error */ if(verbosity==1) { printf("+"); fflush(stdout); } } else if(xi_fullset[heldout] > 1.0) { /* guaranteed to produce a leave-one-out error */ loo_count++; if(label[heldout] > 0) loo_count_pos++; else loo_count_neg++; if(verbosity==1) { printf("-"); fflush(stdout); } } else { loocomputed++; heldout_c=learn_parm->svm_cost[heldout]; /* set upper bound to zero */ learn_parm->svm_cost[heldout]=0; /* make sure heldout example is not currently */ /* shrunk away. Assumes that lin is up to date! */ shrink_state.active[heldout]=1; if(verbosity>=2) printf("\nLeave-One-Out test on example %ld\n",heldout); if(verbosity>=1) { printf("(?[%ld]",heldout); fflush(stdout); } optimize_to_convergence(docs,label,totdoc,totwords,learn_parm, kernel_parm, kernel_cache,&shrink_state,model,inconsistent,unlabeled, a,lin,c,&timing_profile, &maxdiff,heldout,(long)2); /* printf("%.20f\n",(lin[heldout]-model->b)*(double)label[heldout]); */ if(((lin[heldout]-model->b)*(double)label[heldout]) <= 0.0) { loo_count++; /* there was a loo-error */ if(label[heldout] > 0) loo_count_pos++; else loo_count_neg++; if(verbosity>=1) { printf("-)"); fflush(stdout); } } else { if(verbosity>=1) { printf("+)"); fflush(stdout); } } /* now we need to restore the original data set*/ learn_parm->svm_cost[heldout]=heldout_c; /* restore upper bound */ } } /* end of leave-one-out loop */ if(verbosity>=1) { printf("\nRetrain on full problem"); fflush(stdout); } optimize_to_convergence(docs,label,totdoc,totwords,learn_parm, kernel_parm, kernel_cache,&shrink_state,model,inconsistent,unlabeled, a,lin,c,&timing_profile, &maxdiff,(long)-1,(long)1); if(verbosity >= 1) printf("done.\n"); /* after all leave-one-out computed */ model->loo_error=100.0*loo_count/(double)totdoc; model->loo_recall=(1.0-(double)loo_count_pos/(double)trainpos)*100.0; model->loo_precision=(trainpos-loo_count_pos)/ (double)(trainpos-loo_count_pos+loo_count_neg)*100.0; if(verbosity >= 1) { fprintf(stdout,"Leave-one-out estimate of the error: error=%.2f%%\n", model->loo_error); fprintf(stdout,"Leave-one-out estimate of the recall: recall=%.2f%%\n", model->loo_recall); fprintf(stdout,"Leave-one-out estimate of the precision: precision=%.2f%%\n", model->loo_precision); fprintf(stdout,"Actual leave-one-outs computed: %ld (rho=%.2f)\n", loocomputed,learn_parm->rho); printf("Runtime for leave-one-out in cpu-seconds: %.2f\n", (double)(get_runtime()-runtime_start_loo)/100.0); } } if(learn_parm->alphafile[0]) write_alphas(learn_parm->alphafile,a,label,totdoc); shrink_state_cleanup(&shrink_state); free(label); free(inconsistent); free(unlabeled); free(c); free(a); free(a_fullset); free(xi_fullset); free(lin); free(learn_parm->svm_cost);}
?? 快捷鍵說明
復(fù)制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號(hào)
Ctrl + =
減小字號(hào)
Ctrl + -