?? multilayerperceptron.java
字號:
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* MultilayerPerceptron.java
* Copyright (C) 2000 Malcolm Ware
*/
package weka.classifiers.functions;
import java.util.*;
import java.awt.*;
import java.awt.event.*;
import javax.swing.*;
import weka.classifiers.functions.neural.*;
import weka.classifiers.*;
import weka.core.*;
import weka.filters.unsupervised.attribute.NominalToBinary;
import weka.filters.Filter;
/**
* A Classifier that uses backpropagation to classify instances.
* This network can be built by hand, created by an algorithm or both.
* The network can also be monitored and modified during training time.
* The nodes in this network are all sigmoid (except for when the class
* is numeric in which case the the output nodes become unthresholded linear
* units).
*
* @author Malcolm Ware (mfw4@cs.waikato.ac.nz)
* @version $Revision: 1.1 $
*/
public class MultilayerPerceptron extends Classifier
implements OptionHandler, WeightedInstancesHandler {
/**
* Main method for testing this class.
*
* @param argv should contain command line options (see setOptions)
*/
public static void main(String [] argv) {
try {
System.out.println(Evaluation.evaluateModel(new MultilayerPerceptron(), argv));
} catch (Exception e) {
System.err.println(e.getMessage());
e.printStackTrace();
}
System.exit(0);
}
/**
* This inner class is used to connect the nodes in the network up to
* the data that they are classifying, Note that objects of this class are
* only suitable to go on the attribute side or class side of the network
* and not both.
*/
protected class NeuralEnd extends NeuralConnection {
/**
* the value that represents the instance value this node represents.
* For an input it is the attribute number, for an output, if nominal
* it is the class value.
*/
private int m_link;
/** True if node is an input, False if it's an output. */
private boolean m_input;
public NeuralEnd(String id) {
super(id);
m_link = 0;
m_input = true;
}
/**
* Call this function to determine if the point at x,y is on the unit.
* @param g The graphics context for font size info.
* @param x The x coord.
* @param y The y coord.
* @param w The width of the display.
* @param h The height of the display.
* @return True if the point is on the unit, false otherwise.
*/
public boolean onUnit(Graphics g, int x, int y, int w, int h) {
FontMetrics fm = g.getFontMetrics();
int l = (int)(m_x * w) - fm.stringWidth(m_id) / 2;
int t = (int)(m_y * h) - fm.getHeight() / 2;
if (x < l || x > l + fm.stringWidth(m_id) + 4
|| y < t || y > t + fm.getHeight() + fm.getDescent() + 4) {
return false;
}
return true;
}
/**
* This will draw the node id to the graphics context.
* @param g The graphics context.
* @param w The width of the drawing area.
* @param h The height of the drawing area.
*/
public void drawNode(Graphics g, int w, int h) {
if ((m_type & PURE_INPUT) == PURE_INPUT) {
g.setColor(Color.green);
}
else {
g.setColor(Color.orange);
}
FontMetrics fm = g.getFontMetrics();
int l = (int)(m_x * w) - fm.stringWidth(m_id) / 2;
int t = (int)(m_y * h) - fm.getHeight() / 2;
g.fill3DRect(l, t, fm.stringWidth(m_id) + 4
, fm.getHeight() + fm.getDescent() + 4
, true);
g.setColor(Color.black);
g.drawString(m_id, l + 2, t + fm.getHeight() + 2);
}
/**
* Call this function to draw the node highlighted.
* @param g The graphics context.
* @param w The width of the drawing area.
* @param h The height of the drawing area.
*/
public void drawHighlight(Graphics g, int w, int h) {
g.setColor(Color.black);
FontMetrics fm = g.getFontMetrics();
int l = (int)(m_x * w) - fm.stringWidth(m_id) / 2;
int t = (int)(m_y * h) - fm.getHeight() / 2;
g.fillRect(l - 2, t - 2, fm.stringWidth(m_id) + 8
, fm.getHeight() + fm.getDescent() + 8);
drawNode(g, w, h);
}
/**
* Call this to get the output value of this unit.
* @param calculate True if the value should be calculated if it hasn't
* been already.
* @return The output value, or NaN, if the value has not been calculated.
*/
public double outputValue(boolean calculate) {
if (Double.isNaN(m_unitValue) && calculate) {
if (m_input) {
if (m_currentInstance.isMissing(m_link)) {
m_unitValue = 0;
}
else {
m_unitValue = m_currentInstance.value(m_link);
}
}
else {
//node is an output.
m_unitValue = 0;
for (int noa = 0; noa < m_numInputs; noa++) {
m_unitValue += m_inputList[noa].outputValue(true);
}
if (m_numeric && m_normalizeClass) {
//then scale the value;
//this scales linearly from between -1 and 1
m_unitValue = m_unitValue *
m_attributeRanges[m_instances.classIndex()] +
m_attributeBases[m_instances.classIndex()];
}
}
}
return m_unitValue;
}
/**
* Call this to get the error value of this unit, which in this case is
* the difference between the predicted class, and the actual class.
* @param calculate True if the value should be calculated if it hasn't
* been already.
* @return The error value, or NaN, if the value has not been calculated.
*/
public double errorValue(boolean calculate) {
if (!Double.isNaN(m_unitValue) && Double.isNaN(m_unitError)
&& calculate) {
if (m_input) {
m_unitError = 0;
for (int noa = 0; noa < m_numOutputs; noa++) {
m_unitError += m_outputList[noa].errorValue(true);
}
}
else {
if (m_currentInstance.classIsMissing()) {
m_unitError = .1;
}
else if (m_instances.classAttribute().isNominal()) {
if (m_currentInstance.classValue() == m_link) {
m_unitError = 1 - m_unitValue;
}
else {
m_unitError = 0 - m_unitValue;
}
}
else if (m_numeric) {
if (m_normalizeClass) {
if (m_attributeRanges[m_instances.classIndex()] == 0) {
m_unitError = 0;
}
else {
m_unitError = (m_currentInstance.classValue() - m_unitValue ) /
m_attributeRanges[m_instances.classIndex()];
//m_numericRange;
}
}
else {
m_unitError = m_currentInstance.classValue() - m_unitValue;
}
}
}
}
return m_unitError;
}
/**
* Call this to reset the value and error for this unit, ready for the next
* run. This will also call the reset function of all units that are
* connected as inputs to this one.
* This is also the time that the update for the listeners will be
* performed.
*/
public void reset() {
if (!Double.isNaN(m_unitValue) || !Double.isNaN(m_unitError)) {
m_unitValue = Double.NaN;
m_unitError = Double.NaN;
m_weightsUpdated = false;
for (int noa = 0; noa < m_numInputs; noa++) {
m_inputList[noa].reset();
}
}
}
/**
* Call this function to set What this end unit represents.
* @param input True if this unit is used for entering an attribute,
* False if it's used for determining a class value.
* @param val The attribute number or class type that this unit represents.
* (for nominal attributes).
*/
public void setLink(boolean input, int val) throws Exception {
m_input = input;
if (input) {
m_type = PURE_INPUT;
}
else {
m_type = PURE_OUTPUT;
}
if (val < 0 || (input && val > m_instances.numAttributes())
|| (!input && m_instances.classAttribute().isNominal()
&& val > m_instances.classAttribute().numValues())) {
m_link = 0;
}
else {
m_link = val;
}
}
/**
* @return link for this node.
*/
public int getLink() {
return m_link;
}
}
/** Inner class used to draw the nodes onto.(uses the node lists!!)
* This will also handle the user input. */
private class NodePanel extends JPanel {
/**
* The constructor.
*/
public NodePanel() {
addMouseListener(new MouseAdapter() {
public void mousePressed(MouseEvent e) {
if (!m_stopped) {
return;
}
if ((e.getModifiers() & e.BUTTON1_MASK) == e.BUTTON1_MASK &&
!e.isAltDown()) {
Graphics g = NodePanel.this.getGraphics();
int x = e.getX();
int y = e.getY();
int w = NodePanel.this.getWidth();
int h = NodePanel.this.getHeight();
int u = 0;
FastVector tmp = new FastVector(4);
for (int noa = 0; noa < m_numAttributes; noa++) {
if (m_inputs[noa].onUnit(g, x, y, w, h)) {
tmp.addElement(m_inputs[noa]);
selection(tmp,
(e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK
, true);
return;
}
}
for (int noa = 0; noa < m_numClasses; noa++) {
if (m_outputs[noa].onUnit(g, x, y, w, h)) {
tmp.addElement(m_outputs[noa]);
selection(tmp,
(e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK
, true);
return;
}
}
for (int noa = 0; noa < m_neuralNodes.length; noa++) {
if (m_neuralNodes[noa].onUnit(g, x, y, w, h)) {
tmp.addElement(m_neuralNodes[noa]);
selection(tmp,
(e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK
, true);
return;
}
}
NeuralNode temp = new NeuralNode(String.valueOf(m_nextId),
m_random, m_sigmoidUnit);
m_nextId++;
temp.setX((double)e.getX() / w);
temp.setY((double)e.getY() / h);
tmp.addElement(temp);
addNode(temp);
selection(tmp, (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK
, true);
}
else {
//then right click
Graphics g = NodePanel.this.getGraphics();
int x = e.getX();
int y = e.getY();
int w = NodePanel.this.getWidth();
int h = NodePanel.this.getHeight();
int u = 0;
FastVector tmp = new FastVector(4);
for (int noa = 0; noa < m_numAttributes; noa++) {
if (m_inputs[noa].onUnit(g, x, y, w, h)) {
tmp.addElement(m_inputs[noa]);
selection(tmp,
(e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK
, false);
return;
}
}
for (int noa = 0; noa < m_numClasses; noa++) {
if (m_outputs[noa].onUnit(g, x, y, w, h)) {
tmp.addElement(m_outputs[noa]);
selection(tmp,
(e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK
, false);
return;
}
}
for (int noa = 0; noa < m_neuralNodes.length; noa++) {
if (m_neuralNodes[noa].onUnit(g, x, y, w, h)) {
tmp.addElement(m_neuralNodes[noa]);
selection(tmp,
(e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK
, false);
return;
}
}
selection(null, (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK
, false);
}
}
});
}
/**
* This function gets called when the user has clicked something
* It will amend the current selection or connect the current selection
* to the new selection.
* Or if nothing was selected and the right button was used it will
* delete the node.
* @param v The units that were selected.
* @param ctrl True if ctrl was held down.
* @param left True if it was the left mouse button.
*/
private void selection(FastVector v, boolean ctrl, boolean left) {
if (v == null) {
//then unselect all.
m_selected.removeAllElements();
repaint();
return;
}
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -