?? multilayerperceptron.java
字號:
///////////////////////////// //this sets up the gui for usage if (m_gui) { m_win = new JFrame(); m_win.addWindowListener(new WindowAdapter() { public void windowClosing(WindowEvent e) { boolean k = m_stopIt; m_stopIt = true; int well =JOptionPane.showConfirmDialog(m_win, "Are You Sure...\n" + "Click Yes To Accept" + " The Neural Network" + "\n Click No To Return", "Accept Neural Network", JOptionPane.YES_NO_OPTION); if (well == 0) { m_win.setDefaultCloseOperation(JFrame.DISPOSE_ON_CLOSE); m_accepted = true; blocker(false); } else { m_win.setDefaultCloseOperation(JFrame.DO_NOTHING_ON_CLOSE); } m_stopIt = k; } }); m_win.getContentPane().setLayout(new BorderLayout()); m_win.setTitle("Neural Network"); m_nodePanel = new NodePanel(); // without the following two lines, the NodePanel.paintComponents(Graphics) // method will go berserk if the network doesn't fit completely: it will // get called on a constant basis, using 100% of the CPU // see the following forum thread: // http://forum.java.sun.com/thread.jspa?threadID=580929&messageID=2945011 m_nodePanel.setPreferredSize(new Dimension(640, 480)); m_nodePanel.revalidate(); JScrollPane sp = new JScrollPane(m_nodePanel, JScrollPane.VERTICAL_SCROLLBAR_ALWAYS, JScrollPane.HORIZONTAL_SCROLLBAR_NEVER); m_controlPanel = new ControlPanel(); m_win.getContentPane().add(sp, BorderLayout.CENTER); m_win.getContentPane().add(m_controlPanel, BorderLayout.SOUTH); m_win.setSize(640, 480); m_win.setVisible(true); } //This sets up the initial state of the gui if (m_gui) { blocker(true); m_controlPanel.m_changeEpochs.setEnabled(false); m_controlPanel.m_changeLearning.setEnabled(false); m_controlPanel.m_changeMomentum.setEnabled(false); } //For silly situations in which the network gets accepted before training //commenses if (m_numeric) { setEndsToLinear(); } if (m_accepted) { m_win.dispose(); m_controlPanel = null; m_nodePanel = null; m_instances = new Instances(m_instances, 0); return; } //connections done. double right = 0; double driftOff = 0; double lastRight = Double.POSITIVE_INFINITY; double tempRate; double totalWeight = 0; double totalValWeight = 0; double origRate = m_learningRate; //only used for when reset //ensure that at least 1 instance is trained through. if (numInVal == m_instances.numInstances()) { numInVal--; } if (numInVal < 0) { numInVal = 0; } for (int noa = numInVal; noa < m_instances.numInstances(); noa++) { if (!m_instances.instance(noa).classIsMissing()) { totalWeight += m_instances.instance(noa).weight(); } } if (m_valSize != 0) { for (int noa = 0; noa < valSet.numInstances(); noa++) { if (!valSet.instance(noa).classIsMissing()) { totalValWeight += valSet.instance(noa).weight(); } } } m_stopped = false; for (int noa = 1; noa < m_numEpochs + 1; noa++) { right = 0; for (int nob = numInVal; nob < m_instances.numInstances(); nob++) { m_currentInstance = m_instances.instance(nob); if (!m_currentInstance.classIsMissing()) { //this is where the network updating (and training occurs, for the //training set resetNetwork(); calculateOutputs(); tempRate = m_learningRate * m_currentInstance.weight(); if (m_decay) { tempRate /= noa; } right += (calculateErrors() / m_instances.numClasses()) * m_currentInstance.weight(); updateNetworkWeights(tempRate, m_momentum); } } right /= totalWeight; if (Double.isInfinite(right) || Double.isNaN(right)) { if (!m_reset) { m_instances = null; throw new Exception("Network cannot train. Try restarting with a" + " smaller learning rate."); } else { //reset the network if possible if (m_learningRate <= Utils.SMALL) throw new IllegalStateException( "Learning rate got too small (" + m_learningRate + " <= " + Utils.SMALL + ")!"); m_learningRate /= 2; buildClassifier(i); m_learningRate = origRate; m_instances = new Instances(m_instances, 0); return; } } ////////////////////////do validation testing if applicable if (m_valSize != 0) { right = 0; for (int nob = 0; nob < valSet.numInstances(); nob++) { m_currentInstance = valSet.instance(nob); if (!m_currentInstance.classIsMissing()) { //this is where the network updating occurs, for the validation set resetNetwork(); calculateOutputs(); right += (calculateErrors() / valSet.numClasses()) * m_currentInstance.weight(); //note 'right' could be calculated here just using //the calculate output values. This would be faster. //be less modular } } if (right < lastRight) { driftOff = 0; } else { driftOff++; } lastRight = right; if (driftOff > m_driftThreshold || noa + 1 >= m_numEpochs) { m_accepted = true; } right /= totalValWeight; } m_epoch = noa; m_error = right; //shows what the neuralnet is upto if a gui exists. updateDisplay(); //This junction controls what state the gui is in at the end of each //epoch, Such as if it is paused, if it is resumable etc... if (m_gui) { while ((m_stopIt || (m_epoch >= m_numEpochs && m_valSize == 0)) && !m_accepted) { m_stopIt = true; m_stopped = true; if (m_epoch >= m_numEpochs && m_valSize == 0) { m_controlPanel.m_startStop.setEnabled(false); } else { m_controlPanel.m_startStop.setEnabled(true); } m_controlPanel.m_startStop.setText("Start"); m_controlPanel.m_startStop.setActionCommand("Start"); m_controlPanel.m_changeEpochs.setEnabled(true); m_controlPanel.m_changeLearning.setEnabled(true); m_controlPanel.m_changeMomentum.setEnabled(true); blocker(true); if (m_numeric) { setEndsToLinear(); } } m_controlPanel.m_changeEpochs.setEnabled(false); m_controlPanel.m_changeLearning.setEnabled(false); m_controlPanel.m_changeMomentum.setEnabled(false); m_stopped = false; //if the network has been accepted stop the training loop if (m_accepted) { m_win.dispose(); m_controlPanel = null; m_nodePanel = null; m_instances = new Instances(m_instances, 0); return; } } if (m_accepted) { m_instances = new Instances(m_instances, 0); return; } } if (m_gui) { m_win.dispose(); m_controlPanel = null; m_nodePanel = null; } m_instances = new Instances(m_instances, 0); } /** * Call this function to predict the class of an instance once a * classification model has been built with the buildClassifier call. * @param i The instance to classify. * @return A double array filled with the probabilities of each class type. * @throws Exception if can't classify instance. */ public double[] distributionForInstance(Instance i) throws Exception { // default model? if (m_ZeroR != null) { return m_ZeroR.distributionForInstance(i); } if (m_useNomToBin) { m_nominalToBinaryFilter.input(i); m_currentInstance = m_nominalToBinaryFilter.output(); } else { m_currentInstance = i; } if (m_normalizeAttributes) { for (int noa = 0; noa < m_instances.numAttributes(); noa++) { if (noa != m_instances.classIndex()) { if (m_attributeRanges[noa] != 0) { m_currentInstance.setValue(noa, (m_currentInstance.value(noa) - m_attributeBases[noa]) / m_attributeRanges[noa]); } else { m_currentInstance.setValue(noa, m_currentInstance.value(noa) - m_attributeBases[noa]); } } } } resetNetwork(); //since all the output values are needed. //They are calculated manually here and the values collected. double[] theArray = new double[m_numClasses]; for (int noa = 0; noa < m_numClasses; noa++) { theArray[noa] = m_outputs[noa].outputValue(true); } if (m_instances.classAttribute().isNumeric()) { return theArray; } //now normalize the array double count = 0; for (int noa = 0; noa < m_numClasses; noa++) { count += theArray[noa]; } if (count <= 0) { return null; } for (int noa = 0; noa < m_numClasses; noa++) { theArray[noa] /= count; } return theArray; } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ public Enumeration listOptions() { Vector newVector = new Vector(14); newVector.addElement(new Option( "\tLearning Rate for the backpropagation algorithm.\n" +"\t(Value should be between 0 - 1, Default = 0.3).", "L", 1, "-L <learning rate>")); newVector.addElement(new Option( "\tMomentum Rate for the backpropagation algorithm.\n" +"\t(Value should be between 0 - 1, Default = 0.2).", "M", 1, "-M <momentum>")); newVector.addElement(new Option( "\tNumber of epochs to train through.\n" +"\t(Default = 500).", "N", 1,"-N <number of epochs>")); newVector.addElement(new Option( "\tPercentage size of validation set to use to terminate\n" + "\ttraining (if this is non zero it can pre-empt num of epochs.\n" +"\t(Value should be between 0 - 100, Default = 0).", "V", 1, "-V <percentage size of validation set>")); newVector.addElement(new Option( "\tThe value used to seed the random number generator\n" + "\t(Value should be >= 0 and and a long, Default = 0).", "S", 1, "-S <seed>")); newVector.addElement(new Option( "\tThe consequetive number of errors allowed for validation\n" + "\ttesting before the netwrok terminates.\n" + "\t(Value should be > 0, Default = 20).", "E", 1, "-E <threshold for number of consequetive errors>")); newVector.addElement(new Option( "\tGUI will be opened.\n" +"\t(Use this to bring up a GUI).", "G", 0,"-G")); newVector.addElement(new Option( "\tAutocreation of the network connections will NOT be done.\n" +"\t(This will be ignored if -G is NOT set)", "A", 0,"-A")); newVector.addElement(new Option( "\tA NominalToBinary filter will NOT automatically be used.\n" +"\t(Set this to not use a NominalToBinary filter).", "B", 0,"-B")); newVector.addElement(new Option( "\tThe hidden layers to be created for the network.\n" + "\t(Value should be a list of comma separated Natural \n" + "\tnumbers or the letters 'a' = (attribs + classes) / 2, \n" + "\t'i' = attribs, 'o' = classes, 't' = attribs .+ classes)\n" + "\tfor wildcard values, Default = a).", "H", 1, "-H <comma seperated numbers for nodes on each layer>")); newVector.addElement(new Option( "\tNormalizing a numeric class will NOT be done.\n" +"\t(Set this to not normalize the class if it's numeric).", "C", 0,"-C")); newVector.addElement(new Option( "\tNormalizing the attributes will NOT be done.\n" +"\t(Set this to not normalize the attributes).", "I", 0,"-I")); newVector.addElement(new Option( "\tReseting the network will NOT be allowed.\n" +"\t(Set this to not allow the network to reset).", "R", 0,"-R")); newVector.addElement(new Option( "\tLearning rate decay will occur.\n" +"\t(Set this to cause the learning rate to decay).", "D", 0,"-D")); return newVector.elements(); } /** * Parses a given list of options. <p/> * <!-- options-start --> * Valid options are: <p/> * * <pre> -L <learning rate> * Learning Rate for the backpropagation algorithm. * (Value should be between 0 - 1, Default = 0.3).</pre> * * <pre> -M <momentum> * Momentum Rate for the backpropagation algorithm. * (Value should be between 0 - 1, Default = 0.2).</pre> * * <pre> -N <number of epochs> * Number of epochs to train through. * (Default = 500).</pre> * * <pre> -V <percentage size of validation set> * Percentage size of validation set to use to terminate * training (if this is non zero it can pre-empt num of epochs. * (Value should be between 0 - 100, Default = 0).</pre> * * <pre> -S <seed> * The value used to seed the random number generator * (Value should be >= 0 and and a long, Default = 0).</pre> * * <pre> -E <threshold for number of consequetive errors> * The consequetive number of errors allowed for validation * testing before the netwrok terminates. * (Value should be > 0, Default = 20).</pre> * * <pre> -G * GUI will be opened. * (Use this to bring up a GUI).</pre> * * <pre> -A * Autocreation of the network connections will NOT be done. * (This will be ignored if -G is NOT set)</pre> * * <pre> -B * A NominalToBinary filter will NOT automatically be used. * (Set this to not use a NominalToBinary filter).</pre> * * <pre> -H <comma seperated numbers for nodes on each layer> * The hidden layers to be created for the network. * (Value should be a list of comma separated Natural * numbers or the letters 'a' = (attribs + classes) / 2, * 'i' = attribs, 'o' = classes, 't' = attribs .+ classes) * for wildcard values, Default = a).</pre> * * <pre> -C * Normalizing a numeric class will NOT be done. * (Set this to not normalize the class if it's numeric).</pre> * * <pre> -I * Normalizing the attributes will NOT be done. * (Set this to not normalize the attributes).</pre> * * <pre> -R * Reseting the network will NOT be allowed. * (Set this to not allow the network to reset).</pre> * * <pre> -D * Learning rate decay will occur. * (Set this to cause the learning rate to decay).</pre> * <!-- options-end --> * * @param options the list of options as an ar
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -