?? nlfa_iter.m
字號:
function [sources, net, params, status, fs] = ... nlfa_iter(data, sources, net, params, status)% NLFA_ITER Perform the NLFA iteration% Copyright (C) 1999-2004 Antti Honkela, Harri Valpola,% and Xavier Giannakopoulos.%% This package comes with ABSOLUTELY NO WARRANTY; for details% see License.txt in the program package. This is free software,% and you are welcome to redistribute it under certain conditions;% see License.txt for details.nsampl = size(data, 2);%nlfa_batches = 1:status.batch_size:nsampl;%%nlfa_batch = [nlfa_batches', [nlfa_batches(2:end)-1, nsampl]'];iters_left = status.iters;if ~strcmp(status.updatealg, 'old'), if isfield(status, 'oldgrads') && status.cgreset ~= -1, oldgrads = status.oldgrads; else fprintf('Resetting CG\n'); oldgrads.net = netgrad_zeros(net); oldgrads.s = zeros(size(sources)); oldgrads.norm = 0; endendwhile iters_left > 0 dcp_dnetm = netgrad_zeros(net); dcp_dnetv = netgrad_zeros(net); fs = probdist(zeros(size(data)), ones(size(data))); newkls = kl_static(net, params); % for k = 1:size(nlfa_batch, 1), %curbatch = nlfa_batch(k,1):nlfa_batch(k,2); curbatch = 1:nsampl; % Do feedforward calculations x = feedfw( sources(:, curbatch) , net, status.approximation); fs(:, curbatch) = probdist(x{4}.e, x{4}.var); % Calculate and possibly display current value of the cost function newkls = newkls + kl_batch(fs(:, curbatch), sources(:, curbatch), ... data(:, curbatch), params); %if k == size(nlfa_batch, 1) fprintf('Iteration #%d: %f\n', size(status.kls, 2), newkls); if isnan(newkls), iters_left = 0; %if size(nlfa_batch, 1) == 1, fprintf('Cost is NaN, bailing out...\n'); return %end end if (size(status.kls, 2) > 400 && ... ((min(diff(status.kls(end-10:end))) > 0) || ... (min(diff(status.kls(end-200:end))) > -status.epsilon))), fprintf('The iteration appears to have converged, bailing out...\n'); iters_left = 0; end status.kls = [status.kls newkls]; status.cputime = [status.cputime cputime]; %end % Calculate partial derivatives for parameters to adapt [dcp_dsm, dcp_dsv, newdcp_dnetm, newdcp_dnetv] =... feedback(x, net, sources(:, curbatch), data(:, curbatch), ... params.noise, status); [newdcp_dsm, newdcp_dsv] = ... feedback_srcpriors(sources(:, curbatch), params.src); dcp_dsm = dcp_dsm + newdcp_dsm; dcp_dsv = dcp_dsv + newdcp_dsv; dcp_dnetm = sum_structs(dcp_dnetm, newdcp_dnetm); dcp_dnetv = sum_structs(dcp_dnetv, newdcp_dnetv); [newdcp_dnetm, newdcp_dnetv] = ... feedback_netpriors(net, params.net, params.hyper.net); dcp_dnetm = sum_structs(dcp_dnetm, newdcp_dnetm); dcp_dnetv = sum_structs(dcp_dnetv, newdcp_dnetv); if strcmp(status.updatealg, 'old'), % Get new values for sources and alphas if appropriate if max([status.updatesrcs, status.updatesrcvars]) >= 0 sources = probdist_alpha(sources); newsources = ... updatesources(sources(:, curbatch), dcp_dsm, dcp_dsv, x{4}, ... params.src, params.noise); if status.updatesrcs < 0 sources = ... probdist_alpha(sources.e(:, curbatch), newsources.var, ... sources.malpha(:, curbatch), newsources.valpha, ... sources.msign(:, curbatch), newsources.vsign); else sources = newsources; end end if status.updatenet >= 0 net = updatenetwork(net, dcp_dnetm, dcp_dnetv); end else % new updatealg [sources, net, oldgrads, status] = update_everything(... sources, net, dcp_dsm, dcp_dsv, x{4}, params, dcp_dnetm, dcp_dnetv, ... data, newkls, status, oldgrads); if (status.cgreset > 0) && (mod(length(status.kls), status.cgreset) == 0), fprintf('Resetting CG\n'); oldgrads.net = netgrad_zeros(net); oldgrads.s = zeros(size(sources)); oldgrads.norm = 0; end end % updatealg if status.updatesrcs < 0 status.updatesrcs = status.updatesrcs + 1; if (status.updatesrcs == 0) && (~strcmp(status.updatealg, 'old')), fprintf('Resetting CG\n'); oldgrads.net = netgrad_zeros(net); oldgrads.s = zeros(size(sources)); oldgrads.norm = 0; end end if status.updatesrcvars < 0 status.updatesrcvars = status.updatesrcvars + 1; end if status.updatenet < 0 status.updatenet = status.updatenet + 1; end % Update estimates for different parameters if appropriate if status.updateparams < 0 status.updateparams = status.updateparams + 1; if (status.updateparams == 0) && (~strcmp(status.updatealg, 'old')), fprintf('Resetting CG\n'); oldgrads.net = netgrad_zeros(net); oldgrads.s = zeros(size(sources)); oldgrads.norm = 0; end else params.noise = estimatevars(probdist(fs.e-data, fs.var), ... params.hyper.noise, params.noise); params.src = estimatevars(sources, params.hyper.src, params.src); params.net.w2var = estimatevars(net.w2, params.hyper.net.w2var, ... params.net.w2var, 1); [params.hyper.net.w2var.mean, params.hyper.net.w2var.var] = ... estimatemeanvars(params.net.w2var, params.prior.net.w2var.mean, ... params.prior.net.w2var.var, params.hyper.net.w2var.var); [params.hyper.noise.mean, params.hyper.noise.var] = ... estimatemeanvars(params.noise, params.prior.noise.mean, ... params.prior.noise.var, params.hyper.noise.var, 1); [params.hyper.net.b1.mean, params.hyper.net.b1.var] = ... estimatemeanvars(net.b1, params.prior.net.b1.mean, ... params.prior.net.b1.var, params.hyper.net.b1.var, 1); [params.hyper.net.b2.mean, params.hyper.net.b2.var] = ... estimatemeanvars(net.b2, params.prior.net.b2.mean, ... params.prior.net.b2.var, params.hyper.net.b2.var, 1); [params.hyper.src.mean, params.hyper.src.var] = ... estimatemeanvars(params.src, params.prior.src.mean, ... params.prior.src.var, params.hyper.src.var, 1); end if strcmp(status.updatealg, 'old'), if (size(sources, 1) > 1), [sources, net, params] = ... scalesources(sources, net, params); end end iters_left = iters_left - 1;endif ~strcmp(status.updatealg, 'old'), status.oldgrads = oldgrads;endfs = probdist(zeros(size(data)), ones(size(data)));newkls = kl_static(net, params);% Do feedforward calculationscurbatch = 1:nsampl;x = feedfw( sources(:, curbatch) , net, status.approximation);fs(:, curbatch) = probdist(x{4}.e, x{4}.var);newkls = newkls + kl_batch(fs(:, curbatch), sources(:, curbatch), ... data(:, curbatch), params);fprintf('Finally after %d iterations: %f\n', size(status.kls, 2), newkls);function [dc_dsm, dc_dsv] = feedback_srcpriors(sources, srcparams)% FEEDBACK_SRCPRIORS Calculate the contribution of source priors% to the gradients of the cost function with respect to source valuessourcevar = normalvar(srcparams);nsampl = size(sources, 2);temp = sourcevar * ones(1, nsampl);dc_dsm = sources.e ./ temp;dc_dsv = .5 ./ temp;function [dc_dnetm, dc_dnetv] = feedback_netpriors(net, params, hypers)% FEEDBACK_NETPRIORS Calculate the contribution of network priors% to the gradients of the cost function with respect to network weightsw1var = ones(1, size(net.w1, 2));w2var = normalvar(params.w2var);[dc_dnetm.w2, dc_dnetv.w2, dc_dnetm.b2, dc_dnetv.b2] = ... netgradsprior(net.w2, net.b2, w2var, hypers.b2);[dc_dnetm.w1, dc_dnetv.w1, dc_dnetm.b1, dc_dnetv.b1] = ... netgradsprior(net.w1, net.b1, w1var, hypers.b1);function [dcp_dwm, dcp_dwv, dcp_dbm, dcp_dbv] = ... netgradsprior(w, b, wprior, bprior)% NETGRADSPRIOR Calculate the contribution of priors to partial% derivatives of kldiv with respect to network weightswpvar = repmat(wprior, [size(w, 1) 1]);bpexp = repmat(bprior.mean.e, size(b));bpvar = repmat(normalvar(bprior.var), size(b));dcp_dwm = w.e ./ wpvar;dcp_dwv = .5 ./ wpvar;dcp_dbm = (b.e - bpexp) ./ bpvar;dcp_dbv = .5 ./ bpvar;function grad = netgrad_zeros(net)grad.w2 = zeros(size(net.w2));grad.b2 = zeros(size(net.b2));grad.w1 = zeros(size(net.w1));grad.b1 = zeros(size(net.b1));function s = sum_structs(s1, s2)% SUM_STRUCTS Add all the fields of two structures togetherf = fieldnames(s1);c1 = struct2cell(s1);c2 = struct2cell(s2);if size(c1) ~= size(c2) error('sum_structs: Structures must be of same type')endc = cell(size(c1));for k=1:length(c1), c{k} = c1{k} + c2{k};ends = cell2struct(c, f, 1);
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -