?? training.cs
字號:
//Copyright (C) 2007 Matthew Johnson
//This program is free software; you can redistribute it and/or modify
//it under the terms of the GNU General Public License as published by
//the Free Software Foundation; either version 2 of the License, or
//(at your option) any later version.
//This program is distributed in the hope that it will be useful,
//but WITHOUT ANY WARRANTY; without even the implied warranty of
//MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
//GNU General Public License for more details.
//You should have received a copy of the GNU General Public License along
//with this program; if not, write to the Free Software Foundation, Inc.,
//51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
using System;
using System.Collections.Generic;
using System.IO;
namespace SVM
{
/// <summary>
/// Class containing the routines to train SVM models.
/// </summary>
public static class Training
{
private static double doCrossValidation(Problem problem, Parameter parameters, int nr_fold)
{
int i;
int total_correct = 0;
double total_error = 0;
double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
double[] target = new double[problem.Count];
Procedures.svm_cross_validation(problem, parameters, nr_fold, target);
if (parameters.SvmType == SvmType.EPSILON_SVR || parameters.SvmType == SvmType.NU_SVR)
{
for (i = 0; i < problem.Count; i++)
{
double y = problem.Y[i];
double v = target[i];
total_error += (v - y) * (v - y);
sumv += v;
sumy += y;
sumvv += v * v;
sumyy += y * y;
sumvy += v * y;
}
}
else
for (i = 0; i < problem.Count; i++)
if (target[i] == problem.Y[i])
++total_correct;
return (double)total_correct / problem.Count;
}
/// <summary>
/// Legacy. Allows use as if this was svm_train. See libsvm documentation for details on which arguments to pass.
/// </summary>
/// <param name="args"></param>
[Obsolete("Provided only for legacy compatibility, use the other Train() methods")]
public static void Train(params string[] args)
{
Parameter parameters;
Problem problem;
bool crossValidation;
int nrfold;
string modelFilename;
parseCommandLine(args, out parameters, out problem, out crossValidation, out nrfold, out modelFilename);
if (crossValidation)
PerformCrossValidation(problem, parameters, nrfold);
else Model.Write(modelFilename, Train(problem, parameters));
}
/// <summary>
/// Performs cross validation.
/// </summary>
/// <param name="problem">The training data</param>
/// <param name="parameters">The parameters to test</param>
/// <param name="nrfold">The number of cross validations to use</param>
/// <returns>The cross validation score</returns>
public static double PerformCrossValidation(Problem problem, Parameter parameters, int nrfold)
{
Procedures.svm_check_parameter(problem, parameters);
return doCrossValidation(problem, parameters, nrfold);
}
/// <summary>
/// Trains a model using the provided training data and parameters.
/// </summary>
/// <param name="problem">The training data</param>
/// <param name="parameters">The parameters to use</param>
/// <returns>A trained SVM Model</returns>
public static Model Train(Problem problem, Parameter parameters)
{
Procedures.svm_check_parameter(problem, parameters);
return Procedures.svm_train(problem, parameters);
}
private static void parseCommandLine(string[] args, out Parameter parameters, out Problem problem, out bool crossValidation, out int nrfold, out string modelFilename)
{
int i;
parameters = new Parameter();
// default values
crossValidation = false;
nrfold = 0;
// parse options
for (i = 0; i < args.Length; i++)
{
if (args[i][0] != '-')
break;
++i;
switch (args[i - 1][1])
{
case 's':
parameters.SvmType = (SvmType)int.Parse(args[i]);
break;
case 't':
parameters.KernelType = (KernelType)int.Parse(args[i]);
break;
case 'd':
parameters.Degree = int.Parse(args[i]);
break;
case 'g':
parameters.Gamma = double.Parse(args[i]);
break;
case 'r':
parameters.Coefficient0 = double.Parse(args[i]);
break;
case 'n':
parameters.Nu = double.Parse(args[i]);
break;
case 'm':
parameters.CacheSize = double.Parse(args[i]);
break;
case 'c':
parameters.C = double.Parse(args[i]);
break;
case 'e':
parameters.EPS = double.Parse(args[i]);
break;
case 'p':
parameters.P = double.Parse(args[i]);
break;
case 'h':
parameters.Shrinking = int.Parse(args[i]) == 1;
break;
case 'b':
parameters.Probability = int.Parse(args[i]) == 1;
break;
case 'v':
crossValidation = true;
nrfold = int.Parse(args[i]);
if (nrfold < 2)
{
throw new ArgumentException("n-fold cross validation: n must >= 2");
}
break;
case 'w':
++parameters.WeightCount;
{
int[] old = parameters.WeightLabels;
parameters.WeightLabels = new int[parameters.WeightCount];
Array.Copy(old, 0, parameters.WeightLabels, 0, parameters.WeightCount - 1);
}
{
double[] old = parameters.Weights;
parameters.Weights = new double[parameters.WeightCount];
Array.Copy(old, 0, parameters.Weights, 0, parameters.WeightCount - 1);
}
parameters.WeightLabels[parameters.WeightCount - 1] = int.Parse(args[i - 1].Substring(2));
parameters.Weights[parameters.WeightCount - 1] = double.Parse(args[i]);
break;
default:
throw new ArgumentException("Unknown Parameter");
}
}
// determine filenames
if (i >= args.Length)
throw new ArgumentException("No input file specified");
problem = Problem.Read(args[i]);
if (parameters.Gamma == 0)
parameters.Gamma = 1.0 / problem.MaxIndex;
if (i < args.Length - 1)
modelFilename = args[i + 1];
else
{
int p = args[i].LastIndexOf('/') + 1;
modelFilename = args[i].Substring(p) + ".model";
}
}
}
}
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -