?? train_test.m
字號:
%----------------------------------------% Input data format% trainx -> input data matrix (number of data x Dimension)% Ytrain -> input label (number of data x total classes)% crossx -> cross-validation data (number of data x Dimension)% Ycross -> cross-validation label (number of data x Dimension)% The training and cross-validation labels should be binary inputs% or prior probabilities.% An example for a three class label is [0 0 1] to indicate% that the training label belongs to class 3. Or it could% be [0.1 0.3 0.6] to indicate prior confidence.% KTYPE and KSCALE are the kernel parameters. For example% KTYPE = 6 is for gaussian and KSCALE is the sigma parameter% of the gaussian.% Cgini is the C parameter in SVM% B is the gamma parameter (see JMLR paper).% Niter is the number of iterations of the randomized SMO% algorithm. % The plotting function is useful only if the data is two dimensional.% To disable the plotting routines set plotflag = 0%---------------------------------------[N,D] = size(trainx);[Ny,M] = size(Ytrain);if Ny ~= N, error('Training Data size neq labels');end;[Ncross,D] = size(crossx);[Nycross,M] = size(Ycross);if Nycross ~= Ncross, error('Cross-validation Data size neq labels');end;global KTYPE;global KSCALE;%--------------------------------% Parameters to tune%--------------------------------B = 0.1; % gamma parameterNiter = 20000; % For larger dataset increase the number of randomized % iterations.Cgini = 1; % To prevent overfitting%Cgini = 2.471/B; % To prevent overfittingKTYPE = 6; % Kernel type KSCALE = 1; % Parameter%--------------------------------plotflag = 1;for i = 1:N, yindex(i) = find(Ytrain(i,:) > 0.5);end;for i = 1:Ncross, ycrossindex(i) = find(Ycross(i,:) > 0.5);end;%---------------------------------------% Train the ginisvm%---------------------------------------fprintf('Starting SVM Training....');[testcoeff,testbias] = ginitrain(trainx,Ytrain,Cgini*ones(N,1),Niter,B*ones(N,1));fprintf('....Done\n');%---------------------------------------% Compute the sparsity index.%---------------------------------------spind = find(sum(abs(testcoeff),2) < 1e-5);nsv = length(spind);for k = 1:nsv, testcoeff(spind(k),:) = zeros(1,M);end;fprintf('Sparsity Index = %d\n',nsv/N*100); %---------------------------------------% Performance on training set%---------------------------------------fprintf('Evaluating Performance on Training set\n');errordist = zeros(1,M);error = 0;eflag = zeros(N,1);for k = 1:N, mvalue = kernel(trainx(k,:),trainx)*testcoeff+testbias; [result, resultmargin] = ginitest(mvalue,B); [maxval,ind] = max(result); if ind ~= yindex(k), error = error + 1; eflag(k) = 1; errordist(yindex(k)) = errordist(yindex(k)) + 1; end;end;fprintf('Multi-class Train Error = %d percent \n',(error/N)*100);for i = 1:M, fprintf('Class %d Error = %d percent \n',i,(errordist(i)/N)*100);end;clear result resultmargin;%---------------------------------------% Performance on test set%---------------------------------------fprintf('Evaluating Performance on Cross-validation set\n');errordist = zeros(1,M);error = 0;eflagcross = zeros(Ncross,1);for k = 1:Ncross, mvalue = kernel(crossx(k,:),trainx)*testcoeff+testbias; [result, resultmargin] = ginitest(mvalue,B); [maxval,ind] = max(result); if ind ~= ycrossindex(k), error = error + 1; eflagcross(k) = 1; errordist(ycrossindex(k)) = errordist(ycrossindex(k)) + 1; end;end;fprintf('Multi-class Cross-validation Error = %d percent \n',(error/Ncross)*100);for i = 1:M, fprintf('Class %d Error = %d percent \n',i,(errordist(i)/Ncross)*100);end;clear result resultmargin;%---------------------------------------% Plot the probability Contour%---------------------------------------if plotflag == 1, fprintf('Plotting Contour ....'); figure; giniplot(trainx,Ytrain,testcoeff,testbias',B); fprintf('....done\n');end;
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -