?? learn.cpp
字號:
svm_result train(){
svm_result the_result;
if(parameters->cross_validation > 0){
the_result = do_cv();
}
else{
kernel->init(parameters->kernel_cache,training_set);
svm->init(kernel,parameters);
if(parameters->is_nu || parameters->is_distribution){
cout<<"Training started with nu = "
<<parameters->nu
<<"."<<endl;
}
else if(parameters->get_Cpos() == parameters->get_Cneg()){
cout<<"Training started with C = "
<<parameters->get_Cpos()
<<"."<<endl;
}
else{
cout<<"Training started with C = ("<<parameters->get_Cpos()
<<","<<parameters->get_Cneg()<<")."<<endl;
};
the_result = svm->train(training_set);
};
return the_result;
};
inline
SVMFLOAT to_minimize(svm_result result){
// which value to minimize in calc_c
if((parameters->cross_validation <= 0) && (1 == parameters->is_pattern)){
return result.pred_accuracy;
}
else{
return result.loss;
};
};
svm_result calc_c(){
const SVMFLOAT lambda = 0.618033989; // (sqrt(5)-1)/2
SVMINT verbosity = parameters->verbosity;
parameters->verbosity -= 2;
svm_result the_result;
SVMFLOAT c_min = parameters->c_min;
SVMFLOAT c_max = parameters->c_max;
SVMFLOAT c_delta = parameters->c_delta;
SVMFLOAT oldC;
SVMINT last_dec=0; // when did loss decrease?
// setup s,t
if(verbosity >= 3){
cout<<"starting search for C"<<endl;
};
if((parameters->search_c == 'a') ||(parameters->search_c == 'm')){
SVMFLOAT minimal_value=infinity;
SVMFLOAT minimal_C=c_min;
svm_result minimal_result;
SVMFLOAT result_value;
oldC=c_min;
training_set->clear_alpha();
while(c_min <= c_max){
if(verbosity>=3){
cout<<"C = "<<c_min<<" :"<<endl;
};
parameters->realC = c_min;
training_set->scale_alphas(c_min/oldC);
// training_set->clear_alpha();
oldC = c_min;
if(verbosity >= 4){
cout<<"C = "<<c_min<<endl;
}
the_result = train();
if(verbosity>=3){
cout<<"loss = "<<the_result.loss<<endl;
if(parameters->is_pattern){
cout<<"predicted loss = "<<the_result.pred_loss<<endl;
};
cout<<"VCdim <= "<<the_result.VCdim<<endl;
};
result_value = to_minimize(the_result);
// cout<<result_value<<endl;
last_dec++;
if((result_value<minimal_value) && (! _isnan(result_value))){
minimal_value=result_value;
minimal_C=c_min;
minimal_result = the_result;
last_dec=0;
};
if(parameters->search_c == 'a'){
c_min += c_delta;
}
else{
c_min *= c_delta;
};
if((parameters->search_stop > 0) && (last_dec >= parameters->search_stop)){
// no decrease in loss, stop
c_min = 2*c_max;
};
};
parameters->realC = minimal_C;
the_result = minimal_result;
}
else{
// method of golden ratio
SVMFLOAT s = lambda*c_min+(1-lambda)*c_max;
SVMFLOAT t = (1-lambda)*c_min+lambda*c_max;
SVMFLOAT phi_s;
SVMFLOAT phi_t;
parameters->realC = s;
training_set->clear_alpha();
the_result = train();
phi_s = to_minimize(the_result);
parameters->realC = t;
training_set->scale_alphas(t/s);
oldC = t;
the_result = train();
phi_t = to_minimize(the_result);
while(c_max - c_min > c_delta*c_min){
if(verbosity >= 3){
cout<<"C in ["<<c_min<<","<<c_max<<"]"<<endl;
};
if(phi_s < phi_t){
c_max = t;
t = s;
phi_t = phi_s;
// calc s
s = lambda*c_min+(1-lambda)*c_max;
parameters->realC = s;
training_set->scale_alphas(s/oldC);
oldC=s;
the_result = train();
phi_s = to_minimize(the_result);
}
else{
c_min = s;
s = t;
phi_s = phi_t;
// calc t
t = (1-lambda)*c_min+lambda*c_max;
parameters->realC = t;
training_set->scale_alphas(t/oldC);
oldC=t;
the_result = train();
phi_t = to_minimize(the_result);
};
};
// save last results
if(phi_s < phi_t){
c_max = t;
}
else{
c_min = s;
};
parameters->realC = (c_min+c_max)/2;
};
// ouput result
if(verbosity >= 1){
cout<<"*** Optimal C is "<<parameters->realC;
if(parameters->search_c == 'g'){
cout<<" +/-"<<((c_max-c_min)/2);
};
cout<<endl;
};
if(verbosity>=2){
cout<<"result:"<<endl
<<"Loss: "<<the_result.loss<<endl;
if(parameters->Lpos != parameters->Lneg){
cout<<" Loss+: "<<the_result.loss_pos<<endl;
cout<<" Loss-: "<<the_result.loss_neg<<endl;
};
if(parameters->is_pattern){
cout<<"predicted Loss: "<<the_result.pred_loss<<endl;
};
cout<<"MAE: "<<the_result.MAE<<endl;
cout<<"MSE: "<<the_result.MSE<<endl;
cout<<"VCdim <= "<<the_result.VCdim<<endl;
if(parameters->is_pattern){
cout<<"Accuracy : "<<the_result.accuracy<<endl
<<"Precision : "<<the_result.precision<<endl
<<"Recall : "<<the_result.recall<<endl;
if(parameters->cross_validation == 0){
cout<<"predicted Accuracy : "<<the_result.pred_accuracy<<endl
<<"predicted Precision : "<<the_result.pred_precision<<endl
<<"predicted Recall : "<<the_result.pred_recall<<endl;
};
};
cout<<"Support Vectors : "<<the_result.number_svs<<endl;
cout<<"Bounded SVs : "<<the_result.number_bsv<<endl;
if(parameters->search_c == 'g'){
cout<<"(WARNING: this is the last result attained and may slightly differ from the result of the optimal C!)"<<endl;
};
};
parameters->verbosity = verbosity;
return the_result;
};
///////////////////////////////////////////////////////////////
int main(int argc,char* argv[]){
cout<<"*** mySVM version "<<mysvmversion<<" ***"<<endl;
cout.precision(8);
// read objects
try{
if(argc<2){
cout<<"Reading from STDIN"<<endl;
// read vom cin
read_input(cin,"mysvm");
}
else{
char* s = argv[1];
if((0 == strcmp("-h",s)) || (0==strcmp("-help",s)) || (0==strcmp("--help",s))){
// print out command-line help
print_help();
}
else{
// read in all input files
for(int i=1;i<argc;i++){
if(0 == strcmp(argv[i],"-")){
cout<<"Reading from STDIN"<<endl;
// read vom cin
read_input(cin,"mysvm");
}
else{
cout<<"Reading "<<argv[i]<<endl;
ifstream input_file(argv[i]);
if(input_file.bad()){
cout<<"ERROR: Could not read file \""<<argv[i]<<"\", exiting."<<endl;
exit(1);
};
read_input(input_file,argv[i]);
input_file.close();
};
};
};
};
}
catch(general_exception &the_ex){
cout<<"*** Error while reading input: "<<the_ex.error_msg<<endl;
exit(1);
}
catch(...){
cout<<"*** Program ended because of unknown error while reading input"<<endl;
exit(1);
};
if(0 == parameters){
cout << "*** ERROR: You did not enter the svm parameters"<<endl;
exit(1);
};
if(0 == kernel){
kernel = new kernel_dot_c();
};
if(0 == training_set){
cout << "*** ERROR: You did not enter the training set"<<endl;
exit(1);
};
if(parameters->is_distribution){
svm = new svm_distribution_c();
cout<<"distribution estimation SVM generated"<<endl;
}
else if(parameters->is_nu){
if(parameters->is_pattern){
svm = new svm_nu_pattern_c();
cout<<"nu-PSVM generated"<<endl;
}
else{
svm = new svm_nu_regression_c();
cout<<"nu-RSVM generated"<<endl;
};
}
else if(parameters->is_pattern){
svm = new svm_pattern_c();
cout<<"PSVM generated"<<endl;
}
else{
svm = new svm_regression_c();
cout<<"RSVM generated"<<endl;
};
// scale examples
if(parameters->do_scale){
training_set->scale();
};
// training the svm
if(parameters->search_c != 'n'){
calc_c();
cout<<"re-training without CV and C = "<<parameters->realC<<endl;
parameters->cross_validation = 0;
parameters->verbosity -= 1;
train();
parameters->verbosity += 1;
}
else{
train();
};
if(0 == parameters->cross_validation){
// save results
if(parameters->verbosity > 1){
cout<<"Saving trained SVM to "<<(training_set->get_filename())<<".svm"<<endl;
};
char* outname = new char[MAXCHAR];
strcpy(outname,training_set->get_filename());
strcat(outname,".svm");
ofstream output_file(outname,
ios::out|ios::trunc,filebuf::openprot);
output_file.precision(16);
output_file<<*training_set;
output_file.close();
delete []outname;
};
// testing
if((parameters->cross_validation > 0) && (0 != test_sets)){
// test result of cross validation: train new SVM on whole example set
parameters->cross_validation = 0;
cout<<"Re-training SVM on whole example set for testing"<<endl;
train();
};
if(0 != test_sets){
cout<<"----------------------------------------"<<endl;
cout<<"Starting tests"<<endl;
example_set_c* next_test;
SVMINT test_no = 0;
char* outname = new char[MAXCHAR];
while(test_sets != 0){
test_no++;
next_test = test_sets->the_set;
if(parameters->do_scale){
next_test->scale(training_set->get_exp(),
training_set->get_var(),
training_set->get_dim());
};
if(next_test->initialised_y()){
cout<<"Testing examples from file "<<(next_test->get_filename())<<endl;
svm->test(next_test,1);
}
else{
cout<<"Predicting examples from file "<<(next_test->get_filename())<<endl;
svm->predict(next_test);
// output to file .pred
strcpy(outname,next_test->get_filename());
strcat(outname,".pred");
ofstream output_file(outname,
ios::out,filebuf::openprot);
output_file<<"@examples"<<endl;
output_file<<(*next_test);
output_file.close();
};
test_sets = test_sets->next; // skip delete!
};
delete []outname;
};
delete svm;
if(parameters->verbosity > 1){
cout << "mysvm ended successfully."<<endl;
};
return 0;
};
?? 快捷鍵說明
復(fù)制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -