?? hdda_learn.m
字號:
function prms = hdda_learn(Xl,varargin);
% High Dimensional Discriminant Analysis (learning)
%
% Usage: (1) prms = hdda_learn(X,'model','best','seuil',s);
% (1) prms = hdda_learn(X,'model','best','dim',d);
% (3) prms = hdda_learn(X,'model','AijBiQiDi','seuil',s);
% (4) prms = hdda_learn(X,'model','AijBiQiD','dim',d);
%
% Input:
% - X: training data made of X.data and X.cls
% - model: 'AijBiQiDi', 'AijBQiDi', 'AiBiQiDi', 'ABiQiDi', 'AiBQiDi', 'ABQiDi',
% 'AijBiQiD', 'AijBQiD', 'AiBiQiD', 'ABiQiD', 'AiBQiD', 'ABQiD',
% 'AjBQD', 'ABQD', 'best'.
% Output:
% - prms: learned model parameters
% * prms.model: model used
% * prms.bic: bic value
% * prms.k: number of classes
% * prms.p: original dimension of the data
% * prms.prop: proportions of the classes
% * prms.a: parameters a_ij c
% * prms.b: parameters b_i of the classes
% * prms.d: intrinsic dimensions of the classes
% * prms.m: means of the classes;
% * prms.Q: orientations matrices Q_i of the classes
%
% Authors: C. Bouveyron <charles.bouveyron@inrialpes.fr> - 2004-2006
%
% Reference: C. Bouveyron, S. Girard and C. Schmid, "High Dimensional Discriminant Analysis",
% Communications in Statistics, Theory and methods, in press, 2007.
%%%%%%%%%%%%%%%%%%%% Initialization %%%%%%%%%%%%%%%%%%%%
% Global parameters
seuil = 0.2;
dim = [];
model = 'AiBiQiDi';
% Data management
data = Xl.data; cls = Xl.cls;
% PARAMETERS MANAGEMENT
varrem={};
for i=1:2:length(varargin)
if ~isstr(varargin{i}) | ~exist(varargin{i},'var')
varrem = varargin(i:end);
end
eval([varargin{i} '= varargin{i+1};']);
end
% Test of parameter value
if seuil>=1 || seuil<=0, error('> The parameter ''seuil'' must be strictly within ]0,1[!'); end
if isempty(strmatch(model,strvcat('AijBiQiDi', 'AijBQiDi', 'AiBiQiDi', 'ABiQiDi', 'AiBQiDi', 'ABQiDi',...
'AijBiQiD', 'AijBQiD', 'AiBiQiD', 'ABiQiD', 'AiBQiD', 'ABQiD','AjBQD', 'ABQD','best'),'exact'))
error('--> The parameter ''model'' is not valide: see the help of hdda.');
end
% Calling the main function %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
if isequal(model,'best') % looking for the model with the smallest BIC value
if isempty(dim), % models with free dimensions
models = {'AijBiQiDi', 'AijBQiDi', 'AiBiQiDi', 'ABiQiDi', 'AiBQiDi', 'ABQiDi'};
fprintf('--> Bic values of the HDDA models with free dimensions:\n');
for i = 1:length(models)
prms_t{i} = learn(data,cls,models{i},seuil,dim);
bic(i) = prms_t{i}.bic;
fprintf(' - model %s: %g\n',models{i},bic(i));
end
else % models with common dimensions
models = {'AijBiQiD', 'AijBQiD', 'AiBiQiD', 'ABiQiD', 'AiBQiD', 'ABQiD'};
fprintf('--> Bic values of the HDDA models with common dimensions:\n');
for i = 1:length(models)
prms_t{i} = learn(data,cls,models{i},seuil,dim);
bic(i) = prms_t{i}.bic;
fprintf(' - model %s: %g\n',models{i},bic(i));
end
end
[val,ind] = min(fliplr(bic));
prms = prms_t{ind};
fprintf('--> Best model: %s\n',prms_t{ind}.model);
else
prms = learn(data,cls,model,seuil,dim);
fprintf('--> Used model: %s\n',prms.model);
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% SUB-FUNCTION %
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function prms = learn(data,cls,model,seuil,dim);
% Initialization of parameters
k = max(cls);
common_d = 0;
[N,p] = size(data);
for i=1:k,
X{i} = data(cls==i,:);
m(i,:) = mean(X{i});
end
if ~isempty(strmatch(model,strvcat('AijBiQiD', 'AijBQiD', 'AiBiQiD', 'ABiQiD', 'AiBQiD', 'ABQiD','AjBQD', 'ABQD'),'exact'))
common_d = 1;
%fprintf('--> the dimensions will be common between classes!\n')
end
%%%%%%%%%%%%%%%%%%%% Compute intrinsec dimensions and others %%%%%%%%%%%%%%%%%%%%
L = zeros(k,p);
for i=1:k
% Compute the class proportion
n(i) = size(X{i},1);
prop(i) = n(i) / N;
if n(i) < p, % if the number of observations is smaller than the dimension p
Y = X{i} - repmat(m(i,:),size(X{i},1),1);
[W,LL] = eig(Y*Y');
LL = diag(LL) ./ n(i);
[LL,ind] = sort(LL,'descend');
Tr(i) = sum(LL);
else % otherwise
Sigma{i} = cov(X{i},1);
Tr(i) = trace(Sigma{i});
LL = eig(Sigma{i});
[LL,ind] = sort(LL,'descend');
end
% Find intrinsic dimensions using the sree-test of Cattell
if common_d,
d(i) = dim;
else
sc = diff([LL(2:end),LL(1:end-1)],1,2);
for j=1:p-1,
if prod(double(sc(j+1:end) < seuil*max(sc))),
d(i) = j; break;
end
end
end
% Find the d_i eigenvectors
L(i,1:d(i)) = LL(1:d(i))';
if n(i) < p, V{i} = Y' * W(:,ind(1:d(i)));
else opt.disp = 0; [V{i},LL] = eigs(Sigma{i},d(i),'LM',opt);
end
clear Y
end
%%%%%%%%%%%%%%%%%%%% Models with free and common Q_i %%%%%%%%%%%%%%%%%%%%
switch model
case {'AijBiQiDi','AijBiQiD'} % Models [a_ij b_i Q_i d_i] and [a_ij b_i Q_i d]
a = zeros(max(d(i)),k);
for i=1:k
a(1:d(i),i) = L(i,1:d(i));
b(i) = (Tr(i) - sum(L(i,1:d(i)),2)) / (p-d(i));
end
if isequal(model,'AijBiQiDi')
q = (k*p+k-1) + sum( d.*(p-(d+1)/2) ) + sum(d) + 2*k;
else q = (k*p+k-1) + sum( d.*(p-(d+1)/2) ) + sum(d) + k + 1;
end
case {'AijBQiDi','AijBQiD'} % Models [a_ij b Q_i d_i] and [a_ij b Q_i d]
a = zeros(max(d(i)),k);
for i=1:k
a(1:d(i),i) = L(i,1:d(i));
s(i) = (Tr(i) - sum(L(i,1:d(i)),2));
end
b(1:k) = sum(n.*s) / sum(n.*(p-d));
clear s;
if isequal(model,'AijBQiDi')
q = (k*p+k-1) + sum( d.*(p-(d+1)/2) ) + sum(d) + k + 1;
else q = (k*p+k-1) + sum( d.*(p-(d+1)/2) ) + sum(d) + 2;
end
case {'AjBiQiD'} % Model [a_j b_i Q_i d]
a = zeros(max(d(i)),k);
W = zeros(p);
for i=1:k,
W = W + prop(i).*Sigma{i};
b(i) = (Tr(i) - sum(L(i,1:d(i)),2)) / (p-d(i));
end
opt.disp = 0;
[VV,LL] = eigs(W,dim,'LM',opt);
LL = diag(LL);
a(1:dim,1:k) = repmat(LL(1:dim),1,k);
q = (k*p+k-1) + sum( d.*(p-(d+1)/2) ) + k + dim + 1;
case {'AjBQiD'} % Model [a_j b Q_i d]
a = zeros(max(d(i)),k);
W = zeros(p);
for i=1:k
W = W + prop(i).*Sigma{i};
s(i) = (Tr(i) - sum(L(i,1:d(i)),2));
end
b(1:k) = sum(n.*s) / sum(n.*(p-d));
clear s;
opt.disp = 0;
[VV,LL] = eigs(W,dim,'LM',opt);
LL = diag(LL);
a(1:dim,1:k) = repmat(LL(1:dim),1,k);
q = (k*p+k-1) + sum( d.*(p-(d+1)/2) ) + dim + 2;
case {'AiBiQiDi','AiBiQiD'} % Model [a_i b_i Q_i]
for i=1:k
a(i) = sum(L(i,1:d(i)),2) / d(i);
b(i) = (Tr(i) - sum(L(i,1:d(i)),2)) / (p-d(i));
end
if isequal(model,'AiBiQiDi')
q = (k*p+k-1) + sum( d.*(p-(d+1)/2) ) + 3*k;
else q = (k*p+k-1) + sum( d.*(p-(d+1)/2) ) + 2*k + 1;
end
case {'ABiQiDi','ABiQiD'} % Model [a b_i Q_i]
for i=1:k,
s(i) = sum(L(i,1:d(i)));
b(i) = (Tr(i) - sum(L(i,1:d(i)),2)) / (p-d(i));
end
a(1:k) = sum(n.*s) / sum(n.*d);
clear s;
if isequal(model,'ABiQiDi')
q = (k*p+k-1) + sum( d.*(p-(d+1)/2) ) + 2*k + 1;
else q = (k*p+k-1) + sum( d.*(p-(d+1)/2) ) + k + 2;
end
case {'AiBQiDi','AiBQiD'} % Model [a_i b Q_i]
for i=1:k,
s(i) = (Tr(i) - sum(L(i,1:d(i)),2));
a(i) = sum(L(i,1:d(i)),2) / d(i);
end
b(1:k) = sum(n.*s) / sum(n.*(p-d));
clear s;
if isequal(model,'AiBQiDi')
q = (k*p+k-1) + sum( d.*(p-(d+1)/2) ) + 2*k + 1;
else q = (k*p+k-1) + sum( d.*(p-(d+1)/2) ) + k + 2;
end
case {'ABQiDi','ABQiD'} % Model [a b Q_i]
for i=1:k,
s1(i) = sum(L(i,1:d(i)));
s2(i) = (Tr(i) - sum(L(i,1:d(i)),2));
end
a(1:k) = sum(n.*s1) / sum(n.*d);
b(1:k) = sum(n.*s2) / sum(n.*(p-d));
if isequal(model,'ABQiDi')
q = (k*p+k-1) + sum( d.*(p-(d+1)/2) ) + k + 2;
else q = (k*p+k-1) + sum( d.*(p-(d+1)/2) ) + 3;
end
case 'AjBQD' % Model [a_j b Q d]
W = zeros(p);
for i=1:k, W = W + prop(i).*Sigma{i}; end
opt.disp = 0;
[VV,LL] = eigs(W,dim,'LM',opt);
LL = diag(LL);
for i=1:k, V{i} = VV; end
tr = trace(W);
a(1:dim,1:k) = repmat(LL(1:dim),1,k);
b(1:k) = (tr - sum(LL(1:dim))) / (p-dim);
d(1:k) = dim;
q = (k*p+k-1) + dim*(p-(dim+1)/2) + dim + 2;
case 'ABQD' % Model [a b Q d]
W = zeros(p);
for i=1:k, W = W + prop(i).*Sigma{i}; end
opt.disp = 0;
[VV,LL] = eigs(W,d(i),'LM',opt);
LL = diag(LL);
for i=1:k, V{i} = VV; end
tr = trace(W);
a(1:k) = sum(LL(1:dim)) / dim;
b(1:k) = (tr - sum(LL(1:dim))) / (p-dim);
q = (k*p+k-1) + dim*(p-(dim+1)/2) + 3;
otherwise fprintf('--> not yet implemented!')
end
% BIC value %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
switch model
case {'AijBiQiDi','AijBiQiD','AijBQiDi','AijBQiD','AjBiQiD','AjBQiD','AjBQD'}
ll = 0;
for i=1:k
Pa = ((data - repmat(m(i,:),N,1)) * V{i}) * V{i}';
Pb = Pa + repmat(m(i,:),N,1) - data;
ai = a(1:d(i),i)';
K(:,i) = diag(Pa * V{i} * diag(1./ai) * V{i}' * Pa') ...
+ (1/b(i) * sum(Pb.^2,2)) + sum(log(ai)) ...
+ (p-d(i)) * log(b(i)) - 2 * log(prop(i)) + p * log(2*pi);
ll = ll + sum(K(cls==i,i),1);
end
otherwise % other models
ll = 0;
for i=1:k
Pa = ((data - repmat(m(i,:),N,1)) * V{i}) * V{i}';
Pb = Pa + repmat(m(i,:),N,1) - data;
K(:,i) = 1/a(i) * sum(Pa.^2,2) + (1/b(i) * sum(Pb.^2,2)) ...
+ d(i) * log(a(i)) + (p-d(i)) * log(b(i)) - 2 * log(prop(i)) + p * log(2*pi);
ll = ll + sum(K(cls==i,i),1);
end
end
bic = (ll + q * log(N)) / N;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%% Return parameters %%%%%%%%%%%%%%%%%%%%%%%%%%%%
prms.model = model; prms.bic = bic;
prms.k = k; prms.p = p;
prms.a = a; prms.b = b;
prms.d = d; prms.prop = prop;
prms.m = m; prms.Q = V;
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -