【问题标题】:finding the count of number of sub arrays of size K whose sum is divisible by M?查找大小为 K 且总和可被 M 整除的子数组的数量?
【发布时间】:2020-06-26 17:04:23
【问题描述】:

Alice 有 N 个硬币,数量从 0 到 (N-1)。 Bob 想从中取出 k 个硬币,但 Alice 只会在这组 K 个硬币有趣的情况下给。

如果一组硬币的总和能被一个唯一整数 M 整除,那么它们就很有趣。现在 Bob 想知道他可以通过多少种方式获得 K 个硬币。

按 answer%(10^9+7) 打印结果

输入格式:- 三个空格分隔的整数 N,K,M

约束:-

  • 1
  • 1

样本输入:- 4 2 2

示例输出:- 2({1,3},{2,4})

我尝试使用 python 库中的组合来解决问题,但结果是 超出了内存限制。 后来我对它使用了递归方法,但它也导致超过时间限制。因为每个私人测试用例需要10秒的时间。

任何人都可以帮助解决这种有效的方式吗?

递归方法代码:

cou=0
n,k,m = input().split()
out_ = solve(k,m,n)
print(out_)


def getCombinations(data,n,k,m):
    val=[0]*k
    combiUtil(data,n,k,0,val,0,m)

def combiUtil(data,n,k,ind,val,i,m):
    global cou
    if(ind==k):
        if(sum(val)%m==0):
            cou+=1
        return
    if(i>=n):
        return
    val[ind]=data[i]
    combiUtil(data,n,k,ind+1,val,i+1,m)
    combiUtil(data,n,k,ind,val,i+1,m)

def solve(k,m,n):
    global cou
    k=int(k)
    m=int(m)
    n=int(n)
    ans =0
    data = [j for j in range(1,n+1)]
    getCombinations(data,n,k,m)   
    return cou%(10**9+7)

【问题讨论】:

  • 这看起来更像是一个滑动窗口问题,因为它说 大小为 k 的子数组
  • 你的问题陈述说数组是[0 to N-1],但代码暗示[1 to N],对吗?
  • 当前的解决方案都是指数级的,你要求 小于nmk - 所以O(nmk) 对你来说是不可行的?
  • 这是哪里来的?是否在网上某个地方进行测试?

标签: python arrays list algorithm data-structures


【解决方案1】:

如果您查看“尝试所有组合”蛮力解决方案的时间复杂度,它等于 O((N choose K) * K) = O(K * N^K),因为有 N choose K 方法可以从 1 to N 中选择 K 不同的整数,并且评估他们的总和需要K-1 加法。除了 NK 的微小值之外,这是天文数字。

更好的解决方案:动态规划

动态编程是一种更快、更简单的解决方案。我们可以把它写成一个 3D 动态规划问题:

Let dp[i][j][r], 0 <= i <= N; 0 <= j <= K; 0 <= r < M
be the number of combinations of j ints from [1, 2, ..., i] 
with sum congruent to r modulo M. We want dp[N][K][0]

dp[i][j][r] = 1 if i == j == r == 0
              0 if i == 0 and (j /= 0 or r /= 0)
              1 if j == 0 and r == 0
              0 if j == 0 and r /= 0
              dp[i-1][j][r] + dp[i-1][j-1][(r-i) % M] otherwise

公式中添加了很多边缘情况,但最重要的方面是最后一种情况:我们的动态规划子问题最多依赖于 2 个其他子问题,因此总运行时间是我们的 DP 数组的大小:@987654330 @。这是一个 Python 实现:

def get_combinations_dp(max_part_size: int, total_num_parts: int, mod: int) -> int:
    BIG_MOD = 10 ** 9 + 7

    # Optimization if no partitions exist
    if total_num_parts > max_part_size:
        return 0

    largest_sum = ((max_part_size * (max_part_size + 1) // 2)
                   - ((max_part_size - total_num_parts) *
                      (max_part_size - total_num_parts + 1) // 2))
    # Optimization if largest partition sum still smaller than mod
    if largest_sum < mod:
        return 0

    dp = [[0 for _ in range(mod)] for _ in range(total_num_parts + 1)]
    dp[0][0] = 1

    for curr_max_part in range(1, max_part_size + 1):
        for curr_num_parts in reversed(range(0, total_num_parts)):
            for rem in range(mod):
                dp[curr_num_parts + 1][(rem + curr_max_part) % mod] += dp[curr_num_parts][rem]
                dp[curr_num_parts + 1][(rem + curr_max_part) % mod] %= BIG_MOD
    return dp[total_num_parts][0]

参数为N, K, M,重命名为max_part_size, total_num_parts, mod,如果没有分区,则进行一些可选的预检查以立即返回0

现在,假设您想比O(nmk) 做得更好。在这里,事情变得棘手。如果你想做得更好,我能想象的唯一方法是找到这些分区的生成函数,并使用 FFT 或其他快速多项式乘法模10**9 + 7。为了开始研究如何做到这一点,我建议在数学 stackexchange 上使用this thread,它涉及根据更知名的分区号精确计算这些分区,其生成函数是已知的。即便如此,我也找不到任何关于这个生成函数是否具有简短表示的信息,并且直接使用分区号并不能提高 O(nmk) 的复杂性。

使用组合数学

如果您仍想使用这种动态编程方法,可以使用组合数学进行小修改,当N 大于M*K 时可能会渐近更快:它运行时间O((M*K)^2),不依赖于N。我们的想法是使用我们的 DP 公式,但不是从 [1, ... N] 中选择 K 个不同的整数,而是从 [0, ... M-1] 中选择 K 个(可能是重复的)残基类别。

这是如何工作的?首先,我们需要计算 [1, ... N] 中有多少 int 属于每个残基类 i mod M。拨打此号码R[i],拨打0 &lt;= i &lt; M。你可以这样计算

R[i] = floor(N/M) + (1 if 0 < i <= N%M else 0)

现在我们可以写一个稍微修改的动态规划定义和公式:

Let dp[i][j][r], 0 <= i <= M; 0 <= j <= K; 0 <= r < M
be the number of combinations with replacement of j ints from 
residue classes [0, 1, ... i-1] with sum congruent to r modulo M. 
We want dp[M][K][0]:

dp[i][j][r] = 1 if i == j == r == 0
              0 if i == 0 and (j /= 0 or r /= 0)
              0 if i < 0 or j < 0
              F(i, j, r) otherwise

F(i, j, r) = Sum from p = 0 to min(R[i], j) of:
(R[i] choose p) * dp[i-1][j-p][(r - i*p) % M]

【讨论】:

    【解决方案2】:

    @kcsquared 解决方案的三个 NumPy 版本,可以在 10 秒的时间限制内轻松解决最坏的情况:

    def numpy1(n, k, m):
        dp = np.zeros((k+1, m), np.int32)
        dp[0][0] = 1
        for i in range(1, n+1):
            dp[1:,] += dp[:-1, (np.arange(m) + i) % m]
            dp %= 10**9 + 7
        return dp[k][0]
    
    def numpy2(n, k, m):
        dp = np.zeros((k+1, m), np.int32)
        dp[0][0] = 1
        i = range(m)
        for _ in range(n):
            i = np.roll(i, 1)
            dp[1:,] += dp[:-1, i]
            dp %= 10**9 + 7
        return dp[k][0]
    
    def numpy3(n, k, m):
        dp = np.zeros((k+1, m), np.int32)
        dp[0][0] = 1
        for i in range(n):
            dp[1:,] += np.roll(dp[:-1,], i, axis=1)
            dp %= 10**9 + 7
        return dp[k][0]
    

    小型、中型和最坏情况的基准:

    n = 19   k = 11   m = 13
    -------------------------------------------------------
    22856.8 μs  23409.3 μs  23421.4 μs  naive
      496.9 μs    500.2 μs    524.7 μs  dtjc
      918.6 μs    928.6 μs    936.3 μs  kcsquared
      173.5 μs    183.6 μs    191.9 μs  numpy1
      402.2 μs    403.5 μs    411.1 μs  numpy2
      297.8 μs    318.4 μs    320.1 μs  numpy3
    
    n = 200   k = 100   m = 200
    -------------------------------------------------------
    2033.6 ms  2177.3 ms  2178.1 ms  dtjc
    1410.6 ms  1420.2 ms  1430.5 ms  kcsquared
      19.5 ms    19.8 ms    20.3 ms  numpy1
      22.5 ms    22.9 ms    23.0 ms  numpy2
      26.8 ms    27.3 ms    27.3 ms  numpy3
    
    n = 1000   k = 100   m = 1000
    -------------------------------------------------------
    508.0 ms  516.1 ms  519.2 ms  numpy1
    518.3 ms  518.8 ms  526.3 ms  numpy2
    495.1 ms  496.4 ms  499.2 ms  numpy3
    

    基准代码 (Try it online!):

    from timeit import repeat
    from itertools import combinations
    from functools import lru_cache
    import numpy as np
    
    def naive(n, k, m):
        return sum(sum(combi) % m == 0
                   for combi in combinations(range(1, n+1), k)) % (10**9 + 7)
    
    @lru_cache(None)
    def dtjc(n, k, m, r=0):
        if k > n:
            return 0
        if k == 0:
            return 1 if r == 0 else 0
        return (dtjc(n-1, k, m, r) + dtjc(n-1, k-1, m, (r+n) % m)) % (10**9 + 7)
    
    def kcsquared(max_part_size: int, total_num_parts: int, mod: int) -> int:
        BIG_MOD = 10 ** 9 + 7
    
        # Optimization if no partitions exist
        if total_num_parts > max_part_size:
            return 0
    
        largest_sum = ((max_part_size * (max_part_size + 1) // 2)
                       - ((max_part_size - total_num_parts) *
                          (max_part_size - total_num_parts + 1) // 2))
        # Optimization if largest partition sum still smaller than mod
        if largest_sum < mod:
            return 0
    
        dp = [[0 for _ in range(mod)] for _ in range(total_num_parts + 1)]
        dp[0][0] = 1
    
        for curr_max_part in range(1, max_part_size + 1):
            for curr_num_parts in reversed(range(0, total_num_parts)):
                for rem in range(mod):
                    dp[curr_num_parts + 1][(rem + curr_max_part) % mod] += dp[curr_num_parts][rem]
                    dp[curr_num_parts + 1][(rem + curr_max_part) % mod] %= BIG_MOD
        return dp[total_num_parts][0]
    
    def numpy1(n, k, m):
        dp = np.zeros((k+1, m), np.int32)
        dp[0][0] = 1
        for i in range(1, n+1):
            dp[1:,] += dp[:-1, (np.arange(m) + i) % m]
            dp %= 10**9 + 7
        return dp[k][0]
    
    def numpy2(n, k, m):
        dp = np.zeros((k+1, m), np.int32)
        dp[0][0] = 1
        i = range(m)
        for _ in range(n):
            i = np.roll(i, 1)
            dp[1:,] += dp[:-1, i]
            dp %= 10**9 + 7
        return dp[k][0]
    
    def numpy3(n, k, m):
        dp = np.zeros((k+1, m), np.int32)
        dp[0][0] = 1
        for i in range(n):
            dp[1:,] += np.roll(dp[:-1,], i, axis=1)
            dp %= 10**9 + 7
        return dp[k][0]
    
    def test(args, solutions, number, format_time):
        print('n = %d   k = %d   m = %d' % args)
        print('-' * 55)
        for _ in range(1):
            results = set()
            for func in solutions:
                times = sorted(repeat(lambda: dtjc.cache_clear() or results.add(func(*args)), number=number))[:3]
                print(*(format_time(t / number) for t in times), func.__name__)
            print('results set:', results)
            assert len(results) == 1
            print()
    
    test((19, 11, 13),
         [naive, dtjc, kcsquared, numpy1, numpy2, numpy3],
         10,
         lambda time: '%7.1f μs ' % (time * 1e6))
    test((200, 100, 200),
         [dtjc, kcsquared, numpy1, numpy2, numpy3],
         1,
         lambda time: '%6.1f ms ' % (time * 1e3))
    test((1000, 100, 1000),
         [numpy1, numpy2, numpy3],
         1,
         lambda time: '%5.1f ms ' % (time * 1e3))
    

    【讨论】:

    • 这很酷;我想知道自上而下和自下而上 DP 之间的速度盈亏平衡点在哪里。我没有花太多精力优化这个 DP 解决方案 - 我确实尝试使用来自 the math stack thread 的想法,我链接甚至实现了参考教科书 Ruskey 的 Combinatorial Generation 中的算法。但是它要慢得多,并且作者暗示没有什么渐近更快的算法是已知的,所以我停在那里。
    • @kcsquared 是的,我还没有查看其他内容,但是 NumPy 版本让它看起来非常简单(尤其是第二个和第三个),我想也许有类似通过表示来计算这个的东西这些操作作为矩阵乘法然后进行快速矩阵幂左右,类似于如何以这种方式计算斐波那契数。但即使这样可行,它也需要更大的指数,即更大的 n,才能有益。鉴于 O(nmk) 显然足以使用正确的“语言”,我怀疑这就是问题作者预期的复杂性。
    • 完全可以想象,基于矩阵的 DP 解决方案可以提供更好的缩放行为。然而,看看所有计算分区的公式和技术,生成函数技术似乎更受欢迎。由于这个问题是计算受限分区,因此计算分区的常见注意事项似乎适用 - 与快速近似相比,封闭形式的解决方案很少见,并且递归关系和 GF 是已知用于精确计算的最佳方法。
    • @kcsquared 我不会假装我理解所有这些,甚至几乎不记得生成函数这个术语:-)。顺便说一句,我不会对我的自上而下的 DP 与你的自下而上的 DP 的盈亏平衡点给予太大的重视,不确定它们有多大的可比性。我的内存至少更差,因为我占用了 O(nmk) 空间。我没有努力,只是想在阅读其他人的解决方案之前自己解决得足够好。主要是因为我懒得剪掉它。
    • @kcsquared 哦,还把它留在那里进行正确性检查,总是很好地相互验证独立编写的解决方案的结果。而且我喜欢这种“级联”验证方式,从小案例开始,包括一个明显正确的幼稚解决方案,然后只用足够快的解决方案就升级到更大的案例。构建一种“信任链”。
    【解决方案3】:

    我希望你可能已经解决了这个问题。不过,我还是会为那些觉得有帮助的人回答这个问题。

    您曾尝试自己获取组合,但您可以使用库来获取所有可能的组合,然后简单地遍历并检查条件。如果目的不仅仅是为了学习,使用已经可用的代码总是可以的。

    无论如何,看看代码。谢谢。

    from itertools import combinations
    def getCombi(n, k, m):
        count = 0
        #Required Combinations
        reqcombis = []
        array = [i for i in range(1, n+1)]
        #getting all possible combinations
        totalcombins = combinations(array, k)
        for i in totalcombins:
            if sum(i) % m == 0 and sum(i) <= n:
                count+=1
                reqcombis.append(i) 
        return count, reqcombis
    
    if __name__ == "__main__":
        n, k, m = input().split(",")
        n, k, m = int(n), int(k), int(m)
        print(getCombi(n, k, m))
    

    【讨论】:

    • combinations(array, k) 返回 O(N^K) 元素。在给定N &lt;= 1000K &lt;= 100 的约束下,该算法将迭代1000^100 元素,这远远超过了宇宙中的原子数量。
    猜你喜欢
    • 2012-10-16
    • 2012-07-31
    • 2022-01-05
    • 1970-01-01
    • 1970-01-01
    • 2013-05-12
    • 1970-01-01
    • 2013-08-06
    • 2015-12-22
    相关资源
    最近更新 更多