【问题标题】:How to count number of combinations?如何计算组合的数量?
【发布时间】:2014-02-08 15:43:31
【问题描述】:

我有一个问题,我想计算满足以下条件的组合的数量:

 a < b < a+d < c < b+d

其中a, b, c 是列表的元素,d 是固定增量。

这是一个普通的实现:

def count(l, d):
    s = 0
    for a in l:
        for b in l:
            for c in l:
                if a < b < a + d < c < b + d:
                    s += 1
    return s

这是一个测试:

def testCount():
    l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
    assert(32 == count(l, 4)) # Gone through everything by hand.

问题

如何加快速度?我正在查看 200 万的列表大小。

补充信息

我正在处理 [-pi, pi] 范围内的浮点数。例如,这限制了a &lt; 0

到目前为止我所拥有的:

我有一些实现,我在其中构建用于bc 的索引。但是,以下代码在某些情况下会失败。 (即这是错误的)。

def count(l, d=pi):
    low = lower(l, d)
    high = upper(l, d)
    s = 0
    for indA in range(len(l)):
            for indB in range(indA+1, low[indA]+1):
                    s += low[indB] + 1 - high[indA]
    return s

def lower(l, d=pi):
    '''Returns ind, s.t l[ind[i]] < l[i] + d and l[ind[i]+1] >= l[i] + d, for all i
    Input must be sorted!
    '''
    ind = []
    x = 0
    length = len(l)
    for  elem in l:
        while x < length and l[x] < elem + d:
            x += 1
        if l[x-1] < elem + d:
            ind.append(x-1)
        else:
            assert(x == length)
            ind.append(x)
    return ind


def upper(l, d=pi):
    ''' Returns first index where l[i] > l + d'''
    ind = []
    x = 0
    length = len(l)
    for elem in l:
        while x < length and l[x] <= elem + d:
            x += 1
        ind.append(x)
    return ind

原来的问题

最初的问题来自一个著名的数学/计算机科学竞赛。比赛要求您不要在网上发布解决方案。但那是两周前的事了。

我可以用这个函数生成列表:

def points(n):
    x = 1
    y = 1
    for _ in range(n):
        x = (x * 1248) % 32323
        y = (y * 8421) % 30103
        yield atan2(x - 16161, y - 15051)

def C(n):
    angles = points(n)
    angles.sort()
    return count(angles, pi)

【问题讨论】:

  • 列表总是排序的吗?
  • 最初没有排序,但我们可以做到。
  • 对我来说,这是一项测试:a &lt; (c - d) &lt; b。因此,首先,您永远不必查看小于或等于ab。这会大大减少你的比较。
  • 因为对于给定范围内的整数,您可以使用段树。你不能用花车做到这一点。一般来说,对于此类问题,请提供您拥有的所有信息。
  • 你能链接到问题陈述或简要描述它吗?我认为即使在你给出的一般公式中也可以在多线性时间内解决这个问题,但我认为比赛不需要这样做

标签: python algorithm combinations


【解决方案1】:

有一种方法可以产生O(n log n) 算法来解决您的问题。让X 成为一组值。现在让我们修复b。设A_b 为值集{ x in X: b - d &lt; x &lt; b }C_b 为值集{ x in X: b &lt; x &lt; b + d }。如果我们能快速找到|{ (x,y) : A_b X C_b | y &gt; x + d }|,我们就解决了问题。

如果我们对X 进行排序,我们可以将A_bC_b 表示为指向已排序数组的指针,因为它们是连续的。如果我们以非递减顺序处理b 候选,我们可以因此使用sliding window algorithm 维护这些集合。它是这样的:

  1. 排序X。让X = { x_1, x_2, ..., x_n }x_1 &lt;= x_2 &lt;= ... &lt;= x_n
  2. 设置left = i = 1 并设置right 以便C_b = { x_{i + 1}, ..., x_right }。设置count = 0
  3. i1 迭代到n。在每次迭代中,我们找出有效三元组 (a,b,c)b = x_i 的数量。为此,请尽可能增加leftright,以使A_b = { x_left, ..., x_{i-1} }C_b = { x_{i + 1}, ..., x_right } 仍然有效。在此过程中,您基本上从虚构集合A_bC_b 中添加和删除元素。 如果您删除或添加一个元素到其中一个集合,请检查您添加或销毁了多少对 (a, c)c &gt; a + da 来自 A_bc 来自 C_b(这可以通过在另一组中进行简单的二进制搜索)。相应地更新 count 以使不变量 count = |{ (x,y) : A_b X C_b | y &gt; x + d }| 仍然成立。
  4. 在每次迭代中总结count 的值。这是最终结果。

复杂度为O(n log n)

如果你想用这个算法解决欧拉问题,你必须避免浮点问题。我建议使用仅使用整数算术的自定义比较函数(使用 2D 矢量几何)按角度对点进行排序。也可以仅使用整数运算来实现|a-b| &lt; d 比较。此外,由于您正在以 2*pi 为模工作,因此您可能必须引入每个角度 a 的三个副本:a - 2*piaa + 2*pi。然后,您只需在[0, 2*pi) 范围内查找b 并将结果除以三。

