【问题标题】:Efficient way to multiply a large set of small numbers乘以大量小数的有效方法
【发布时间】:2014-04-15 17:50:12
【问题描述】:

这个问题是在一次采访中被问到的。

你有一个小整数数组。你必须将它们全部相乘。您不必担心溢出,您对此有充分的支持。你可以做些什么来加快你机器上的乘法速度?

在这种情况下多次添加会更好吗?

我建议使用分而治之的方法进行乘法,但面试官并没有留下深刻的印象。最好的解决方案是什么?

【问题讨论】:

  • 由于集合很大,整数很小,所以应该有很多重复。您可以对数字进行排序,然后对重复项使用二进制取幂,然后将结果相乘。
  • @EvgenyKluev 这听起来是个糟糕的主意——这怎么可能比单纯地把它们相乘更快?
  • @A.Webb 那又怎样?您仍然需要对整个数组进行排序,而不是仅仅将每个项目相乘。
  • 对于大量的小整数,排序并不可怕。只计算排序。
  • 使用计数排序,您可以在每个槽同时建立二进制求幂片段。

标签: algorithm multiplication


【解决方案1】:

以下是一些想法:

多线程分治法:将输入分成 n 个大小为 b 的不同块,并递归地将每个块中的所有数字相乘。然后,递归地将所有 n / b 块相乘。如果您有多个内核并且可以并行运行其中的一部分,则总体上可以节省大量时间。

字级并行:假设您的数字都从上方以某个数字 U 为界,该数字恰好是 2 的幂。现在,假设您想将 a、b、c 和 d 相乘。首先计算 (4U2a + b) × (4U2c + d) = 16U4ac + 4U2 ad + 4U2bc + bd。现在,请注意这个表达式 mod U2 只是 bd。 (由于 bd 2,我们不需要担心 mod U2 步骤搞砸了)。这意味着如果我们计算这个产品并取它 mod U2,我们会得到 bd。由于 U2 是 2 的幂,因此可以使用位掩码来完成。

接下来,请注意

4U2ad + 4U2bc + bd 4 + 4U4 + U24 4

这意味着如果我们将整个表达式除以 16U4 并向下取整,我们最终会得到 ad。这种除法可以通过位移来完成,因为 16U4 是 2 的幂。

因此,通过一次乘法,您可以通过应用后续的位移位和位掩码来取回 ac 和 bd 的值。一旦有了 ac 和 bd,就可以直接将它们相乘,得到 abcd 的值。假设位掩码和位移位比乘法更快,这将必要的乘法次数减少了 33%(这里是两个而不是三个)。

希望这会有所帮助!

【讨论】:

  • 这个词级并行性是否有好的来源?我不太明白除以 U2 并向下取整部分。
  • @Sohaib 我已经更新了我的答案(看起来数学中有一个初始错误)。希望这会有所帮助!
  • 如果我们除以16U^4 会不会得到0?我在这里有点困惑。即使我们右移 n 位考虑 16U^4=2^n 为什么整个表达式不会变成 0?
  • 如果将 16U^4 ac + 4U^2 ac + 4U^2 bc + bd 除以 16U^4,则三个小项都归零。在第一项中,16U^4 项被除掉,留下 ac(因为 16U^4 ac / 16U^4 = ac)
  • 好的,我明白了!谢谢!
【解决方案2】:

您的分而治之的建议是一个好的开始。只是需要更多解释才能给人留下深刻印象。

使用用于乘法大数(大整数)的快速乘法算法,相乘大小相似的被乘数比一系列大小不匹配的被乘数要高效得多。

这是 Clojure 中的一个示例

; Define a vector of 100K random integers between 2 and 100 inclusive
(def xs (vec (repeatedly 100000 #(+ 2 (rand-int 99)))))

; Naive multiplication accumulating linearly through the array
(time (def a1 (apply *' xs)))
"Elapsed time: 7116.960557 msecs"

; Simple Divide and conquer algorithm
(defn d-c [v] 
  (let [m (quot (count v) 2)] 
    (if (< m 3) 
      (reduce *' v)
      (*' (d-c (subvec v 0 m)) (d-c (subvec v m))))))

; Takes less than 1/10th the time.
(time (def a2 (d-c xs)))
"Elapsed time: 600.934268 msecs"

(= a1 a2) ;=> true (same result)

请注意,这种改进并不依赖于对数组中整数大小的设定限制(任意选择 100 个并演示下一个算法),而仅依赖于它们的大小相似。这是一个非常简单的分治法。随着数字变得越来越大,乘起来的成本越来越高,花更多的时间按相似的大小迭代地对它们进行分组是有意义的。在这里,我依赖于随机分布和大小保持相似的可能性,但即使在最坏的情况下,它仍然比简单的方法要好得多。

正如 Evgeny Kluev 在 cmets 中所建议的那样,对于 largesmall 整数,将会有很多重复,因此有效的取幂也更好比天真的乘法。这更多地取决于相对参数而不是分而治之,也就是说,数字必须相对于计数足够小,才能累积足够多的重复项才能打扰,但这些参数肯定表现良好(100K 范围内的数字 2- 100)。

; Hopefully an efficient implementation
(defn pow [x n] (.pow (biginteger x) ^Integer n))

; Perform pow on duplications based on frequencies
(defn exp-reduce [v] (reduce *' (map (partial apply pow) (frequencies v))))

(time (def a3 (exp-reduce xs)))
"Elapsed time: 650.211789 msecs"

请注意,非常简单的分而治之在本次试验中表现稍好一点,但如果预期的重复次数更少,效果会相对更好。

当然我们也可以将两者结合起来:

(defn exp-d-c [v] (d-c (mapv (partial apply pow) (frequencies v))))

(time (def a4 (exp-d-c xs)))
"Elapsed time: 494.394641 msecs"

(= a1 a2 a3 a4) ;=> true (all equal)

请注意,有更好的方法可以将这两者结合起来,因为求幂步骤的结果会产生各种大小的被乘数。这样做所增加的复杂性的价值取决于输入中不同数字的预期数量。在这种情况下,可区分的数字很少,因此不会增加太多复杂性。

另请注意,如果有多个内核可用,这两者很容易并行化。

【讨论】:

    【解决方案3】:

    如果许多小整数多次出现,您可以从计算每个唯一整数开始。如果c(n)是整数n的出现次数,则乘积可以计算为

    P = 2 ^ c(2) * 3 ^ c(3) * 4 ^ c(4) * ...
    

    对于取幂步骤,您可以使用平方取幂,这可以大大减少乘法的次数。

    【讨论】:

    • 与 Evgeny Kluev 建议的相同。好办法。但是,templatetypedef 的答案并没有真正留下任何机会。
    【解决方案4】:

    如果与范围相比,数字的计数确实很大,那么我们已经看到提出了两种渐近解决方案来显着降低复杂性。一个是基于连续平方来计算每个数字 c 在 O(log k) 时间内的 c^k,如果最大数字是 C,则给出 O(C mean(log k)) 时间,并且 k 给出 1 之间每个数字的指数和 C。如果每个数字出现相同的次数,则平均值(log k)项最大化,所以如果你有 N 个数字,那么复杂度变为 O(C log(N/C)),它非常弱地依赖于 N并且本质上只是 O(C),其中 C 指定了数字的范围。

    我们看到的另一种方法是按数字出现的次数对数字进行排序,并跟踪前导数字的乘积(从所有数字开始)并将其提高到一个幂,以便从最不频繁的数字中删除数组,然后更新数组中剩余元素的指数并重复。如果所有数字出现相同的次数 K,那么这给出了 O(C + log K),这是对 O(C log K) 的改进。但是,假设第 k 个数字出现 2^k 次。如果 C > log(N/C),这仍然会给出 O(C^2 + C log(N/C)) 时间,这在技术上比之前的方法 O(C log(N/C)) 更差。因此,如果您没有关于每个数字出现的均匀分布的良好信息,您应该采用第一种方法,只需通过使用连续平方来取乘积中出现的每个不同数字的适当幂,然后取结果的产物。如果有 C 个不同的数字和 N 个总数,总时间 O(C log (N/C))。

    【讨论】:

      【解决方案5】:

      要回答这个问题,我们需要以某种方式解释 OP 的假设:need not worry about overflow。在这个答案的大部分中,它被解释为“忽略溢出”。但我从一些关于其他解释的想法开始(“使用多精度算术”)。在这种情况下,乘法过程可能大致分为 3 个阶段:

      1. 将一组小的小数相乘,得到一大组不那么小的数。此处可能会使用此答案第二部分中的一些想法。
      2. 将这些数字相乘得到一组大数字。可以使用平凡(二次时间)算法或Toom–Cook/Karatsuba(次二次时间)方法。
      3. 将大数相乘。可以使用Fürer's 或 Schönhage–Strassen 算法。这使得整个过程的时间复杂度为 O(N polylog N)。

      二进制取幂可能会带来一些(不是很显着)的速度提升,因为这里提到的大多数(如果不是全部)复杂乘法算法的平方比两个不等数的乘法要快。我们也可以分解每个“小”数并仅对素因数使用二进制取幂。对于均匀分布的“小”数,这将减少乘方次数log(number_of_values),并略微改善平方/乘法的平衡。

      当数字均匀分布时,分而治之是可以的。否则(例如,当输入数组被排序或使用二进制求幂时)我们可以通过将所有被乘数放入优先级队列中做得更好,按数字长度排序(可能近似排序)。然后我们可以将两个最短的数相乘并将结果放回队列中(这个过程非常类似于霍夫曼编码)。无需使用此队列进行平方。此外,我们不应该在数字不够长时使用它。

      更多信息请访问the answer by A. Webb


      如果可以忽略溢出,我们可以将数字与线性时间或更好的算法相乘。

      如果输入数组已排序或输入数据以元组集合 {值,出现次数} 的形式呈现,则亚线性时间算法是可能的。在后一种情况下,我们可以对每个值执行二进制取幂并将结果相乘。时间复杂度为 O(C log(N/C)),其中C 是数组中不同值的数量。 (另见this answer)。

      如果输入数组已排序,我们可以使用二进制搜索来查找值发生变化的位置。这允许确定每个值在数组中出现的次数。然后我们可以对每个值执行二进制取幂并将结果相乘。时间复杂度为 O(C log N)。我们可以在这里使用单向二分搜索做得更好。在这种情况下,时间复杂度为 O(C log(N/C))。

      但是如果输入数组没有排序,我们必须检查每个元素,所以 O(N) 时间复杂度是我们能做的最好的。我们仍然可以使用并行(多线程、SIMD、字级并行)来提高速度。在这里我比较了几种这样的方法。

      为了比较这些方法,我选择了非常小的(3 位)值,这些值非常紧凑(每个 8 位整数一个值)。并以低级语言 (C++11) 实现它们,以便更轻松地访问位操作、特定 CPU 指令和 SIMD。

      以下是所有算法:

      1. accumulate 来自标准库。
      2. 使用 4 个累加器的简单实现。
      3. 乘法的字级并行性,如the answer by templatetypedef 中所述。对于 64 位字长,这种方法最多允许 10 位值(只有 3 次乘法而不是每次 4 次),或者它可以应用两次(我在测试中这样做)最多 5 位值(只需要每 8 次乘法中的 5 次)。
      4. 表查找。在测试中,每 8 个乘法中有 7 个被单表查找代替。如果值大于这些测试中的值,替代乘法的数量会减少,从而减慢算法速度。大于 11-12 位的值使这种方法毫无用处。
      5. 二进制取幂(请参阅下面的详细信息)。大于 4 位的值使这种方法无用。
      6. SIMD (AVX2)。此实现最多可以使用 8 位值。

      Here are sources for all tests on Ideone。请注意,SIMD 测试需要英特尔的 AVX2 指令集。查表测试需要 BMI2 指令集。其他测试不依赖于任何特定的硬件(我希望如此)。我在 64 位 Linux 上运行这些测试,使用 gcc 4.8.1 编译,优化级别 -O2

      以下是二进制取幂测试的更多细节:

          for (size_t i = 0; i < size / 8; i += 2)
          {
              auto compr = (qwords[i] << 4) | qwords[i + 1];
              constexpr uint64_t lsb = 0x1111111111111111;
              if ((compr & lsb) != lsb) // if there is at least one even value
              {
                  auto b = reinterpret_cast<uint8_t*>(qwords + i);
                  acc1 *= accumulate(b, b + 16, acc1, multiplies<unsigned>{});
                  if (!acc1)
                      break;
              }
              else
              {
                  const auto b2 = compr & 0x2222222222222222;
                  const auto b4 = compr & 0x4444444444444444;
                  const auto b24 = b4 & (b2 * 2);
                  const unsigned c7 = __builtin_popcountll(b24);
                  acc3 += __builtin_popcountll(b2) - c7;
                  acc5 += __builtin_popcountll(b4) - c7;
                  acc7 += c7;
              }
          }
          const auto prod4 = acc1 * bexp<3>(acc3) * bexp<5>(acc5) * bexp<7>(acc7);
      

      此代码比输入数组更密集地打包值:每个字节两个值。每个值的低位处理方式不同:由于我们可以在此处找到 32 个零位(结果为“零”)后停止,因此这种情况不会对性能产生太大影响,因此通过最简单的(库)算法处理。

      在剩余的 4 个值中,“1”没有意义,因此我们只需要计算“3”、“5”和“7”的出现次数,并使用按位操作和固有的“人口计数”。

      结果如下:

        source size:    4 Mb:         400 Mb:
      1. accumulate: 0.835392 ns    0.849199 ns
      2.  accum * 4: 0.290373 ns    0.286915 ns
      3. 2 mul in 1: 0.178556 ns    0.182606 ns
      4. mult table: 0.130707 ns    0.176102 ns
      5. binary exp: 0.128484 ns    0.119241 ns
      6.       AVX2: 0.0607049 ns   0.0683234 ns
      

      在这里我们可以看到accumulate 库算法非常慢:由于某种原因,gcc 无法展开循环并使用 4 个独立的累加器。

      “手动”进行这种优化并不难。结果不是特别快。但如果我们为此任务分配 4 个线程,CPU 将大致匹配内存带宽(2 个通道,DDR3-1600)。

      乘法的字级并行性几乎快两倍。 (我们只需要 3 个线程来匹配内存带宽)。

      查表速度更快。但是当输入数组无法放入 L3 缓存时,它的性能会下降。 (我们需要 3 个线程来匹配内存带宽)。

      二进制取幂的速度大致相同。但是对于较大的输入,这种性能不会降低,甚至会略有提高,因为与值计数相比,求幂本身使用的时间更少。 (我们只需要 2 个线程来匹配内存带宽)。

      正如预期的那样,SIMD 是最快的。当输入数组无法放入 L3 缓存时,其性能会略有下降。这意味着我们接近单线程的内存带宽。

      【讨论】:

        【解决方案6】:

        我有一个解决方案。让我们与其他解决方案一起讨论。

        问题的关键部分是如何减少乘法次数。整数很小,但集合很大。

        我的解决方案:

        • 使用一个小数组来记录每个数字出现的次数。
        • 从数组中删除数字 1。你不需要计算它。
        • 找出出现次数最少的数 n。然后将所有数字相乘,得到结果K。然后数K^n。
        • 删除这个数字(例如,您可以将其与数组的最后一个数字切换,并将数组的大小减小为 1)。所以下次你不会再考虑这个数字了。同时,其他号码的出现次数也需要随着号码被移除的次数而减少。
        • 再次获取出现次数最少的数字。执行与第 2 步相同的操作。
        • 重复执行步骤 2-4 并完成计数。

        让我用一个例子来说明我们需要做多少次乘法:假设 我们有 5 个数字 [1, 2, 3, 4, 5]。数字 1 出现 100 次,数字 2 出现 150 次,数字 3 出现 200 次,数字 4 出现 300 次 次,数字 5 出现 400 次。

        方法一:直接相乘或分治法 我们需要 100+150+200+300+400-1 = 1149 乘以得到结果。

        方法 2:我们做 (1^100)(2^150)(3^200)(4^300)(5^400) (100-1)+(150-1)+(200-1)+(300-1)+(400-1)+4 = 1149.[同方法1] 因为 n^m 实际上会做 m-1 乘法。另外,您需要时间来查看所有数字,尽管时间很短。

        本帖中的方法: 首先,您需要时间来处理所有数字。与乘法时间相比,它可以被丢弃。

        您进行的真正计数是: ((2*3*4*5)^150)*((3*4*5)^50)*((4*5)^100)*(5^100)

        然后你需要乘以 3+149+2+49+1+99+99+3 = 405 次

        【讨论】:

          猜你喜欢
          • 1970-01-01
          • 1970-01-01
          • 2023-03-07
          • 2014-09-23
          • 2022-08-21
          • 2015-02-17
          • 1970-01-01
          • 1970-01-01
          相关资源
          最近更新 更多