维基百科提供了蹩脚的解释和不理想的算法。但让我们以它为起点。
首先让我们来看看回溯算法。与其将矩阵的单元格“按某种顺序”排列,不如将第一行中的所有内容,然后第二行中的所有内容,然后第三行中的所有内容,依此类推。显然这会奏效。
现在让我们稍微修改回溯算法。我们不会逐个单元格地进行,而是逐行进行。所以我们列出了n choose n/2 可能的行,其中一半是 0,一半是 1。然后有一个递归函数,看起来像这样:
def count_0_1_matrices(n, filled_rows=None):
if filled_rows is None:
filled_rows = []
if some_column_exceeds_threshold(n, filled_rows):
# Cannot have more than n/2 0s or 1s in any column
return 0
else:
answer = 0
for row in possible_rows(n):
answer = answer + count_0_1_matrices(n, filled_rows + [row])
return answer
这是一个与我们以前一样的回溯算法。我们一次只做整行,而不是单元格。
但请注意,我们传递的信息超出了我们的需要。无需传递行的确切排列。我们只需要知道剩余的每一列中需要多少个 1。所以我们可以让算法看起来更像这样:
def count_0_1_matrices(n, still_needed=None):
if still_needed is None:
still_needed = [int(n/2) for _ in range(n)]
# Did we overrun any column?
for i in still_needed:
if i < 0:
return 0
# Did we reach the end of our matrix?
if 0 == sum(still_needed):
return 1
# Calculate the answer by recursion.
answer = 0
for row in possible_rows(n):
next_still_needed = [still_needed[i] - row[i] for i in range(n)]
answer = answer + count_0_1_matrices(n, next_still_needed)
return answer
这个版本几乎就是维基百科版本中的递归函数。主要区别在于我们的基本情况是,在每一行完成后,我们什么都不需要,而 Wikipedia 会让我们编写基本情况,以便在每一行完成后检查最后一行。
要从这个到一个自上而下的 DP,你只需要记住这个函数。在 Python 中,您可以通过定义然后添加 @memoize 装饰器来实现。像这样:
from functools import wraps
def memoize(func):
cache = {}
@wraps(func)
def wrap(*args):
if args not in cache:
cache[args] = func(*args)
return cache[args]
return wrap
但还记得我批评过维基百科的算法吗?让我们开始改进它吧!第一个大的改进就是这个。您是否注意到still_needed 元素的顺序无关紧要,只是它们的值?因此,仅对元素进行排序就会阻止您为每个排列分别进行计算。 (可能有很多排列!)
@memoize
def count_0_1_matrices(n, still_needed=None):
if still_needed is None:
still_needed = [int(n/2) for _ in range(n)]
# Did we overrun any column?
for i in still_needed:
if i < 0:
return 0
# Did we reach the end of our matrix?
if 0 == sum(still_needed):
return 1
# Calculate the answer by recursion.
answer = 0
for row in possible_rows(n):
next_still_needed = [still_needed[i] - row[i] for i in range(n)]
answer = answer + count_0_1_matrices(n, sorted(next_still_needed))
return answer
那个无害的小sorted 看起来并不重要,但它可以节省很多工作!现在我们知道still_needed 总是被排序的,我们可以简化我们是否完成的检查,以及是否有任何事情是负面的。另外,我们可以添加一个简单的检查来过滤掉我们在列中有太多 0 的情况。
@memoize
def count_0_1_matrices(n, still_needed=None):
if still_needed is None:
still_needed = [int(n/2) for _ in range(n)]
# Did we overrun any column?
if still_needed[-1] < 0:
return 0
total = sum(still_needed)
if 0 == total:
# We reached the end of our matrix.
return 1
elif total*2/n < still_needed[0]:
# We have total*2/n rows left, but won't get enough 1s for a
# column.
return 0
# Calculate the answer by recursion.
answer = 0
for row in possible_rows(n):
next_still_needed = [still_needed[i] - row[i] for i in range(n)]
answer = answer + count_0_1_matrices(n, sorted(next_still_needed))
return answer
而且,假设您实现了 possible_rows,这应该比 Wikipedia 提供的既有效又高效。
=====
这是一个完整的工作实现。在我的机器上,它在 4 秒内计算出第 6 项。
#! /usr/bin/env python
from sys import argv
from functools import wraps
def memoize(func):
cache = {}
@wraps(func)
def wrap(*args):
if args not in cache:
cache[args] = func(*args)
return cache[args]
return wrap
@memoize
def count_0_1_matrices(n, still_needed=None):
if 0 == n:
return 1
if still_needed is None:
still_needed = [int(n/2) for _ in range(n)]
# Did we overrun any column?
if still_needed[0] < 0:
return 0
total = sum(still_needed)
if 0 == total:
# We reached the end of our matrix.
return 1
elif total*2/n < still_needed[-1]:
# We have total*2/n rows left, but won't get enough 1s for a
# column.
return 0
# Calculate the answer by recursion.
answer = 0
for row in possible_rows(n):
next_still_needed = [still_needed[i] - row[i] for i in range(n)]
answer = answer + count_0_1_matrices(n, tuple(sorted(next_still_needed)))
return answer
@memoize
def possible_rows(n):
return [row for row in _possible_rows(n, n/2)]
def _possible_rows(n, k):
if 0 == n:
yield tuple()
else:
if k < n:
for row in _possible_rows(n-1, k):
yield tuple(row + (0,))
if 0 < k:
for row in _possible_rows(n-1, k-1):
yield tuple(row + (1,))
n = 2
if 1 < len(argv):
n = int(argv[1])
print(count_0_1_matrices(2*n)))