?? bpa.java
字號:
package org.scut.DataMining.Algorithm.NeuralNetwork.BP;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Date;
import java.util.Random;
import org.scut.DataMining.Core.MiningData;
import org.scut.DataMining.Core.MiningException;
import org.scut.DataMining.Core.MiningMetaData;
import org.scut.DataMining.Input.File.MiningArffStream;
public class BPA
{
private int ni;
private int nh;
private int no;
private double range = 0.01;
private double enta = 0.5;
private int maxEpoch = 2000;
private double[][] wih;
private double[][] dwih;
private double[][] who;
private double[][] dwho;
private double[] bh;
private double[] dbh;
private double[] bo;
private double[] dbo;
private double[] oh;
private double[] oo;
private double[] so;
private double[] sh;
private ArrayList<double[]> inputSet;
private ArrayList<double[]> targetSet;
private Random r = new Random();
public BPA()
{
super();
}
public void train() throws MiningException
{
this.initialize();
if(this.inputSet == null || this.targetSet == null)
throw new MiningException("input or target set not set yet!");
int size = this.inputSet.size();
int epoch = 0;
while(epoch++ < this.maxEpoch)
{
for(int i=0;i<size;i++)
{
double[] input = this.inputSet.get(i);
double[] target = this.targetSet.get(i);
this.forward(input);
this.backward(input,target);
}
this.updateWeights();
}
}
public double[] work(double[] input)
{
this.forward(input);
return this.oo;
}
private void check() throws MiningException
{
if(this.ni <= 0)
throw new MiningException("BP, input layer count<=0");
if(this.nh<=0)
throw new MiningException("BP, hidden layer count<=0");
if(this.no <= 0)
throw new MiningException("BP, output layer count<=0");
if(this.enta <= 0)
throw new MiningException("BP, learning rate<=0");
}
private double random()
{
return (r.nextDouble()-0.5)*2*this.range;
}
/**
* Sets the parameter for the BP algorithm
* @param params
*/
public void setParameter(Parameter params)
{
this.ni = params.input;
this.nh = params.hidden;
this.no = params.output;
this.maxEpoch = params.maxEpoch;
this.enta = params.learningRate;
this.range = params.randomRange;
}
/**
* Sets the input traning set
* @param inputSet
* @throws MiningException
*/
public void setInputSet(ArrayList<double[]> inputSet) throws MiningException
{
if(inputSet == null)
throw new MiningException("inputSet passed into is null");
this.check();
this.inputSet = new ArrayList<double[]>();
for(double[] input : inputSet)
{
if(input.length != this.ni)
throw new MiningException("Training data input section data size not equal to the network input size");
this.inputSet.add(input);
}
if(this.targetSet != null && this.targetSet.size() != this.inputSet.size())
throw new MiningException("Traning data input and target count not equal");
}
/**
* Sets the target training set
* @param targetSet
* @throws MiningException
*/
public void setTargetSet(ArrayList<double[]> targetSet) throws MiningException
{
if(targetSet == null)
throw new MiningException("targetSet passed into is null");
this.check();
this.targetSet = new ArrayList<double[]>();
for(double[] target : targetSet)
{
if(target.length != this.no)
throw new MiningException("Training data input section data size not equal to the network input size");
this.targetSet.add(target);
}
if(this.inputSet != null && this.targetSet.size() != this.inputSet.size())
throw new MiningException("Traning data target and input count not equal");
}
private void initialize() throws MiningException
{
this.check();
this.oh = new double[this.nh];
this.oo = new double[this.no];
this.sh = new double[this.nh];
this.so = new double[this.no];
this.dwih = new double[this.ni][];
for(int i=0;i<this.ni;i++)
this.dwih[i] = new double[this.nh];
this.dwho = new double[this.nh][];
for(int j=0;j<this.nh;j++)
this.dwho[j] = new double[this.no];
this.dbh = new double[this.nh];
this.dbo = new double[this.no];
this.wih = new double[this.ni][];
for(int i=0;i<this.ni;i++)
this.wih[i] = new double[this.nh];
this.who = new double[this.nh][];
for(int j=0;j<this.nh;j++)
this.who[j] = new double[this.no];
this.bh = new double[this.nh];
this.bo = new double[this.no];
for(int i=0;i<this.ni;i++)
for(int j=0;j<this.nh;j++)
this.wih[i][j] = this.random();
for(int j=0;j<this.nh;j++)
for(int k=0;k<this.no;k++)
this.who[j][k] = this.random();
for(int j=0;j<this.nh;j++)
this.bh[j] = this.random();
for(int k=0;k<this.no;k++)
this.bo[k] = this.random();
}
private void forward(double[] input)
{
for(int j=0;j<this.nh;j++)
{
double netj = this.bh[j];
for(int i=0;i<this.ni;i++)
netj += this.wih[i][j] * input[i];
this.oh[j] = this.squash(netj);
}
for(int k=0;k<this.no;k++)
{
double netk = this.bo[k];
for(int j=0;j<this.nh;j++)
netk += this.who[j][k] * this.oh[j];
this.oo[k] = this.squash(netk);
}
}
private void backward(double[] input,double[] target)
{
for(int k=0;k<this.no;k++)
this.so[k] = (target[k]-this.oo[k])*this.oo[k]*(1-this.oo[k]);
for(int j=0;j<this.nh;j++)
{
double ss = 0;
for(int k=0;k<this.no;k++)
ss += this.who[j][k] * this.so[k];
this.sh[j] = this.oh[j]*(1-this.oh[j])*ss;
}
//: computes the delta weights
for(int j=0;j<this.nh;j++)
for(int k=0;k<this.no;k++)
this.dwho[j][k] += this.enta * this.so[k] * this.oh[j];
for(int i=0;i<this.ni;i++)
for(int j=0;j<this.nh;j++)
this.dwih[i][j] += this.enta * this.sh[j] * input[i];
for(int k=0;k<this.no;k++)
this.dbo[k] += this.enta * this.so[k];
for(int j=0;j<this.nh;j++)
this.dbh[j] += this.enta * this.sh[j];
}
private void updateWeights()
{
//: weights
for(int j=0;j<this.nh;j++)
for(int k=0;k<this.no;k++)
this.who[j][k] += this.dwho[j][k];
for(int i=0;i<this.ni;i++)
for(int j=0;j<this.nh;j++)
this.wih[i][j] += this.dwih[i][j];
for(int k=0;k<this.no;k++)
this.bo[k] += this.dbo[k];
for(int j=0;j<this.nh;j++)
this.bh[j] += this.dbh[j];
//: delta weights
for(int j=0;j<this.nh;j++)
for(int k=0;k<this.no;k++)
this.dwho[j][k] = 0;
for(int i=0;i<this.ni;i++)
for(int j=0;j<this.nh;j++)
this.dwih[i][j] = 0;
for(int k=0;k<this.no;k++)
this.dbo[k] = 0;
for(int j=0;j<this.nh;j++)
this.dbh[j] = 0;
}
private double squash(double value)
{
return 1.0/(1.0+Math.exp(-value));
}
public void save(String fileName)
{
try
{
BufferedWriter bw = new BufferedWriter(new FileWriter(fileName));
String siho = this.ni+","+this.nh+","+this.no+"\n";
StringBuilder swih = new StringBuilder();
StringBuilder swho = new StringBuilder();
StringBuilder sbh = new StringBuilder();
StringBuilder sbo = new StringBuilder();
for(int i=0;i<this.ni;i++)
{
for(int j=0;j<this.nh;j++)
if( i == this.ni-1 && j == this.nh-1)
swih.append(this.wih[i][j] + "\n");
else
swih.append(this.wih[i][j] + ",");
}
for(int j=0;j<this.nh;j++)
{
for(int k=0;k<this.no;k++)
if( j == this.nh-1 && k == this.no -1)
swho.append(this.who[j][k] + "\n");
else
swho.append(this.who[j][k] + ",");
}
for(int j=0;j<this.nh;j++)
{
if(j==this.nh-1)
sbh.append(this.bh[j] + "\n");
else
sbh.append(this.bh[j] + ",");
}
for(int k=0;k<this.no;k++)
{
if(k==this.no-1)
sbo.append(this.bo[k] + "\n");
else
sbo.append(this.bo[k] + ",");
}
bw.write(siho);
bw.write(swih.toString());
bw.write(swho.toString());
bw.write(sbh.toString());
bw.write(sbo.toString());
bw.close();
}
catch (IOException e)
{
e.printStackTrace();
}
}
public static BPA load(String fileName)
{
BPA bp = new BPA();
try
{
BufferedReader br = new BufferedReader(new FileReader(fileName));
String iho = br.readLine();
String wih = br.readLine();
String who = br.readLine();
String bh = br.readLine();
String bo = br.readLine();
String[] siho = iho.split("[,]");
String[] swih = wih.split("[,]");
String[] swho = who.split("[,]");
String[] sbh = bh.split("[,]");
String[] sbo = bo.split("[,]");
bp.ni = Integer.valueOf(siho[0]);
bp.nh = Integer.valueOf(siho[1]);
bp.no = Integer.valueOf(siho[2]);
bp.initialize();
for(int i=0;i<swih.length;i++)
bp.wih[i/bp.nh][i%bp.nh] = Double.valueOf(swih[i]);
for(int i=0;i<swho.length;i++)
bp.who[i/bp.nh][i%bp.nh] = Double.valueOf(swih[i]);
for(int j=0;j<bp.nh;j++)
bp.bh[j] = Double.valueOf(sbh[j]);
for(int k=0;k<bp.no;k++)
bp.bo[k] = Double.valueOf(sbo[k]);
br.close();
}
catch (Exception e)
{
e.printStackTrace();
}
return bp;
}
public static void main(String[] args)
{
long start = 0,end = 0;
start = new Date().getTime();
try
{
ArrayList<MiningData> data = new ArrayList<MiningData>();
MiningArffStream arff = new MiningArffStream("arff//pm.arff");
while(arff.next())
{
MiningData d = new MiningData(arff.getData());
data.add(d);
d.normalize();
//System.out.println(d.toString());
}
MiningMetaData meta = arff.getMetaData();
for(int i=6;i<18;i++) meta.addInput(i);
for(int i=0;i<6;i++) meta.addTarget(i);
ArrayList<double[]> inputSet = new ArrayList<double[]>();
ArrayList<double[]> targetSet = new ArrayList<double[]>();
for(MiningData d : data)
{
inputSet.add(d.getInput());
targetSet.add(d.getTarget());
}
Parameter param = new Parameter();
param.input = meta.getInputCount();
param.output = meta.getTargetCount();
param.hidden = (param.input+param.output)/2;
param.maxEpoch = 200;
param.learningRate = 0.1;
param.randomRange = 0.01;
BPA bp = new BPA();
bp.setParameter(param);
bp.setInputSet(inputSet);
bp.setTargetSet(targetSet);
bp.train();
bp.save("tmp.txt");
}
catch (MiningException e)
{
e.printStackTrace();
}
end = new Date().getTime();
System.out.println("Time eclipsed[s]: " + (end-start)/1000.0);
}
}
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -