【问题标题】:Algorithm for subset-sum with subtraction带减法的子集和算法
【发布时间】:2015-01-31 03:32:32
【问题描述】:

我有一个子集和问题,您可以在其中添加或减去项。例如,如果我有五个项(1、2、3、4、5),我想知道我可以通过多少种方式添加/减去这些项来得到 7:

  • 3 + 4
  • 2 + 5
  • 1 + 2 + 4
  • 5 - 2 + 4

我用Python写了一些代码,但是一旦有很多术语就很慢:

import itertools
from collections import OrderedDict

sum_answer = 1
terms = {"T1": 1, "T2": -2, "T3": 3, "T4": -4, "T5": 5}
numlist = [v for v in terms.values()]
zerlist = [x for x in itertools.repeat(0, len(numlist))]
opslist = [item for item in itertools.product((1, -1), repeat=len(numlist))]


res_list = []
for i in range(1, len(numlist)):
    combos = itertools.combinations(numlist, i)

    for x in combos:
        prnlist = list(x) + zerlist[:len(numlist) - len(x)]

        for o in opslist:
            operators = list(o)
            result = []
            res_sum = 0

            for t in range(len(prnlist)):
                if operators[t] == 1:
                    ops = "+"
                else:
                    ops = "-"
                if prnlist[t] != 0:
                    result += [ops, list(terms.keys())[list(terms.values()).index(prnlist[t])]]
                res_sum += operators[t] * prnlist[t]

            if sum_answer == res_sum:
                res_list += [" ".join(result)]

for ans in OrderedDict.fromkeys(res_list).keys():
    print(ans)

我意识到一百万个嵌套循环非常低效,那么我可以使用更好的算法来加速任何部分吗?

【问题讨论】:

  • 既然你有一个可行的解决方案,你最好把它发布到 CodeReview 而不是这个网站:codereview.stackexchange.com
  • 您真的想要一个所有解决方案的列表,还是只是一个计数?
  • @PatrickBeeson 我不同意,他有一个可行的解决方案,但速度很慢。这是一个有待解决的客观问题。
  • @HughBothwell:目的是尝试找出数据库中的哪些字段用于计算报表的总计,所以我需要所有解决方案。

标签: python algorithm subset-sum


【解决方案1】:

类似于“常规”子集求和问题 - 在您使用 DP 解决问题的情况下,您也将在这里使用它,但需要有更多的可能性 - 减少当前元素而不是添加它。

f(0,i) = 1               //successive subset
f(x,0) = 0    x>0        //failure subset
f(x,i) = f(x+element[i],i-1) + f(x-element[i],i-1) + f(x,i-1)
                                 ^^^
               This is the added option for substraction

将其转换为自下而上的 DP 解决方案时,您需要创建一个大小为 (SUM+1) * (2n+1) 的矩阵,其中 SUM 是所有元素的总和,n 是元素的数量。

【讨论】:

  • OP 的实现似乎打印了实际的解决方案,而不仅仅是计算它们。
【解决方案2】:

我认为您的想法大体上是正确的:生成术语的每个组合,进行求和,看看它是否成功。不过你可以优化你的代码。

问题在于,一旦生成 1 + 2,您会发现它与您想要的总和不匹配,然后将其丢弃。但是,如果您将4 添加到它,这是一个解决方案。但是,在生成 1 + 2 + 4 之前,您将无法获得该解决方案,届时您将从头开始计算总和。您还可以为每个组合从头开始添加运算符,出于同样的原因,这也做了很多冗余工作。

您还使用了很多列表操作,这可能会很慢。

我会这样做:

def solve(terms_list, stack, current_s, desired_s):
    if len(terms_list) == 0:
        if current_s == desired_s:
            print(stack)
        return

    for w in [0, 1, -1]: # ignore term (0), add it (1), subtract it (-1)
        stack.append(w)
        solve(terms_list[1:], stack, current_s + w * terms_list[0], desired_s)
        stack.pop()

例如,初始调用是solve([1,2,3,4,5], [], 0, 7)

请注意,这很复杂 O(3^n)(有点,请继续阅读),因为每个术语都可以添加、减去或忽略。

我的实际实现的复杂度是O(n*3^n),因为递归调用会复制terms_list 参数。但是,您可以避免这种情况,但我想让代码更简单,并将其留作练习。您也可以避免在打印之前构造实际表达式,而是以增量方式构造它,但您可能需要更多参数。

但是,O(3^n) 仍然很多,无论你做什么,你都不应该期望它对大型 n 做得很好。

【讨论】:

    【解决方案3】:

    现在您正试图从一行中暴力破解所有可能的字段值组合(然后针对其他行对每个组合进行有效性测试)。

    我想你有很多行数据可以玩;我建议您通过获取一堆行(至少与您要求解的字段一样多)并应用像numpy.linalg.lstsq 这样的近似矩阵求解器来利用这一点。

    这有很多重要的优势:

    • 让您能够巧妙地处理舍入错误问题(如果您的任何字段为非整数,则必须这样做)

    • 允许您轻松处理系数不在{-1, 0, 1} 中的字段,即系数可能类似于0.12 的税率

    • 使用完全受支持的代码,您无需调试或维护

    • 使用高度优化的代码,运行速度会快得多(**很可能,取决于编译 numpy 时使用的选项)

    • 具有更好的时间复杂度(类似于 O(n ** 2.8) 而不是 O(3 ** n)),这意味着它应该扩展到更多的领域

    所以,一些测试数据:

    import numpy as np
    
    # generate test data
    def make_test_data(coeffs, mean=20.0, base=0.05):
        w      = len(coeffs)    # number of fields
        h      = int(1.5 * w)   # number of rows of data
        rows   = np.random.exponential(mean - base, (h, w)) + base
        totals = data.dot(coeffs)
        return rows.round(2), totals.round(2)
    

    这给了我们类似的东西

    >>> rows, totals = make_test_data([0, 1, 1, 0, -1, 0.12])
    
    >>> print(rows)
    [[  1.45  17.63  22.54   5.54  37.06   1.47]
     [ 11.71  80.43  26.43  18.48  11.08   8.8 ]
     [ 16.09  11.34  63.74   3.31  13.2   13.35]
     [ 11.96  12.17  10.23   8.15  73.3    0.42]
     [  4.03   8.01  20.84  21.46   2.76  18.98]
     [  3.24   6.6   35.06  23.17   9.03   8.58]
     [ 25.05  33.72   6.82   0.49  46.76  12.21]
     [ 70.27   1.48  23.05   0.69  31.11  43.13]
     [  9.04  10.45  15.08   4.32  52.94  11.13]]
    
    >>> print(totals)
    [  3.29  96.84  63.48 -50.85  28.37  33.66  -4.75  -1.4  -26.07]
    

    以及求解器代码,

    >>> sol = np.linalg.lstsq(rows, totals)    # one line!
    
    >>> print(sol[0])       # note the solutions are not *exact*
    [ -1.485730e-04  1.000072e+00  9.999334e-01 -7.992023e-05 -9.999552e-01  1.203379e-01]
    
    >>> print(sol[0].round(3))      # but they are *very* close
    [ 0.    1.    1.    0.   -1.    0.12]
    

    【讨论】:

      猜你喜欢
      • 2011-05-20
      • 2021-12-30
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2023-04-05
      • 1970-01-01
      相关资源
      最近更新 更多