【问题标题】:Python tail recursion optimization in this algo?这个算法中的 Python 尾递归优化?
【发布时间】:2020-11-09 20:47:17
【问题描述】:

我遇到了以下问题: 让我们考虑以下编码表:

0 -> A
1 -> B
2 -> C
...
  • 输入:包含整数列表的字符串
  • 输出:int 表示可以从输入编码的字数
  • 示例:
"0" -> 1 (word is A)
"10" -> 2 (words are BA and K)
"121" -> 3 (words are BCB, MB and BV) 

我写了这个算法

import string
sys.setrecursionlimit(100)  # limit the number of recursions to 100
# {0: 'a', 1: 'b', etc.}
dict_values = dict(zip(range(26), string.ascii_lowercase))

def count_split(list_numbers: list) -> int:
    if len(list_numbers) == 0:
        return 0

    elif len(list_numbers) == 1:
        return 1

    elif len(list_numbers) == 2:
        if int(list_numbers[0]) == 0:
            return 1

        elif 10 <= int("".join(list_numbers)) <= 25:
            return 2

        else:
            return 1

    else:  # list_numbers contains at least three elements
        if count_split(list_numbers[:2]) == 2:
            return count_split(list_numbers[1:]) + count_split(list_numbers[2:])

        else:
            return count_split(list_numbers[1:])

def get_nb_words(code: str) -> int:
    return count_split(list(code))

# print(get_nb_words("0124")) -> 3
# print(get_nb_words("124")) -> 3
# print(get_nb_words("368")) -> 1
# print(get_nb_words("322")) -> 2
# print(get_nb_words("12121212121212121212")) -> 10946

令人惊讶的是,这个算法适用于最后一个例子"12121212121212121212"。我预计会超过递归次数,因为在每一步,函数count_split 在列表中被调用两次。因此,调用次数远远超过100(甚至超过1000)!

同时我发现this post on stackoverflow说尾递归在Python中没有优化,所以我有点惊讶!

有人能解释一下为什么这个算法没有超过递归限制吗?

【问题讨论】:

  • 对我来说有点晚了,但我认为你的最大递归 depth (这很重要)只会与输入字符串的长度有关,不是吗?注意,recursionlimit 有点误导,它实际上是调用堆栈的最大深度。
  • 是的,所以我猜递归depth(默认设置为1000)不是函数被调用的次数而是调用函数的输入数量。我做对了吗?
  • 不,这不是调用函数的输入数量。这是调用堆栈的最大深度,还是您的意思?再说一次,对不起,我来晚了
  • 哦,确定代表连续调用的树的深度,对吧?
  • 这不是一棵树,它是一个堆栈

标签: python recursion


【解决方案1】:

您关心递归深度,即调用堆栈的最大深度(高度?)。

这是一种经验方法:(测量深度的代码)

import string, sys
sys.setrecursionlimit(100)  # limit the number of recursions to 100

dict_values = dict(zip(range(26), string.ascii_lowercase))
stack = []
max_depth = 0
def count_split(list_numbers: list) -> int:
    global max_depth
    stack.append(None)
    max_depth = max(max_depth, len(stack))
    if len(list_numbers) == 0:
        return 0

    elif len(list_numbers) == 1:
        return 1

    elif len(list_numbers) == 2:
        if int(list_numbers[0]) == 0:
            return 1

        elif 10 <= int("".join(list_numbers)) <= 25:
            return 2

        else:
            return 1

    else:  # list_numbers contains at least three elements
        if count_split(list_numbers[:2]) == 2:
            stack.pop()
            result = count_split(list_numbers[1:]) + count_split(list_numbers[2:])
            stack.pop(); stack.pop()
            return result

        else:
            result = count_split(list_numbers[1:])
            stack.pop()
            return result

def get_nb_words(code: str) -> int:
    return count_split(list(code))

print(get_nb_words("12121212121212121212"))
print(max_depth) # 20

【讨论】:

    猜你喜欢
    • 2012-11-15
    • 1970-01-01
    • 1970-01-01
    • 2010-12-03
    • 2014-06-23
    • 2012-08-11
    • 2019-10-25
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多