?? gaussmix.m
字號:
wpk=repmat((1:p)',k,1);
for j=1:l
g1=g; % save previous log likelihood (2*pi factor omitted)
m1=m; % save previous means, variances and weights
v1=v;
w1=w;
for ik=1:k
% these lines added for debugging only
% vk=reshape(v(k,lixi),p,p);
% condk(ik)=cond(vk);
%%%%%%%%%%%%%%%%%%%%
vi((ik-1)*p+(1:p),:)=vik;
vim((ik-1)*p+(1:p))=vik*m(ik,:)';
mtk((ik-1)*p+(1:p))=m(ik,:)';
vm(ik)=sqrt(det(vik))*w(ik); % could do this jointly with the pinv function
% ************ should use log(vm) to avoid overflow problems
end
%
% % first do partial chunk
%
jx=jx0;
ii=1:jx;
py=reshape(sum(reshape((vi*xii-vim(:,wnj)).*(xii(wpk,:)-mtk(:,wnj)),p,jx*k),1),k,jx);
mx=max(py,[],1); % find normalizing factor for each data point to prevent underflow when using exp()
px=exp(py-mx(wk,:)).*vm(:,wnj); % find normalized probability of each mixture for each datapoint
ps=sum(px,1); % total normalized likelihood of each data point
px=px./ps(wk,:); % relative mixture probabilities for each data point (columns sum to 1)
lpx(ii)=log(ps)+mx;
pk=sum(px,2); % effective number of data points for each mixture (could be zero due to underflow)
sx=px*x(ii,:);
sx2=px*(x(ii,rix).*x(ii,cix)); % accumulator for variance calculation (lower tri cov matrix as a row)
for il=2:nl
ix=jx+1;
jx=jx+nb; % increment upper limit
ii=ix:jx;
xii=x(ii,:).';
py=reshape(sum(reshape((vi*xii-vim(:,wnb)).*(xii(wpk,:)-mtk(:,wnb)),p,nb*k),1),k,nb);
mx=max(py,[],1); % find normalizing factor for each data point to prevent underflow when using exp()
px=exp(py-mx(wk,:)).*vm(:,wnb); % find normalized probability of each mixture for each datapoint
ps=sum(px,1); % total normalized likelihood of each data point
px=px./ps(wk,:); % relative mixture probabilities for each data point (columns sum to 1)
lpx(ii)=log(ps)+mx;
pk=pk+sum(px,2); % effective number of data points for each mixture (could be zero due to underflow)
sx=sx+px*x(ii,:); % accumulator for mean calculation
sx2=sx2+px*(x(ii,rix).*x(ii,cix)); % accumulator for variance calculation
end
g=sum(lpx); % total log probability summed over all data points
gg(j)=g; % save convergence history
w=pk/n; % normalize to get the column of weights
if pk % if all elements of pk are non-zero
m=sx./pk(:,wp); % find mean and mean square
v=sx2./pk(:,wpl);
else
wm=pk==0; % mask indicating mixtures with zero weights
[vv,mk]=sort(lpx); % find the lowest probability data points
m=zeros(k,p); % initialize means and variances to zero (variances are floored later)
v=zeros(k,pl);
m(wm,:)=x(mk(1:sum(wm)),:); % set zero-weight mixture means to worst-fitted data points
wm=~wm; % mask for non-zero weights
m(wm,:)=sx(wm,:)./pk(wm,wp); % recalculate means and variances for mixtures with a non-zero weight
v(wm,:)=sx2(wm,:)./pk(wm,wpl);
end
v=v-m(:,cix).*m(:,rix); % subtract off mean squared
v(:,dix)=max(v(:,dix),c); % force diagonal elements to be >= c
if g-g1<=th && j>1
if ~ss, break; end % stop
ss=ss-1; % stop next time
end
end
if sd % we need to calculate the final probabilities
pp=lpx'-0.5*p*log(2*pi); % log of total probability of each data point
gg=gg(1:j)/n-0.5*p*log(2*pi); % average log prob at each iteration
g=gg(end);
% gg' % *** DEBUG ONLY ***
m=m1; % back up to previous iteration
v=v1;
w=w1;
mm=sum(m,1)/k;
sm=sum(m(:,rix).*m(:,cix),1)/k;
vm=sum(v,1)/k;
f=det(sm(lixi)-mm'*mm)/det(vm(lixi));
end
v=reshape(v(:,lixi)',[p,p,k]);
if l==0 % suppress the first three output arguments if l==0
m=g;
v=f;
w=pp;
end
else
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Diagonal Covariance matrices %
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
v=max(v,c); % apply the lower bound
% If data size is large then do calculations in chunks
nb=min(n,max(1,floor(memsize/(8*p*k)))); % chunk size for testing data points
nl=ceil(n/nb); % number of chunks
jx0=n-(nl-1)*nb; % size of first chunk
im=repmat(1:k,1,nb); im=im(:);
th=(l-floor(l))*n;
sd=(nargout > 3*(l~=0)); % = 1 if we are outputting log likelihood values
l=floor(l)+sd; % extra loop needed to calculate final G value
lpx=zeros(1,n); % log probability of each data point
wk=ones(k,1);
wp=ones(1,p);
wnb=ones(1,nb);
wnj=ones(1,jx0);
% EM loop
g=0; % dummy initial value for comparison
gg=zeros(l+1,1);
ss=sd; % initialize stopping count (0 or 1)
for j=1:l
g1=g; % save previous log likelihood (2*pi factor omitted)
m1=m; % save previous means, variances and weights
v1=v;
w1=w;
lvm=log(w)-0.5*sum(log(v),2); % calculate log of mixture scale factor to avoid overflow problems
vi=-0.5*v.^(-1); % exponent scale factors
% first do partial chunk
jx=jx0;
ii=1:jx;
kk=repmat(ii,k,1);
km=repmat(1:k,1,jx);
py=reshape(sum((x(kk(:),:)-m(km(:),:)).^2.*vi(km(:),:),2),k,jx)+lvm(:,wnj);
mx=max(py,[],1); % find normalizing factor for each data point to prevent underflow when using exp()
px=exp(py-mx(wk,:)); % find normalized probability of each mixture for each datapoint
ps=sum(px,1); % total normalized likelihood of each data point
px=px./ps(wk,:); % relative mixture probabilities for each data point (columns sum to 1)
lpx(ii)=log(ps)+mx;
pk=sum(px,2); % effective number of data points for each mixture (could be zero due to underflow)
sx=px*x(ii,:);
sx2=px*x2(ii,:);
for il=2:nl
ix=jx+1;
jx=jx+nb; % increment upper limit
ii=ix:jx;
kk=repmat(ii,k,1);
py=reshape(sum((x(kk(:),:)-m(im,:)).^2.*vi(im,:),2),k,nb)+lvm(:,wnb);
mx=max(py,[],1); % find normalizing factor for each data point to prevent underflow when using exp()
px=exp(py-mx(wk,:)); % find normalized probability of each mixture for each datapoint
ps=sum(px,1); % total normalized likelihood of each data point
px=px./ps(wk,:); % relative mixture probabilities for each data point (columns sum to 1)
lpx(ii)=log(ps)+mx;
pk=pk+sum(px,2); % effective number of data points for each mixture (could be zero due to underflow)
sx=sx+px*x(ii,:);
sx2=sx2+px*x2(ii,:);
end
g=sum(lpx); % total log probability summed over all data points
gg(j)=g;
w=pk/n; % normalize to get the weights
if pk % if all elements of pk are non-zero
m=sx./pk(:,wp);
v=sx2./pk(:,wp);
else
wm=pk==0; % mask indicating mixtures with zero weights
[vv,mk]=sort(lpx); % find the lowest probability data points
m=zeros(k,p); % initialize means and variances to zero (variances are floored later)
v=m;
m(wm,:)=x(mk(1:sum(wm)),:); % set zero-weight mixture means to worst-fitted data points
wm=~wm; % mask for non-zero weights
m(wm,:)=sx(wm,:)./pk(wm,wp); % recalculate means and variances for mixtures with a non-zero weight
v(wm,:)=sx2(wm,:)./pk(wm,wp);
end
v=max(v-m.^2,c); % apply floor to variances
if g-g1<=th && j>1
if ~ss, break; end % stop
ss=ss-1; % stop next time
end
end
if sd % we need to calculate the final probabilities
pp=lpx'-0.5*p*log(2*pi); % log of total probability of each data point
gg=gg(1:j)/n-0.5*p*log(2*pi); % average log prob at each iteration
g=gg(end);
% gg' % *** DEBUG ***
m=m1; % back up to previous iteration
v=v1;
w=w1;
mm=sum(m,1)/k;
f=prod(sum(m.^2,1)/k-mm.^2)/prod(sum(v,1)/k);
end
if l==0 % suppress the first three output arguments if l==0
m=g;
v=f;
w=pp;
end
end
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -