TLDR;不,for 循环并不是一概而论的“坏”,至少,并非总是如此。 说某些向量化操作比迭代慢可能更准确,而不是说迭代比某些向量化操作快。了解何时以及为什么是从代码中获得最大性能的关键。简而言之,这些是值得考虑替代矢量化 pandas 函数的情况:
- 当您的数据较小时(...取决于您在做什么),
- 处理
object/mixed dtypes时
- 使用
str/regex 访问器函数时
让我们逐一检查这些情况。
小数据上的迭代与向量化
Pandas 在其 API 设计中遵循 "Convention Over Configuration" 方法。这意味着已经安装了相同的 API 来满足广泛的数据和用例。
当调用 pandas 函数时,函数必须在内部处理以下事情(其中包括),以确保正常工作
- 索引/轴对齐
- 处理混合数据类型
- 处理缺失数据
几乎每个函数都必须在不同程度上处理这些问题,这会带来开销。数字函数的开销较小(例如,Series.add),而字符串函数的开销更大(例如,Series.str.replace)。
另一方面,for 循环比您想象的要快。更好的是list comprehensions(通过for 循环创建列表)更快,因为它们是针对列表创建优化的迭代机制。
列表推导遵循模式
[f(x) for x in seq]
seq 是 pandas 系列或 DataFrame 列。或者,在对多列进行操作时,
[f(x, y) for x, y in zip(seq1, seq2)]
其中seq1 和seq2 是列。
数值比较
考虑一个简单的布尔索引操作。列表理解方法已针对Series.ne (!=) 和query 计时。以下是函数:
# Boolean indexing with Numeric value comparison.
df[df.A != df.B] # vectorized !=
df.query('A != B') # query (numexpr)
df[[x != y for x, y in zip(df.A, df.B)]] # list comp
为简单起见,我使用perfplot 包来运行本文中的所有 timeit 测试。上述操作的时间安排如下:
对于中等大小的 N,列表解析的性能优于 query,甚至对于小 N 的性能优于矢量化不等于比较。不幸的是,列表解析是线性扩展的,因此对于较大的 N,它不会提供太多的性能提升。
注意
值得一提的是,列表理解的大部分好处来自不必担心索引对齐,
但这意味着如果您的代码依赖于索引对齐,
这将打破。在某些情况下,向量化操作在
底层的 NumPy 数组可以被认为是引入了“最好的
两个世界”,允许向量化没有所有不需要的熊猫函数的开销。这意味着你可以将上面的操作重写为
df[df.A.values != df.B.values]
其性能优于 pandas 和列表理解等价物:
NumPy 矢量化超出了本文的范围,但如果性能很重要,它绝对值得考虑。
价值计算
再举一个例子 - 这一次,使用另一个比 for 循环 更快 的 vanilla python 构造 - collections.Counter。一个常见的要求是计算值计数并将结果作为字典返回。这是通过value_counts、np.unique 和Counter 完成的:
# Value Counts comparison.
ser.value_counts(sort=False).to_dict() # value_counts
dict(zip(*np.unique(ser, return_counts=True))) # np.unique
Counter(ser) # Counter
结果更加明显,Counter 在更大范围的小 N (~3500) 上胜过两种矢量化方法。
注意
更多琐事(礼貌@user2357112)。 Counter 是用 C
accelerator 实现的,
所以虽然它仍然必须使用 python 对象而不是
底层 C 数据类型,它仍然比for 循环快。 Python
力量!
当然,从这里得出的结论是性能取决于您的数据和用例。这些示例的重点是说服您不要将这些解决方案排除为合法选项。如果这些仍然不能为您提供所需的性能,那么总会有cython 和numba。让我们将此测试添加到组合中。
from numba import njit, prange
@njit(parallel=True)
def get_mask(x, y):
result = [False] * len(x)
for i in prange(len(x)):
result[i] = x[i] != y[i]
return np.array(result)
df[get_mask(df.A.values, df.B.values)] # numba
Numba 将循环 python 代码的 JIT 编译为非常强大的矢量化代码。了解如何让 numba 发挥作用涉及到学习曲线。
混合/object dtypes 的操作
基于字符串的比较
回顾第一节中的过滤示例,如果要比较的列是字符串怎么办?考虑上面相同的 3 个函数,但输入 DataFrame 转换为字符串。
# Boolean indexing with string value comparison.
df[df.A != df.B] # vectorized !=
df.query('A != B') # query (numexpr)
df[[x != y for x, y in zip(df.A, df.B)]] # list comp
那么,发生了什么变化?这里要注意的是,字符串操作本质上很难向量化。Pandas 将字符串视为对象,所有对对象的操作都会退回到缓慢、循环的实现。
现在,由于这种循环实现被上述所有开销所包围,因此这些解决方案之间存在恒定的量级差异,即使它们的规模相同。
对于可变/复杂对象的操作,没有可比性。列表理解优于所有涉及字典和列表的操作。
按键访问字典值
以下是从字典列中提取值的两个操作的时间安排:map 和列表推导。设置在附录中,标题为“代码片段”。
# Dictionary value extraction.
ser.map(operator.itemgetter('value')) # map
pd.Series([x.get('value') for x in ser]) # list comprehension
位置列表索引
从列列表中提取第 0 个元素(处理异常)、map、str.get accessor method 和列表理解的 3 次操作的计时:
# List positional indexing.
def get_0th(lst):
try:
return lst[0]
# Handle empty lists and NaNs gracefully.
except (IndexError, TypeError):
return np.nan
ser.map(get_0th) # map
ser.str[0] # str accessor
pd.Series([x[0] if len(x) > 0 else np.nan for x in ser]) # list comp
pd.Series([get_0th(x) for x in ser]) # list comp safe
注意
如果索引很重要,您会想要这样做:
pd.Series([...], index=ser.index)
在重构系列时。
列表扁平化
最后一个例子是扁平化列表。这是另一个常见的问题,在这里展示了纯 python 的强大。
# Nested list flattening.
pd.DataFrame(ser.tolist()).stack().reset_index(drop=True) # stack
pd.Series(list(chain.from_iterable(ser.tolist()))) # itertools.chain
pd.Series([y for x in ser for y in x]) # nested list comp
itertools.chain.from_iterable 和嵌套列表推导式都是纯 Python 构造,并且比stack 解决方案具有更好的扩展性。
这些时间强烈表明 pandas 不具备使用混合 dtype 的能力,您可能应该避免使用它。在可能的情况下,数据应以标量值(整数/浮点数/字符串)的形式出现在单独的列中。
最后,这些解决方案的适用性在很大程度上取决于您的数据。因此,最好的办法是在决定使用什么之前对您的数据进行这些操作测试。请注意我没有在这些解决方案上计时apply,因为它会使图表歪斜(是的,就是这么慢)。
正则表达式操作和.str 访问器方法
Pandas 可以应用正则表达式操作,例如 str.contains、str.extract 和 str.extractall,以及其他“矢量化”字符串操作(例如 str.split、str.find、str.translate 等) 在字符串列上。这些函数比列表推导式要慢,并且比其他任何函数都更方便。
预编译正则表达式模式并使用re.compile 迭代数据通常要快得多(另请参阅Is it worth using Python's re.compile?)。等同于str.contains 的列表组合看起来像这样:
p = re.compile(...)
ser2 = pd.Series([x for x in ser if p.search(x)])
或者,
ser2 = ser[[bool(p.search(x)) for x in ser]]
如果你需要处理 NaN,你可以这样做
ser[[bool(p.search(x)) if pd.notnull(x) else False for x in ser]]
相当于str.extract(不带组)的列表组合看起来像:
df['col2'] = [p.search(x).group(0) for x in df['col']]
如果您需要处理不匹配和 NaN,您可以使用自定义函数(更快!):
def matcher(x):
m = p.search(str(x))
if m:
return m.group(0)
return np.nan
df['col2'] = [matcher(x) for x in df['col']]
matcher 函数具有很强的可扩展性。它可以根据需要为每个捕获组返回一个列表。只需提取查询匹配器对象的group 或groups 属性即可。
对于str.extractall,将p.search 更改为p.findall。
字符串提取
考虑一个简单的过滤操作。想法是如果前面有一个大写字母,则提取 4 位数字。
# Extracting strings.
p = re.compile(r'(?<=[A-Z])(\d{4})')
def matcher(x):
m = p.search(x)
if m:
return m.group(0)
return np.nan
ser.str.extract(r'(?<=[A-Z])(\d{4})', expand=False) # str.extract
pd.Series([matcher(x) for x in ser]) # list comprehension
更多示例
完全披露 - 我是下面列出的这些帖子的作者(部分或全部)。
结论
如上例所示,迭代在处理小行数据帧、混合数据类型和正则表达式时大放异彩。
您获得的加速取决于您的数据和您的问题,因此您的里程可能会有所不同。最好的办法是仔细运行测试,看看付出的努力是否值得。
“矢量化”函数以其简单性和可读性而著称,因此,如果性能不重要,您绝对应该更喜欢这些。
另一方面,某些字符串操作处理有利于使用 NumPy 的约束。以下是 NumPy 向量化优于 python 的两个示例:
此外,有时仅通过.values 对底层数组进行操作,而不是在 Series 或 DataFrames 上操作,可以为大多数常见场景提供足够健康的加速(请参阅 中的 Note上面的数字比较部分)。因此,例如df[df.A.values != df.B.values] 将显示即时性能提升超过df[df.A != df.B]。使用.values 可能并不适用于所有情况,但它是一个有用的技巧。
如上所述,由您决定这些解决方案是否值得实施。
附录:代码片段
import perfplot
import operator
import pandas as pd
import numpy as np
import re
from collections import Counter
from itertools import chain
# Boolean indexing with Numeric value comparison.
perfplot.show(
setup=lambda n: pd.DataFrame(np.random.choice(1000, (n, 2)), columns=['A','B']),
kernels=[
lambda df: df[df.A != df.B],
lambda df: df.query('A != B'),
lambda df: df[[x != y for x, y in zip(df.A, df.B)]],
lambda df: df[get_mask(df.A.values, df.B.values)]
],
labels=['vectorized !=', 'query (numexpr)', 'list comp', 'numba'],
n_range=[2**k for k in range(0, 15)],
xlabel='N'
)
# Value Counts comparison.
perfplot.show(
setup=lambda n: pd.Series(np.random.choice(1000, n)),
kernels=[
lambda ser: ser.value_counts(sort=False).to_dict(),
lambda ser: dict(zip(*np.unique(ser, return_counts=True))),
lambda ser: Counter(ser),
],
labels=['value_counts', 'np.unique', 'Counter'],
n_range=[2**k for k in range(0, 15)],
xlabel='N',
equality_check=lambda x, y: dict(x) == dict(y)
)
# Boolean indexing with string value comparison.
perfplot.show(
setup=lambda n: pd.DataFrame(np.random.choice(1000, (n, 2)), columns=['A','B'], dtype=str),
kernels=[
lambda df: df[df.A != df.B],
lambda df: df.query('A != B'),
lambda df: df[[x != y for x, y in zip(df.A, df.B)]],
],
labels=['vectorized !=', 'query (numexpr)', 'list comp'],
n_range=[2**k for k in range(0, 15)],
xlabel='N',
equality_check=None
)
# Dictionary value extraction.
ser1 = pd.Series([{'key': 'abc', 'value': 123}, {'key': 'xyz', 'value': 456}])
perfplot.show(
setup=lambda n: pd.concat([ser1] * n, ignore_index=True),
kernels=[
lambda ser: ser.map(operator.itemgetter('value')),
lambda ser: pd.Series([x.get('value') for x in ser]),
],
labels=['map', 'list comprehension'],
n_range=[2**k for k in range(0, 15)],
xlabel='N',
equality_check=None
)
# List positional indexing.
ser2 = pd.Series([['a', 'b', 'c'], [1, 2], []])
perfplot.show(
setup=lambda n: pd.concat([ser2] * n, ignore_index=True),
kernels=[
lambda ser: ser.map(get_0th),
lambda ser: ser.str[0],
lambda ser: pd.Series([x[0] if len(x) > 0 else np.nan for x in ser]),
lambda ser: pd.Series([get_0th(x) for x in ser]),
],
labels=['map', 'str accessor', 'list comprehension', 'list comp safe'],
n_range=[2**k for k in range(0, 15)],
xlabel='N',
equality_check=None
)
# Nested list flattening.
ser3 = pd.Series([['a', 'b', 'c'], ['d', 'e'], ['f', 'g']])
perfplot.show(
setup=lambda n: pd.concat([ser2] * n, ignore_index=True),
kernels=[
lambda ser: pd.DataFrame(ser.tolist()).stack().reset_index(drop=True),
lambda ser: pd.Series(list(chain.from_iterable(ser.tolist()))),
lambda ser: pd.Series([y for x in ser for y in x]),
],
labels=['stack', 'itertools.chain', 'nested list comp'],
n_range=[2**k for k in range(0, 15)],
xlabel='N',
equality_check=None
)
# Extracting strings.
ser4 = pd.Series(['foo xyz', 'test A1234', 'D3345 xtz'])
perfplot.show(
setup=lambda n: pd.concat([ser4] * n, ignore_index=True),
kernels=[
lambda ser: ser.str.extract(r'(?<=[A-Z])(\d{4})', expand=False),
lambda ser: pd.Series([matcher(x) for x in ser])
],
labels=['str.extract', 'list comprehension'],
n_range=[2**k for k in range(0, 15)],
xlabel='N',
equality_check=None
)