?? rjnn.m
字號:
function [k,mu,alpha,sigma,nabla,delta,ypred,ypredv,post] = rjnn(x,y,chainLength,Ndata,bFunction,par,xv,yv);%% =============================if nargin < 5, error('Not enough input arguments.'); end;if ((nargin==5) | (nargin==7)), if nargin == 5 Validation = 0; else Validation = 1; end; hyper.a = 2; % Hyperparameter for delta. hyper.b = 10; % Hyperparameter for delta. hyper.e1 = 0.0001; % Hyperparameter for nabla. hyper.e2 = 0.0001; % Hyperparameter for nabla. hyper.v = 0; % Hyperparameter for sigma hyper.gamma = 0; % Hyperparameter for sigma. kMax = 50; % Maximum number of basis. arbC = 0.5; % Constant for birth and death moves. doPlot = 1; % To plot or not to plot? Thats ... sigStar = .1; % Merge-split parameter. sWalk = .001; Lambda = .5; walkPer = 0.1;elseif ((nargin==6) | (nargin==8)) if nargin == 6 Validation = 0; else Validation = 1; end; hyper.a = par.a; hyper.b = par.b; hyper.e1 = par.e1; hyper.e2 = par.e2; hyper.v = par.v; hyper.gamma = par.gamma; kMax = par.kMax; arbC = par.arbC; doPlot = par.doPlot; sigStar = par.merge; sWalk = par.sRW; Lambda = par.Lambda; walkPer = par.walkPer;else error('Wrong Number of input arguments.');end;if Validation, [Nv,dv] = size(xv); % Nv = number of test data, dv = dimension of xv.end;[N,d] = size(x); % N = number of train data, d = dimension of x.[N,c] = size(y); % c = dimension of y, i.e. number of outputs.if Ndata ~= N, error('input must me N by d and output N by c.'); end;% INITIALISATION:% ==============post = ones(chainLength,1); % p(centres,k|y).if Validation, ypredv = zeros(Nv,c,chainLength); % Output fit (test set).end;ypred = zeros(N,c,chainLength); % Output fit (train set).nabla = zeros(chainLength,1); % Poisson parameter.delta = zeros(chainLength,c); % Regularisation parameter.k = ones(chainLength,1); % Model order - number of basis.sigma = ones(chainLength,c); % Output noise variance.mu = cell(chainLength,1); % Radial basis centres.alpha = cell(chainLength,c); % Radial basis coefficients.% DEFINE WALK INTERVAL FOR MU:% ===========================walk = walkPer*(max(x)-min(x));walkInt=zeros(d,1);for i=1:d, walkInt(i,1) = (max(x(:,i))-min(x(:,i))) + 2*walk(i);end;% SAMPLE INITIAL CONDITIONS FROM THEIR PRIORS:% ===========================================nabla(1) = gengamma(0.5 + hyper.e1,hyper.e2);k(1) = poissrnd(nabla(1));k(1) = 40; % TEMPORARY: for demo1 comparison.k(1) = max(k(1),1);k(1) = min(k(1),kMax);for i=1:c delta(1,i) = inv(gengamma(hyper.a,hyper.b)); sigma(1,i) = inv(gengamma(hyper.v/2,hyper.gamma/2)); alpha{1,i} = mvnrnd(zeros(1,k(1)+d+1),sigma(1,i)*delta(1,i)*eye(k(1)+d+1),1)';end;% DRAW THE INITIAL RADIAL CENTRES:% ===============================mu{1}=zeros(k(1),d);for i=1:d, mu{1}(:,i)= (min(x(:,i))-walk(i))*ones(k(1),1) + ((max(x(:,i))+walk(i))-(min(x(:,i))-walk(i)))*rand(k(1),1);end;% FILL THE REGRESSION MATRIX:% ==========================M=zeros(N,k(1)+d+1);M(:,1) = ones(N,1);M(:,2:d+1) = x;for j=d+2:k(1)+d+1, M(:,j) = feval(bFunction,mu{1}(j-d-1,:),x);end;for i=1:c, ypred(:,i,1) = M*alpha{1,i};end;if Validation Mv=zeros(Nv,k(1)+d+1); Mv(:,1) = ones(Nv,1); Mv(:,2:d+1) = xv; for j=d+2:k(1)+d+1, Mv(:,j) = feval(bFunction,mu{1}(j-d-1,:),xv); end; for i=1:c, ypredv(:,i,1) = Mv*alpha{1,i}; end;end;% INITIALISE COUNTERS:% ===================aUpdate=0;rUpdate=0;aBirth=0;rBirth=0;aDeath=0;rDeath=0;aMerge=0;rMerge=0;aSplit=0;rSplit=0;aRW=0;rRW=0;match=0;if doPlot figure(3) clf;end;% ITERATE THE MARKOV CHAIN:% ========================for t=1:chainLength-1, iteration=t % COMPUTE THE CENTRES AND DIMENSION WITH METROPOLIS, BIRTH AND DEATH MOVES: % ======================================================================== decision=rand(1); birth=arbC*min(1,(nabla(t)/(k(t)+1))); death=arbC*min(1,((k(t)+1)/nabla(t))); if ((decision <= birth) & (k(t)<kMax)), [k,mu,M,match,aBirth,rBirth] = radialBirth(match,aBirth,rBirth,k,mu,M,delta,x,y,hyper,t,bFunction,walkInt,walk); elseif ((decision <= birth+death) & (k(t)>0)), [k,mu,M,aDeath,rDeath] = radialDeath(aDeath,rDeath,k,mu,M,delta,x,y,hyper,t,nabla); elseif ((decision <= 2*birth+death) & (k(t)<kMax) & (k(t)>1)), [k,mu,M,aSplit,rSplit] = radialSplit(aSplit,rSplit,k,mu,M,delta,x,y,hyper,t,bFunction,sigStar,walkInt,walk); elseif ((decision <= 2*birth+2*death) & (k(t)>1)), [k,mu,M,aMerge,rMerge] = radialMerge(aMerge,rMerge,k,mu,M,delta,x,y,hyper,t,bFunction,sigStar,walkInt); else uLambda = rand(1); if ((uLambda>Lambda) & (k(t)>0)) [k,mu,M,match,aRW,rRW] = radialRW(match,aRW,rRW,k,mu,M,delta,x,y,hyper,t,bFunction,sWalk,walk); else [k,mu,M,match,aUpdate,rUpdate] = radialUpdate(match,aUpdate,rUpdate,k,mu,M,delta,x,y,hyper,t,bFunction,walkInt,walk); end; end; % UPDATE OTHER PARAMETERS WITH GIBBS: % ================================== H=zeros(k(t+1)+1+d,k(t+1)+1+d,c); F=zeros(k(t+1)+1+d,c); P=zeros(N,N,c); for i=1:c, H(:,:,i) = inv(M'*M + (1/delta(t,i))*eye(k(t+1)+1+d)); F(:,i) = H(:,:,i)*M'*y(:,i); P(:,:,i) = eye(N) - M*H(:,:,i)*M'; sigma(t+1,i) = inv(gengamma((hyper.v+N)/2,(hyper.gamma+y(:,i)'*P(:,:,i)*y(:,i))/2)); alpha{t+1,i} = mvnrnd(F(:,i),sigma(t+1,i)*H(:,:,i),1)'; delta(t+1,i) = inv(gengamma(hyper.a+(k(t+1)+d+1)/2,hyper.b+inv(2*sigma(t+1,i))*alpha{t+1,i}'*alpha{t+1,i})); end; nabla(t+1) = gengamma(0.5+hyper.e1+k(t+1),1+hyper.e2); % COMPUTE THE POSTERIOR FOR MONITORING: % ==================================== posterior =exp(-nabla(t+1)) * delta(t+1,1)^(-(d+k(t+1)+1)/2) * inv(prod(1:k(t+1)) * prod(walkInt)^(k(t+1))) * nabla(t+1)^(k(t+1)) * sqrt(det(H(:,:,1))) * (hyper.gamma+y(:,1)'*P(:,:,1)*y(:,1))^(-(hyper.v+N)/2); for i=2:c, newpost = delta(t+1,i)^(-(d+k(t+1)+1)/2) * sqrt(det(H(:,:,i))) * (hyper.gamma+y(:,i)'*P(:,:,i)*y(:,i))^(-(hyper.v+N)/2); posterior = posterior * newpost; end; post(t+1) = log(posterior); % PLOT FOR FUN AND MONITORING: % ============================ for i=1:c, ypred(:,i,t+1) = M*alpha{t+1,i}; end; msError = inv(N) * trace((y-ypred(:,:,t+1))'*(y-ypred(:,:,t+1)));% NRMSE = sqrt((y-ypred(:,:,t+1))'*(y-ypred(:,:,t+1))*inv((y-mean(y)*ones(size(y)))'*(y-mean(y)*ones(size(y))))) if Validation, % FILL THE VALIDATION REGRESSION MATRIX: % ====================================== Mv=zeros(Nv,k(t+1)+d+1); Mv(:,1) = ones(Nv,1); Mv(:,2:d+1) = xv; for j=d+2:k(t+1)+d+1, Mv(:,j) = feval(bFunction,mu{t+1}(j-d-1,:),xv); end; for i=1:c, ypredv(:,i,t+1) = Mv*alpha{t+1,i}; end; msErrorv = inv(Nv) * trace((yv-ypredv(:,:,t+1))'*(yv-ypredv(:,:,t+1))); end; if doPlot, figure(1) clf if (c==2), plot(x(:,1),y(:,1),'b+',x(:,2),y(:,2),'r+',x(:,1),ypred(:,1,t+1),'bo',x(:,2),ypred(:,2,t+1),'ro'); elseif c==1, plot(x,y,'b+',x,ypred(:,:,t+1),'ro'); end; errorv = sum(abs(yv-ypredv(:,:,t+1)))*100*inv(Nv); ylabel('Output','fontsize',15) xlabel('Input','fontsize',15) figure(3) subplot(511); hold on; plot(t,k(t),'*'); ylabel('k','fontsize',15); subplot(512); hold on; plot(t,post(t+1),'*'); ylabel('p(k,mu|y)','fontsize',15); subplot(513); hold on; plot(t,msError,'r*'); ylabel('Train error','fontsize',15); subplot(514); hold on; plot(t,msErrorv,'r*'); ylabel('Test error','fontsize',15); subplot(515); hold on; bar([1 2 3 4 5 6 7 8 9 10 11 12 13],[match aUpdate rUpdate aBirth rBirth aDeath rDeath aMerge rMerge aSplit rSplit aRW rRW]); ylabel('Acceptance','fontsize',15); xlabel('match aU rU aB rB aD rD aM rM aS rS aRW rRW','fontsize',15) end;end;
?? 快捷鍵說明
復(fù)制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -