?? evaluation.java
字號:
// Classifier was trained incrementally, so we have to
// reopen the training data in order to test on it.
// Incremental testing
train = new Instances (data,1);
if (classIndex != -1) {
train.setClassIndex(classIndex - 1);
}
else {
train.setClassIndex(train.numAttributes() - 1);
}
testTimeStart = System.currentTimeMillis();
for (int k=0; k<train.numInstances(); k++) {
trainingEvaluation.
evaluateModelOnce((Classifier)classifier,train.instance(0));
train.delete(0);
}
testTimeElapsed = System.currentTimeMillis() - testTimeStart;
}
else {
testTimeStart = System.currentTimeMillis();
trainingEvaluation.evaluateModel(classifier,train);
testTimeElapsed = System.currentTimeMillis() - testTimeStart;
}
// Print the results of the training evaluation
if (printMargins) {
return trainingEvaluation.toCumulativeMarginDistributionString();
}
else {
text.append("\nTime taken to build model: " + Utils.doubleToString(trainTimeElapsed / 1000.0,2) +" seconds");
text.append("\nTime taken to test model on training data: " + Utils.doubleToString(testTimeElapsed / 1000.0,2) + " seconds");
text.append(trainingEvaluation.toSummaryString("\n\n=== Error on training" + " data ===\n", printComplexityStatistics));
if (template.classAttribute().isNominal()) {
if (classStatistics) {
text.append("\n\n" + trainingEvaluation.toClassDetailsString());
}
text.append("\n\n" + trainingEvaluation.toMatrixString());
}
}
}
// Compute proper error estimates
if (testFileName.length() != 0) {
// Testing is on the supplied test data
test = new Instances (testDocument);
for (int k=0; k < test.numInstances(); k++) {
testingEvaluation.evaluateModelOnce((Classifier)classifier, test.instance(0));
test.delete(0);
}
text.append("\n\n" + testingEvaluation.
toSummaryString("=== Error on test data --- The new version!!!! ===\n",
printComplexityStatistics));
}
else {
// Testing is via cross-validation on training data
random = new Random(seed);
random.setSeed(seed);
train.randomize(random);
testingEvaluation.
crossValidateModel(classifier, train, folds);
if (template.classAttribute().isNumeric()) {
text.append("\n\n\n" + testingEvaluation.
toSummaryString("=== Cross-validation ===\n",
printComplexityStatistics));
}
else {
text.append("\n\n\n" + testingEvaluation.
toSummaryString("=== Stratified " +
"cross-validation ===\n",
printComplexityStatistics));
}
}
if (template.classAttribute().isNominal()) {
if (classStatistics) {
text.append("\n\n" + testingEvaluation.toClassDetailsString());
}
text.append("\n\n" + testingEvaluation.toMatrixString());
}
return text.toString();
}
public static String evaluateModelInstancePercentage(Classifier classifier,
String[] options, Instances data, int percentage) throws Exception {
Instances train = null, tempTrain, test = null, template = null;
int seed = 1, folds = 10, classIndex = -1, percent = 0, sizeOfTrainFile = 0, sizeOfTestFile = 0;
String sourceClass,
classIndexString, seedString, foldsString, objectInputFileName,
objectOutputFileName, attributeRangeString;
boolean IRstatistics = false, noOutput = false,
printClassifications = false, trainStatistics = true,
printMargins = false, printComplexityStatistics = false,
printGraph = false, classStatistics = false, printSource = false;
StringBuffer text = new StringBuffer();
ObjectInputStream objectInputStream = null;
Random random;
CostMatrix costMatrix = null;
StringBuffer schemeOptionsText = null;
Range attributesToOutput = null;
long trainTimeStart = 0, trainTimeElapsed = 0,
testTimeStart = 0, testTimeElapsed = 0;
try {
percent = percentage;
// Get basic options (options the same for all schemes)
classIndexString = Utils.getOption('c', options);
if (classIndexString.length() != 0) {
classIndex = Integer.parseInt(classIndexString);
}
objectInputFileName = Utils.getOption('l', options);
objectOutputFileName = Utils.getOption('d', options);
try {
if (objectInputFileName.length() != 0) {
InputStream is = new FileInputStream(objectInputFileName);
if (objectInputFileName.endsWith(".gz")) {
is = new GZIPInputStream(is);
}
objectInputStream = new ObjectInputStream(is);
}
}
catch (Exception e) {
throw new Exception("Can't open file " + e.getMessage() + '.');
}
if ((classifier instanceof UpdateableClassifier)) {
sizeOfTrainFile = data.numInstances() * percent / 100;
sizeOfTestFile = data.numInstances() - sizeOfTrainFile;
train = new Instances (data,0,sizeOfTrainFile);
test = new Instances (data,sizeOfTrainFile,sizeOfTestFile);
template = train;
}
else {
sizeOfTrainFile = data.numInstances() * percent / 100;
sizeOfTestFile = data.numInstances() - sizeOfTrainFile;
train = new Instances (data,0,sizeOfTrainFile);
test = new Instances (data,sizeOfTrainFile,sizeOfTestFile);
template = train;
}
if (classIndex != -1) {
train.setClassIndex(classIndex - 1);
} else {
train.setClassIndex(train.numAttributes() - 1);
}
if (classIndex > train.numAttributes()) {
throw new Exception("Index of class attribute too large.");
}
//train = new Instances(train);
if (template == null) {
throw new Exception("No actual dataset provided to use as template");
}
costMatrix = handleCostOption(Utils.getOption('m', options), template.numClasses());
classStatistics = Utils.getFlag('i', options);
noOutput = Utils.getFlag('o', options);
trainStatistics = !Utils.getFlag('v', options);
printComplexityStatistics = Utils.getFlag('k', options);
printMargins = Utils.getFlag('r', options);
printGraph = Utils.getFlag('g', options);
sourceClass = Utils.getOption('z', options);
printSource = (sourceClass.length() != 0);
// Check -p option
try {
attributeRangeString = Utils.getOption('p', options);
}
catch (Exception e) {
throw new Exception(e.getMessage() + "\nNOTE: the -p option has changed. " +
"It now expects a parameter specifying a range of attributes " +
"to list with the predictions. Use '-p 0' for none.");
}
if (attributeRangeString.length() != 0) {
printClassifications = true;
if (!attributeRangeString.equals("0"))
attributesToOutput = new Range(attributeRangeString);
}
// If a model file is given, we can't process
// scheme-specific options
if (objectInputFileName.length() != 0) {
Utils.checkForRemainingOptions(options);
}
else {
// Set options for classifier
if (classifier instanceof OptionHandler) {
for (int i = 0; i < options.length; i++) {
if (options[i].length() != 0) {
if (schemeOptionsText == null) {
schemeOptionsText = new StringBuffer();
}
if (options[i].indexOf(' ') != -1) {
schemeOptionsText.append('"' + options[i] + "\" ");
} else {
schemeOptionsText.append(options[i] + " ");
}
}
}
((OptionHandler)classifier).setOptions(options);
}
}
Utils.checkForRemainingOptions(options);
}
catch (Exception e) {
throw new Exception("\nWeka exception: " + e.getMessage()
+ makeOptionString(classifier));
}
// Setup up evaluation objects
Evaluation trainingEvaluation = new Evaluation(new Instances(template, 0), costMatrix);
Evaluation testingEvaluation = new Evaluation(new Instances(template, 0), costMatrix);
if (objectInputFileName.length() != 0) {
// Load classifier from file
classifier = (Classifier) objectInputStream.readObject();
objectInputStream.close();
}
// Build the classifier if no object file provided
if ((classifier instanceof UpdateableClassifier) &&
(costMatrix == null)) {
// Build classifier incrementally
trainingEvaluation.setPriors(train);
testingEvaluation.setPriors(train);
trainTimeStart = System.currentTimeMillis();
if (objectInputFileName.length() == 0) {
classifier.buildClassifier(train);
}
for (int i=0; i<train.numInstances(); i++) {
trainingEvaluation.updatePriors(train.instance(0));
testingEvaluation.updatePriors(train.instance(0));
((UpdateableClassifier)classifier).
updateClassifier(train.instance(0));
train.delete(0);
}
trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
} else if (objectInputFileName.length() == 0) {
// Build classifier in one go
tempTrain = new Instances(train);
trainingEvaluation.setPriors(tempTrain);
testingEvaluation.setPriors(tempTrain);
trainTimeStart = System.currentTimeMillis();
classifier.buildClassifier(tempTrain);
trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
}
// Save the classifier if an object output file is provided
if (objectOutputFileName.length() != 0) {
OutputStream os = new FileOutputStream(objectOutputFileName);
if (objectOutputFileName.endsWith(".gz")) {
os = new GZIPOutputStream(os);
}
ObjectOutputStream objectOutputStream = new ObjectOutputStream(os);
objectOutputStream.writeObject(classifier);
objectOutputStream.flush();
objectOutputStream.close();
}
// If classifier is drawable output string describing graph
if ((classifier instanceof Drawable)
&& (printGraph)){
return ((Drawable)classifier).graph();
}
// Output the classifier as equivalent source
if ((classifier instanceof Sourcable)
&& (printSource)){
return wekaStaticWrapper((Sourcable) classifier, sourceClass);
}
/*
// Output test instance predictions only
if (printClassifications) {
return printClassifications(classifier, new Instances(template, 0),
testFileName, classIndex, attributesToOutput);
}
*/
// Output model
if (!(noOutput || printMargins)) {
if (classifier instanceof OptionHandler) {
if (schemeOptionsText != null) {
text.append("\nOptions: "+schemeOptionsText);
text.append("\n");
}
}
text.append("\n" + classifier.toString() + "\n");
}
if (!printMargins && (costMatrix != null)) {
text.append("\n=== Evaluation Cost Matrix ===\n\n")
.append(costMatrix.toString());
}
// Compute error estimate from training data
if ((trainStatistics)) {
if ((classifier instanceof UpdateableClassifier) &&
(costMatrix == null)) {
// Classifier was trained incrementally, so we have to
// reopen the training data in order to test on it.
// Incremental testing
train = new Instances (data,0,sizeOfTrainFile);
test = new Instances (data,sizeOfTrainFile,sizeOfTestFile);
if (classIndex != -1) {
train.setClassIndex(classIndex - 1);
} else {
train.setClassIndex(train.numAttributes() - 1);
}
testTimeStart = System.currentTimeMillis();
for (int k=0; k<train.numInstances(); k++) {
trainingEvaluation.
evaluateModelOnce((Classifier)classifier,
train.instance(0));
train.delete(0);
}
testTimeElapsed = System.currentTimeMillis() - testTimeStart;
} else {
testTimeStart = System.currentTimeMillis();
trainingEvaluation.evaluateModel(classifier,
train);
testTimeElapsed = System.currentTimeMillis() - testTimeStart;
}
// Print the results of the training evaluation
if (printMargins) {
return trainingEvaluation.toCumulativeMarginDistributionString();
} else {
text.append("\nTime taken to build model: " +
Utils.doubleToString(trainTimeElapsed / 1000.0,2) +
" seconds");
text.append("\nTime taken to test model on training data: " +
Utils.doubleToString(testTimeElapsed / 1000.0,2) +
" seconds");
text.append(trainingEvaluation.
toSummaryString("\n\n=== Error on training" +
" data ===\n", printComplexityStatistics));
if (template.classAttribute().isNominal()) {
if (classStatistics) {
text.append("\n\n" + trainingEvaluation.toClassDetailsString());
}
text.append("\n\n" + trainingEvaluation.toMatrixString());
}
}
}
// Compute proper error estimates
// Testing is on the supplied test data
for (int j=0; j < test.numInstances(); j++) {
testingEvaluation.evaluateModelOnce((Classifier)classifier, test.instance(0));
test.delete(0);
}
text.append("\n\n" + testingEvaluation.toSummaryString("=== Error on test data ===\n",
printComplexityStatistics));
if (template.classAttribute().isNominal()) {
if (classStatistics) {
text.append("\n\n" + testingEvaluation.toClassDetailsString());
}
text.append("\n\n" + testingEvaluation.toMatrixString());
}
return text.toString();
}
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -