?? obsprune.m
字號:
function [theta_data,PI_vector,FPE_vector,PI_test_vec,deff_vec,pvec]=...
obsprune(NetDef,W1,W2,PHI,Y,trparms,prparms,PHI2,Y2)
% OBSPRUNE
% --------
% This function applies the Optimal Brain Surgeon (OBS) algorithm
% for pruning ordinary feedforward neural networks
%
% CALL:
% [theta_data,NSSEvec,FPEvec,NSSEtestvec,deff,pvec]=...
% obsprune(NetDef,W1,W2,PHI,Y,trparms,prparms,PHI2,Y2)
%
% INPUT:
% NetDef, W1, W2,
% PHI, Y, trparms : See for example the function MARQ
% PHI2,Y2 (optional) : Test data. Can be used for pointing out the
% optimal network architecture.
% prparms : Parameters assocoated with the pruning session
% prparms = [iter RePercent]
% iter : Max. number of retraining iterations
% RePercent : Prune 'RePercent' percent of the
% remaining weights (0 = prune one at a time)
% If passed as [], prparms=[50 0] will be used.
%
% OUTPUT:
% theta_data : Matrix containing all the parameter vectors
% NSSEvec : Vector containing the training error (SSE/2N) after each
% weight elimination
% FPEvec : Contains the FPE estimate of the average generalization error
% NSSEtestvec : Contains the test error
% deff : Contains the "effective" number of parameters
% pvec : Index to the above vectors
%
% Programmed by : Magnus Norgaard, IAU/IMM, Technical University of Denmark
% LastEditDate : July 17, 1996
%----------------------------------------------------------------------------------
%-------------- NETWORK INITIALIZATIONS -------------
%----------------------------------------------------------------------------------
more off
if nargin>7, TestDataFlag = 1; % Check if test data was given as argument
else TestDataFlag = 0;end
iter = prparms(1); % Max. retraining iterations
RePercent = prparms(2); % % of remaining weights to prune
[outputs,N] = size(Y); % # of outputs and # of data
[hidden,inputs] = size(W1); % # of hidden units
inputs=inputs-1; % # of inputs
L_hidden = find(NetDef(1,:)=='L')'; % Location of linear hidden neurons
H_hidden = find(NetDef(1,:)=='H')'; % Location of tanh hidden neuron
L_output = find(NetDef(2,:)=='L')'; % Location of linear output neurons
H_output = find(NetDef(2,:)=='H')'; % Location of tanh output neurons
y1 = zeros(hidden,N); % Hidden layer outputs
y2 = zeros(outputs,N); % Network output
index = outputs*(hidden+1) + 1 + [0:hidden-1]*(inputs+1); % A usefull vector!
index2 = (0:N-1)*outputs; % Yet another usefull vector
PHI_aug = [PHI;ones(1,N)]; % Augment PHI with a row containing ones
parameters1= hidden*(inputs+1); % # of input-to-hidden weights
parameters2= outputs*(hidden+1); % # of hidden-to-output weights
parameters = parameters1 + parameters2; % Total # of weights
ones_h = ones(hidden+1,1); % A vector of ones
ones_i = ones(inputs+1,1); % Another vector of ones
% Parameter vector containing all weights
theta = [reshape(W2',parameters2,1) ; reshape(W1',parameters1,1)];
theta_index = find(theta); % Index to weights<>0
theta_red = theta(theta_index); % Reduced parameter vector
reduced = length(theta_index); % The # of parameters in theta_red
reduced0 = reduced; % Copy of 'reduced'. Will be constant
theta_data=zeros(parameters,reduced); % Matrix used for collecting theta vectors
theta_data(:,reduced) = theta; % Insert 'initial' theta
PSI = zeros(parameters,outputs*N); % Deriv. of each output w.r.t. each weight
p0 = 1e6; % Diag. element of H_inv (no weight decay)
H_inv = p0*eye(reduced); % Initial inverse Hessian (no weight decay)
Ident = eye(outputs); % Identity matrix
PI_vector= zeros(1,reduced); % A vector containing the collected PI's
FPE_vector= zeros(1,reduced); % Vector used for collecting FPE estimates
if length(trparms)==4, % Scalar weight decay parameter
D0 = trparms(4*ones(1,reduced))';
elseif length(trparms)==5, % Two weight decay parameters
D0 = trparms([4*ones(1,parameters2) 5*ones(1,parameters1)])';
D0 = D0(theta_index);
else % No weight decay D = 0;
D0 = zeros(reduced,1);
end
D = D0;
if TestDataFlag, % Initializations if a test set exists
[tmp,N2] = size(Y2); % # of data in test set
ytest1 = zeros(hidden,N2); % Hidden layer outputs
ytest2 = zeros(outputs,N2); % Network output
PHI2_aug = [PHI2;ones(1,N2)]; % Augment PHI with a row containing ones
PI_test_vec = zeros(1,reduced); % Collected PI's for the test set
end
deff_vec = zeros(1,reduced); % The effective number of parameters
minweights = 2*outputs; % Prune until 'minweights'(>=2) weights remain
FirstTimeFlag=1; % Initialize flag
pr = 0; % Initialize counter
pvec=[]; % Initialize index vector
HiddenIndex = []; % Connection to hidden number X
for k=1:outputs,
HiddenIndex = [HiddenIndex;(1:(hidden+1))'];
end
for k=1:hidden,
HiddenIndex = [HiddenIndex;k*ones(inputs+1,1)];
end
ConnectToHidden = (inputs+1)*ones(hidden,1); % Connections to each hidden unit
ConnectFromHidden = outputs*ones(hidden,1); % Connections from each hidden unit
%----------------------------------------------------------------------------------
%--------------- MAIN LOOP --------------
%----------------------------------------------------------------------------------
while reduced>=minweights,
% >>>>>>>>>>>>>>>>>>>>>>>>> Retrain Network <<<<<<<<<<<<<<<<<<<<<<<<<<<
% -- Don't retrain the first time --
if ~FirstTimeFlag,
[W1,W2,dummy1,dummy2,dummy3] = marq(NetDef,W1,W2,PHI,Y,[iter,0,1,D']);
theta = [reshape(W2',parameters2,1) ; reshape(W1',parameters1,1)];
theta_red = theta(theta_index); % Vector containing non-zero parameters
if ElimWeights==1, % Store parameter vector
theta_data(:,reduced) = theta;
else
theta_data(:,[reduced reduced+LEidx-1]) = theta(:,ones(1,LEidx));
end
end
% >>>>>>>>>>>>> COMPUTE NETWORK OUTPUT FROM TEST DATA y2(theta) <<<<<<<<<<<<<<
% -- Compute only if a test set is present --
if TestDataFlag,
htest1 = W1*PHI2_aug;
ytest1(H_hidden,:) = pmntanh(htest1(H_hidden,:));
ytest1(L_hidden,:) = htest1(L_hidden,:);
ytest1_aug=[ytest1;ones(1,N2)];
htest2 = W2*ytest1_aug;
ytest2(H_output,:) = pmntanh(htest2(H_output,:));
ytest2(L_output,:) = htest2(L_output,:);
E = Y2 - ytest2; % Training error
E_vector = E(:); % Reshape E into a long vector
SSE = E_vector'*E_vector; % Sum of squared errors (SSE)
PI_test = SSE/(2*N2); % Cost function evaluated on test data
PI_test_vec(reduced) = PI_test; % Collect PI_test in vector
end
% >>>>>>>>>>> COMPUTE NETWORK OUTPUT FROM TRAINING DATA y2(theta) <<<<<<<<<<<<
h1 = W1*PHI_aug;
y1(H_hidden,:) = pmntanh(h1(H_hidden,:));
y1(L_hidden,:) = h1(L_hidden,:);
y1_aug=[y1; ones(1,N)];
h2 = W2*y1_aug;
y2(H_output,:) = pmntanh(h2(H_output,:));
y2(L_output,:) = h2(L_output,:);
E = Y - y2; % Training error
E_vector = E(:); % Reshape E into a long vector
SSE = E_vector'*E_vector; % Sum of squared errors (SSE)
PI = SSE/(2*N); % Value of cost function
PI_vector(reduced) = PI; % Collect PI in vector
% >>>>>>>>>>>>>>>>>>>>>>>>>> COMPUTE THE PSI MATRIX <<<<<<<<<<<<<<<<<<<<<<<<<
% (The derivative of each network output (y2) with respect to each weight)
% ============ Elements corresponding to the linear output units ============
for i = L_output',
index1 = (i-1) * (hidden + 1) + 1;
% -- The part of PSI corresponding to hidden-to-output layer weights --
PSI(index1:index1+hidden,index2+i) = y1_aug;
% ---------------------------------------------------------------------
% -- The part of PSI corresponding to input-to-hidden layer weights ---
for j = L_hidden',
PSI(index(j):index(j)+inputs,index2+i) = W2(i,j)*PHI_aug;
end
for j = H_hidden',
tmp = W2(i,j)*(1-y1(j,:).*y1(j,:));
PSI(index(j):index(j)+inputs,index2+i) = tmp(ones_i,:).*PHI_aug;
end
% ---------------------------------------------------------------------
end
% ======= Elements corresponding to the hyperbolic tangent output units =======
for i = H_output',
index1 = (i-1) * (hidden + 1) + 1;
% -- The part of PSI corresponding to hidden-to-output layer weights --
tmp = 1 - y2(i,:).*y2(i,:);
PSI(index1:index1+hidden,index2+i) = y1_aug.*tmp(ones_h,:);
% ---------------------------------------------------------------------
% -- The part of PSI corresponding to input-to-hidden layer weights ---
for j = L_hidden',
tmp = W2(i,j)*(1-y2(i,:).*y2(i,:));
PSI(index(j):index(j)+inputs,index2+i) = tmp(ones_i,:).* PHI_aug;
end
for j = H_hidden',
tmp = W2(i,j)*(1-y1(j,:).*y1(j,:));
tmp2 = (1-y2(i,:).*y2(i,:));
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -