?? quickprop.c
字號:
/* Quickprop algorithm (developed by Scott Fahlman)
- a standard modified bp algorithm for comparison
*/
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <ctype.h>
#include <string.h>
#include <time.h>
#define GA 1
#define S 1 /* generalized BP with the S parameter */
#define mu 1.75 /* maximum grouth factor */
#define threshold 0.0 /* mode switch threshold */
#define ERRORLEVEL 0.001 /* stopping criteria */
#define NITERATIONS 3000 /* no. of iterations to be run */
#define P 4 /* no. of patterns to be trained */
#define I 2 /* no. of input nodes */
#define H 2 /* no. of hidden nodes */
#define J 1 /* no. of output nodes */
#define N 9 /* no. of weights = H*(I+1)+J*(H+1) */
#define weightfile 30 /* no. of weight files used */
double target[P][J], out0[P][I], out1[P][H], out2[P][J];
double weights[N],delta1[P][H],delta2[P][J];
double dE1[H][I+1],dE2[J][H+1];
double pre_dE1[H][I+1],pre_dE2[J][H+1];
double delw1[H][I+1], delw2[J][H+1];
FILE *fpRun, *fpPattern, *fpWts;
FILE *fpWeightsOut, *fpResults, *fpError;
void itoa(n, s) /* convert integer to character */
int n; char s[];
{
int i=0;
if (n/10 ==0)
s[i++]= n +'0';
else
{
s[i++] = (n/10)+'0';
s[i++] = (n%10)+'0';
}
s[i] = '\0';
}
double minimum(a, b)
double a, b;
{
if (a < b)
return a;
else
return b;
}
double maximum(a, b)
double a, b;
{
if (a > b)
return a;
else
return b;
}
double sign(a)
double a;
{
if (a < 0.0)
return -1.0;
else
return 1.0;
}
main (argc, argv)
int argc;
char *argv[];
{
double eta,alpha;
double error[20],derror,temp,temp1,sum,dw;
double converge=0.0;
register int h,i,j,p,q,r,l,x,min,tmp;
int nIterations=NITERATIONS;
unsigned steady=0,non_conv=0;
char szResults[66],szError[66],szPattern[66],szWeightsOut[66];
char charstr[12],tmpstr[3];
double optwts[N], minerr;
time_t t,start,end;
int offset; /* a value to set the weight files */
t = time(NULL); /* randomize the seed for each run */
/* tmp = srand(t);
*/
srand(t);
if (argc < 2)
{
fprintf(stderr, "Usage: %s runfilename\n", argv[0]);
exit(1);
}
if ((fpRun = fopen(*++argv,"r")) == NULL)
{
fprintf(stderr, "can't open file %s\n", *argv);
exit(1);
}
offset = atoi(argv[1]);
fscanf(fpRun, "%s %s %s %s %lf %lf",
szResults, szError, szPattern, szWeightsOut, &eta, &alpha);
fclose(fpRun);
if ((fpPattern = fopen(szPattern, "r")) == NULL)
{
fprintf(stderr, "can't open file %s\n", szPattern);
exit(1);
}
for (p=0; p<P; p++)
{
for (i=0; i<I; i++)
fscanf(fpPattern, "%lf", &out0[p][i]);
for (j=0; j<J; j++)
fscanf(fpPattern, "%lf", &target[p][j]);
}
fclose(fpPattern);
if ((fpError = fopen(szError, "w")) == NULL)
{
fprintf(stderr, "can't open file %s \n", szError);
exit(1);
}
start = time(NULL);
for (x=offset; x<weightfile+offset; x++)
{
minerr = 99999999.0; steady=0;
strcpy(charstr,"w");
itoa(x,tmpstr);
strcat(charstr,tmpstr);
strcat(charstr,".wts");
if ((fpWts = fopen(charstr,"r")) == NULL)
{
fprintf (stderr, "can't open wts file\n");
exit(1);
}
for (h=0; h<H; h++)
for (i=0; i<=I; i++)
{
fscanf (fpWts, "%lf", &weights[h*(I+1)+i]);
delw1[h][i] = 0.0;
pre_dE1[h][i] = 0.0;
}
for (j=0; j<J; j++)
for (h=0; h<=H; h++)
{
fscanf(fpWts, "%lf", &weights[H*(I+1)+j*(H+1)+h]);
delw2[j][h] = 0.0;
pre_dE2[j][h]= 0.0;
}
fclose(fpWts);
/* begin processing */
for (q=0; q <= nIterations; q++)
{
/* calculate feed-forward net = forward pass */
for (p=0; p<P; p++)
{
for (h=0; h<H; h++)
{
sum = weights[h*(I+1)+I];
for (i=0; i< I; i++)
sum += weights[h*(I+1)+i] * out0[p][i];
out1[p][h] = 1.0 / (1.0 + exp(-sum));
}
for (j=0; j<J; j++)
{
sum = weights[H*(I+1)+j*(H+1)+H];
for (h=0; h< H; h++)
sum += weights[H*(I+1)+j*(H+1)+h] * out1[p][h];
out2[p][j] = 1.0 / (1.0 + exp(-sum));
}
/* calculate error signals */
for (j=0; j<J; j++)
{
temp = target[p][j]-out2[p][j];
temp1= pow(out2[p][j]*(1.0-out2[p][j]),1.0/S);
delta2[p][j] = temp * temp1;
}
for (h=0; h<H; h++)
{
sum = 0.0;
temp1= pow(out1[p][h]*(1.0-out1[p][h]),1.0/S);
for (j=0; j<J; j++)
sum += delta2[p][j]*weights[H*(I+1)+j*(H+1)+h];
delta1[p][h] = sum * temp1;
}
}
/* calculate system error */
if (q==0) r=0;
for (p=0, error[r]=0.0; p<P; p++)
{
for (j=0; j<J; j++)
{
temp = out2[p][j]-target[p][j];
error[r] += temp * temp;
}
}
error[r] /= (P * J);
if (error[r] < ERRORLEVEL)
break;
if (error[r] < minerr)
{
minerr = error[r];
for (l=0; l<N; l++) optwts[l] = weights[l];
}
fprintf (stderr,"Iteration %5d/%-5d Error %lf minerr %lf\r",
q, nIterations, error[r], minerr);
/* calculate rate of change of error with respect to weights */
for (j=0; j<J; j++)
{
dE2[j][H] = 0.0;
for (p=0; p<P; p++)
dE2[j][H] += delta2[p][j];
for (h=0; h<H; h++)
{
dE2[j][h] = 0.0;
for (p=0; p<P; p++)
dE2[j][h] += delta2[p][j] * out1[p][h];
}
}
for (h=0; h<H; h++)
{
dE1[h][I] = 0.0;
for (p=0; p<P; p++)
dE1[h][I] += delta1[p][h];
for (i=0; i<I; i++)
{
dE1[h][i] = 0.0;
for (p=0; p < P; p++)
dE1[h][i] += delta1[p][h] * out0[p][i];
}
}
/* calculate weight update rule */
for (j=0; j<J; j++)
for (h=0; h<=H; h++)
{
dw = 0.0;
if (delw2[j][h] > threshold)
{
if (dE2[j][h] > 0.0)
dw += eta*dE2[j][h];
if (dE2[j][h] > (mu/(1.0+mu))*pre_dE2[j][h])
dw += mu*delw2[j][h];
else
dw +=dE2[j][h]/(pre_dE2[j][h]-dE2[j][h])*delw2[j][h];
}
else if (delw2[j][h] < -threshold)
{
if (dE2[j][h] <0.0)
dw += eta*dE2[j][h];
if (dE2[j][h] < (mu/(1.0+mu))*pre_dE2[j][h])
dw += mu*delw2[j][h];
else
dw +=dE2[j][h]/(pre_dE2[j][h]-dE2[j][h])*delw2[j][h];
}
else
dw += eta*dE2[j][h]+alpha*delw2[j][h];
weights[H*(I+1)+j*(H+1)+h] += dw;
pre_dE2[j][h] = dE2[j][h];
delw2[j][h] = dw;
}
for (h=0; h<H; h++)
for (i=0; i<=I; i++)
{
dw = 0.0;
if (delw1[h][i] > threshold)
{
if (dE1[h][i] >0.0)
dw += eta*dE1[h][i];
if (dE1[h][i] > (mu/(1+mu))*pre_dE1[h][i])
dw += mu*delw1[h][i];
else
dw +=dE1[h][i]/(pre_dE1[h][i]-dE1[h][i])*delw1[h][i];
}
else if (delw1[h][i] < -threshold)
{
if (dE1[h][i] >0.0)
dw += eta*dE1[h][i];
if (dE1[h][i] < (mu/(1+mu))*pre_dE1[h][i])
dw += mu*delw1[h][i];
else
dw +=dE1[h][i]/(pre_dE1[h][i]-dE1[h][i])*delw1[h][i];
}
else
dw += eta*dE1[h][i]+alpha*delw1[h][i];
weights[h*(I+1)+i] += dw;
pre_dE1[h][i] = dE1[h][i];
delw1[h][i] = dw;
}
fprintf (fpError, "%lf\n", error[r]);
if (++r==10) r=0;
}
/* end processing */
printf ("Iteration %5d/%-5d Error %lf minerr %lf\n",q-1,nIterations,error[r],minerr);
if (q-1 == NITERATIONS)
non_conv ++;
else
converge += q-1;
fprintf (stderr,"\n");
fprintf (fpError, "%lf\n", error[r]);
fclose(fpError);
}
end = time(NULL);
printf ("\nElapsed time = %ld sec\n",(long)end - (long)start);
printf ("The avg rate is %5.2lf, percentage of conv is %5.2lf\n\n",converge/(weightfile-non_conv),(double)(weightfile-non_conv)/weightfile*100);
if ((fpWeightsOut = fopen(szWeightsOut, "w")) == NULL)
{
fprintf(stderr, "can't write file %s\n", szWeightsOut);
exit(1);
}
for (h=0; h < H; h++)
for (i=0; i <= I; i++)
fprintf(fpWeightsOut, "%9.6f%c", weights[h*(I+1)+i],
(i == I) ? '\n':' ');
for (j=0; j < J; j++)
for (h=0; h <= H; h++)
fprintf(fpWeightsOut, "%9.6f%c", weights[H*(I+1)+j*(H+1)+h],
(h == H) ? '\n':' ');
fclose(fpWeightsOut);
if ((fpResults = fopen(szResults,"w")) == NULL)
{
fprintf(stderr, "can't write file %s\n", szResults);
fpResults = stderr;
}
for (p=0; p<P; p++)
{
fprintf(fpResults, "%d ", p);
for (j=0; j < J; j++)
fprintf(fpResults, " %lf", out2[p][j]);
fprintf (fpResults,"\n");
}
fclose(fpResults);
return 0;
}
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -