我尝试了三种不同的方法来解决这个问题——优化的蛮力、动态编程方法和贪心算法。前两个无法处理n > 17 的输入,但生成了最优解,因此我可以使用它们来验证贪心方法的平均性能。我将首先从动态规划方法开始,然后描述贪婪的方法。
动态编程
首先,请注意,如果我们确定(1, 2, 3, 4) 和(5, 6, 7, 8) 的总和小于(3, 4, 5, 6) 和(1, 2, 7, 8),那么您的最佳解决方案绝对不能同时包含(3, 4, 5, 6) 和(1, 2, 7, 8)-因为您可以将它们换成前者,并且金额较小。扩展这个逻辑,(a, b, c, d) 和 (e, f, g, h) 的最佳组合将导致所有x0, x1, x2, x3, x4, x5, x6, x7 组合的和最小,因此我们可以排除所有其他组合。
利用这些知识,我们可以通过暴力破解x0, x1, x2, x3 的所有组合的总和,将集合[0, n) 中的所有x0, x1, x2, x3, x4, x5, x6, x7 组合映射到它们的最小总和。然后,我们可以使用这些映射从x0, x1, x2, x3, x4, x5, x6, x7 和x0, x1, x2, x3 对中重复x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11 的过程。我们重复这个过程,直到我们获得x0, x1 ... x_(4*m-1) 的所有最小总和,然后我们对其进行迭代以找到最小总和。
def dp_solve(const_dict, n, m):
lookup = {comb:(comb,) for comb in const_dict.keys()}
keys = set(range(n))
for size in range(8, 4 * m + 1, 4):
for key_total in combinations(keys, size):
key_set = set(key_total)
min_keys = (key_total[:4], key_total[4:])
min_val = const_dict[min_keys[0]] + const_dict[min_keys[1]]
key1, key2 = min(zip(combinations(key_total, 4), reversed(list(combinations(key_total, size - 4)))), key=lambda x:const_dict[x[0]]+const_dict[x[1]])
k = tuple(sorted(x for x in key1 + key2))
const_dict[k] = const_dict[key1] + const_dict[key2]
lookup[k] = lookup[key1] + lookup[key2]
key, val = min(((key, val) for key, val in const_dict.items() if len(key) == 4 * m), key=lambda x: x[1])
return lookup[key], val
诚然,这个实现相当粗糙,因为我不断地进行微优化,希望能够让它足够快,而不必切换到贪婪的方法。
贪婪
这可能是您关心的,因为它可以快速处理相当大的输入,并且非常准确。
首先为部分和构建一个列表,然后通过增加值开始迭代字典中的元素。对于每个元素,找到所有不与其键产生任何冲突的部分和,并将它们“组合”成一个新的部分和,并附加到列表中。在这样做的过程中,您构建了一个最小部分和列表,该列表可以从字典中的最小 k 值创建。为了加快这一切,我使用哈希集来快速检查哪些部分和包含相同的键对。
在“快速”贪婪方法中,您将在找到密钥长度为 4 * m(或等效地为 m 4 元组)的部分总和时中止。根据我的经验,这通常会产生相当好的结果,但如果需要,我想添加一些逻辑以使其更准确。为此,我添加了两个因素-
-
extra_runs - 它规定了在中断之前需要多少额外的迭代来寻找更好的解决方案
-
check_factor - 指定当前搜索“深度”的倍数,以向前扫描 single 新整数,该整数会在当前状态下创建更好的解决方案。这与上面的不同之处在于它不会“保留”每个检查的新整数 - 它只会快速求和以查看它是否创建了一个新的最小值。这大大加快了速度,但代价是其他 m - 1 4 元组必须已经存在于其中一个部分和中。
结合起来,这些检查似乎总能找到真正的最小总和,代价是运行时间延长了大约 5 倍(尽管仍然相当快)。要禁用它们,只需为这两个因素传递 0。
def greedy_solve(const_dict, n, m, extra_runs=10, check_factor=2):
pairs = sorted(const_dict.items(), key=lambda x: x[1])
lookup = [set([]) for _ in range(n)]
nset = set([])
min_sums = []
min_key, min_val = None, None
for i, (pkey, pval) in enumerate(pairs):
valid = set(nset)
for x in pkey:
valid -= lookup[x]
lookup[x].add(len(min_sums))
nset.add(len(min_sums))
min_sums.append(((pkey,), pval))
for x in pkey:
lookup[x].update(range(len(min_sums), len(min_sums) + len(valid)))
for idx in valid:
comb, val = min_sums[idx]
for key in comb:
for x in key:
lookup[x].add(len(min_sums))
nset.add(len(min_sums))
min_sums.append((comb + (pkey,), val + pval))
if len(comb) == m - 1 and (not min_key or min_val > val + pval):
min_key, min_val = min_sums[-1]
if min_key:
if not extra_runs: break
extra_runs -= 1
for pkey, pval in pairs[:int(check_factor*i)]:
valid = set(nset)
for x in pkey:
valid -= lookup[x]
for idx in valid:
comb, val = min_sums[idx]
if len(comb) < m - 1:
nset.remove(idx)
elif min_val > val + pval:
min_key, min_val = comb + (pkey,), val + pval
return min_key, min_val
我针对n < 36 和m < 9 进行了测试,它似乎运行得相当快(最坏的情况下只需几秒钟即可完成)。我想它应该很快适用于您的情况12 <= n <= 24。