?? glmtrain.m
字號:
function [net, options] = glmtrain(net, options, x, t)%GLMTRAIN Specialised training of generalized linear model%% Description% NET = GLMTRAIN(NET, OPTIONS, X, T) uses the iterative reweighted% least squares (IRLS) algorithm to set the weights in the generalized% linear model structure NET. This is a more efficient alternative to% using GLMERR and GLMGRAD and a non-linear optimisation routine% through NETOPT. Note that for linear outputs, a single pass through% the algorithm is all that is required, since the error function is% quadratic in the weights. The error function value at the final set% of weights is returned in OPTIONS(8). Each row of X corresponds to% one input vector and each row of T corresponds to one target vector.%% The optional parameters have the following interpretations.%% OPTIONS(1) is set to 1 to display error values during training. If% OPTIONS(1) is set to 0, then only warning messages are displayed. If% OPTIONS(1) is -1, then nothing is displayed.%% OPTIONS(2) is a measure of the precision required for the value of% the weights W at the solution.%% OPTIONS(3) is a measure of the precision required of the objective% function at the solution. Both this and the previous condition must% be satisfied for termination.%% OPTIONS(5) is set to 1 if an approximation to the Hessian (which% assumes that all outputs are independent) is used for softmax% outputs. With the default value of 0 the exact Hessian (which is more% expensive to compute) is used.%% OPTIONS(14) is the maximum number of iterations for the IRLS% algorithm; default 100.%% See also% GLM, GLMERR, GLMGRAD%% Copyright (c) Christopher M Bishop, Ian T Nabney (1996, 1997)% Check arguments for consistencyerrstring = consist(net, 'glm', x, t);if ~errstring error(errstring);endif(~options(14)) options(14) = 100;enddisplay = options(1);% Do we need to test for termination?test = (options(2) | options(3));ndata = size(x, 1);% Add a column of ones for the bias inputs = [x ones(ndata, 1)];% Linear outputs are a special case as they can be found in one stepif strcmp(net.actfn, 'linear') % Solve for the weights and biases using left matrix divide temp = inputs\t; net.w1 = temp(1:net.nin, :); net.b1 = temp(net.nin+1, :); % Store error value in options vector options(8) = glmerr(net, x, t); return;end% Otherwise need to use iterative reweighted least squarese = ones(1, net.nin+1);for n = 1:options(14) switch net.actfn case 'logistic' if n == 1 % Initialise model p = (t+0.5)/2; act = log(p./(1-p)); end link_deriv = p.*(1-p); w = sqrt(link_deriv); % sqrt of weights if (min(min(w)) < eps) fprintf(1, 'Warning: ill-conditioned weights in glmtrain\n') return end z = act + (t-p)./link_deriv; % Treat each output independently with relevant set of weights for j = 1:net.nout indep = inputs.*(w(:,j)*e); dep = z(:,j).*w(:,j); temp = indep\dep; net.w1(:,j) = temp(1:net.nin); net.b1(j) = temp(net.nin+1); end [err, p, act] = glmerr(net, x, t); if n == 1 errold = err; wold = glmpak(net); else w = glmpak(net); end case 'softmax' if n == 1 % Initialise model p = (t+0.5)/2; act = log(p./(1-p)); end if options(5) == 1 | n == 1 link_deriv = p.*(1-p); weights = sqrt(link_deriv); % sqrt of weights if (min(min(weights)) < eps) fprintf(1, 'Warning: ill-conditioned weights in glmtrain\n') return end z = act + (t-p)./link_deriv; % Treat each output independently with relevant set of weights for j = 1:net.nout indep = inputs.*(weights(:,j)*e); dep = z(:,j).*weights(:,j); temp = indep\dep; net.w1(:,j) = temp(1:net.nin); net.b1(j) = temp(net.nin+1); end [err, p, act] = glmerr(net, x, t); if n == 1 errold = err; wold = glmpak(net); else w = glmpak(net); end else % Exact method of calculation after w first initialised % Start by working out Hessian %Hessian = glmhess(net, x, t); [junk, Hessian] = glmhess(net, x, t); temp = p-t; gw1 = x'*(temp); gb1 = sum(temp, 1); gradient = [gw1(:)', gb1]; % Now compute modification to weights deltaw = -gradient*pinv(Hessian); w = wold + deltaw; net = glmunpak(net, w); [err, p] = glmerr(net, x, t); end end if options(1) fprintf(1, 'Cycle %4d Error %11.6f\n', n, err) end % Test for termination % Terminate if error increases if err > errold errold = err; w = wold; options(8) = err; fprintf(1, 'Error has increased: terminating\n') return; end if test & n > 1 if (max(abs(w - wold)) < options(2) & abs(err-errold) < options(3)) options(8) = err; return; else errold = err; wold = w; end endendoptions(8) = err;if (options(1) > 0) disp('Warning: Maximum number of iterations has been exceeded');end
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -