?? decisionstump.java
字號:
} } /** * Finds best split for nominal attribute and nominal class * and returns value. * * @param index attribute index * @return value of criterion for the best split * @throws Exception if something goes wrong */ private double findSplitNominalNominal(int index) throws Exception { double bestVal = Double.MAX_VALUE, currVal; double[][] counts = new double[m_Instances.attribute(index).numValues() + 1][m_Instances.numClasses()]; double[] sumCounts = new double[m_Instances.numClasses()]; double[][] bestDist = new double[3][m_Instances.numClasses()]; int numMissing = 0; // Compute counts for all the values for (int i = 0; i < m_Instances.numInstances(); i++) { Instance inst = m_Instances.instance(i); if (inst.isMissing(index)) { numMissing++; counts[m_Instances.attribute(index).numValues()] [(int)inst.classValue()] += inst.weight(); } else { counts[(int)inst.value(index)][(int)inst.classValue()] += inst .weight(); } } // Compute sum of counts for (int i = 0; i < m_Instances.attribute(index).numValues(); i++) { for (int j = 0; j < m_Instances.numClasses(); j++) { sumCounts[j] += counts[i][j]; } } // Make split counts for each possible split and evaluate System.arraycopy(counts[m_Instances.attribute(index).numValues()], 0, m_Distribution[2], 0, m_Instances.numClasses()); for (int i = 0; i < m_Instances.attribute(index).numValues(); i++) { for (int j = 0; j < m_Instances.numClasses(); j++) { m_Distribution[0][j] = counts[i][j]; m_Distribution[1][j] = sumCounts[j] - counts[i][j]; } currVal = ContingencyTables.entropyConditionedOnRows(m_Distribution); if (currVal < bestVal) { bestVal = currVal; m_SplitPoint = (double)i; for (int j = 0; j < 3; j++) { System.arraycopy(m_Distribution[j], 0, bestDist[j], 0, m_Instances.numClasses()); } } } // No missing values in training data. if (numMissing == 0) { System.arraycopy(sumCounts, 0, bestDist[2], 0, m_Instances.numClasses()); } m_Distribution = bestDist; return bestVal; } /** * Finds best split for nominal attribute and numeric class * and returns value. * * @param index attribute index * @return value of criterion for the best split * @throws Exception if something goes wrong */ private double findSplitNominalNumeric(int index) throws Exception { double bestVal = Double.MAX_VALUE, currVal; double[] sumsSquaresPerValue = new double[m_Instances.attribute(index).numValues()], sumsPerValue = new double[m_Instances.attribute(index).numValues()], weightsPerValue = new double[m_Instances.attribute(index).numValues()]; double totalSumSquaresW = 0, totalSumW = 0, totalSumOfWeightsW = 0, totalSumOfWeights = 0, totalSum = 0; double[] sumsSquares = new double[3], sumOfWeights = new double[3]; double[][] bestDist = new double[3][1]; // Compute counts for all the values for (int i = 0; i < m_Instances.numInstances(); i++) { Instance inst = m_Instances.instance(i); if (inst.isMissing(index)) { m_Distribution[2][0] += inst.classValue() * inst.weight(); sumsSquares[2] += inst.classValue() * inst.classValue() * inst.weight(); sumOfWeights[2] += inst.weight(); } else { weightsPerValue[(int)inst.value(index)] += inst.weight(); sumsPerValue[(int)inst.value(index)] += inst.classValue() * inst.weight(); sumsSquaresPerValue[(int)inst.value(index)] += inst.classValue() * inst.classValue() * inst.weight(); } totalSumOfWeights += inst.weight(); totalSum += inst.classValue() * inst.weight(); } // Check if the total weight is zero if (totalSumOfWeights <= 0) { return bestVal; } // Compute sum of counts without missing ones for (int i = 0; i < m_Instances.attribute(index).numValues(); i++) { totalSumOfWeightsW += weightsPerValue[i]; totalSumSquaresW += sumsSquaresPerValue[i]; totalSumW += sumsPerValue[i]; } // Make split counts for each possible split and evaluate for (int i = 0; i < m_Instances.attribute(index).numValues(); i++) { m_Distribution[0][0] = sumsPerValue[i]; sumsSquares[0] = sumsSquaresPerValue[i]; sumOfWeights[0] = weightsPerValue[i]; m_Distribution[1][0] = totalSumW - sumsPerValue[i]; sumsSquares[1] = totalSumSquaresW - sumsSquaresPerValue[i]; sumOfWeights[1] = totalSumOfWeightsW - weightsPerValue[i]; currVal = variance(m_Distribution, sumsSquares, sumOfWeights); if (currVal < bestVal) { bestVal = currVal; m_SplitPoint = (double)i; for (int j = 0; j < 3; j++) { if (sumOfWeights[j] > 0) { bestDist[j][0] = m_Distribution[j][0] / sumOfWeights[j]; } else { bestDist[j][0] = totalSum / totalSumOfWeights; } } } } m_Distribution = bestDist; return bestVal; } /** * Finds best split for numeric attribute and returns value. * * @param index attribute index * @return value of criterion for the best split * @throws Exception if something goes wrong */ private double findSplitNumeric(int index) throws Exception { if (m_Instances.classAttribute().isNominal()) { return findSplitNumericNominal(index); } else { return findSplitNumericNumeric(index); } } /** * Finds best split for numeric attribute and nominal class * and returns value. * * @param index attribute index * @return value of criterion for the best split * @throws Exception if something goes wrong */ private double findSplitNumericNominal(int index) throws Exception { double bestVal = Double.MAX_VALUE, currVal, currCutPoint; int numMissing = 0; double[] sum = new double[m_Instances.numClasses()]; double[][] bestDist = new double[3][m_Instances.numClasses()]; // Compute counts for all the values for (int i = 0; i < m_Instances.numInstances(); i++) { Instance inst = m_Instances.instance(i); if (!inst.isMissing(index)) { m_Distribution[1][(int)inst.classValue()] += inst.weight(); } else { m_Distribution[2][(int)inst.classValue()] += inst.weight(); numMissing++; } } System.arraycopy(m_Distribution[1], 0, sum, 0, m_Instances.numClasses()); // Save current distribution as best distribution for (int j = 0; j < 3; j++) { System.arraycopy(m_Distribution[j], 0, bestDist[j], 0, m_Instances.numClasses()); } // Sort instances m_Instances.sort(index); // Make split counts for each possible split and evaluate for (int i = 0; i < m_Instances.numInstances() - (numMissing + 1); i++) { Instance inst = m_Instances.instance(i); Instance instPlusOne = m_Instances.instance(i + 1); m_Distribution[0][(int)inst.classValue()] += inst.weight(); m_Distribution[1][(int)inst.classValue()] -= inst.weight(); if (inst.value(index) < instPlusOne.value(index)) { currCutPoint = (inst.value(index) + instPlusOne.value(index)) / 2.0; currVal = ContingencyTables.entropyConditionedOnRows(m_Distribution); if (currVal < bestVal) { m_SplitPoint = currCutPoint; bestVal = currVal; for (int j = 0; j < 3; j++) { System.arraycopy(m_Distribution[j], 0, bestDist[j], 0, m_Instances.numClasses()); } } } } // No missing values in training data. if (numMissing == 0) { System.arraycopy(sum, 0, bestDist[2], 0, m_Instances.numClasses()); } m_Distribution = bestDist; return bestVal; } /** * Finds best split for numeric attribute and numeric class * and returns value. * * @param index attribute index * @return value of criterion for the best split * @throws Exception if something goes wrong */ private double findSplitNumericNumeric(int index) throws Exception { double bestVal = Double.MAX_VALUE, currVal, currCutPoint; int numMissing = 0; double[] sumsSquares = new double[3], sumOfWeights = new double[3]; double[][] bestDist = new double[3][1]; double totalSum = 0, totalSumOfWeights = 0; // Compute counts for all the values for (int i = 0; i < m_Instances.numInstances(); i++) { Instance inst = m_Instances.instance(i); if (!inst.isMissing(index)) { m_Distribution[1][0] += inst.classValue() * inst.weight(); sumsSquares[1] += inst.classValue() * inst.classValue() * inst.weight(); sumOfWeights[1] += inst.weight(); } else { m_Distribution[2][0] += inst.classValue() * inst.weight(); sumsSquares[2] += inst.classValue() * inst.classValue() * inst.weight(); sumOfWeights[2] += inst.weight(); numMissing++; } totalSumOfWeights += inst.weight(); totalSum += inst.classValue() * inst.weight(); } // Check if the total weight is zero if (totalSumOfWeights <= 0) { return bestVal; } // Sort instances m_Instances.sort(index); // Make split counts for each possible split and evaluate for (int i = 0; i < m_Instances.numInstances() - (numMissing + 1); i++) { Instance inst = m_Instances.instance(i); Instance instPlusOne = m_Instances.instance(i + 1); m_Distribution[0][0] += inst.classValue() * inst.weight(); sumsSquares[0] += inst.classValue() * inst.classValue() * inst.weight(); sumOfWeights[0] += inst.weight(); m_Distribution[1][0] -= inst.classValue() * inst.weight(); sumsSquares[1] -= inst.classValue() * inst.classValue() * inst.weight(); sumOfWeights[1] -= inst.weight(); if (inst.value(index) < instPlusOne.value(index)) { currCutPoint = (inst.value(index) + instPlusOne.value(index)) / 2.0; currVal = variance(m_Distribution, sumsSquares, sumOfWeights); if (currVal < bestVal) { m_SplitPoint = currCutPoint; bestVal = currVal; for (int j = 0; j < 3; j++) { if (sumOfWeights[j] > 0) { bestDist[j][0] = m_Distribution[j][0] / sumOfWeights[j]; } else { bestDist[j][0] = totalSum / totalSumOfWeights; } } } } } m_Distribution = bestDist; return bestVal; } /** * Computes variance for subsets. * * @param s * @param sS * @param sumOfWeights * @return the variance */ private double variance(double[][] s,double[] sS,double[] sumOfWeights) { double var = 0; for (int i = 0; i < s.length; i++) { if (sumOfWeights[i] > 0) { var += sS[i] - ((s[i][0] * s[i][0]) / (double) sumOfWeights[i]); } } return var; } /** * Returns the subset an instance falls into. * * @param instance the instance to check * @return the subset the instance falls into * @throws Exception if something goes wrong */ private int whichSubset(Instance instance) throws Exception { if (instance.isMissing(m_AttIndex)) { return 2; } else if (instance.attribute(m_AttIndex).isNominal()) { if ((int)instance.value(m_AttIndex) == m_SplitPoint) { return 0; } else { return 1; } } else { if (instance.value(m_AttIndex) <= m_SplitPoint) { return 0; } else { return 1; } } } /** * Main method for testing this class. * * @param argv the options */ public static void main(String [] argv) { Classifier scheme; try { scheme = new DecisionStump(); System.out.println(Evaluation.evaluateModel(scheme, argv)); } catch (Exception e) { System.err.println(e.getMessage()); } }}
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -