?? decisionstump.java
字號:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * DecisionStump.java * Copyright (C) 1999 Eibe Frank * */package weka.classifiers.trees;import weka.classifiers.Classifier;import weka.classifiers.Evaluation;import weka.classifiers.Sourcable;import weka.core.Attribute;import weka.core.Capabilities;import weka.core.ContingencyTables;import weka.core.Instance;import weka.core.Instances;import weka.core.Utils;import weka.core.WeightedInstancesHandler;import weka.core.Capabilities.Capability;/** <!-- globalinfo-start --> * Class for building and using a decision stump. Usually used in conjunction with a boosting algorithm. Does regression (based on mean-squared error) or classification (based on entropy). Missing is treated as a separate value. * <p/> <!-- globalinfo-end --> * * Typical usage: <p> * <code>java weka.classifiers.trees.LogitBoost -I 100 -W weka.classifiers.trees.DecisionStump * -t training_data </code><p> * <!-- options-start --> * Valid options are: <p/> * * <pre> -D * If set, classifier is run in debug mode and * may output additional info to the console</pre> * <!-- options-end --> * * @author Eibe Frank (eibe@cs.waikato.ac.nz) * @version $Revision: 1.20 $ */public class DecisionStump extends Classifier implements WeightedInstancesHandler, Sourcable { /** for serialization */ static final long serialVersionUID = 1618384535950391L; /** The attribute used for classification. */ private int m_AttIndex; /** The split point (index respectively). */ private double m_SplitPoint; /** The distribution of class values or the means in each subset. */ private double[][] m_Distribution; /** The instances used for training. */ private Instances m_Instances; /** * Returns a string describing classifier * @return a description suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "Class for building and using a decision stump. Usually used in " + "conjunction with a boosting algorithm. Does regression (based on " + "mean-squared error) or classification (based on entropy). Missing " + "is treated as a separate value."; } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); // attributes result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.DATE_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); // class result.enable(Capability.NOMINAL_CLASS); result.enable(Capability.NUMERIC_CLASS); result.enable(Capability.DATE_CLASS); result.enable(Capability.MISSING_CLASS_VALUES); return result; } /** * Generates the classifier. * * @param instances set of instances serving as training data * @throws Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances instances) throws Exception { double bestVal = Double.MAX_VALUE, currVal; double bestPoint = -Double.MAX_VALUE; int bestAtt = -1, numClasses; // can classifier handle the data? getCapabilities().testWithFail(instances); // remove instances with missing class instances = new Instances(instances); instances.deleteWithMissingClass(); double[][] bestDist = new double[3][instances.numClasses()]; m_Instances = new Instances(instances); if (m_Instances.classAttribute().isNominal()) { numClasses = m_Instances.numClasses(); } else { numClasses = 1; } // For each attribute boolean first = true; for (int i = 0; i < m_Instances.numAttributes(); i++) { if (i != m_Instances.classIndex()) { // Reserve space for distribution. m_Distribution = new double[3][numClasses]; // Compute value of criterion for best split on attribute if (m_Instances.attribute(i).isNominal()) { currVal = findSplitNominal(i); } else { currVal = findSplitNumeric(i); } if ((first) || (currVal < bestVal)) { bestVal = currVal; bestAtt = i; bestPoint = m_SplitPoint; for (int j = 0; j < 3; j++) { System.arraycopy(m_Distribution[j], 0, bestDist[j], 0, numClasses); } } // First attribute has been investigated first = false; } } // Set attribute, split point and distribution. m_AttIndex = bestAtt; m_SplitPoint = bestPoint; m_Distribution = bestDist; if (m_Instances.classAttribute().isNominal()) { for (int i = 0; i < m_Distribution.length; i++) { double sumCounts = Utils.sum(m_Distribution[i]); if (sumCounts == 0) { // This means there were only missing attribute values System.arraycopy(m_Distribution[2], 0, m_Distribution[i], 0, m_Distribution[2].length); Utils.normalize(m_Distribution[i]); } else { Utils.normalize(m_Distribution[i], sumCounts); } } } // Save memory m_Instances = new Instances(m_Instances, 0); } /** * Calculates the class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return predicted class probability distribution * @throws Exception if distribution can't be computed */ public double[] distributionForInstance(Instance instance) throws Exception { return m_Distribution[whichSubset(instance)]; } /** * Returns the decision tree as Java source code. * * @param className the classname of the generated code * @return the tree as Java source code * @throws Exception if something goes wrong */ public String toSource(String className) throws Exception { StringBuffer text = new StringBuffer("class "); Attribute c = m_Instances.classAttribute(); text.append(className) .append(" {\n" +" public static double classify(Object [] i) {\n"); text.append(" /* " + m_Instances.attribute(m_AttIndex).name() + " */\n"); text.append(" if (i[").append(m_AttIndex); text.append("] == null) { return "); text.append(sourceClass(c, m_Distribution[2])).append(";"); if (m_Instances.attribute(m_AttIndex).isNominal()) { text.append(" } else if (((String)i[").append(m_AttIndex); text.append("]).equals(\""); text.append(m_Instances.attribute(m_AttIndex).value((int)m_SplitPoint)); text.append("\")"); } else { text.append(" } else if (((Double)i[").append(m_AttIndex); text.append("]).doubleValue() <= ").append(m_SplitPoint); } text.append(") { return "); text.append(sourceClass(c, m_Distribution[0])).append(";"); text.append(" } else { return "); text.append(sourceClass(c, m_Distribution[1])).append(";"); text.append(" }\n }\n}\n"); return text.toString(); } /** * Returns the value as string out of the given distribution * * @param c the attribute to get the value for * @param dist the distribution to extract the value * @return the value */ private String sourceClass(Attribute c, double []dist) { if (c.isNominal()) { return Integer.toString(Utils.maxIndex(dist)); } else { return Double.toString(dist[0]); } } /** * Returns a description of the classifier. * * @return a description of the classifier as a string. */ public String toString(){ if (m_Instances == null) { return "Decision Stump: No model built yet."; } try { StringBuffer text = new StringBuffer(); text.append("Decision Stump\n\n"); text.append("Classifications\n\n"); Attribute att = m_Instances.attribute(m_AttIndex); if (att.isNominal()) { text.append(att.name() + " = " + att.value((int)m_SplitPoint) + " : "); text.append(printClass(m_Distribution[0])); text.append(att.name() + " != " + att.value((int)m_SplitPoint) + " : "); text.append(printClass(m_Distribution[1])); } else { text.append(att.name() + " <= " + m_SplitPoint + " : "); text.append(printClass(m_Distribution[0])); text.append(att.name() + " > " + m_SplitPoint + " : "); text.append(printClass(m_Distribution[1])); } text.append(att.name() + " is missing : "); text.append(printClass(m_Distribution[2])); if (m_Instances.classAttribute().isNominal()) { text.append("\nClass distributions\n\n"); if (att.isNominal()) { text.append(att.name() + " = " + att.value((int)m_SplitPoint) + "\n"); text.append(printDist(m_Distribution[0])); text.append(att.name() + " != " + att.value((int)m_SplitPoint) + "\n"); text.append(printDist(m_Distribution[1])); } else { text.append(att.name() + " <= " + m_SplitPoint + "\n"); text.append(printDist(m_Distribution[0])); text.append(att.name() + " > " + m_SplitPoint + "\n"); text.append(printDist(m_Distribution[1])); } text.append(att.name() + " is missing\n"); text.append(printDist(m_Distribution[2])); } return text.toString(); } catch (Exception e) { return "Can't print decision stump classifier!"; } } /** * Prints a class distribution. * * @param dist the class distribution to print * @return the distribution as a string * @throws Exception if distribution can't be printed */ private String printDist(double[] dist) throws Exception { StringBuffer text = new StringBuffer(); if (m_Instances.classAttribute().isNominal()) { for (int i = 0; i < m_Instances.numClasses(); i++) { text.append(m_Instances.classAttribute().value(i) + "\t"); } text.append("\n"); for (int i = 0; i < m_Instances.numClasses(); i++) { text.append(dist[i] + "\t"); } text.append("\n"); } return text.toString(); } /** * Prints a classification. * * @param dist the class distribution * @return the classificationn as a string * @throws Exception if the classification can't be printed */ private String printClass(double[] dist) throws Exception { StringBuffer text = new StringBuffer(); if (m_Instances.classAttribute().isNominal()) { text.append(m_Instances.classAttribute().value(Utils.maxIndex(dist))); } else { text.append(dist[0]); } return text.toString() + "\n"; } /** * Finds best split for nominal attribute and returns value. * * @param index attribute index * @return value of criterion for the best split * @throws Exception if something goes wrong */ private double findSplitNominal(int index) throws Exception { if (m_Instances.classAttribute().isNominal()) { return findSplitNominalNominal(index); } else { return findSplitNominalNumeric(index);
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -