?? mltrain1.m
字號:
function [ocinv,cinv, Cen, ppri, R] = mltrain1(Pr,Tr,setflag);
% Usage: [cinv, Cen, ppri, R] = mltrain(Pr,Tr,setflag);
% Given training feature and target labels, determine a Gaussian mixture
% model for each class of samples to enable Maximum likelihood classification
% Pr: training set feature vector K by N
% Tr: training set target vector K by S
% cinv: inverse of covariance matrix N by N by S
% Cen: Gaussian centers S X N
% ppri: prior probability and weighting terms of each conditional prob. cn X 1
%
% (C) 2001 by Yu Hen Hu
% created: 6/25/2001
% modified: 9/26/2001 to add Gaussian mixture model
% modified: 11/4/2001
[K,N] = size(Pr);
prnorm=ones(K,1)./sqrt(sum(Pr'.*Pr')'); % K X 1 vector
Pr = (prnorm*ones(1,N)).*Pr; % make each row unity norm
S = size(Tr,2);
if S==1, % if only one output with 0 in one class and 1 the other, change it to 2 outputs
Tr=[Tr ones(K,1)-Tr];
S=2;
end
nos=sum(Tr); % number of samples in each class in training set
ppri=nos'/sum(nos); % S X 1
%===================
% calculate mean and cov matrix of each class
Cen = zeros(S,N);
cdet=ones(S,1); % determinant of covariance matrix of each class
% now try to calculate the mean and covariance matrix.
% if nos(cn) < N, then covariance matrix is degenerated.
for cn=1:S,
if nos(cn) > 1, % if at least two points in the same class
tmp=Pr(Tr(:,cn)==1,:); % extract features vectors of class cn
R{cn}=tmp'*tmp/nos(cn); % correlation matrix
%Rinv{cn}=inv(R{cn});% inverse of correlation matrix
% tmp is nos(cn) by N
Cen(cn,:)=mean(tmp); % centroid Cen is S by N, each a mean vector
% of the cn class.
tmp1=tmp-ones(nos(cn),1)*Cen(cn,:); % tmp1 (=x-m(cn)) is nos(cn) by N, subtract mean
if setflag == 0,
ctmp=tmp1'*tmp1/nos(cn); % ctmp is N by N the sample covariance matrix
else
ctmp = R{cn};
end
% however, ctmp may be rank deficient due to too few samples nos(cn) < N
% or sample points are linearly dependent.
[v,D]=eig(ctmp); % perform eigenvalue deposition
ocinv{cn} = ctmp;
% disp(['class # ' int2str(cn) ' eigenvalues are: '])
[d,idd]=sort(-diag(D)); d=-d; v=v(:,idd); % sort eigenvalue and eigenvectors
% figure(1),stem([1:N],d);
d=max(d,(1/N)*trace(D));
%idx=find(diag(D)>thresh); crank0=length(idx);
%disp([int2str(crank0) ' eigenvalues are greater than ' num2str(thresh)]);
%crank=input(['Default rank = ' int2str(crank0) '. Enter 1 to change, return to accept: ']);
%if isempty(crank), crank=crank0; end
%d = diag(D);
D=diag(d);
%cinv{cn}=v(:,idx)*pinv(diag(d(idx)))*v(:,idx)'; % compute C^(-1)
cinv{cn}=v*pinv(D)*v';
% cdet(cn)=prod(d(idx));
elseif nos(cn)==1,
Cen(cn,:)=Pr(Tr(:,cn)==1,:);% use the data point itself as the centroid
cinv{cn}=eye(N);
elseif nos(cn)==0, % no training sample in class i
Cen(cn,:)=zeros(1,N); % choose origin as arbitrary center
cinv{cn}=eye(N);
end
end
ppri=log(ppri);
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -