【问题标题】:How to do arithmetic modulo another number, without overflow?如何对另一个数字进行算术模运算,而不会溢出?
【发布时间】:2017-08-28 11:39:14
【问题描述】:

我正在尝试为 Rust 的 u32u64 数据类型实现快速素性测试。作为其中的一部分,我需要计算(n*n)%d,其中nd 分别是u32(或u64)。

虽然结果很容易适合数据类型,但我不知道如何计算它。据我所知,没有用于此的处理器原语。

对于u32,我们可以伪造它——转换为u64,这样乘积不会溢出,然后取模,然后再转换回u32,知道这不会溢出。但是,由于我没有 u128 数据类型(据我所知),因此此技巧不适用于 u64

所以对于u64,我能想到的最明显的实现方法是以某种方式计算x*y 以获得u64 的一对(carry, product),因此我们捕获溢出量而不是仅仅丢失它(或恐慌,或其他)。

有没有办法做到这一点?还是另一种解决问题的标准方法?

【问题讨论】:

标签: rust integer-overflow integer-arithmetic modular-arithmetic


【解决方案1】:

Richard Rast pointed out 维基百科版本仅适用于 63 位整数。我扩展了 Boiethios 提供的代码以使用全范围的 64 位无符号整数。

fn mul_mod64(mut x: u64, mut y: u64, m: u64) -> u64 {
    let msb = 0x8000_0000_0000_0000;
    let mut d = 0;
    let mp2 = m >> 1;
    x %= m;
    y %= m;

    if m & msb == 0 {
        for _ in 0..64 {
            d = if d > mp2 {
                (d << 1) - m
            } else {
                d << 1
            };
            if x & msb != 0 {
                d += y;
            }
            if d >= m {
                d -= m;
            }
            x <<= 1;
        }
        d
    } else {
        for _ in 0..64 {
            d = if d > mp2 {
                d.wrapping_shl(1).wrapping_sub(m)
            } else {
                // the case d == m && x == 0 is taken care of 
                // after the end of the loop
                d << 1
            };
            if x & msb != 0 {
                let (mut d1, overflow) = d.overflowing_add(y);
                if overflow {
                    d1 = d1.wrapping_sub(m);
                }
                d = if d1 >= m { d1 - m } else { d1 };
            }
            x <<= 1;
        }
        if d >= m { d - m } else { d }
    }
}

#[test]
fn test_mul_mod64() {
    let half = 1 << 16;
    let max = std::u64::MAX;

    assert_eq!(mul_mod64(0, 0, 2), 0);
    assert_eq!(mul_mod64(1, 0, 2), 0);
    assert_eq!(mul_mod64(0, 1, 2), 0);
    assert_eq!(mul_mod64(1, 1, 2), 1);
    assert_eq!(mul_mod64(42, 1, 2), 0);
    assert_eq!(mul_mod64(1, 42, 2), 0);
    assert_eq!(mul_mod64(42, 42, 2), 0);
    assert_eq!(mul_mod64(42, 42, 42), 0);
    assert_eq!(mul_mod64(42, 42, 41), 1);
    assert_eq!(mul_mod64(1239876, 2948635, 234897), 163320);

    assert_eq!(mul_mod64(1239876, 2948635, half), 18476);
    assert_eq!(mul_mod64(half, half, half), 0);
    assert_eq!(mul_mod64(half+1, half+1, half), 1);

    assert_eq!(mul_mod64(max, max, max), 0);
    assert_eq!(mul_mod64(1239876, 2948635, max), 3655941769260);
    assert_eq!(mul_mod64(1239876, max, max), 0);
    assert_eq!(mul_mod64(1239876, max-1, max), max-1239876);
    assert_eq!(mul_mod64(max, 2948635, max), 0);
    assert_eq!(mul_mod64(max-1, 2948635, max), max-2948635);
    assert_eq!(mul_mod64(max-1, max-1, max), 1);
    assert_eq!(mul_mod64(2, max/2, max-1), 0);
}

【讨论】:

  • @mcarton half 的一半是什么?
  • @Shepmaster 现在我想了想,没什么:)
  • 很遗憾的说我没有经历过理解这段代码的工作:(但是我已经彻底测试过了,效果很好。谢谢!
【解决方案2】:

这是另一种方法(现在有 u128 数据类型):

fn mul_mod(a: u64, b: u64, m: u64) -> u64 {
    let (a, b, m) = (a as u128, b as u128, m as u128);
    ((a * b) % m) as u64
}

这种方法只是依赖于 LLVM 的 128 位整数运算。

我喜欢这个版本的一点是,它真的很容易让自己相信该解决方案对整个域都是正确的。由于abu64s,因此产品保证适合u128,并且由于mu64,所以保证最后的downcast 是安全的。

我不知道与其他方法相比性能如何,但如果它显着变慢,我会感到非常惊讶。如果您真的关心性能,那么无论如何您都需要运行一些基准测试并尝试一些替代方案。

【讨论】:

    【解决方案3】:

    使用简单的数学:

    (n*n)%d = (n%d)*(n%d)%d
    

    要查看这确实是真的,请设置n = k*d+r

    n*n%d = k**2*d**2+2*k*d*r+r**2 %d = r**2%d = (n%d)*(n%d)%d
    

    【讨论】:

    • 严格来说n * n % d = (n % d) * (n % d) % d
    • 如果d 很大,这仍然会溢出。
    • @interjay 具体来说,如果d &gt; 2^16会溢出
    【解决方案4】:

    red75prime added a useful comment。下面是计算两个相乘数的模的 Rust 代码,取自 Wikipedia:

    fn mul_mod(mut x: u64, mut y: u64, m: u64) -> u64 {
        let mut d = 0_u64;
        let mp2 = m >> 1;
        x %= m;
        y %= m;
    
        for _ in 0..64 {
            d = if d > mp2 {
                (d << 1) - m
            } else {
                d << 1
            };
            if x & 0x8000_0000_0000_0000_u64 != 0 {
                d += y;
            }
            if d > m {
                d -= m;
            }
            x <<= 1;
        }
        d
    }
    

    【讨论】:

    • @red75prime 如果你想发表你自己的答案,我删除我的。
    • 不,没关系。问题是这个算法不正确。我发现一个错误if d &gt; m ... 应该是if d &gt;= m ...。另一个导致 ` (d
    • 它似乎也给出了不正确的答案。 11552001120680995*15777587326414455(mod 18442563521290148565)应该是844062957336182220,算法给出13054753449364403936
    • 请注意,维基百科部分指出所有参数(xym)最多只能为 63 位(即 $
    猜你喜欢
    • 2021-01-12
    • 2011-05-15
    • 1970-01-01
    • 2023-04-08
    • 1970-01-01
    • 1970-01-01
    • 2011-05-11
    相关资源
    最近更新 更多