?? smo.java
字號:
if ((y2 == -1) && (a2 == C2)) {
m_I2.insert(i2);
} else {
m_I2.delete(i2);
}
if ((y2 == 1) && (a2 == C2)) {
m_I3.insert(i2);
} else {
m_I3.delete(i2);
}
if ((y2 == -1) && (a2 == 0)) {
m_I4.insert(i2);
} else {
m_I4.delete(i2);
}
// Update weight vector to reflect change a1 and a2, if linear SVM
if (!m_useRBF && m_exponent == 1.0) {
Instance inst1 = m_data.instance(i1);
for (int p1 = 0; p1 < inst1.numValues(); p1++) {
if (inst1.index(p1) != m_data.classIndex()) {
m_weights[inst1.index(p1)] +=
y1 * (a1 - alph1) * inst1.valueSparse(p1);
}
}
Instance inst2 = m_data.instance(i2);
for (int p2 = 0; p2 < inst2.numValues(); p2++) {
if (inst2.index(p2) != m_data.classIndex()) {
m_weights[inst2.index(p2)] +=
y2 * (a2 - alph2) * inst2.valueSparse(p2);
}
}
}
// Update error cache using new Lagrange multipliers
for (int j = m_I0.getNext(-1); j != -1; j = m_I0.getNext(j)) {
if ((j != i1) && (j != i2)) {
m_errors[j] +=
y1 * (a1 - alph1) * m_kernel.eval(i1, j, m_data.instance(i1)) +
y2 * (a2 - alph2) * m_kernel.eval(i2, j, m_data.instance(i2));
}
}
// Update error cache for i1 and i2
m_errors[i1] += y1 * (a1 - alph1) * k11 + y2 * (a2 - alph2) * k12;
m_errors[i2] += y1 * (a1 - alph1) * k12 + y2 * (a2 - alph2) * k22;
// Update array with Lagrange multipliers
m_alpha[i1] = a1;
m_alpha[i2] = a2;
// Update thresholds
m_bLow = -Double.MAX_VALUE; m_bUp = Double.MAX_VALUE;
m_iLow = -1; m_iUp = -1;
for (int j = m_I0.getNext(-1); j != -1; j = m_I0.getNext(j)) {
if (m_errors[j] < m_bUp) {
m_bUp = m_errors[j]; m_iUp = j;
}
if (m_errors[j] > m_bLow) {
m_bLow = m_errors[j]; m_iLow = j;
}
}
if (!m_I0.contains(i1)) {
if (m_I3.contains(i1) || m_I4.contains(i1)) {
if (m_errors[i1] > m_bLow) {
m_bLow = m_errors[i1]; m_iLow = i1;
}
} else {
if (m_errors[i1] < m_bUp) {
m_bUp = m_errors[i1]; m_iUp = i1;
}
}
}
if (!m_I0.contains(i2)) {
if (m_I3.contains(i2) || m_I4.contains(i2)) {
if (m_errors[i2] > m_bLow) {
m_bLow = m_errors[i2]; m_iLow = i2;
}
} else {
if (m_errors[i2] < m_bUp) {
m_bUp = m_errors[i2]; m_iUp = i2;
}
}
}
if ((m_iLow == -1) || (m_iUp == -1)) {
throw new Exception("This should never happen!");
}
// Made some progress.
return true;
}
/**
* Quick and dirty check whether the quadratic programming problem is solved.
*/
protected void checkClassifier() throws Exception {
double sum = 0;
for (int i = 0; i < m_alpha.length; i++) {
if (m_alpha[i] > 0) {
sum += m_class[i] * m_alpha[i];
}
}
System.err.println("Sum of y(i) * alpha(i): " + sum);
for (int i = 0; i < m_alpha.length; i++) {
double output = SVMOutput(i, m_data.instance(i));
if (Utils.eq(m_alpha[i], 0)) {
if (Utils.sm(m_class[i] * output, 1)) {
System.err.println("KKT condition 1 violated: " + m_class[i] * output);
}
}
if (Utils.gr(m_alpha[i], 0) &&
Utils.sm(m_alpha[i], m_C * m_data.instance(i).weight())) {
if (!Utils.eq(m_class[i] * output, 1)) {
System.err.println("KKT condition 2 violated: " + m_class[i] * output);
}
}
if (Utils.eq(m_alpha[i], m_C * m_data.instance(i).weight())) {
if (Utils.gr(m_class[i] * output, 1)) {
System.err.println("KKT condition 3 violated: " + m_class[i] * output);
}
}
}
}
}
/** The filter to apply to the training data */
public static final int FILTER_NORMALIZE = 0;
public static final int FILTER_STANDARDIZE = 1;
public static final int FILTER_NONE = 2;
public static final Tag [] TAGS_FILTER = {
new Tag(FILTER_NORMALIZE, "Normalize training data"),
new Tag(FILTER_STANDARDIZE, "Standardize training data"),
new Tag(FILTER_NONE, "No normalization/standardization"),
};
/** The binary classifier(s) */
protected BinarySMO[][] m_classifiers = null;
/** The exponent for the polynomial kernel. */
protected double m_exponent = 1.0;
/** Use lower-order terms? */
protected boolean m_lowerOrder = false;
/** Gamma for the RBF kernel. */
protected double m_gamma = 0.01;
/** The complexity parameter. */
protected double m_C = 1.0;
/** Epsilon for rounding. */
protected double m_eps = 1.0e-12;
/** Tolerance for accuracy of result. */
protected double m_tol = 1.0e-3;
/** Whether to normalize/standardize/neither */
protected int m_filterType = FILTER_NORMALIZE;
/** Feature-space normalization? */
protected boolean m_featureSpaceNormalization = false;
/** Use RBF kernel? (default: poly) */
protected boolean m_useRBF = false;
/** The size of the cache (a prime number) */
protected int m_cacheSize = 250007;
/** The filter used to make attributes numeric. */
protected NominalToBinary m_NominalToBinary;
/** The filter used to standardize/normalize all values. */
protected Filter m_Filter = null;
/** The filter used to get rid of missing values. */
protected ReplaceMissingValues m_Missing;
/** Only numeric attributes in the dataset? */
protected boolean m_onlyNumeric;
/** The class index from the training data */
protected int m_classIndex = -1;
/** The class attribute */
protected Attribute m_classAttribute;
/** Turn off all checks and conversions? Turning them off assumes
that data is purely numeric, doesn't contain any missing values,
and has a nominal class. Turning them off also means that
no header information will be stored if the machine is linear.
Finally, it also assumes that no instance has a weight equal to 0.*/
protected boolean m_checksTurnedOff;
/** Precision constant for updating sets */
protected static double m_Del = 1000 * Double.MIN_VALUE;
/** Whether logistic models are to be fit */
protected boolean m_fitLogisticModels = false;
/** The number of folds for the internal cross-validation */
protected int m_numFolds = -1;
/** The random number seed */
protected int m_randomSeed = 1;
/**
* Turns off checks for missing values, etc. Use with caution.
*/
public void turnChecksOff() {
m_checksTurnedOff = true;
}
/**
* Turns on checks for missing values, etc.
*/
public void turnChecksOn() {
m_checksTurnedOff = false;
}
/**
* Method for building the classifier. Implements a one-against-one
* wrapper for multi-class problems.
*
* @param insts the set of training instances
* @exception Exception if the classifier can't be built successfully
*/
public void buildClassifier(Instances insts) throws Exception {
if (!m_checksTurnedOff) {
if (insts.checkForStringAttributes()) {
throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");
}
if (insts.classAttribute().isNumeric()) {
throw new UnsupportedClassTypeException("SMO can't handle a numeric class! Use"
+ "SMOreg for performing regression.");
}
insts = new Instances(insts);
insts.deleteWithMissingClass();
if (insts.numInstances() == 0) {
throw new Exception("No training instances without a missing class!");
}
/* Removes all the instances with weight equal to 0.
MUST be done since condition (8) of Keerthi's paper
is made with the assertion Ci > 0 (See equation (3a). */
Instances data = new Instances(insts, insts.numInstances());
for(int i = 0; i < insts.numInstances(); i++){
if(insts.instance(i).weight() > 0)
data.add(insts.instance(i));
}
if (data.numInstances() == 0) {
throw new Exception("No training instances left after removing " +
"instance with either a weight null or a missing class!");
}
insts = data;
}
m_onlyNumeric = true;
if (!m_checksTurnedOff) {
for (int i = 0; i < insts.numAttributes(); i++) {
if (i != insts.classIndex()) {
if (!insts.attribute(i).isNumeric()) {
m_onlyNumeric = false;
break;
}
}
}
}
if (!m_checksTurnedOff) {
m_Missing = new ReplaceMissingValues();
m_Missing.setInputFormat(insts);
insts = Filter.useFilter(insts, m_Missing);
} else {
m_Missing = null;
}
if (!m_onlyNumeric) {
m_NominalToBinary = new NominalToBinary();
m_NominalToBinary.setInputFormat(insts);
insts = Filter.useFilter(insts, m_NominalToBinary);
} else {
m_NominalToBinary = null;
}
if (m_filterType == FILTER_STANDARDIZE) {
m_Filter = new Standardize();
m_Filter.setInputFormat(insts);
insts = Filter.useFilter(insts, m_Filter);
} else if (m_filterType == FILTER_NORMALIZE) {
m_Filter = new Normalize();
m_Filter.setInputFormat(insts);
insts = Filter.useFilter(insts, m_Filter);
} else {
m_Filter = null;
}
m_classIndex = insts.classIndex();
m_classAttribute = insts.classAttribute();
// Generate subsets representing each class
Instances[] subsets = new Instances[insts.numClasses()];
for (int i = 0; i < insts.numClasses(); i++) {
subsets[i] = new Instances(insts, insts.numInstances());
}
for (int j = 0; j < insts.numInstances(); j++) {
Instance inst = insts.instance(j);
subsets[(int)inst.classValue()].add(inst);
}
for (int i = 0; i < insts.numClasses(); i++) {
subsets[i].compactify();
}
// Build the binary classifiers
Random rand = new Random(m_randomSeed);
m_classifiers = new BinarySMO[insts.numClasses()][insts.numClasses()];
for (int i = 0; i < insts.numClasses(); i++) {
for (int j = i + 1; j < insts.numClasses(); j++) {
m_classifiers[i][j] = new BinarySMO();
Instances data = new Instances(insts, insts.numInstances());
for (int k = 0; k < subsets[i].numInstances(); k++) {
data.add(subsets[i].instance(k));
}
for (int k = 0; k < subsets[j].numInstances(); k++) {
data.add(subsets[j].instance(k));
}
data.compactify();
data.randomize(rand);
m_classifiers[i][j].buildClassifier(data, i, j,
m_fitLogisticModels,
m_numFolds, m_randomSeed);
}
}
}
/**
* Estimates class probabilities for given instance.
*/
public double[] distributionForInstance(Instance inst) throws Exception {
// Filter instance
if (!m_checksTurnedOff) {
m_Missing.input(inst);
m_Missing.batchFinished();
inst = m_Missing.output();
}
if (!m_onlyNumeric) {
m_NominalToBinary.input(inst);
m_NominalToBinary.batchFinished();
inst = m_NominalToBinary.output();
}
if (m_Filter != null) {
m_Filter.input(inst);
m_Filter.batchFinished();
inst = m_Filter.output();
}
if (!m_fitLogisticModels) {
double[] result = new double[inst.numClasses()];
for (int i = 0; i < inst.numClasses(); i++) {
for (int j = i + 1; j < inst.numClasses(); j++) {
if ((m_classifiers[i][j].m_alpha != null) ||
(m_classifiers[i][j].m_sparseWeights != null)) {
double output = m_classifiers[i][j].SVMOutput(-1, inst);
if (output > 0) {
result[j] += 1;
} else {
result[i] += 1;
}
}
}
}
Utils.normalize(result);
return result;
} else {
// We only need to do pairwise coupling if there are more
// then two classes.
if (inst.numClasses() == 2) {
double[] newInst = new double[2];
newInst[0] = m_classifiers[0][1].SVMOutput(-1, inst);
newInst[1] = Instance.missingValue();
return m_classifiers[0][1].m_logistic.
distributionForInstance(new Instance(1, newInst));
}
double[][] r = new double[inst.numClasses()][inst.numClasses()];
double[][] n = new double[inst.numClasses()][inst.numClasses()];
for (int i = 0; i < inst.numClasses(); i++) {
for (int j = i + 1; j < inst.numClasses(); j++) {
if ((m_classifiers[i][j].m_alpha != null) ||
(m_classifiers[i][j].m_sparseWeights != null)) {
double[] newInst = new double[2];
newInst[0] = m_classifiers[i][j].SVMOutput(-1, inst);
newInst[1] = Instance.missingValue();
r[i][j] = m_classifiers[i][j].m_logistic.
distributionForInstance(new Instance(1, newInst))[0];
n[i][j] = m_classifiers[i][j].m_sumOfWeights;
}
}
}
return pairwiseCoupling(n, r);
}
}
/**
* Implements pairwise coupling.
*
* @param n the sum of weights used to train each model
* @param r the probability estimate from each model
* @return the coupled estimates
*/
public double[] pairwiseCoupling(double[][] n, double[][] r) {
// Initialize p and u array
double[] p = new double[r.length];
for (int i =0; i < p.length; i++) {
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -