?? neuralnetwork.java
字號:
/* Copyright 2006, 2007 Brian Greer This file is part of the Java NN Trainer. Java NN Trainer is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. Java NN Trainer is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with Java NN Trainer; if not, write to the Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA*/package algorithms;import java.util.Random;import java.io.FileOutputStream;import java.io.FileInputStream;import java.io.ObjectOutputStream;import java.io.ObjectInputStream;public class NeuralNetwork{ private double fitness = 0.0; private int numInput = 4; private int numHidden = 20; private int numOutput = 5; private double inWeights[][] = null; private double outWeights[][] = null; private static Random random = new Random(System.currentTimeMillis()); public NeuralNetwork(int numInput, int numHidden, int numOutput){ this.numInput = numInput; this.numHidden = numHidden; this.numOutput = numOutput; reset(); } private void reset(){ inWeights = new double[numInput][numHidden]; outWeights = new double[numHidden][numOutput]; randomNet(); } public NeuralNetwork copy(){ NeuralNetwork nn = new NeuralNetwork(numInput, numHidden, numOutput); double [][] weights = nn.getInWeights(); for(int i = 0; i < numInput; i++) for(int j = 0; j < numHidden; j++) weights[i][j] = inWeights[i][j]; weights = nn.getOutWeights(); for(int i = 0; i < numHidden; i++) for(int j = 0; j < numOutput; j++) weights[i][j] = outWeights[i][j]; return nn; } public double[][] getInWeights(){ return inWeights; } public double[][] getOutWeights(){ return outWeights; } public int getNumInput(){ return numInput; } public int getNumHidden(){ return numHidden; } public int getNumOutput(){ return numOutput; } private void randomNet(){ for(int i = 0; i < numInput; i++) for(int j = 0; j < numHidden; j++) inWeights[i][j] = random.nextGaussian(); for(int i = 0; i < numHidden; i++) for(int j = 0; j < numOutput; j++) outWeights[i][j] = random.nextGaussian(); } public void save(String fileName){ try{ ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(fileName)); oos.writeObject(inWeights); oos.writeObject(outWeights); oos.close(); } catch(Exception e){ e.printStackTrace(); } } public void load(String fileName){ try{ ObjectInputStream ois = new ObjectInputStream(new FileInputStream(fileName)); inWeights = (double[][])ois.readObject(); outWeights = (double[][])ois.readObject(); numInput = inWeights.length; numHidden = outWeights.length; numOutput = outWeights[0].length; ois.close(); } catch(Exception e){ e.printStackTrace(); } } private double sigmoid(double x){ if(x > 15.0) return 1.0; else if(x < -15.0) return 0.0; else return (1.0 / (1.0 + Math.exp(-x))); } private double symmetricSigmoid(double x){ if(x > 15.0) return 0.5; else if(x < -15.0) return -0.5; else return (1.0 / (1.0 + Math.exp(-x)) - 0.5); } public void activate(double[] inputs, double[] outputs){ double hidden[] = new double[numHidden]; activate(inputs, hidden, outputs); } public void activate(double[] inputs, double [] hidden, double[] outputs){ for(int j = 0; j < numHidden; j++){ hidden[j] = 0; for(int i = 0; i < numInput; i++) hidden[j] += inputs[i] * inWeights[i][j]; hidden[j] = sigmoid(hidden[j]); } for(int j = 0; j < numOutput; j++){ outputs[j] = 0; for(int i = 0; i < numHidden; i++) outputs[j] += hidden[i] * outWeights[i][j]; outputs[j] = sigmoid(outputs[j]); } } public double getFitness(){ return fitness; } public void setFitness(double fitness){ this.fitness = fitness; } public static double sumSquaredError(double [] outputs, double [] targets){ double error = 0.0; int numOutput = outputs.length; for(int i = 0; i < numOutput; i++){ double diff = outputs[i] - targets[i]; error += diff * diff; } return Math.sqrt(error); } public double evaluate(double [] inputs, double [] targets){ double [] outputs = new double[numOutput]; activate(inputs, outputs); fitness = sumSquaredError(outputs, targets); return fitness; } public double evaluate(double [][] inputs, double [][] targets){ int numPatterns = inputs.length; fitness = 0; if(numPatterns > 0){ double [] outputs = new double[numOutput]; for(int i = 0; i < numPatterns; i++){ activate(inputs[i], outputs); fitness += sumSquaredError(outputs, targets[i]); } fitness /= numPatterns; } return fitness; }}// vim:noet:ts=3:sw=3
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -