?? evaluation.java
字號(hào):
/**
*
* AgentAcademy - an open source Data Mining framework for
* training intelligent agents
*
* Copyright (C) 2001-2003 AA Consortium.
*
* This library is open source software; you can redistribute it
* and/or modify it under the terms of the GNU Lesser General
* Public License as published by the Free Software Foundation;
* either version 2.0 of the License, or (at your option) any later
* version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free
* Software Foundation, Inc., 59 Temple Place, Suite 330, Boston,
* MA 02111-1307 USA
*
*/
package org.agentacademy.modules.dataminer.classifiers.evaluation;
/**
* <p>Title: The Data Miner prototype</p>
* <p>Description: A prototype for the DataMiner (DM), the Agent Academy (AA) module responsible for performing data mining on the contents of the Agent Use Repository (AUR). The extracted knowledge is to be sent back to the AUR in the form of a PMML document.</p>
* <p>Copyright: Copyright (c) 2002</p>
* <p>Company: CERTH</p>
* @author asymeon
* @version 0.3
*/
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.Reader;
import java.util.Enumeration;
import java.util.Random;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import org.agentacademy.modules.dataminer.core.Drawable;
import org.agentacademy.modules.dataminer.core.Instance;
import org.agentacademy.modules.dataminer.core.Instances;
import org.agentacademy.modules.dataminer.core.Option;
import org.agentacademy.modules.dataminer.core.OptionHandler;
import org.agentacademy.modules.dataminer.core.Range;
import org.agentacademy.modules.dataminer.core.Summarizable;
import org.agentacademy.modules.dataminer.core.Utils;
import org.jdom.Document;
import org.jdom.input.SAXBuilder;
import weka.estimators.Estimator;
import weka.estimators.KernelEstimator;
/**
* Class for evaluating machine learning models. <p>
*
* ------------------------------------------------------------------- <p>
*
* General options when evaluating a learning scheme from the command-line: <p>
*
* -t filename <br>
* Name of the file with the training data. (required) <p>
*
* -T filename <br>
* Name of the file with the test data. If missing a cross-validation
* is performed. <p>
*
* -c index <br>
* Index of the class attribute (1, 2, ...; default: last). <p>
*
* -x number <br>
* The number of folds for the cross-validation (default: 10). <p>
*
* -s seed <br>
* Random number seed for the cross-validation (default: 1). <p>
*
* -m filename <br>
* The name of a file containing a cost matrix. <p>
*
* -l filename <br>
* Loads classifier from the given file. <p>
*
* -d filename <br>
* Saves classifier built from the training data into the given file. <p>
*
* -v <br>
* Outputs no statistics for the training data. <p>
*
* -o <br>
* Outputs statistics only, not the classifier. <p>
*
* -i <br>
* Outputs information-retrieval statistics per class. <p>
*
* -k <br>
* Outputs information-theoretic statistics. <p>
*
* -p range <br>
* Outputs predictions for test instances, along with the attributes in
* the specified range (and nothing else). Use '-p 0' if no attributes are
* desired. <p>
*
* -r <br>
* Outputs cumulative margin distribution (and nothing else). <p>
*
* -g <br>
* Only for classifiers that implement "Graphable." Outputs
* the graph representation of the classifier (and nothing
* else). <p>
*
* ------------------------------------------------------------------- <p>
*
* Example usage as the main of a classifier (called FunkyClassifier):
* <code> <pre>
* public static void main(String [] args) {
* try {
* Classifier scheme = new FunkyClassifier();
* System.out.println(Evaluation.evaluateModel(scheme, args));
* } catch (Exception e) {
* System.err.println(e.getMessage());
* }
* }
* </pre> </code>
* <p>
*
* ------------------------------------------------------------------ <p>
*
* Example usage from within an application:
* <code> <pre>
* Instances trainInstances = ... instances got from somewhere
* Instances testInstances = ... instances got from somewhere
* Classifier scheme = ... scheme got from somewhere
*
* Evaluation evaluation = new Evaluation(trainInstances);
* evaluation.evaluateModel(scheme, testInstances);
* System.out.println(evaluation.toSummaryString());
* </pre> </code>
*/
public class Evaluation implements Summarizable {
/** The number of classes. */
private int m_NumClasses;
/** The number of folds for a cross-validation. */
private int m_NumFolds;
/** The weight of all incorrectly classified instances. */
private double m_Incorrect;
/** The weight of all correctly classified instances. */
private double m_Correct;
/** The weight of all unclassified instances. */
private double m_Unclassified;
/*** The weight of all instances that had no class assigned to them. */
private double m_MissingClass;
/** The weight of all instances that had a class assigned to them. */
private double m_WithClass;
/** Array for storing the confusion matrix. */
private double [][] m_ConfusionMatrix;
/** The names of the classes. */
private String [] m_ClassNames;
/** Is the class nominal or numeric? */
private boolean m_ClassIsNominal;
/** The prior probabilities of the classes */
private double [] m_ClassPriors;
/** The sum of counts for priors */
private double m_ClassPriorsSum;
/** The cost matrix (if given). */
private CostMatrix m_CostMatrix;
/** The total cost of predictions (includes instance weights) */
private double m_TotalCost;
/** Sum of errors. */
private double m_SumErr;
/** Sum of absolute errors. */
private double m_SumAbsErr;
/** Sum of squared errors. */
private double m_SumSqrErr;
/** Sum of class values. */
private double m_SumClass;
/** Sum of squared class values. */
private double m_SumSqrClass;
/*** Sum of predicted values. */
private double m_SumPredicted;
/** Sum of squared predicted values. */
private double m_SumSqrPredicted;
/** Sum of predicted * class values. */
private double m_SumClassPredicted;
/** Sum of absolute errors of the prior */
private double m_SumPriorAbsErr;
/** Sum of absolute errors of the prior */
private double m_SumPriorSqrErr;
/** Total Kononenko & Bratko Information */
private double m_SumKBInfo;
/*** Resolution of the margin histogram */
private static int k_MarginResolution = 500;
/** Cumulative margin distribution */
private double m_MarginCounts [];
/** Number of non-missing class training instances seen */
private int m_NumTrainClassVals;
/** Array containing all numeric training class values seen */
private double [] m_TrainClassVals;
/** Array containing all numeric training class weights */
private double [] m_TrainClassWeights;
/** Numeric class error estimator for prior */
private Estimator m_PriorErrorEstimator;
/** Numeric class error estimator for scheme */
private Estimator m_ErrorEstimator;
/**
* The minimum probablility accepted from an estimator to avoid
* taking log(0) in Sf calculations.
*/
private static final double MIN_SF_PROB = Double.MIN_VALUE;
/** Total entropy of prior predictions */
private double m_SumPriorEntropy;
/** Total entropy of scheme predictions */
private double m_SumSchemeEntropy;
/**
* Initializes all the counters for the evaluation.
*
* @param data set of training instances, to get some header
* information and prior class distribution information
* @exception Exception if the class is not defined
*/
public Evaluation(Instances data) throws Exception {
this(data, null);
}
/**
* Initializes all the counters for the evaluation and also takes a
* cost matrix as parameter.
*
* @param data set of instances, to get some header information
* @param costMatrix the cost matrix---if null, default costs will be used
* @exception Exception if cost matrix is not compatible with
* data, the class is not defined or the class is numeric
*/
public Evaluation(Instances data, CostMatrix costMatrix)
throws Exception {
m_NumClasses = data.numClasses();
m_NumFolds = 1;
m_ClassIsNominal = data.classAttribute().isNominal();
if (m_ClassIsNominal) {
m_ConfusionMatrix = new double [m_NumClasses][m_NumClasses];
m_ClassNames = new String [m_NumClasses];
for(int i = 0; i < m_NumClasses; i++) {
m_ClassNames[i] = data.classAttribute().value(i);
}
}
m_CostMatrix = costMatrix;
if (m_CostMatrix != null) {
if (!m_ClassIsNominal) {
throw new Exception("Class has to be nominal if cost matrix " +
"given!");
}
if (m_CostMatrix.size() != m_NumClasses) {
throw new Exception("Cost matrix not compatible with data!");
}
}
m_ClassPriors = new double [m_NumClasses];
setPriors(data);
m_MarginCounts = new double [k_MarginResolution + 1];
}
/**
* Returns a copy of the confusion matrix.
*
* @return a copy of the confusion matrix as a two-dimensional array
*/
public double[][] confusionMatrix() {
double[][] newMatrix = new double[m_ConfusionMatrix.length][0];
for (int i = 0; i < m_ConfusionMatrix.length; i++) {
newMatrix[i] = new double[m_ConfusionMatrix[i].length];
System.arraycopy(m_ConfusionMatrix[i], 0, newMatrix[i], 0,
m_ConfusionMatrix[i].length);
}
return newMatrix;
}
/**
* Performs a (stratified if class is nominal) cross-validation
* for a classifier on a set of instances.
*
* @param classifier the classifier with any options set.
* @param data the data on which the cross-validation is to be
* performed
* @param numFolds the number of folds for the cross-validation
* @exception Exception if a classifier could not be generated
* successfully or the class is not defined
*/
public void crossValidateModel(Classifier classifier,
Instances data, int numFolds)
throws Exception {
// Make a copy of the data we can reorder
data = new Instances(data);
if (data.classAttribute().isNominal()) {
data.stratify(numFolds);
}
// Do the folds
for (int i = 0; i < numFolds; i++) {
Instances train = data.trainCV(numFolds, i);
setPriors(train);
classifier.buildClassifier(train);
Instances test = data.testCV(numFolds, i);
evaluateModel(classifier, test);
}
m_NumFolds = numFolds;
}
/**
* Performs a (stratified if class is nominal) cross-validation
* for a classifier on a set of instances.
*
* @param classifier a string naming the class of the classifier
* @param data the data on which the cross-validation is to be
* performed
* @param numFolds the number of folds for the cross-validation
* @param options the options to the classifier. Any options
* accepted by the classifier will be removed from this array.
* @exception Exception if a classifier could not be generated
* successfully or the class is not defined
*/
public void crossValidateModel(String classifierString,
Instances data, int numFolds,
String[] options)
?? 快捷鍵說(shuō)明
復(fù)制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號(hào)
Ctrl + =
減小字號(hào)
Ctrl + -