?? init_pot.c
字號:
sequence = malloc(NZB * 2 * sizeof(int));
bigTable = malloc(NZB * sizeof(double));
samemask = malloc(sdim * sizeof(int));
diffmask = malloc(diffdim * sizeof(int));
bCumprod = malloc(bdim * sizeof(int));
sCumprod = malloc(sdim * sizeof(int));
weight = malloc(ND * sizeof(int));
ssubv = malloc(sdim * sizeof(int));
count = 0;
count1 = 0;
for(i=0; i<bdim; i++){
match = 0;
for(j=0; j<sdim; j++){
if(pbDomain[i] == psDomain[j]){
samemask[count] = i;
match = 1;
count++;
break;
}
}
if(match == 0){
diffmask[count1] = i;
count1++;
}
}
bCumprod[0] = 1;
for(i=0; i<bdim-1; i++){
bCumprod[i+1] = bCumprod[i] * (int)pbSize[i];
}
sCumprod[0] = 1;
for(i=0; i<sdim-1; i++){
sCumprod[i+1] = sCumprod[i] * (int)psSize[i];
}
count = 0;
compute_fixed_weight(weight, pbSize, diffmask, bCumprod, ND, diffdim);
for(i=0; i<NZS; i++){
sindex = sir[i];
ind_subv(sindex, sCumprod, sdim, ssubv);
temp = 0;
for(j=0; j<sdim; j++){
temp += ssubv[j] * bCumprod[samemask[j]];
}
for(j=0; j<ND; j++){
bindex = weight[j] + temp;
bigTable[nzCounts] = spr[i];
sequence[count] = bindex;
count++;
sequence[count] = nzCounts;
nzCounts++;
count++;
}
}
pTemp = mxGetField(bigPot, 0, "T");
if(pTemp)mxDestroyArray(pTemp);
qsort(sequence, nzCounts, sizeof(int) * 2, compare);
pTemp = convert_ill_table_to_sparse(bigTable, sequence, nzCounts, NB);
mxSetField(bigPot, 0, "T", pTemp);
free(sequence);
free(bigTable);
free(samemask);
free(diffmask);
free(bCumprod);
free(sCumprod);
free(weight);
free(ssubv);
}
void multiply_spPot_by_fuPot(mxArray *bigPot, const mxArray *smallPot){
int i, j, count, bdim, sdim, NB, NZB, bindex, sindex, nzCounts=0;
int *mask, *index, *bir, *bjc, *bCumprod, *sCumprod, *bsubv, *ssubv;
double *bigTable, *pbDomain, *psDomain, *pbSize, *psSize, *bpr, *spr, value;
mxArray *pTemp;
pTemp = mxGetField(bigPot, 0, "domain");
pbDomain = mxGetPr(pTemp);
bdim = mxGetNumberOfElements(pTemp);
pTemp = mxGetField(smallPot, 0, "domain");
psDomain = mxGetPr(pTemp);
sdim = mxGetNumberOfElements(pTemp);
pTemp = mxGetField(bigPot, 0, "sizes");
pbSize = mxGetPr(pTemp);
pTemp = mxGetField(smallPot, 0, "sizes");
psSize = mxGetPr(pTemp);
NB = 1;
for(i=0; i<bdim; i++){
NB *= (int)pbSize[i];
}
pTemp = mxGetField(bigPot, 0, "T");
bpr = mxGetPr(pTemp);
bir = mxGetIr(pTemp);
bjc = mxGetJc(pTemp);
NZB = bjc[1];
pTemp = mxGetField(smallPot, 0, "T");
spr = mxGetPr(pTemp);
bigTable = malloc(NZB * sizeof(double));
index = malloc(NZB * sizeof(double));
mask = malloc(sdim * sizeof(int));
bCumprod = malloc(bdim * sizeof(int));
sCumprod = malloc(sdim * sizeof(int));
bsubv = malloc(bdim * sizeof(int));
ssubv = malloc(sdim * sizeof(int));
for(i=0; i<NZB; i++){
bigTable[i] = 0;
}
count = 0;
for(i=0; i<sdim; i++){
for(j=0; j<bdim; j++){
if(psDomain[i] == pbDomain[j]){
mask[count] = j;
count++;
break;
}
}
}
bCumprod[0] = 1;
for(i=0; i<bdim-1; i++){
bCumprod[i+1] = bCumprod[i] * (int)pbSize[i];
}
sCumprod[0] = 1;
for(i=0; i<sdim-1; i++){
sCumprod[i+1] = sCumprod[i] * (int)psSize[i];
}
for(i=0; i<NZB; i++){
bindex = bir[i];
ind_subv(bindex, bCumprod, bdim, bsubv);
for(j=0; j<sdim; j++){
ssubv[j] = bsubv[mask[j]];
}
sindex = subv_ind(sdim, sCumprod, ssubv);
value = spr[sindex];
if(value != 0){
bigTable[nzCounts] = bpr[i] * value;
index[nzCounts] = bindex;
nzCounts++;
}
}
pTemp = mxGetField(bigPot, 0, "T");
if(pTemp)mxDestroyArray(pTemp);
pTemp = convert_table_to_sparse(bigTable, index, nzCounts, NB);
mxSetField(bigPot, 0, "T", pTemp);
free(bigTable);
free(index);
free(mask);
free(bCumprod);
free(sCumprod);
free(bsubv);
free(ssubv);
}
void multiply_spPot_by_spPot(mxArray *bigPot, const mxArray *smallPot){
int i, j, count, bdim, sdim, NB, NZB, NZS, position, bindex, sindex, nzCounts=0;
int *mask, *index, *result, *bir, *sir, *bjc, *sjc, *bCumprod, *sCumprod, *bsubv, *ssubv;
double *bigTable, *pbDomain, *psDomain, *pbSize, *psSize, *bpr, *spr, value;
mxArray *pTemp;
pTemp = mxGetField(bigPot, 0, "domain");
pbDomain = mxGetPr(pTemp);
bdim = mxGetNumberOfElements(pTemp);
pTemp = mxGetField(smallPot, 0, "domain");
psDomain = mxGetPr(pTemp);
sdim = mxGetNumberOfElements(pTemp);
pTemp = mxGetField(bigPot, 0, "sizes");
pbSize = mxGetPr(pTemp);
pTemp = mxGetField(smallPot, 0, "sizes");
psSize = mxGetPr(pTemp);
NB = 1;
for(i=0; i<bdim; i++){
NB *= (int)pbSize[i];
}
pTemp = mxGetField(bigPot, 0, "T");
bpr = mxGetPr(pTemp);
bir = mxGetIr(pTemp);
bjc = mxGetJc(pTemp);
NZB = bjc[1];
pTemp = mxGetField(smallPot, 0, "T");
spr = mxGetPr(pTemp);
sir = mxGetIr(pTemp);
sjc = mxGetJc(pTemp);
NZS = sjc[1];
bigTable = malloc(NZB * sizeof(double));
index = malloc(NZB * sizeof(double));
mask = malloc(sdim * sizeof(int));
bCumprod = malloc(bdim * sizeof(int));
sCumprod = malloc(sdim * sizeof(int));
bsubv = malloc(bdim * sizeof(int));
ssubv = malloc(sdim * sizeof(int));
for(i=0; i<NZB; i++){
bigTable[i] = 0;
}
count = 0;
for(i=0; i<sdim; i++){
for(j=0; j<bdim; j++){
if(psDomain[i] == pbDomain[j]){
mask[count] = j;
count++;
break;
}
}
}
bCumprod[0] = 1;
for(i=0; i<bdim-1; i++){
bCumprod[i+1] = bCumprod[i] * (int)pbSize[i];
}
sCumprod[0] = 1;
for(i=0; i<sdim-1; i++){
sCumprod[i+1] = sCumprod[i] * (int)psSize[i];
}
for(i=0; i<NZB; i++){
value = bpr[i];
bindex = bir[i];
ind_subv(bindex, bCumprod, bdim, bsubv);
for(j=0; j<sdim; j++){
ssubv[j] = bsubv[mask[j]];
}
sindex = subv_ind(sdim, sCumprod, ssubv);
result = (int *) bsearch(&sindex, sir, NZS, sizeof(int), compare);
if(result){
position = result - sir;
value *= spr[position];
bigTable[nzCounts] = value;
index[nzCounts] = bindex;
nzCounts++;
}
}
pTemp = mxGetField(bigPot, 0, "T");
if(pTemp)mxDestroyArray(pTemp);
pTemp = convert_table_to_sparse(bigTable, index, nzCounts, NB);
mxSetField(bigPot, 0, "T", pTemp);
free(bigTable);
free(index);
free(mask);
free(bCumprod);
free(sCumprod);
free(bsubv);
free(ssubv);
}
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]){
int i, j, c, loop, nNodes, nCliques, ndomain, dims[2];
double *pClqs, *pr, *pt, *pSize;
mxArray *pTemp, *pTemp1, *pStruct, *pCliques, *pBigpot, *pSmallpot;
const char *field_names[] = {"domain", "T", "sizes"};
nNodes = mxGetNumberOfElements(prhs[1]);
pCliques = mxGetField(prhs[0], 0, "cliques");
nCliques = mxGetNumberOfElements(pCliques);
pTemp = mxGetField(prhs[0], 0, "eff_node_sizes");
pSize = mxGetPr(pTemp);
plhs[0] = mxCreateCellArray(1, &nCliques);
for(i=0; i<nCliques; i++){
pStruct = mxCreateStructMatrix(1, 1, 3, field_names);
mxSetCell(plhs[0], i, pStruct);
pTemp = mxGetCell(pCliques, i);
ndomain = mxGetNumberOfElements(pTemp);
pt = mxGetPr(pTemp);
pTemp1 = mxDuplicateArray(pTemp);
mxSetField(pStruct, 0, "domain", pTemp1);
pTemp = mxCreateDoubleMatrix(1, ndomain, mxREAL);
mxSetField(pStruct, 0, "sizes", pTemp);
pr = mxGetPr(pTemp);
for(j=0; j<ndomain; j++){
pr[j] = pSize[(int)pt[j]-1];
}
}
pClqs = mxGetPr(prhs[1]);
for(loop=0; loop<nNodes; loop++){
c = (int)pClqs[loop] - 1;
pSmallpot = mxGetCell(prhs[2], loop);
pTemp = mxGetField(pSmallpot, 0, "T");
pBigpot = mxGetCell(plhs[0], c);
pTemp1 = mxGetField(pBigpot, 0, "T");
if(pTemp1){
if(mxIsSparse(pTemp))
multiply_spPot_by_spPot(pBigpot, pSmallpot);
else multiply_spPot_by_fuPot(pBigpot, pSmallpot);
}
else{
if(mxIsSparse(pTemp))
multiply_null_by_spPot(pBigpot, pSmallpot);
else multiply_null_by_fuPot(pBigpot, pSmallpot);
}
}
dims[0] = nCliques;
dims[1] = nCliques;
plhs[1] = mxCreateCellArray(2, dims);
}
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -