?? bp.java
字號:
package org.scut.DataMining.Algorithm.NeuralNetwork.BP;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Date;
import org.scut.DataMining.Algorithm.NeuralNetwork.Core.*;
import org.scut.DataMining.Core.MiningData;
import org.scut.DataMining.Core.MiningException;
import org.scut.DataMining.Core.MiningMetaData;
import org.scut.DataMining.Input.File.MiningArffStream;
public final class BP
{
/** Parameters for BP algorithm */
private int inputCount;
private int hiddenCount;
private int outputCount;
private int maxEpoch;
private double learningRate;
private double randomRange;
/** Input layer of the BP algorithm */
private Layer inputLayer;
/** Hidden layer of the BP algorithm */
private Layer hiddenLayer;
/** Output layer of the BP algorithm */
private Layer outputLayer;
/** Input train data set */
private ArrayList<double[]> inputSet;
/** Input target train data set */
private ArrayList<double[]> targetSet;
public BP()
{
super();
}
public double[] work(double[] input)
{
double[] pass = input;
if(input.length == this.inputCount-1)
{
pass = new double[this.inputCount];
pass[0] = 1;
for(int i=0;i<input.length;i++)
pass[i+1] = input[i];
}
try
{
this.forward(pass);
}
catch (MiningException e)
{
e.printStackTrace();
}
return this.outputLayer.getOutActivity();
}
/**
* Sets the parameter for the BP algorithm
* @param params
*/
public void setParameter(Parameter params)
{
this.inputCount = params.input + 1; //: 1 neuron for bias
this.hiddenCount = params.hidden + 1 ; //: 1 neuron for bias
this.outputCount = params.output;
this.maxEpoch = params.maxEpoch;
this.learningRate = params.learningRate;
this.randomRange = params.randomRange;
}
/**
* Sets the input traning set
* @param inputSet
* @throws MiningException
*/
public void setInputSet(ArrayList<double[]> inputSet) throws MiningException
{
if(inputSet == null)
throw new MiningException("inputSet passed into is null");
this.checkParameter();
this.inputSet = new ArrayList<double[]>();
for(double[] input : inputSet)
{
if(input.length != this.inputCount -1)
throw new MiningException("Training data input section data size not equal to the network input size");
double[] pass = new double[this.inputCount];
pass[0] = 1; //: bias input forbbiden, any value is ok
for(int i=0;i<input.length;i++)
pass[i+1] = input[i];
this.inputSet.add(pass);
}
if(this.targetSet != null && this.targetSet.size() != this.inputSet.size())
throw new MiningException("Traning data input and target count not equal");
}
/**
* Sets the target training set
* @param targetSet
* @throws MiningException
*/
public void setTargetSet(ArrayList<double[]> targetSet) throws MiningException
{
if(targetSet == null)
throw new MiningException("targetSet passed into is null");
this.checkParameter();
this.targetSet = new ArrayList<double[]>();
for(double[] target : targetSet)
{
if(target.length != this.outputCount)
throw new MiningException("Training data input section data size not equal to the network input size");
this.targetSet.add(target);
}
if(this.inputSet != null && this.targetSet.size() != this.inputSet.size())
throw new MiningException("Traning data target and input count not equal");
}
/**
* Checks the validation of the parameter
* @throws MiningException
*/
private void checkParameter() throws MiningException
{
if(this.inputCount <= 0)
throw new MiningException("BP, input layer count<=0");
if(this.hiddenCount<=0)
throw new MiningException("BP, hidden layer count<=0");
if(this.outputCount <= 0)
throw new MiningException("BP, output layer count<=0");
if(this.learningRate <= 0)
throw new MiningException("BP, learning rate<=0");
}
/**
* Initializes the BP network
* @throws MiningException
*/
private void initialize() throws MiningException
{
this.checkParameter();
this.inputLayer = new Layer(this.inputCount,Layer.LayerType.Input);
this.inputLayer.setBiasNeuron(0);
this.hiddenLayer = new Layer(this.hiddenCount,Layer.LayerType.Hidden);
this.hiddenLayer.setBiasNeuron(0);
this.outputLayer = new Layer(this.outputCount,Layer.LayerType.Output);
Layer.linkLayer(this.inputLayer,this.hiddenLayer);
Layer.linkLayer(this.hiddenLayer,this.outputLayer);
this.inputLayer.setOutSynapseInitRandomRange(this.randomRange);
this.inputLayer.setOutSynapseLearningRate(this.learningRate);
this.hiddenLayer.setOutSynapseInitRandomRange(this.randomRange);
this.hiddenLayer.setOutSynapseLearningRate(this.learningRate);
}
/**
* Trains the BP network
* @throws MiningException
*/
public void train() throws MiningException
{
this.initialize();
if(this.inputSet == null || this.targetSet == null)
throw new MiningException("Input and target set not already set");
int epoch = 0;
int size = this.inputSet.size();
while(epoch++<this.maxEpoch)
{
for(int i=0;i<size;i++)
{
this.forward(this.inputSet.get(i));
this.backward(this.targetSet.get(i));
}
this.updateWeights();
}
}
/**
* Forward propagates the input throw the network
* @param input
* @throws MiningException
*/
private void forward(double[] input) throws MiningException
{
this.inputLayer.lockOutActivity(input);
this.hiddenLayer.activate();
this.outputLayer.activate();
}
/**
* Backward propagates the target throw the network
* @param target
* @throws MiningException
*/
private void backward(double[] target) throws MiningException
{
double[] fd = new double[target.length];
double[] out = this.outputLayer.getOutActivity();
for(int i=0;i<fd.length;i++)
fd[i] = target[i]-out[i];
this.outputLayer.lockInFeedback(fd);
this.outputLayer.feedback();
this.outputLayer.updateBackwardDeltaWeights();
this.hiddenLayer.feedback();
this.hiddenLayer.updateBackwardDeltaWeights();
}
/**
* Updates the weights of all the synapses
*/
private void updateWeights()
{
this.outputLayer.updateBackwardWeights();
this.hiddenLayer.updateBackwardWeights();
}
public void save(String fileName)
{
try
{
BufferedWriter bw = new BufferedWriter(new FileWriter(fileName));
String iho = this.inputCount+","+this.hiddenCount+","+this.outputCount+"\n";
StringBuilder sbih = new StringBuilder();
StringBuilder sbho = new StringBuilder();
int nwih = this.inputLayer.getOutSynapseCount();
for(int i=0;i<nwih;i++)
{
Synapse syn = this.inputLayer.getOutSynapse(i);
if( i != nwih-1)
sbih.append(syn.getWeight()+",");
else
sbih.append(syn.getWeight()+"\n");
}
int nwho = this.hiddenLayer.getOutSynapseCount();
for(int i=0;i<nwho;i++)
{
Synapse syn = this.hiddenLayer.getOutSynapse(i);
if( i != nwho-1)
sbho.append(syn.getWeight()+",");
else
sbho.append(syn.getWeight()+"\n");
}
bw.write(iho);
bw.write(sbih.toString());
bw.write(sbho.toString());
bw.close();
}
catch (IOException e)
{
e.printStackTrace();
}
}
public static BP load(String fileName)
{
BP bp = new BP();
try
{
BufferedReader br = new BufferedReader(new FileReader(fileName));
String iho = br.readLine();
String wih = br.readLine();
String who = br.readLine();
String[] siho = iho.split("[,]");
String[] swih = wih.split("[,]");
String[] swho = who.split("[,]");
bp.inputCount = Integer.valueOf(siho[0]);
bp.hiddenCount = Integer.valueOf(siho[1]);
bp.outputCount = Integer.valueOf(siho[2]);
bp.inputLayer = new Layer(bp.inputCount,Layer.LayerType.Input);
bp.inputLayer.setBiasNeuron(0);
bp.hiddenLayer = new Layer(bp.hiddenCount,Layer.LayerType.Hidden);
bp.hiddenLayer.setBiasNeuron(0);
bp.outputLayer = new Layer(bp.outputCount,Layer.LayerType.Output);
Layer.linkLayer(bp.inputLayer,bp.hiddenLayer);
Layer.linkLayer(bp.hiddenLayer,bp.outputLayer);
for(int i=0;i<bp.inputLayer.getOutSynapseCount();i++)
{
Synapse syn = bp.inputLayer.getOutSynapse(i);
double weight = Double.valueOf(swih[i]);
syn.setWeight(weight);
}
for(int i=0;i<bp.hiddenLayer.getOutSynapseCount();i++)
{
Synapse syn = bp.hiddenLayer.getOutSynapse(i);
double weight = Double.valueOf(swho[i]);
syn.setWeight(weight);
}
br.close();
}
catch (Exception e)
{
e.printStackTrace();
}
return bp;
}
/*********************************************************************/
public static void main(String[] args)
{
long start = 0,end = 0;
start = new Date().getTime();
try
{
ArrayList<MiningData> data = new ArrayList<MiningData>();
MiningArffStream arff = new MiningArffStream("arff//vowel.arff");
while(arff.next())
{
MiningData d = new MiningData(arff.getData());
data.add(d);
d.normalize();
//System.out.println(d.toString());
}
MiningMetaData meta = arff.getMetaData();
meta.addTarget("'class'");
meta.addInput("'feld4'");
meta.addInput("'feld5'");
meta.addInput("'feld6'");
meta.addInput("'feld7'");
meta.addInput("'feld8'");
meta.addInput("'feld9'");
meta.addInput("'feld10'");
meta.addInput("'feld11'");
meta.addInput("'feld12'");
meta.addInput("'feld13'");
ArrayList<double[]> inputSet = new ArrayList<double[]>();
ArrayList<double[]> targetSet = new ArrayList<double[]>();
for(MiningData d : data)
{
inputSet.add(d.getInput());
targetSet.add(d.getTarget());
}
Parameter param = new Parameter();
param.input = meta.getInputCount();
param.output = meta.getTargetCount();
param.hidden = (param.input+param.output)/2;
param.maxEpoch = 200;
param.learningRate = 0.5;
param.randomRange = 0.05;
BP bp = new BP();
bp.setParameter(param);
bp.setInputSet(inputSet);
bp.setTargetSet(targetSet);
bp.train();
bp.save("tmp.txt");
}
catch (MiningException e)
{
e.printStackTrace();
}
end = new Date().getTime();
System.out.println("Time eclipsed[s]: " + (end-start)/1000.0);
}
/*********************************************************************/
}
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -