【问题标题】:Matlab taking too much time in computing mex function for large arrayMatlab 在计算大型数组的 mex 函数时花费了太多时间
【发布时间】:2016-08-26 21:32:29
【问题描述】:

我编写了一个 MATLAB 脚本,其中我传递了几个标量和一个行向量作为 mex 函数的输入参数,在进行一些计算之后,它返回一个标量作为输出。必须对大小为 1 X 1638400 的数组的所有元素执行此过程。下面是相应的代码:

ans=0;
for i=0:1638400-1
    temp = sub_imed(r,i,diff);
    ans  = ans + temp*diff(i+1); 
end

其中 r,i 是标量,diff 是大小为 1 X 1638400 的向量,sub_imed 是执行以下工作的 MEX 函数:

void sub_imed(double r,mwSize base, double* diff, mwSize dim, double* ans)              
{                                                                                           
     mwSize i,k,l,k1,l1;
     double d,g,temp;

     for(i=0; i<dim; i++)
     {   
          k = (base/200) + 1;
          l = (base%200) + 1;
          k1 = (i/200) + 1;
          l1 = (i%200) + 1;

          d = sqrt(pow((k-k1),2) + pow((l-l1),2));

          g=(1/(2*pi*pow(r,2)))*exp(-(pow(d,2))/(2*(pow(r,2))));   

          temp = temp + diff[i]*g;
     }

     *ans  = temp;
}

void mexFunction(int nlhs,mxArray *plhs[],int nrhs,const mxArray *prhs[]) 
{
    double *diff;           /* Input data vectors */
    double r;               /* Value of r (input) */
    double* ans;            /* Output ImED distance */
    size_t base,ncols;      /* For storing the size of input vector and base */

    /* Checking for proper number of arguments */
    if(nrhs!=3) 
       mexErrMsgTxt("Error..Three inputs required.");

    if(nlhs!=1) 
       mexErrMsgTxt("Error..Only one output required.");

    /* make sure the first input argument(value of r) is scalar */
    if( !mxIsDouble(prhs[0]) || mxIsComplex(prhs[0]) || mxGetNumberOfElements(prhs[0])!=1) 
       mexErrMsgTxt("Error..Value of r must be a scalar."); 

    /* make sure that the input value of base is a scalar */
    if( !mxIsDouble(prhs[1]) || mxIsComplex(prhs[1]) || mxGetNumberOfElements(prhs[1])!=1) 
       mexErrMsgTxt("Error..Value of base must be a scalar."); 

    /* make sure that the input vector diff is of type double */
    if(!mxIsDouble(prhs[2]) || mxIsComplex(prhs[2]))    
       mexErrMsgTxt("Error..Input vector must be of type double.");

    /* check that number of rows in input arguments is 1 */
    if(mxGetM(prhs[2])!=1) 
       mexErrMsgTxt("Error..Inputs must be row vectors."); 

    /* Get the value of r */
    r = mxGetScalar(prhs[0]);
    base = mxGetScalar(prhs[1]);

    /* Getting the input vectors */
    diff = mxGetPr(prhs[2]);
    ncols = mxGetN(prhs[2]);

    /* Creating link for the scalar output */
    plhs[0] = mxCreateDoubleMatrix(1,1,mxREAL);
    ans = mxGetPr(plhs[0]); 

    sub_imed(r,base,diff,(mwSize)ncols,ans);
}

有关问题和下划线算法的更多详细信息,请关注线程Euclidean distance between images

我对我的 MATLAB 脚本进行了分析,发现它需要 63 秒。仅用于对 sub_imed() mex 函数的 387 次调用。因此,对于 sub_imed 的 1638400 次调用,理想情况下需要大约 74 小时,这太长了。

有人可以通过建议一些替代方法来帮助我优化代码以减少计算时间。

提前致谢。

【问题讨论】:

  • 这个 mex 函数是你写的吗?你为什么使用 mex 而不是 MATLAB?
  • 是的..实际上没有 mex 需要更长的时间..所以我认为为内部循环编写一个 mex 函数是个好主意,它可能会降低我的运行成本。
  • 好的,dim 是什么?我猜是size(diff)?
  • 是的..你是对的。
  • 几个计算:例如。 l = (base%200) + 1;(1/(2*pi*pow(r,2)))/(2*pow(r,2)) 可以移到 for 循环之外...

标签: arrays matlab optimization


【解决方案1】:
  • 您是否使用优化标志进行编译? (见标志-O
  • 不要在 C/C++ 中使用 pow(x,2),而应编写 x*x
  • 我几乎可以肯定,它缓慢的原因不是因为您显示的代码,而是因为您在mexFunction 中所做的事情。如果我不得不猜测,我会说你在毫无意义地复制 diff 中的内存,但我们需要查看整个 Mex 函数才能确定。

使用mex -O myfile.cpp 尝试以下 C++ 代码:

void sub_imed( double r, size_t base, const double *diff, size_t dim, double& ans)
{
    double d, g;

    // these need to be double to avoid underflow
    double k = base / 200;
    double l = base % 200;

    r = 2*r*r;
    for(; dim; --dim, ++diff )
    {
        d = k - i/200;
        g = l - i%200;

        ans += (*diff) * exp( - (d*d + g*g)/r ) / (pi*r);
    }
}

void mexFunction(int nlhs,mxArray *plhs[],int nrhs,const mxArray *prhs[])
{
    /* Checking for proper number of arguments */
    if(nrhs!=3)
        mexErrMsgTxt("Error..Three inputs required.");

    if(nlhs!=1)
        mexErrMsgTxt("Error..Only one output required.");

    /* make sure the first input argument(value of r) is scalar */
    if( !mxIsDouble(prhs[0]) || mxIsComplex(prhs[0]) || mxGetNumberOfElements(prhs[0])!=1 )
        mexErrMsgTxt("Error..Value of r must be a scalar.");

    /* make sure that the input value of base is a scalar */
    if( !mxIsDouble(prhs[1]) || mxIsComplex(prhs[1]) || mxGetNumberOfElements(prhs[1])!=1 )
        mexErrMsgTxt("Error..Value of base must be a scalar.");

    /* make sure that the input vector diff is of type double */
    if( !mxIsDouble(prhs[2]) || mxIsComplex(prhs[2]) )
        mexErrMsgTxt("Error..Input vector must be of type double.");

    /* check that number of rows in input arguments is 1 */
    if( mxGetM(prhs[2])!=1 )
        mexErrMsgTxt("Error..Inputs must be row vectors.");

    /* Get the value of r */
    double r    = mxGetScalar(prhs[0]);
    size_t base = static_cast<size_t>(mxGetScalar(prhs[1]);

    /* Getting the input vectors */
    const double *diff = mxGetPr(prhs[2]);
    size_t nrows = static_cast<size_t>(mxGetN(prhs[2]));

    /* Creating link for the scalar output */
    plhs[0] = mxCreateDoubleScalar(0.0);
    sub_imed( r, base, diff, nrows, *mxGetPr(plhs[0]) );
}

【讨论】:

  • 我不认为从 plhs 复制到 diff 是没有意义的。无论如何,我已经编辑了我的问题,现在你可以看到完整的 mex 函数了。
  • 我不明白你的评论,你不是在复制它,当然也不是来自plhs。使用mxCreateDoubleScalar,你应该使用const double*,而不是double*diff
  • 什么是带有优化标志且没有调用pow的更新时间?
  • @Sh3lijohn:感谢您的代码。我将实施它,并随时向您发布有关其性能的信息。
  • @nagarwal 我是 Sheljohn :) 等待更新。
【解决方案2】:

我将您的代码移植回 MATLAB 并进行了一些小的调整,而结果应该保持不变。我引入了以下常量:

N = 8192;
step = 0.005;

请注意N / step = 1638400。有了它,您可以重写您的变量k(并将其重命名为baseDiv):

baseDiv = 1 + (0 : step : (N-step)).';

即它是1:8193,步长为0.005。 同理,l就是1:200=1:(1/0.005)),连续重复8192次,也就是(现在叫baseMod):

baseMod = (repmat(1:1:(1/step), 1, N)).';

您的变量k1l1 只是klith 元素,即baseDiv(i)baseMod(i)

使用向量baseDivbaseMod,可以计算dg 和您的临时变量tmp

d = sqrt((baseDiv(k)-baseDiv).^2 + (baseMod(k)-baseMod).^2);
g = 1/(2*pi*r^2) * exp(-(d.^2) / (2*r^2));
tmp = sum(diffVec .* g);

我们可以把它放到你的 MATLAB for 循环中,这样整个程序就变成了

% Constants
N = 8192;
step = 0.005;

% Some example data
r = 2;
diffVec = rand(N/step,1);

base = (0:(numel(diffVec)-1)).';    
baseDiv = (1:step:1+N-step).';
baseMod = (repmat(1:1:(1/step), 1, N)).';

res = 0;
for k=1:(N/step)
    d = sqrt((baseDiv(k)-baseDiv).^2 + (baseMod(k)-baseMod).^2);
    g = 1/(2*pi*r^2) * exp(-(d.^2) / (2*r^2));
    tmp = sum(diffVec .* g);
    res = res + tmp * diffVec(k);
end

通过消除内部 for 循环并以矢量化方式计算它,1000 次迭代仍然需要 11 秒,因此总运行时间为 5 小时。仍然 - 加速超过 10 倍。要获得更高的加速,您有两种可能性:

1) 完全向量化:您可以轻松地向量化剩余的 for 循环,方法是使用 bsxfun(@minus, baseDiv, baseDiv.')sum 在列上计算 all 值同时。不幸的是,我们遇到了一个小问题:1638400×1638400 双矩阵会占用 20TB 的 RAM,我假设你的笔记本电脑中没有 ;-)

2) 更少的样本:您正在做一些数学变换,分辨率为step=0.005。检查您是否真的,真的需要这种精度!如果你取 1/10 的精度:step=0.05,你会快 100 倍,并且在 3 分钟内完成

【讨论】:

  • sum(a.*b) 通常可以表示为点积,即a*b.',由于内存中的一次传递和融合乘加,这应该更快地计算。 (实际上在 Matlab 中进行转置,即使是一维数组,也可能需要在内存中进行完整的传递,我不确定)。
  • 先生,非常感谢您详尽而充实的回复。它真的帮了我很多,但我对你的方法没什么可说的。实际上,我有两个大小为 8192 x 200 的图像,我想计算一个矩阵 (1638400 x 1638400),其中包含每个像素之间的欧几里得距离。因此,在变量 k 和 L 中,我计算了第一张图像的行和列,在 k1 和 l1 中,计算了第二张图像的行和列,但这里 baseDiv 将包含一些浮点数据。无论如何,我可以为我纠正。
  • @nagarwal - 我怀疑有更好的方法来解决你的实际问题,你可能想把它作为一个新问题提出来,参考这个。
  • 对于你的最后两点,是的..在我的电脑上创建一个更大尺寸的矩阵是不可能的,事实上这就是让我在循环中进行所有这些计算的原因。早些时候,我还尝试对所有内容进行向量化,以便它可能只是一些向量和矩阵的乘法,但我没有考虑部分向量化。其次,由于它是一种计算图像之间的欧几里得距离,所以我有必要考虑 200 列和 8192 行,不少于。
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2017-02-15
  • 2020-07-23
相关资源
最近更新 更多