【问题标题】:Recursive matrix multiplication递归矩阵乘法
【发布时间】:2012-10-16 19:24:13
【问题描述】:

我正在阅读 CLRS 的算法简介。本书展示了简单的分治矩阵乘法的伪代码:

n = A.rows
let c be a new n x n matrix
if n == 1
    c11 = a11 * b11
else partition A, B, and C
    C11 = SquareMatrixMultiplyRecursive(A11, B11)
        + SquareMatrixMultiplyRecursive(A12, B21)
    //...
return C

例如,A11 是大小为 n/2 x n/2 的 A 的子矩阵。 作者还暗示我应该使用索引计算而不是创建新矩阵来表示子矩阵,所以我这样做了:

#include <iostream>
#include <vector>

template<class T>
struct Matrix
{
    Matrix(size_t r, size_t c)
    {
        Data.resize(c, std::vector<T>(r, 0));
    }    

    void SetSubMatrix(const int r, const int c, const int n, const Matrix<T>& A, const Matrix<T>& B)
    {
        for(int _c=c; _c<n; ++_c)
        {
            for(int _r=r; _r<n; ++_r)
            {
                Data[_c][_r] = A.Data[_c][_r] + B.Data[_c][_r];
            }
        }
    }

    static Matrix<T> SquareMultiplyRecursive(Matrix<T>& A, Matrix<T>& B, int ar, int ac, int br, int bc, int n)
    {
        Matrix<T> C(n, n);

        if(n == 1)
        {
            C.Data[0][0] = A.Data[ac][ar] * B.Data[bc][br];
        }
        else
        {
            C.SetSubMatrix(0, 0, n / 2,
                           SquareMultiplyRecursive(A, B, ar, ac, br, bc, n / 2),
                           SquareMultiplyRecursive(A, B, ar, ac + (n / 2), br + (n / 2), bc, n / 2));

            C.SetSubMatrix(0, n / 2, n / 2,
                           SquareMultiplyRecursive(A, B, ar, ac, br, bc + (n / 2), n / 2),
                           SquareMultiplyRecursive(A, B, ar, ac + (n / 2), br + (n / 2), bc + (n / 2), n / 2));

            C.SetSubMatrix(n / 2, 0, n / 2,
                           SquareMultiplyRecursive(A, B, ar + (n / 2), ac, br, bc, n / 2),
                           SquareMultiplyRecursive(A, B, ar + (n / 2), ac + (n / 2), br + (n / 2), bc, n / 2));

            C.SetSubMatrix(n / 2, n / 2, n / 2,
                           SquareMultiplyRecursive(A, B, ar + (n / 2), ac, br, bc + (n / 2), n / 2),
                           SquareMultiplyRecursive(A, B, ar + (n / 2), ac + (n / 2), br + (n / 2), bc + (n / 2), n / 2));
        }

        return C;
    }

    void Print()
    {
        for(int c=0; c<Data.size(); ++c)
        {
            for(int r=0; r<Data[0].size(); ++r)
            {
                std::cout << Data[c][r] << " ";
            }
            std::cout << "\n";
        }
        std::cout << "\n";
    }

    std::vector<std::vector<T> > Data;
};

int main()
{
    Matrix<int> A(2, 2);
    Matrix<int> B(2, 2);
    A.Data[0][0] = 2;
    A.Data[0][1] = 1;
    A.Data[1][0] = 1;
    A.Data[1][1] = 2;

    B.Data[0][0] = 2;
    B.Data[0][1] = 1;
    B.Data[1][0] = 1;
    B.Data[1][1] = 2;

    A.Print();
    B.Print();

    Matrix<int> C(Matrix<int>::SquareMultiplyRecursive(A, B, 0, 0, 0, 0, 2));

    C.Print();
}

它给了我不正确的结果,但我不确定我做错了什么......

【问题讨论】:

    标签: c++ algorithm matrix matrix-multiplication clrs


    【解决方案1】:
    // Recursive naive matrix multiplication in C, not strassen.
    // 2013-Feb-15 Fri 12:28 moshahmed/at/gmail
    
    #include <assert.h>
    #include <stdio.h>
    #include <stdlib.h>
    #include <time.h>
    
    #define M 2
    #define N (1<<M)
    
    typedef int mat[N][N]; // mat[2**M,2**M]  for divide and conquer mult.
    typedef struct { int ra, rb, ca, cb; } corners; // for tracking rows and columns.
    
    // set A[a] = k
    void set(mat A, corners a, int k){
      int i,j;
      for(i=a.ra;i<a.rb;i++)
        for(j=a.ca;j<a.cb;j++)
          A[i][j] = k;
    }
    
    // set A[a] = [random(l..h)].
    void randk(mat A, corners a, int l, int h){
      int i,j;
      for(i=a.ra;i<a.rb;i++)
        for(j=a.ca;j<a.cb;j++)
          A[i][j] = l + rand()% (h-l);
    }
    
    // Print A[a]
    void print(mat A, corners a, char *name) {
      int i,j;
      printf("%s = {\n",name);
      for(i=a.ra;i<a.rb;i++){
        for(j=a.ca;j<a.cb;j++)
          printf("%4d, ", A[i][j]);
        printf("\n");
      }
      printf("}\n");
    }
    
    // Return 1/4 of the matrix: top/bottom , left/right.
    void find_corners(corners a, int i, int j, corners *b) {
      int rm = a.ra + (a.rb - a.ra)/2 ;
      int cm = a.ca + (a.cb - a.ca)/2 ;
      *b = a;
      if (i==0)  b->rb = rm;     // top rows
      else       b->ra = rm;     // bot rows
      if (j==0)  b->cb = cm;     // left cols
      else       b->ca = cm;     // right cols
    }
    
    // Naive Multiply: A[a] * B[b] => C[c], recursively.
    void mul(mat A, mat B, mat C, corners a, corners b, corners c) {
      corners aii[2][2], bii[2][2], cii[2][2];
      int i, j, m, n, p;
    
      // Check: A[m n] * B[n p] = C[m p]
      m = a.rb - a.ra; assert(m==(c.rb-c.ra));
      n = a.cb - a.ca; assert(n==(b.rb-b.ra));
      p = b.cb - b.ca; assert(p==(c.cb-c.ca));
      assert(m>0);
    
      if (n==1) {
        C[c.ra][c.ca] += A[a.ra][a.ca] * B[b.ra][b.ca];
        return;
      }
    
      // Create the smaller matrices:
      for(i=0;i<2;i++) {
      for(j=0;j<2;j++) {
            find_corners(a, i, j, &aii[i][j]);
            find_corners(b, i, j, &bii[i][j]);
            find_corners(c, i, j, &cii[i][j]);
          }
      }
    
      // Now do the 8 sub matrix multiplications.
      // C00 = A00*B00 + A01*B10
      // C01 = A00*B01 + A01*B11
      // C10 = A10*B00 + A11*B10
      // C11 = A10*B01 + A11*B11
    
      mul( A, B, C, aii[0][0], bii[0][0], cii[0][0] );
      mul( A, B, C, aii[0][1], bii[1][0], cii[0][0] );
    
      mul( A, B, C, aii[0][0], bii[0][1], cii[0][1] );
      mul( A, B, C, aii[0][1], bii[1][1], cii[0][1] );
    
      mul( A, B, C, aii[1][0], bii[0][0], cii[1][0] );
      mul( A, B, C, aii[1][1], bii[1][0], cii[1][0] );
    
      mul( A, B, C, aii[1][0], bii[0][1], cii[1][1] );
      mul( A, B, C, aii[1][1], bii[1][1], cii[1][1] );
    
    }
    
    int main() {
      mat A, B, C;
      corners ai = {0,N,0,N};
      corners bi = {0,N,0,N};
      corners ci = {0,N,0,N};
      //set(A,ai,2);
      //set(B,bi,2);
      srand(time(0));
      randk(A,ai, 0, 2);
      randk(B,bi, 0, 2);
      set(C,ci,0); // set to zero before mult.
      print(A, ai, "A");
      print(B, bi, "B");
      mul(A,B,C, ai, bi, ci);
      print(C, ci, "C");
      return 0;
    }  
    

    【讨论】:

      【解决方案2】:

      我找到了解决方案... SetSubMatrix 完全不正确:

      void SetSubMatrix(const int r, const int c, const int rn, const int cn, const Matrix<T>& A, const Matrix<T>& B)
      {
          for(int _c=c; _c<cn; ++_c)
          {
              for(int _r=r; _r<rn; ++_r)
              {
                  Data[_c][_r] = A.Data[_c-c][_r-r] + B.Data[_c-c][_r-r];
              }
          }
      }
      
      static Matrix<T> SquareMultiplyRecursive(Matrix<T>& A, Matrix<T>& B, int ar, int ac, int br, int bc, int n)
      {
          Matrix<T> C(n, n);
      
          if(n == 1)
          {
              C.Data[0][0] = A.Data[ac][ar] * B.Data[bc][br];
          }
          else
          {
              C.SetSubMatrix(0, 0, n / 2, n / 2,
                             SquareMultiplyRecursive(A, B, ar, ac, br, bc, n / 2),
                             SquareMultiplyRecursive(A, B, ar, ac + (n / 2), br + (n / 2), bc, n / 2));
      
              C.SetSubMatrix(0, n / 2, n / 2, n,
                             SquareMultiplyRecursive(A, B, ar, ac, br, bc + (n / 2), n / 2),
                             SquareMultiplyRecursive(A, B, ar, ac + (n / 2), br + (n / 2), bc + (n / 2), n / 2));
      
              C.SetSubMatrix(n / 2, 0, n, n / 2,
                             SquareMultiplyRecursive(A, B, ar + (n / 2), ac, br, bc, n / 2),
                             SquareMultiplyRecursive(A, B, ar + (n / 2), ac + (n / 2), br + (n / 2), bc, n / 2));
      
              C.SetSubMatrix(n / 2, n / 2, n, n,
                             SquareMultiplyRecursive(A, B, ar + (n / 2), ac, br, bc + (n / 2), n / 2),
                             SquareMultiplyRecursive(A, B, ar + (n / 2), ac + (n / 2), br + (n / 2), bc + (n / 2), n / 2));
          }
      
          return C;
      }
      

      【讨论】:

        【解决方案3】:

        这是我的答案,取决于递归矩阵乘法。 仅适用于 N = 2 ^ M,其中 M >= 2

        template <std::size_t size>
        int matrix_mul_recursive(int N, int i, int j, const int (&A)[size][size], const int (&B)[size][size], int (&C)[size][size]) {
            if (N == 1) {
                return *(const_cast<int*>(&(A[0][0])) + i) * (*(const_cast<int*>(&(B[0][0])) + j));
            }
            else {
                const int H = N / 2;
                const int T = (size * H);
        
                int r = i / size;
                int c = 0;
                if (j < size) {
                    c = j;
                }
                else {
                    c = j % size;
                }
        
                C[r][c] += matrix_mul_recursive<size>(H, i, j, A, B, C) + 
                    matrix_mul_recursive<size>(H, i + H, T + j, A, B, C);
                C[r][c + H] += matrix_mul_recursive<size>(H, i, j + H, A, B, C) +
                    matrix_mul_recursive<size>(H, i + H, T + j + H, A, B, C);
                C[r + H][c] += matrix_mul_recursive<size>(H, T + i, j, A, B, C) +
                    matrix_mul_recursive<size>(H, T + i + H, T + j, A, B, C);
                C[r + H][c + H] += matrix_mul_recursive<size>(H, T + i, j + H, A, B, C) +
                    matrix_mul_recursive<size>(H, T + i + H, T + j + H, A, B, C);
            }
            return 0;
        }
        

        【讨论】:

          猜你喜欢
          • 1970-01-01
          • 2018-11-27
          • 1970-01-01
          • 2020-04-11
          • 1970-01-01
          • 1970-01-01
          • 2017-07-16
          • 1970-01-01
          • 2018-04-11
          相关资源
          最近更新 更多