?? bp_mlpno1.m
字號:
function BP_MLP
clear;
clc;
sample_num = 20;
test_num = 360;
h_unitnum = 10;
i_unitnum = 2;
o_unitnum = 1;
rand('state', sum(100*clock));%保證每次產生隨機數狀態重置
xmin = -10;
xmax = 10;
p = [zeros(sample_num);zeros(sample_num)];
x = linspace(xmin,xmax,sample_num);
y = linspace(xmin,xmax,sample_num);
z = (sin(x)./x)'*(sin(y)./y);
for i = 1:sample_num
x_train = linspace(x(i),x(i),sample_num);
y_train = y;
p((2*i-1):2*i,:) = [x_train;y_train];
end
alpha = 0.016;
eita = 0.01;
threshold = 0.8;
a = 0.8;
b = 1.0;
w_h = a*rand(h_unitnum,i_unitnum)-0.5*a;%h_unitnum*i_unitnum
w_o = b*rand(o_unitnum,h_unitnum)-0.5*b;%o_unitnum*h_unitnum
b_h = a*rand(h_unitnum,1)-0.5*a;%h_unitnum*1
b_o = b*rand(o_unitnum,1)-0.5*b;%o_unitnum*1
w_h_exp = [w_h (-1)*b_h];%h_unitnum*(i_unitnum+1)輸入當隱層的權值擴展
w_o_exp = [w_o (-1)*b_o];%o_unitnum*(h_unitnum+1)隱層到輸出的權值擴展
error=[];
epochmax=20000;
w_o_exp_old = zeros(o_unitnum,h_unitnum+1);
w_h_exp_old = zeros(h_unitnum,i_unitnum+1);
e = zeros(sample_num);
for i=1:epochmax
for j=1:sample_num
p_exp = [p((2*j-1):2*j,:)' ones(sample_num,1)]';% (i_unitnum+1)*sample_num采樣輸入擴展
h_out = logsig(w_h_exp*p_exp); %h_unitnum*sample_num隱層輸出
h_out_exp = [h_out' ones(sample_num,1)]';%(h_unitnum+1)*sample_num隱層輸出擴展
o_out = w_o_exp*h_out_exp;%o_unitnum*sample_num
e(j,:) = z(j,:)-o_out;
%反向計算誤差
%隱層到輸出權值閾值更新
delta_o = e(j,:); %o_unitnum*sample_num
delta_o1 = delta_o*h_out_exp'; %o_unitnum*(h_unitnum+1)
w_o_exp = w_o_exp+alpha*delta_o1+eita*(w_o_exp-w_o_exp_old);%o_unitnum*(h_unitnum+1)
w_o_exp_old = w_o_exp;
%輸入到隱層權值閾值更新
delta_h = w_o'*delta_o.*h_out.*(1-h_out);%h_unitnum*sample_num
delta_h1 = delta_h*p_exp';%h_unitnum*(i_unitnum+1)
w_h_exp = w_h_exp+alpha*delta_h1+eita*(w_h_exp-w_h_exp_old);
w_h_exp_old = w_h_exp;
end
et = sumsqr(e);
error = [error et];
%判斷
if et<threshold
break;
end
[m,n] = size(error);
switch n
case 80
sw = 1;
meshout(sw,x,y,z,test_num,xmin,xmax,w_h_exp,w_o_exp);
case 500
sw = 2;
meshout(sw,x,y,z,test_num,xmin,xmax,w_h_exp,w_o_exp);
case 1000
sw = 3;
meshout(sw,x,y,z,test_num,xmin,xmax,w_h_exp,w_o_exp);
case 5000
sw = 4;
meshout(sw,x,y,z,test_num,xmin,xmax,w_h_exp,w_o_exp);
case 20000
sw = 5;
meshout(sw,x,y,z,test_num,xmin,xmax,w_h_exp,w_o_exp);
otherwise
continue;
end
end
%輸出訓練誤差
[m,n] = size(error);
figure(1);
plot(1:n,error);
%測試輸出
function meshout(sw,x,y,z,test_num,xmin,xmax,wh,wo)
pt = [zeros(test_num);zeros(test_num)];
xt = linspace(xmin,xmax,test_num);
yt = linspace(xmin,xmax,test_num);
for i = 1:test_num
x_train = linspace(xt(i),xt(i),test_num);
y_train = yt;
pt((2*i-1):2*i,:) = [x_train;y_train];
end
for j=1:test_num
pt_exp = [pt((2*j-1):2*j,:)' ones(test_num,1)]';
ht_out = logsig(wh*pt_exp);
ht_out_exp = [ht_out' ones(test_num,1)]';
ot_out = wo*ht_out_exp;
zt(j,:) = ot_out+1;
end
%輸出結果
figure(sw+1);
mesh(x,y,z);
xlabel('input 1');
ylabel('input 2');
zlabel('output');
hold on;
mesh(xt,yt,zt);
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -