【问题标题】:Fast bignum square computation快速大数平方计算
【发布时间】:2013-08-30 04:39:30
【问题描述】:

为了加快我的 bignum divisons,我需要加快操作 y = x^2 的 bigints,它表示为无符号 DWORD 的动态数组。说清楚:

DWORD x[n+1] = { LSW, ......, MSW };
  • 其中 n+1 是使用的 DWORD 数
  • 所以数字x = x[0]+x[1]<<32 + ... x[N]<<32*(n)的值

问题是:如何在不损失精度的情况下尽可能快地计算y = x^2 - 使用 C++ 和整数运算(带进位的 32 位)。

我目前的方法是应用乘法 y = x*x 并避免多次乘法。

例如:

x = x[0] + x[1]<<32 + ... x[n]<<32*(n)

为简单起见,让我重写一下:

x = x0+ x1 + x2 + ... + xn

其中index代表数组内部的地址,所以:

y = x*x
y = (x0 + x1 + x2 + ...xn)*(x0 + x1 + x2 + ...xn)
y = x0*(x0 + x1 + x2 + ...xn) + x1*(x0 + x1 + x2 + ...xn) + x2*(x0 + x1 + x2 + ...xn) + ...xn*(x0 + x1 + x2 + ...xn)

y0     = x0*x0
y1     = x1*x0 + x0*x1
y2     = x2*x0 + x1*x1 + x0*x2
y3     = x3*x0 + x2*x1 + x1*x2
...
y(2n-3) = xn(n-2)*x(n  ) + x(n-1)*x(n-1) + x(n  )*x(n-2)
y(2n-2) = xn(n-1)*x(n  ) + x(n  )*x(n-1)
y(2n-1) = xn(n  )*x(n  )

仔细一看,很明显几乎所有xi*xj都出现了两次(不是第一次也不是最后一次),这意味着N*N乘法可以被(N+1)*(N/2)乘法代替。附言32bit*32bit = 64bit 所以每个mul+add 操作的结果都被处理为64+1 bit

有没有更好的方法来快速计算?我在搜索过程中发现的只是 sqrts 算法,而不是 sqr...

快速平方

!!!请注意,我的代码中的所有数字都是首先是 MSW,... 不像上面的测试(为了简化方程式,首先是 LSW,否则会是索引混乱)。

当前的功能 fsqr 实现

void arbnum::sqr(const arbnum &x)
{
    // O((N+1)*N/2)
    arbnum c;
    DWORD h, l;
    int N, nx, nc, i, i0, i1, k;
    c._alloc(x.siz + x.siz + 1);
    nx = x.siz - 1;
    nc = c.siz - 1;
    N = nx + nx;
    for (i=0; i<=nc; i++)
        c.dat[i]=0;
    for (i=1; i<N; i++)
        for (i0=0; (i0<=nx) && (i0<=i); i0++)
        {
            i1 = i - i0;
            if (i0 >= i1)
                break;
            if (i1 > nx)
                continue;
            h = x.dat[nx-i0];
            if (!h)
                continue;
            l = x.dat[nx-i1];
            if (!l)
                continue;
            alu.mul(h, l, h, l);
            k = nc - i;
            if (k >= 0)
                alu.add(c.dat[k], c.dat[k], l);
            k--;
            if (k>=0)
                alu.adc(c.dat[k], c.dat[k],h);
            k--;
            for (; (alu.cy) && (k>=0); k--)
                alu.inc(c.dat[k]);
        }
        c.shl(1);
        for (i = 0; i <= N; i += 2)
        {
            i0 = i>>1;
            h = x.dat[nx-i0];
            if (!h)
                continue;
            alu.mul(h, l, h, h);
            k = nc - i;
            if (k >= 0)
                alu.add(c.dat[k], c.dat[k],l);
            k--;
            if (k>=0)
                alu.adc(c.dat[k], c.dat[k], h);
            k--;
            for (; (alu.cy) && (k >= 0); k--)
                alu.inc(c.dat[k]);
        }
        c.bits = c.siz<<5;
        c.exp = x.exp + x.exp + ((c.siz - x.siz - x.siz)<<5) + 1;
        c.sig = sig;
        *this = c;
    }

使用 Karatsuba 乘法

(感谢卡尔皮斯)

我实现了 Karatsuba 乘法,但结果比使用简单的 O(N^2) 乘法要慢得多,这可能是因为我看不到任何方法可以避免的可怕递归。它的权衡必须是非常大的数字(大于数百位数)......但即便如此,也有很多内存传输。有没有办法避免递归调用(非递归变体,......几乎所有递归算法都可以这样做)。尽管如此,我还是会尝试调整一下,看看会发生什么(避免规范化等......,这也可能是代码中的一些愚蠢的错误)。无论如何,在为 case x*x 解决 Karatsuba 之后,性能并没有太大提升。

优化的 Karatsuba 乘法

y = x^2 looped 1000x times, 0.9 &lt; x &lt; 1 ~ 32*98 bits 的性能测试:

x = 0.98765588997654321000000009876... | 98*32 bits
sqr [ 213.989 ms ] ... O((N+1)*N/2) fast sqr
mul1[ 363.472 ms ] ... O(N^2) classic multiplication
mul2[ 349.384 ms ] ... O(3*(N^log2(3))) optimized Karatsuba multiplication
mul3[ 9345.127 ms] ... O(3*(N^log2(3))) unoptimized Karatsuba multiplication

x = 0.98765588997654321000... | 195*32 bits
sqr [ 883.01 ms ]
mul1[ 1427.02 ms ]
mul2[ 1089.84 ms ]

x = 0.98765588997654321000... | 389*32 bits
sqr [ 3189.19 ms ]
mul1[ 5553.23 ms ]
mul2[ 3159.07 ms ]

对 Karatsuba 进行优化后,代码比以前快了很多。尽管如此,对于较小的数字,它仍略低于我的 O(N^2) 乘法的一半速度。对于更大的数字,布斯乘法的复杂性给出的比率会更快。乘法的阈值约为 32*98 位,sqr 的阈值约为 32*389 位,因此如果输入位的总和超过此阈值,则将使用 Karatsuba 乘法来加速乘法,sqr 也是如此。

顺便说一句,包括优化:

  • 通过太大的递归参数来减少堆垃圾
  • 避免使用任何带进位的 bignum 算术 (+,-) 32 位 ALU。
  • 忽略 0*yx*00*0 情况
  • 将输入 x,y 数字大小重新格式化为 2 的幂以避免重新分配
  • z1 = (x0 + x1)*(y0 + y1) 实现模乘以最小化递归

将 Schönhage-Strassen 乘法修改为 sqr 实现

我已经测试了使用 FFTNTT 变换来加速 sqr 计算。结果如下:

  1. FFT

    失去准确性,因此需要高精度的复数。这实际上大大减慢了速度,因此不存在加速。结果不精确(可能舍入错误),因此 FFT 不可用(暂时)

  2. NTT

    NTT 是有限域DFT,因此不会发生精度损失。它需要对无符号整数进行模运算:modpow, modmul, modaddmodsub

    我使用DWORD(32 位无符号整数)。 NTT 输入/输出向量大小因溢出问题而受到限制!!!对于 32 位模运算,N 被限制为 (2^32)/(max(input[])^2) 所以bigint 必须被划分成更小的块(我使用BYTES 所以处理的bigint 的最大大小是

    (2^32)/((2^8)^2) = 2^16 bytes = 2^14 DWORDs = 16384 DWORDs)
    

    sqr 仅使用 1xNTT + 1xINTT 而不是 2xNTT + 1xINTT 进行乘法运算,但 NTT 使用速度太慢,而且阈值大小太大,无法在我的实现中实际使用(对于 @987654360 @ 和 sqr)。

    有可能甚至超过了溢出限制,因此应该使用 64 位模运算,这会进一步减慢速度。所以 NTT 对我来说也无法使用。

一些测量结果:

a = 0.98765588997654321000 | 389*32 bits
looped 1x times
sqr1[ 3.177 ms ] fast sqr
sqr2[ 720.419 ms ] NTT sqr
mul1[ 5.588 ms ] simpe mul
mul2[ 3.172 ms ] karatsuba mul
mul3[ 1053.382 ms ] NTT mul

我的实现:

void arbnum::sqr_NTT(const arbnum &x)
{
    // O(N*log(N)*(log(log(N)))) - 1x NTT
    // Schönhage-Strassen sqr
    // To prevent NTT overflow: n <= 48K * 8 bit -> result siz <= 12K * 32 bit -> x.siz + y.siz <= 12K!!!
    int i, j, k, n;
    int s = x.sig*x.sig, exp0 = x.exp + x.exp - ((x.siz+x.siz)<<5) + 2;
    i = x.siz;
    for (n = 1; n < i; n<<=1)
        ;
    if (n + n > 0x3000) {
        _error(_arbnum_error_TooBigNumber);
        zero();
        return;
    }
    n <<= 3;
    DWORD *xx, *yy, q, qq;
    xx = new DWORD[n+n];
    #ifdef _mmap_h
    if (xx)
        mmap_new(xx, (n+n) << 2);
    #endif
    if (xx==NULL) {
        _error(_arbnum_error_NotEnoughMemory);
        zero();
        return;
    }
    yy = xx + n;

    // Zero padding (and split DWORDs to BYTEs)
    for (i--, k=0; i >= 0; i--)
    {
        q = x.dat[i];
        xx[k] = q&0xFF; k++; q>>=8;
        xx[k] = q&0xFF; k++; q>>=8;
        xx[k] = q&0xFF; k++; q>>=8;
        xx[k] = q&0xFF; k++;
    }
    for (;k<n;k++)
        xx[k] = 0;

    //NTT
    fourier_NTT ntt;

    ntt.NTT(yy,xx,n);    // init NTT for n

    // Convolution
    for (i=0; i<n; i++)
        yy[i] = modmul(yy[i], yy[i], ntt.p);

    //INTT
    ntt.INTT(xx, yy);

    //suma
    q=0;
    for (i = 0, j = 0; i<n; i++) {
        qq = xx[i];
        q += qq&0xFF;
        yy[n-i-1] = q&0xFF;
        q>>=8;
        qq>>=8;
        q+=qq;
    }

    // Merge WORDs to DWORDs and copy them to result
    _alloc(n>>2);
    for (i = 0, j = 0; i<siz; i++)
    {
        q  =(yy[j]<<24)&0xFF000000; j++;
        q |=(yy[j]<<16)&0x00FF0000; j++;
        q |=(yy[j]<< 8)&0x0000FF00; j++;
        q |=(yy[j]    )&0x000000FF; j++;
        dat[i] = q;
    }

    #ifdef _mmap_h
    if (xx)
        mmap_del(xx);
    #endif
    delete xx;
    bits = siz<<5;
    sig = s;
    exp = exp0 + (siz<<5) - 1;
        // _normalize();
    }

结论

对于较小的数字,这是我快速sqr 方法的最佳选择,之后 阈值 Karatsuba 乘法更好。但我仍然认为应该有一些我们忽略的微不足道的东西。还有其他想法吗?

NTT 优化

经过大规模优化(主要是 NTT):堆栈溢出问题Modular arithmetics and NTT (finite field DFT) optimizations

一些值发生了变化:

a = 0.98765588997654321000 | 1553*32bits
looped 10x times
mul2[ 28.585 ms ] Karatsuba mul
mul3[ 26.311 ms ] NTT mul

所以现在 NTT 乘法在大约 1500*32 位阈值之后终于快于 Karatsuba

一些测量结果和发现的错误

a = 0.99991970486 | 1553*32 bits
looped: 10x
sqr1[  58.656 ms ] fast sqr
sqr2[  13.447 ms ] NTT sqr
mul1[ 102.563 ms ] simpe mul
mul2[  28.916 ms ] Karatsuba mul Error
mul3[  19.470 ms ] NTT mul

我发现我的 Karatsuba(上/下)流过 bignum 的每个 DWORD 段的 LSB。等我研究好了再更新代码……

此外,在进一步 NTT 优化之后,阈值发生了变化,因此对于 NTT sqr,它是 操作数310*32 bits = 9920 bits,对于 NTT mul 它是 result1396*32 bits = 44672 bits(操作数位的总和)。

感谢@greybeard 修复了 Karatsuba 代码

//---------------------------------------------------------------------------
void arbnum::_mul_karatsuba(DWORD *z, DWORD *x, DWORD *y, int n)
{
    // Recursion for Karatsuba
    // z[2n] = x[n]*y[n];
    // n=2^m
    int i;
    for (i=0; i<n; i++)
        if (x[i]) {
            i=-1;
            break;
        } // x==0 ?

    if (i < 0)
        for (i = 0; i<n; i++)
            if (y[i]) {
                i = -1;
                break;
            } // y==0 ?

    if (i >= 0) {
        for (i = 0; i < n + n; i++)
            z[i]=0;
            return;
        } // 0.? = 0

    if (n == 1) {
        alu.mul(z[0], z[1], x[0], y[0]);
        return;
    }

    if (n< 1)
        return;
    int n2 = n>>1;
    _mul_karatsuba(z+n, x+n2, y+n2, n2);                         // z0 = x0.y0
    _mul_karatsuba(z  , x   , y   , n2);                         // z2 = x1.y1
    DWORD *q = new DWORD[n<<1], *q0, *q1, *qq;
    BYTE cx,cy;
    if (q == NULL) {
        _error(_arbnum_error_NotEnoughMemory);
        return;
    }
    #define _add { alu.add(qq[i], q0[i], q1[i]); for (i--; i>=0; i--) alu.adc(qq[i], q0[i], q1[i]); } // qq = q0 + q1 ...[i..0]
    #define _sub { alu.sub(qq[i], q0[i], q1[i]); for (i--; i>=0; i--) alu.sbc(qq[i], q0[i], q1[i]); } // qq = q0 - q1 ...[i..0]
    qq = q;
    q0 = x + n2;
    q1 = x;
    i = n2 - 1;
    _add;
    cx = alu.cy; // =x0+x1

    qq = q + n2;
    q0 = y + n2;
    q1 = y;
    i = n2 - 1;
    _add;
    cy = alu.cy; // =y0+y1

    _mul_karatsuba(q + n, q + n2, q, n2);                       // =(x0+x1)(y0+y1) mod ((2^N)-1)

    if (cx) {
        qq = q + n;
        q0 = qq;
        q1 = q + n2;
        i = n2 - 1;
        _add;
        cx = alu.cy;
    }// += cx*(y0 + y1) << n2

    if (cy) {
        qq = q + n;
        q0 = qq;
        q1 = q;
        i = n2 -1;
        _add;
        cy = alu.cy;
    }// +=cy*(x0+x1)<<n2

    qq = q + n;  q0 = qq; q1 = z + n; i = n - 1; _sub;  // -=z0
    qq = q + n;  q0 = qq; q1 = z;     i = n - 1; _sub;  // -=z2
    qq = z + n2; q0 = qq; q1 = q + n; i = n - 1; _add;  // z1=(x0+x1)(y0+y1)-z0-z2

    DWORD ccc=0;

    if (alu.cy)
        ccc++;    // Handle carry from last operation
    if (cx || cy)
        ccc++;    // Handle carry from before last operation
    if (ccc)
    {
        i = n2 - 1;
        alu.add(z[i], z[i], ccc);
        for (i--; i>=0; i--)
            if (alu.cy)
                alu.inc(z[i]);
            else
                break;
    }

    delete[] q;
    #undef _add
    #undef _sub
    }

//---------------------------------------------------------------------------
void arbnum::mul_karatsuba(const arbnum &x, const arbnum &y)
{
    // O(3*(N)^log2(3)) ~ O(3*(N^1.585))
    // Karatsuba multiplication
    //
    int s = x.sig*y.sig;
    arbnum a, b;
    a = x;
    b = y;
    a.sig = +1;
    b.sig = +1;
    int i, n;
    for (n = 1; (n < a.siz) || (n < b.siz); n <<= 1)
        ;
    a._realloc(n);
    b._realloc(n);
    _alloc(n + n);
    for (i=0; i < siz; i++)
        dat[i]=0;
    _mul_karatsuba(dat, a.dat, b.dat, n);
    bits = siz << 5;
    sig = s;
    exp = a.exp + b.exp + ((siz-a.siz-b.siz)<<5) + 1;
    //    _normalize();
    }
//---------------------------------------------------------------------------

我的arbnum 号码表示:

// dat is MSDW first ... LSDW last
DWORD *dat; int siz,exp,sig,bits;
  • dat[siz] 是尾数。 LSDW 表示最不重要的 DWORD。
  • expdat[0] 的 MSB 的指数
  • 尾数中存在第一个非零位!!!

    // |-----|---------------------------|---------------|------|
    // | sig | MSB      mantisa      LSB |   exponent    | bits |
    // |-----|---------------------------|---------------|------|
    // | +1  | 0.(0      ...          0) | 2^0           |   0  | +zero
    // | -1  | 0.(0      ...          0) | 2^0           |   0  | -zero
    // |-----|---------------------------|---------------|------|
    // | +1  | 1.(dat[0] ... dat[siz-1]) | 2^exp         |   n  | +number
    // | -1  | 1.(dat[0] ... dat[siz-1]) | 2^exp         |   n  | -number
    // |-----|---------------------------|---------------|------|
    // | +1  | 1.0                       | 2^+0x7FFFFFFE |   1  | +infinity
    // | -1  | 1.0                       | 2^+0x7FFFFFFE |   1  | -infinity
    // |-----|---------------------------|---------------|------|
    

【问题讨论】:

  • 我的问题是您为什么决定实现自己的 bignum 实现? The GNU Multiple Precision Arithmetic Library 可能是最常用的 bignum 库之一,它的所有操作都应该是最佳的。
  • 出于兼容性原因,我使用自己的 bignum 库。将所有代码移植到不同的库比乍看之下更耗时(有时甚至因为编译器不兼容,特别是与 gcc 代码不兼容)。我目前只是在调整一些东西,......所有运行都按原样运行,但总是需要更快的速度:)
  • P.S.对于 NTT 使用,我强烈建议 NTT 的计算精度比输入值高 4 倍(因此对于 8 位数字,您需要将它们转换为 32 位数字)以在最大数组大小和速度之间取得折衷

标签: c++ algorithm multiplication bignum sqr


【解决方案1】:

如果您想编写一个新的更好的指数,您可能必须在汇编中编写它。这是来自 golang 的代码。

https://code.google.com/p/go/source/browse/src/pkg/math/exp_amd64.s

【讨论】:

  • 我主要使用 borland C++ 编译器...汇编函数比普通 C++ 实现慢(不支持为什么,也许是状态推送/弹出),我只搜索 x^2,对于 pow i我正在使用 exp2,log2 但无论如何都不错的 asm 代码
【解决方案2】:

如果我正确理解您的算法,似乎O(n^2) 其中n 是位数。

你看过Karatsuba Algorithm吗? 它使用分而治之的方法加速乘法。可能值得一看。

【讨论】:

  • 很好,这大大加快了速度...因为 x=y ...在编码之前很难假设复杂性。
  • 另一方面,为 x*x 解决 karatsuba 的结果与我的方法相同 :( 如果更多递归方法更好,我会尝试...我的复杂性现在来自 O(n^ 2) ~O(0.5*N^2) 但根据该页面应该更低
  • 好的,我已经检查了 karatsuba 算法。加速乘法很好,但 x^2 仅适用于非常大的数字。我认为应该有一些比一般乘法更简单且速度更快的方法。
  • 我成功地测试了 Schönhage–Strassen 乘法,FFT 存在舍入问题,并且由于复数而有点慢,我是 NTT 新手,但现在我必须让它工作使用 Schönhage–Strassen 为 sqr 实现快速 NTT 并选择字长(目前仅在十进制字符串上测试)应该比乘法快 1/3(删除 1xNTT)所以我很好奇我的实现完成后会有多快。
  • 我也测试过NTT,但结果并不好。所以我的快速 sqr 和你的 karatsuba 赢得了比赛。我接受了你的回答。
猜你喜欢
  • 2017-09-07
  • 2016-02-21
  • 1970-01-01
  • 1970-01-01
  • 2016-05-19
  • 1970-01-01
  • 1970-01-01
  • 2015-01-15
  • 2010-11-01
相关资源
最近更新 更多