UPDATE OP 在 Python 中实现了这个算法。显然它包含一些错误,但它展示了总体思路:

def count(X, d):
    X.sort()
    count = 0
    s = 0
    length = len(X)
    a_l = 0
    a_r = 1
    c_l = 0
    c_r = 0
    for b in X:
        if X[a_r-1] < b:
            # find boundaries of A s.t. b -d < a < b
            while a_r < length and X[a_r] < b:
                a_r += 1  # This adds an element to A_b. 
                ind = bisect_right(X, X[a_r-1]+d, c_l, c_r)
                if c_l <= ind < c_r:
                    count += (ind - c_l)
            while a_l < length and X[a_l] <= b - d:
                a_l += 1  # This removes an element from A_b
                ind = bisect_right(X, X[a_l-1]+d, c_l, c_r)
                if c_l <= ind < c_r:
                    count -= (c_r - ind)
            # Find boundaries of C s.t. b < c < b + d
            while c_l < length and X[c_l] <= b:
                c_l += 1  # this removes an element from C_b
                ind = bisect_left(X, X[c_l-1]-d, a_l, a_r)
                if a_l <= ind <= a_r:
                    count -= (ind - a_l)
            while c_r  < length and X[c_r] < b + d:
                c_r += 1 # this adds an element to C_b
                ind = bisect_left(X, X[c_r-1]-d, a_l, a_r)
                if a_l <= ind <= a_r:
                    count += (ind - a_l)
            s += count
    return s

【讨论】:

  • 不遵循这个,因为它不是用python编写的。回复:“...介绍每个角度 a 的三个副本:a - pi、a 和 a + pi。”如果 2\*pi 取模,a - pia + pi 不一样吗?
  • @hughdbrown:首先,我的意思是a - 2*pia + 2*pi。当然可以,但是该算法不知道模数,因此我们需要对其进行模拟以使用滑动窗口。我现在不能用 Python 写这个,因为我正在度假,而且我也不认为有必要查看源代码来理解总体思路(相反)。
  • @hughdbrown:实际上最后一段并不真正适用于这里提出的问题,而是更适用于潜在的欧拉项目问题的解决方案。
  • 好的,据我了解,在O(n) 时间迭代b 的值,ac 的值可以在O(logn) 时间选择?我已经证明,可以在恒定时间内计算任意范围内的元素计数(累积和查找表)。
  • 您在任何时候迭代b 并维护a 的有效值集和c 的有效值集(即那些acbd 范围内)。这很简单。现在我们想知道在任何时候有多少 (a,c) 对满足 c - a &gt; d。我们也可以保持这个数量。每次我们插入a 候选集时,我们都会查看另一个集中有多少c 候选集有c - a &gt; d。我们可以通过二分查找(log n)来做到这一点。如果我们插入c 候选集,反过来也是一样的。删除也非常相似。
【解决方案2】:
from bisect import bisect_left, bisect_right
from collections import Counter

def count(l, d):
    # cdef long bleft, bright, cleft, cright, ccount, s
    s = 0

    # Find the unique elements and their counts
    cc = Counter(l)

    l = sorted(cc.keys())

    # Generate a cumulative sum array
    cumulative = [0] * (len(l) + 1)
    for i, key in enumerate(l, start=1):
        cumulative[i] = cumulative[i-1] + cc[key]

    # Pregenerate all the left and right lookups
    lefthand = [bisect_right(l, a + d) for a in l]
    righthand = [bisect_left(l, a + d) for a in l]

    aright = bisect_left(l, l[-1] - d)
    for ai in range(len(l)):
        bleft = ai + 1
        # Search only the values of a that have a+d in range
        if bleft > aright:
            break
        # This finds b such that a < b < a + d.
        bright = righthand[ai]
        for bi in range(bleft, bright):
            # This finds the range for c such that a+d < c < b+d.
            cleft = lefthand[ai]
            cright = righthand[bi]
            if cleft != cright:
                # Find the count of c elements in the range cleft..cright.
                ccount = cumulative[cright] - cumulative[cleft]
                s += cc[l[ai]] * cc[l[bi]] * ccount
    return s

def testCount():
    l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
    result = count(l, 4)
    assert(32 == result)

testCount()
  1. 去除重复的相同值

  2. 仅迭代值所需的范围

  3. 使用两个索引的累积计数来消除c上的循环

  4. x + d 上的缓存查找

这不再是 O(n^3) 而是更像 O(n^2)`。

这显然还没有扩展到 200 万。以下是我使用 cython 加速执行的较小浮点数据集(即很少或没有重复)的时间:

50: 0:00:00.157849 seconds
100: 0:00:00.003752 seconds
200: 0:00:00.022494 seconds
400: 0:00:00.071192 seconds
800: 0:00:00.253750 seconds
1600: 0:00:00.951133 seconds
3200: 0:00:03.508596 seconds
6400: 0:00:10.869102 seconds
12800: 0:00:55.986448 seconds

这是我的基准测试代码(不包括上面的操作代码):

from math import atan2, pi

def points(n):
    x, y = 1, 1
    for _ in range(n):
        x = (x * 1248) % 32323
        y = (y * 8421) % 30103
        yield atan2(x - 16161, y - 15051)

def C(n):
    angles = sorted(points(n))
    return count(angles, pi)

def test_large():
    from datetime import datetime
    for n in [50, 100, 200, 400, 800, 1600, 3200, 6400, 12800]:
        s = datetime.now()
        C(n)
        elapsed = datetime.now() - s
        print("{1}: {0} seconds".format(elapsed, n))

if __name__ == '__main__':
    testCount()
    test_large()

【讨论】:

    【解决方案3】:

    由于l 已排序且a &lt; b &lt; c 必须为真,因此您可以使用itertools.combinations() 减少循环次数:

    sum(1 for a, b, c in combinations(l, r=3) if a < b < a + d < c < b + d)
    

    查看组合只会将此循环减少到 816 次迭代。

    >>> l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
    >>> d = 4
    >>> sum(1 for a, b, c in combinations(l, r=3))
    816
    >>> sum(1 for a, b, c in combinations(l, r=3) if a < b < a + d < c < b + d)
    32
    

    a &lt; b 测试是多余的。

    【讨论】:

    • 我认为 OP 正在寻找算法改进。这只是编写循环约 10^19 次迭代的更紧凑的方式。
    • @DSM:我刚刚添加了一个。
    • 不过,我认为这不会扩展到 OP 的 2M 案例。
    • @MartijnPieters 这非常优雅,但仍然是 O(n^3)。我想我可以根据a+d &lt; c &lt; b+d、@DSM 的范围做一些事情,它甚至不能扩展到 40k。
    • @Unapiedra:进一步减少了所需的迭代次数。
    【解决方案4】:

    1) 为了减少每个级别的迭代次数,您可以从列表中删除不通过每个级别的条件的元素
    2) 使用setcollections.counter 可以通过删除重复项来减少迭代:

    from collections import Counter
    def count(l, d):
        n = Counter(l)
        l = set(l)
        s = 0
        for a in l:
            for b in (i for i in l if a < i < a+d):
                for c in (i for i in l if a+d < i < b+d):
                    s += (n[a] * n[b] * n[c])
        return s
    
    >>> l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
    >>> count(l, 4)
    32
    

    为您的版本测试的迭代次数(a、b、c):

    >>> count1(l, 4)
    18 324 5832
    

    我的版本:

    >>> count2(l, 4)
    9 16 7
    

    【讨论】:

    • 重要的是计算与原始代码相比,您正在执行多少次内部迭代。
    • 还有:i for i in l if a &lt; (i-d) &lt; b
    • @hughdbrown 内部迭代次数完全取决于列表内容...两种算法都是 O(n^3) 但我的版本应该更快
    • 而且两者都太慢了。至少需要 O(n^(2-eps)。
    • @NiklasB。你知道更好的解决方案吗?为什么是-1?
    【解决方案5】:

    基本思路是:

    1. 摆脱重复的相同值
    2. 让每个值只在它必须迭代的范围内迭代。

    这样你就可以无条件增加s,性能大概是O(N),N是数组的大小。

    import collections
    
    def count(l, d):
        s = 0
        # at first we get rid of repeated items
        counter = collections.Counter(l)
        # sort the list
        uniq = sorted(set(l))
        n = len(uniq)
        # kad is the index of the first element > a+d
        kad = 0 
        # ka is the index of a
        for ka in range(n):
            a = uniq[ka]
            while uniq[kad] <= a+d:
                kad += 1
                if kad == n:
                    return s
    
            for kb in range( ka+1, kad ):
                # b only runs in the range [a..a+d)
                b = uniq[kb]
                if b  >= a+d:
                    break
                for kc in range( kad, n ):
                    # c only rund from (a+d..b+d)
                    c = uniq[kc]
                    if c >= b+d:
                        break
                    print( a, b, c )
                    s += counter[a] * counter[b] * counter[c]
        return s
    

    编辑:对不起,我搞砸了提交。固定。

    【讨论】:

    • 这绝对不是 O(n)。
    • 这取决于您对 n 采取的措施。是O(M),M是最终结果。
    • 这是一种特殊的算法估计方法。无论实现如何,斐波那契计算都会产生一个数字,但这并不是说所有实现都是O(1),对吧?查看输出的数量并不是估算运行时间的合适方法。查看输入的大小是更常用的方法。
    • 在这种情况下,复杂度估计告诉您没有比这里描述的更快的基于计数的算法。
    • @pentadecagon Niklas B 和 hughdbrown 的算法都快得多。
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2011-11-03
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多