?? ndtimes.c
字號:
/*
Compute matrix product A*B for ND slice matrices
Usage
-----
C = ndtimes(A,B)
Inputs
-------
A Matrix (k x n x a1 x a2 x ... al) L = (l + 2) dimensionnal
B Matrix (n x p x b1 x b2 x ... bs) S = (s + 2) dimensionnal
CONDITIONS !!
ai = bi, i = 1,...., min(l , s)
Ouputs
-------
C Matrix C
Example
-------
d = 4;
M = 3;
N = 10000;
Q = [3 1 0 0 ; 0 1 0 0 ; 0 0 3 1 ; 0 0 0 1];
Fk = randn(d , d , M);
Xk = randn(d , 1 , M , N);
Qk = Q(: , : , ones(1 , M));
Nk = randn(d , 1 , M , N);
C = reshape( ndtimes(Fk , Xk) + ndtimes(permute(ndchol(Qk) , [2 1 3]) , Nk) , [d M N]);
Compile with:
------------
mex ndtimes.c
or
mex -f mexopts_intel10amd.bat -output ndtimes.dll ndtimes.c
Author S閎astien PARIS (sebastien.paris@lsis.org) (5/4/08)
-------
*/
#include <malloc.h>
#include "mex.h"
/* --------------------------------------- DECLARATION ------------------------------------- */
void ndtimes(double *, double * , double * , int , int , int , int * , int * , int * , int);
/* -------------------------------------------------------------------------------------------- */
void mexFunction( int nlhs, mxArray *plhs[] , int nrhs, const mxArray *prhs[] )
{
const int *dimsA , *dimsB;
int *dimsC;
int *ind_tpA = NULL , *ind_tpB = NULL , *indA = NULL , *indB = NULL , *indC = NULL;
double *A, *B , *C;
int i, j , k , n , p , o, h , v , numDimsA = 0 , numDimsB = 0 , numDimsC = 0 , sizC = 1 , tpA = 1, tpB = 1 , rap_dim = 1;
/* Check nargin */
if(nrhs != 2)
{
mexErrMsgTxt("Two ND matrix are requiered");
}
/* ----------- Input ------------ */
A = mxGetPr(prhs[0]);
numDimsA = mxGetNumberOfDimensions(prhs[0]);
dimsA = mxGetDimensions(prhs[0]);
B = mxGetPr(prhs[1]);
numDimsB = mxGetNumberOfDimensions(prhs[1]);
dimsB = mxGetDimensions(prhs[1]);
k = dimsA[0];
n = dimsA[1];
o = dimsB[0];
p = dimsB[1];
if (n != o)
{
mexErrMsgTxt("Inner dimensions are not matching !! A(k x n x ...) and B(n x p x ...)");
}
if (numDimsA > numDimsB)
{
for (i = 2 ; i<numDimsB ; i++)
{
tpA *= dimsA[i];
tpB *= dimsB[i];
}
if (tpA != tpB)
{
mexErrMsgTxt("Dimensions > 2 are not matching");
}
for (i = numDimsB ; i <numDimsA ; i++)
{
tpA *= dimsA[i];
rap_dim *= dimsA[i];
}
}
if (numDimsA <= numDimsB)
{
for (i=2 ; i<numDimsA ; i++)
{
tpA *= dimsA[i];
tpB *= dimsB[i];
}
if (tpA != tpB)
{
mexErrMsgTxt("Dimensions > 2 are not matching");
}
for (i = numDimsA ; i <numDimsB ; i++)
{
tpB *= dimsB[i];
rap_dim *= dimsB[i];
}
}
ind_tpA = (int *)mxMalloc(tpA*sizeof(int));
for (i=0 ; i<tpA ; i++)
ind_tpA[i] = i;
ind_tpB = (int *)mxMalloc(tpB*sizeof(int));
for (i=0 ; i<tpB ; i++)
ind_tpB[i] = i;
if (numDimsA > numDimsB)
{
indA = (int *)mxMalloc(tpA*sizeof(int));
indB = (int *)mxMalloc(tpA*sizeof(int));
indC = (int *)mxMalloc(tpA*sizeof(int));
for (i=0 ; i<rap_dim ; i++)
{
h = i*tpB;
for (j=0 ; j<tpB ; j++)
{
v = j + h;
indA[v] = v;
indB[v] = ind_tpB[j];
indC[v] = v;
}
}
numDimsC = numDimsA;
dimsC = (int *)mxMalloc(numDimsC*sizeof(int));
dimsC[0] = k;
dimsC[1] = p;
sizC = tpA;
for (i = 2; i<numDimsC ; i++)
dimsC[i] = dimsA[i];
}
if (numDimsA < numDimsB)
{
indA = (int *)mxMalloc(tpB*sizeof(int));
indB = (int *)mxMalloc(tpB*sizeof(int));
indC = (int *)mxMalloc(tpB*sizeof(int));
for (i=0 ; i<rap_dim ; i++)
{
h = i*tpA;
for (j=0 ; j<tpA ; j++)
{
v = j + h;
indA[v] = ind_tpA[j];
indB[v] = v;
indC[v] = v;
}
}
numDimsC = numDimsB;
dimsC = (int *)mxMalloc(numDimsC*sizeof(int));
dimsC[0] = k;
dimsC[1] = p;
sizC = tpB;
for (i = 2; i<numDimsC ; i++)
dimsC[i] = dimsB[i];
}
if (numDimsA == numDimsB)
{
indA = (int *)mxMalloc(tpB*sizeof(int));
indB = (int *)mxMalloc(tpB*sizeof(int));
indC = (int *)mxMalloc(tpB*sizeof(int));
for (i=0 ; i<tpB ; i++)
{
indA[i] = i;
indB[i] = i;
indC[i] = i;
}
numDimsC = numDimsB;
dimsC = (int *)mxMalloc(numDimsC*sizeof(int));
dimsC[0] = k;
dimsC[1] = p;
sizC = tpA;
for (i = 2; i<numDimsC ; i++)
dimsC[i] = dimsA[i];
}
plhs[0] = mxCreateNumericArray(numDimsC, dimsC, mxDOUBLE_CLASS, mxREAL);
C = mxGetPr(plhs[0]);
/* ----------------- Array multiplication ------------ */
ndtimes(A, B , C , k , n , p , indA , indB , indC , sizC);
/* ----------------------- Free space ---------------- */
mxFree(indA);
mxFree(indB);
mxFree(indC);
mxFree(ind_tpA);
mxFree(ind_tpB);
mxFree(dimsC);
}
/* ----------------------------------------------------------------------------------------------- */
void ndtimes(double *A, double *B , double *C , int k , int n , int p , int *indA , int *indB , int *indC , int sizC)
{
int v , rA , rB , rC , t , l , i , tl , kp , kn , np , rAkn, lnrB , trCkp, rBnp , rCkp;
kp = k*p;
kn = k*n;
np = n*p;
for(v = 0 ; v<sizC ; v++)
{
rA = indA[v];
rB = indB[v];
rC = indC[v];
rBnp = rB*np;
rAkn = rA*kn;
rCkp = rC*kp;
for (t = 0 ; t<k ; t++)
{
trCkp = t + rCkp;
for(l = 0 ; l<p ; l++)
{
tl = l*k + trCkp;
lnrB = l*n + rBnp;
C[tl] = 0.0;
for(i = 0 ; i<n ; i++)
C[tl] += A[t + i*k + rAkn]*B[i + lnrB];
}
}
}
}
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -