?? bpmxor.m
字號:
% back propagation algorithm for XOR problem
% batch mode & momentum term
function bpmxor()
clear all
clc
nSampNum = 4;
nSampDim = 2;
% change the Hidden unit number at your will
nHidden = 3;
nOut = 1;
%-----------------------------------------------
% generate the samples and expected outputs
SampIn = [];
SampOut = [];
for x = 0 : 1
for y = 0 : 1
samp = [x;y];
SampIn = [SampIn,samp];
SampOut = [SampOut,xor(x,y)];
end
end
% extended samples
% SampIn = SampIn + 1;
SampInEx = [SampIn',1*ones(nSampNum,1)]';
%-----------------------------------------------
% initial the weight matrix
w = 2*(rand(nHidden,nSampDim)-1/2);
b = 2*(rand(nHidden,1)-1/2);
wex = [w,b];
W = 2*(rand(nOut,nHidden)-1/2);
B = 2*(rand(nOut,1)-1/2);
WEX = [W,B];
eb = 0.01; % error bound
eta = 0.6; % learning rate
mc = 0.8 % momentum coefficient
maxiter = 10000; % to be changed
iteration = 0;
errRec = [];
outRec = [];
% seqRec = [];
for i = 1 : maxiter
sampex = SampInEx; % to be changed
expected = SampOut; % to be changed
hp = wex*sampex; % net input for the hidden layer nodes
tau = logsig(hp); % output of the hidden layer nodes
tauex = [tau', -1*ones(nSampNum,1)]'; % extended output of the hidden layer
HM = WEX*tauex; % net input for the output layer nodes
out = logsig(HM); % output of the network
outRec = [outRec,out'];
err = expected - out;
sse = sumsqr(err);
errRec = [errRec,sse]; % save the square errors
fprintf('sse = %10.8f \n',sse ) % disp(['sum square error is:',num2str(sse)])
iteration = iteration + 1; % put here for correct iteration times
if sse<=eb, break,end
% back propagation from output layer
DELTA = err.*dlogsig(HM,out); % out = g(HM)
delta = W' * DELTA.*dlogsig(hp,tau); % tau = g(hp)
% the difference of the weight sequence
dWEX = DELTA*tauex';
dwex = delta*sampex';
% adjust the weights
if i == 1
WEX = WEX + eta * dWEX;
wex = wex + eta * dwex;
else
WEX = WEX + (1 - mc)*eta*dWEX + mc * dWEXOld;
wex = wex + (1 - mc)*eta*dwex + mc * dwexOld;
end
% save the dw for use of momentum term
dWEXOld = dWEX;
dwexOld = dwex;
% get the W for delta use in iteration
W = WEX(:,1:nHidden);
end % end for iteration
% simple display the results
disp(['iteration = ',num2str(iteration)])
W = WEX(:,1:nHidden)
B = WEX(:,1+nHidden)
w = wex(:,1:nSampDim)
b = wex(:,1+nSampDim)
disp(out); % real output
% draw the error figure
figure
axis on
hold on
grid
[nRow,nCol] = size(errRec);
plot(1:nCol,errRec,'b-','LineWidth',1.5);
legend('SumSqr Errors');
xlabel('iteration times','FontName','Times','FontSize',10);
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -