【问题标题】:Why is this naive matrix multiplication faster than base R's?为什么这种简单的矩阵乘法比基数 R 更快?
【发布时间】:2018-12-05 20:16:06
【问题描述】:

在 R 中,矩阵乘法非常优化,即实际上只是对 BLAS/LAPACK 的调用。然而,令我惊讶的是,这个用于矩阵向量乘法的非常幼稚的 C++ 代码似乎可靠地快了 30%。

 library(Rcpp)

 # Simple C++ code for matrix multiplication
 mm_code = 
 "NumericVector my_mm(NumericMatrix m, NumericVector v){
   int nRow = m.rows();
   int nCol = m.cols();
   NumericVector ans(nRow);
   double v_j;
   for(int j = 0; j < nCol; j++){
     v_j = v[j];
     for(int i = 0; i < nRow; i++){
       ans[i] += m(i,j) * v_j;
     }
   }
   return(ans);
 }
 "
 # Compiling
 my_mm = cppFunction(code = mm_code)

 # Simulating data to use
 nRow = 10^4
 nCol = 10^4

 m = matrix(rnorm(nRow * nCol), nrow = nRow)
 v = rnorm(nCol)

 system.time(my_ans <- my_mm(m, v))
#>    user  system elapsed 
#>   0.103   0.001   0.103 
 system.time(r_ans <- m %*% v)
#>   user  system elapsed 
#>  0.154   0.001   0.154 

 # Double checking answer is correct
 max(abs(my_ans - r_ans))
 #> [1] 0

base R 的%*% 是否会执行某种我跳过的数据检查?

编辑:

在了解发生了什么之后(感谢 SO!),值得注意的是,对于 R 的 %*% 而言,这是最坏的情况,即逐个向量矩阵。例如,@RalfStubner 指出,使用矩阵向量乘法的 RcppArmadillo 实现甚至比我演示的幼稚实现还要快,这意味着比基本 R 快得多,但实际上与基本 R 的 %*% 矩阵矩阵相同相乘(当两个矩阵都很大且为正方形时):

 arma_code <- 
   "arma::mat arma_mm(const arma::mat& m, const arma::mat& m2) {
 return m * m2;
 };"
 arma_mm = cppFunction(code = arma_code, depends = "RcppArmadillo")

 nRow = 10^3 
 nCol = 10^3

 mat1 = matrix(rnorm(nRow * nCol), 
               nrow = nRow)
 mat2 = matrix(rnorm(nRow * nCol), 
               nrow = nRow)

 system.time(arma_mm(mat1, mat2))
#>   user  system elapsed 
#>   0.798   0.008   0.814 
 system.time(mat1 %*% mat2)
#>   user  system elapsed 
#>   0.807   0.005   0.822  

因此,R 的当前 (v3.5.0) %*% 对于矩阵矩阵来说几乎是最佳的,但如果您可以跳过检查,则可以显着加快矩阵向量的速度。

【问题讨论】:

  • 它可能无法解释所有这些,但 R 的方法必须处理 NA 值。此外,基于我对计算中的数值方法了解的非常少,您的幼稚方法很可能在某些情况下最终会变得无法接受,因此其他方法会以一些速度换取更高的准确性。
  • 查看:getAnywhere(%*%),我们有:function (x, y) .Primitive("%*%")。因此,这是与 C 库的接口,但正如 @joran 指出的那样,您没有考虑到 NA 处理。
  • @joran:据我所知,这可以正确处理NA。我能看到的唯一区别是这会导致向量而不是矩阵。
  • 这个post 已经过时了,自从他写了这篇文章以来,Radford 可能已经成功地对 R 进行了一些改进,我认为这至少总结了处理 NA、Inf 和 NaN 并不总是那么简单,并且确实需要一些工作。
  • 您可以通过使用线性代数库进行矩阵-矩阵乘法来获得巨大的改进,因为它们可以更好地处理内存和缓存。对于矩阵向量乘法,内存问题不是问题,因此优化更小。参见例如this

标签: r performance rcpp matrix-multiplication


【解决方案1】:

快速浏览names.c (here in particular) 会指向do_matprod,这是由%*% 调用的C 函数,位于文件array.c 中。 (有趣的是,事实证明,crossprodtcrossprod 也分派到同一个函数)。 Here is a linkdo_matprod的代码。

滚动浏览该函数,您可以看到它处理了一些您的幼稚实现无法处理的事情,包括:

  1. 在有意义的地方保留行名和列名。
  2. 当通过调用%*% 操作的两个对象属于已提供此类方法的类时,允许分派到替代S4 方法。 (这就是函数的this portion 中发生的事情。)
  3. 处理实数和复数矩阵。
  4. 实现了一系列规则,用于处理矩阵与矩阵、向量与矩阵、矩阵与向量、向量与向量的乘法。 (回想一下,在 R 中的交叉乘法下,LHS 上的向量被视为行向量,而在 RHS 上,它被视为列向量;这是实现这一点的代码。)

Near the end of the function,它分派到matprodcmatprod。有趣的是(至少对我而言),对于实数矩阵,if 任一矩阵可能包含 NaNInf 值,然后 matprod 调度 (here) 到一个名为simple_matprod 与您自己的一样简单明了。否则,它会分派到几个 BLAS Fortran 例程中的一个,如果可以保证一致地“表现良好”的矩阵元素,这些例程可能会更快。

【讨论】:

  • 有趣(+1)。如果这些是唯一的区别,那么意味着如果我知道我正在做香草矩阵x向量运算,我应该使用my_mm。这让我很惊讶。
  • @CliffAB 通过 RcppArmadillo 直接或间接使用适当的 BLAS 函数并使用多线程 BLAS,您可能会获得更多收益。
【解决方案2】:

Josh 的回答解释了为什么 R 的矩阵乘法不如这种幼稚的方法快。我很想知道使用 RcppArmadillo 可以获得多少收益。代码很简单:

arma_code <- 
  "arma::vec arma_mm(const arma::mat& m, const arma::vec& v) {
       return m * v;
   };"
arma_mm = cppFunction(code = arma_code, depends = "RcppArmadillo")

基准测试:

> microbenchmark::microbenchmark(my_mm(m,v), m %*% v, arma_mm(m,v), times = 10)
Unit: milliseconds
          expr      min       lq      mean    median        uq       max neval
   my_mm(m, v) 71.23347 75.22364  90.13766  96.88279  98.07348  98.50182    10
       m %*% v 92.86398 95.58153 106.00601 111.61335 113.66167 116.09751    10
 arma_mm(m, v) 41.13348 41.42314  41.89311  41.81979  42.39311  42.78396    10

所以 RcppArmadillo 为我们提供了更好的语法和更好的性能。

好奇心战胜了我。这里有一个直接使用 BLAS 的解决方案:

blas_code = "
NumericVector blas_mm(NumericMatrix m, NumericVector v){
  int nRow = m.rows();
  int nCol = m.cols();
  NumericVector ans(nRow);
  char trans = 'N';
  double one = 1.0, zero = 0.0;
  int ione = 1;
  F77_CALL(dgemv)(&trans, &nRow, &nCol, &one, m.begin(), &nRow, v.begin(),
           &ione, &zero, ans.begin(), &ione);
  return ans;
}"
blas_mm <- cppFunction(code = blas_code, includes = "#include <R_ext/BLAS.h>")

基准测试:

Unit: milliseconds
          expr      min       lq      mean    median        uq       max neval
   my_mm(m, v) 72.61298 75.40050  89.75529  96.04413  96.59283  98.29938    10
       m %*% v 95.08793 98.53650 109.52715 111.93729 112.89662 128.69572    10
 arma_mm(m, v) 41.06718 41.70331  42.62366  42.47320  43.22625  45.19704    10
 blas_mm(m, v) 41.58618 42.14718  42.89853  42.68584  43.39182  44.46577    10

犰狳和 BLAS(在我的例子中是 OpenBLAS)几乎相同。 BLAS 代码也是 R 最终所做的。所以 R 所做的 2/3 是错误检查等。

【讨论】:

  • 并且可能要启动 OpenMP(如果您的操作系统/编译器支持它)。
  • @Dirk 我原以为 Armadillo 会将如此简单的事情直接转发给 BLAS(在我的情况下也是多线程的)。至少它们同样快......
  • 非常有趣。检查成本不会像矩阵-矩阵的计算那样快速扩展是有道理的,所以在这种情况下这个成本就消失了。
  • @CliffAB 是的。此外,对于矩阵-矩阵,在您的 BLAS 实现中使用幼稚的方法来超越内存访问将更加困难,c.f.上面由 F.Prive 提供的链接。
【解决方案3】:

要对Ralf Stubner的解决方案再补充一点,那么你可以使用下面的C++版本来

  1. 同时处理多个列,以避免多次重新读取输出向量。
  2. 添加 __restrict__ 以潜在地允许向量操作(在这里可能无关紧要,因为我猜它只是读取)。
#include <Rcpp.h>
using namespace Rcpp;

inline void mat_vec_mult_vanilla
(double const * __restrict__ m, 
 double const * __restrict__ v, 
 double * __restrict__ const res, 
 size_t const dn, size_t const dm) noexcept {
  for(size_t j = 0; j < dm; ++j, ++v){
    double * r = res;
    for(size_t i = 0; i < dn; ++i, ++r, ++m)
      *r += *m * *v;
  }
}

inline void mat_vec_mult
(double const * __restrict__ const m, 
 double const * __restrict__ const v, 
 double * __restrict__ const res, 
 size_t const dn, size_t const dm) noexcept {
  size_t j(0L);
  double const * vj = v,
               * mi = m;
  constexpr size_t const ncl(8L);
  {
    double const * mvals[ncl];
    size_t const end_j = dm - (dm % ncl),
                   inc = ncl * dn;
    for(; j < end_j; j += ncl, vj += ncl, mi += inc){
      double *r = res;
      mvals[0] = mi;
      for(size_t i = 1; i < ncl; ++i)
        mvals[i] = mvals[i - 1L] + dn;
      for(size_t i = 0; i < dn; ++i, ++r)
        for(size_t ii = 0; ii < ncl; ++ii)
          *r += *(vj + ii) * *mvals[ii]++;
    }
  }
  
  mat_vec_mult_vanilla(mi, vj, res, dn, dm - j);
}

// [[Rcpp::export("mat_vec_mult", rng = false)]]
NumericVector mat_vec_mult_cpp(NumericMatrix m, NumericVector v){
  size_t const dn = m.nrow(), 
               dm = m.ncol();
  NumericVector res(dn);
  mat_vec_mult(&m[0], &v[0], &res[0], dn, dm);
  return res;
}

// [[Rcpp::export("mat_vec_mult_vanilla", rng = false)]]
NumericVector mat_vec_mult_vanilla_cpp(NumericMatrix m, NumericVector v){
  size_t const dn = m.nrow(), 
               dm = m.ncol();
  NumericVector res(dn);
  mat_vec_mult_vanilla(&m[0], &v[0], &res[0], dn, dm);
  return res;
}

我的 Makevars 文件和 gcc-8.3 中带有 -O3 的结果是

set.seed(1)
dn <- 10001L
dm <- 10001L
m <- matrix(rnorm(dn * dm), dn, dm)
lv <- rnorm(dm)

all.equal(drop(m %*% lv), mat_vec_mult(m = m, v = lv))
#R> [1] TRUE
all.equal(drop(m %*% lv), mat_vec_mult_vanilla(m = m, v = lv))
#R> [1] TRUE

bench::mark(
  R              = m %*% lv, 
  `OP's version` = my_mm(m = m, v = lv), 
  `BLAS`         = blas_mm(m = m, v = lv),
  `C++ vanilla`  = mat_vec_mult_vanilla(m = m, v = lv), 
  `C++`          = mat_vec_mult(m = m, v = lv), check = FALSE)
#R> # A tibble: 5 x 13
#R>   expression        min   median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time result memory                 time          gc               
#R>   <bch:expr>   <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm> <list> <list>                 <list>        <list>           
#R> 1 R             147.9ms    151ms      6.57    78.2KB        0     4     0      609ms <NULL> <Rprofmem[,3] [2 × 3]> <bch:tm [4]>  <tibble [4 × 3]> 
#R> 2 OP's version   56.9ms   57.1ms     17.4     78.2KB        0     9     0      516ms <NULL> <Rprofmem[,3] [2 × 3]> <bch:tm [9]>  <tibble [9 × 3]> 
#R> 3 BLAS           90.1ms   90.7ms     11.0     78.2KB        0     6     0      545ms <NULL> <Rprofmem[,3] [2 × 3]> <bch:tm [6]>  <tibble [6 × 3]> 
#R> 4 C++ vanilla    57.2ms   57.4ms     17.4     78.2KB        0     9     0      518ms <NULL> <Rprofmem[,3] [2 × 3]> <bch:tm [9]>  <tibble [9 × 3]> 
#R> 5 C++              51ms   51.4ms     19.3     78.2KB        0    10     0      519ms <NULL> <Rprofmem[,3] [2 × 3]> <bch:tm [10]> <tibble [10 × 3]>

所以略有改善。结果可能非常依赖于 BLAS 版本。我使用的版本是

sessionInfo()
#R> #...
#R> Matrix products: default
#R> BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.7.1
#R> LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.7.1
#R> ...

Rcpp::sourceCpp()ed 的整个文件是

#include <Rcpp.h>
#include <R_ext/BLAS.h>
using namespace Rcpp;

inline void mat_vec_mult_vanilla
(double const * __restrict__ m, 
 double const * __restrict__ v, 
 double * __restrict__ const res, 
 size_t const dn, size_t const dm) noexcept {
  for(size_t j = 0; j < dm; ++j, ++v){
    double * r = res;
    for(size_t i = 0; i < dn; ++i, ++r, ++m)
      *r += *m * *v;
  }
}

inline void mat_vec_mult
(double const * __restrict__ const m, 
 double const * __restrict__ const v, 
 double * __restrict__ const res, 
 size_t const dn, size_t const dm) noexcept {
  size_t j(0L);
  double const * vj = v,
               * mi = m;
  constexpr size_t const ncl(8L);
  {
    double const * mvals[ncl];
    size_t const end_j = dm - (dm % ncl),
                   inc = ncl * dn;
    for(; j < end_j; j += ncl, vj += ncl, mi += inc){
      double *r = res;
      mvals[0] = mi;
      for(size_t i = 1; i < ncl; ++i)
        mvals[i] = mvals[i - 1L] + dn;
      for(size_t i = 0; i < dn; ++i, ++r)
        for(size_t ii = 0; ii < ncl; ++ii)
          *r += *(vj + ii) * *mvals[ii]++;
    }
  }
  
  mat_vec_mult_vanilla(mi, vj, res, dn, dm - j);
}

// [[Rcpp::export("mat_vec_mult", rng = false)]]
NumericVector mat_vec_mult_cpp(NumericMatrix m, NumericVector v){
  size_t const dn = m.nrow(), 
               dm = m.ncol();
  NumericVector res(dn);
  mat_vec_mult(&m[0], &v[0], &res[0], dn, dm);
  return res;
}

// [[Rcpp::export("mat_vec_mult_vanilla", rng = false)]]
NumericVector mat_vec_mult_vanilla_cpp(NumericMatrix m, NumericVector v){
  size_t const dn = m.nrow(), 
               dm = m.ncol();
  NumericVector res(dn);
  mat_vec_mult_vanilla(&m[0], &v[0], &res[0], dn, dm);
  return res;
}

// [[Rcpp::export(rng = false)]]
NumericVector my_mm(NumericMatrix m, NumericVector v){
  int nRow = m.rows();
  int nCol = m.cols();
  NumericVector ans(nRow);
  double v_j;
  for(int j = 0; j < nCol; j++){
    v_j = v[j];
    for(int i = 0; i < nRow; i++){
      ans[i] += m(i,j) * v_j;
    }
  }
  return(ans);
}

// [[Rcpp::export(rng = false)]]
NumericVector blas_mm(NumericMatrix m, NumericVector v){
  int nRow = m.rows();
  int nCol = m.cols();
  NumericVector ans(nRow);
  char trans = 'N';
  double one = 1.0, zero = 0.0;
  int ione = 1;
  F77_CALL(dgemv)(&trans, &nRow, &nCol, &one, m.begin(), &nRow, v.begin(),
           &ione, &zero, ans.begin(), &ione);
  return ans;
}

/*** R
set.seed(1)
dn <- 10001L
dm <- 10001L
m <- matrix(rnorm(dn * dm), dn, dm)
lv <- rnorm(dm)

all.equal(drop(m %*% lv), mat_vec_mult(m = m, v = lv))
all.equal(drop(m %*% lv), mat_vec_mult_vanilla(m = m, v = lv))

bench::mark(
  R              = m %*% lv, 
  `OP's version` = my_mm(m = m, v = lv), 
  `BLAS`         = blas_mm(m = m, v = lv),
  `C++ vanilla`  = mat_vec_mult_vanilla(m = m, v = lv), 
  `C++`          = mat_vec_mult(m = m, v = lv), check = FALSE)
*/

【讨论】:

  • 有趣:在你的结果中,BLAS 比简单的 C++ 版本(你的或我的)慢得多。 @RalfStubner 的 BLAS 结果大约是我的两倍。 Ralf 的 BLAS 可以使用 2 个(或更多)线程吗?还是不同的版本?
  • RalfStubner 表示他正在使用 OpenBLAS。我使用的是默认的 BLAS,所以我认为这是造成差异的原因。我怀疑这只是实现,但可能是他使用了更多线程。
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2021-01-22
  • 2018-12-22
  • 2019-03-16
  • 1970-01-01
  • 2018-01-04
相关资源
最近更新 更多