【问题标题】:How to speed up this code with Cython?如何使用 Cython 加速此代码?
【发布时间】:2015-03-15 17:14:48
【问题描述】:
def check(barcode):
    if not len(barcode) == 13:
        return False

    digits = list(map(int, barcode))

    # 1. Add the values of the digits in the 
    # even-numbered positions: 2, 4, 6, etc.
    even_sum = digits[1] + digits[3] + digits[5] + digits[7] + digits[9] + digits[11]

    # 2. Multiply this result by 3.
    even_sum_three = even_sum * 3

    # 3. Add the values of the digits in the 
    # odd-numbered positions: 1, 3, 5, etc.
    odd_sum = digits[0] + digits[2] + digits[4] + digits[6] + digits[8] + digits[10]

    # 4. Sum the results of steps 2 and 3.
    total_sum = even_sum_three + odd_sum

    # 5. The check character is the smallest number which,
    # when added to the result in step 4, produces a multiple of 10.
    next_ten = (math.ceil(total_sum / 10)) * 10
    check_digit = next_ten - total_sum

    # 6. if the check digit and the last digit of the 
    # barcode are OK return true;
    if check_digit == digits[12]:
        return True
    else:
        return False

我有上面的代码。它计算一个 ean 数的校验和。我试图通过使用 Cython 来加速它。此代码在循环中使用。可变条形码是一个字符串。

但我没能提高速度。我试过了:

  • np.array(list(map(int, barcode))) - 这让它稍微慢了一点
  • np.ceil() 而不是 math.ceil() - 也让它稍微慢了一点
  • cdef bool def check(.... - 也没有帮助

我还能做什么?

【问题讨论】:

  • barcode 的初始格式是什么?
  • @orlp:一个 13 个字符的 UPC。
  • @IgnacioVazquez-Abrams 我的意思是在 Python 中,没有称为 UPC 的对象类型 :) 它是字符串,还是?
  • @orlp: UPCs 是 13 个数字的序列,由 12 个识别数字和一个校验数字组成。
  • @IgnacioVazquez-Abrams 我在问输入类型是否应为字符串、字符元组等 - 在运行时。我不知道 UPC 是什么。

标签: python performance cython


【解决方案1】:

您可以从优化函数本身开始。看看

    # 5. The check character is the smallest number which,
    # when added to the result in step 4, produces a multiple of 10.
    next_ten = (math.ceil(total_sum / 10)) * 10
    check_digit = next_ten - total_sum

您还可以执行next_ten = 10 - total_sum % 10 之类的操作来获取检查字符。在那里你摆脱了math.ceil。另外,这段代码也适用于 Python 2。所以最后,折叠所有代码,你最终可能只是做

def check(barcode):
    if not len(barcode) == 13:
        return False
    return not (3*sum(map(int,barcode[1::2]))+sum(map(int,barcode[::2])))%10

当然,您也可以使用列表理解或itertools,但在我的测试中并没有太大的区别。但“最”有效的方法是根本不使用任何花哨的 Python 编程技术:

def check(bc):
    if not len(bc) == 13:
        return False
    return not (3*(int(bc[1])+int(bc[3])+int(bc[5])+int(bc[7])+int(bc[9])+int(bc[11]))+int(bc[0])+int(bc[2])+int(bc[4])+int(bc[6])+int(bc[8])+int(bc[10])+int(bc[12]))%10

最后的注释:通过我的所有测试,Python 2 更快。

【讨论】:

    猜你喜欢
    • 2023-03-21
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2015-07-28
    相关资源
    最近更新 更多