博登,
我很好奇这个函数的性能与您相比如何。
def func2(df):
list2 = []
for r in zip(t['pA'],t['pB'],t['sA'],t['sB']):
if r[0] == r[1]:
list2.append(r[2] + r[3])
if r[0] > r[1]:
list2.append(r[2])
if r[1] > r[0]:
list2.append(r[3])
df['bS'] = list2
return df
这是我在我的系统上运行的和相应的结果。我的 go to function 是一个使用 iterrows() 的 for 循环。在检查并意识到它比您的np.where() 慢后,我尝试了zip(),性能似乎略快。
import numpy as np
import pandas as pd
import timeit
N = 1000*1000
t = pd.DataFrame({'pA' : np.random.randint(0,5,size = N),
'pB' : np.random.randint(0,5,size = N),
'sA' : np.random.randint(0,100,size = N),
'sB' : np.random.randint(0,100,size = N)})
t['bS'] = np.where(t['pA'] == t['pB'],
t['sA'] + t['sB'],
np.where(t['pA'] > t['pB'],
t['sA'], t['sB']))
def func1(df):
list1 = []
for index, row in df.iterrows():
if row['pA'] == row['pB']:
list1.append(row['sA'] + row['sB'])
if row['pA'] > row['pB']:
list1.append(row['sA'])
if row['pB'] > row['pA']:
list1.append(row['sB'])
df['bS'] = list1
return df
def func2(df):
list2 = []
for r in zip(t['pA'],t['pB'],t['sA'],t['sB']):
if r[0] == r[1]:
list2.append(r[2] + r[3])
if r[0] > r[1]:
list2.append(r[2])
if r[1] > r[0]:
list2.append(r[3])
df['bS'] = list2
return df
setup = '''
import numpy as np
import pandas as pd
import timeit
N = 10
t = pd.DataFrame({'pA' : np.random.randint(0,5,size = N),
'pB' : np.random.randint(0,5,size = N),
'sA' : np.random.randint(0,100,size = N),
'sB' : np.random.randint(0,100,size = N)})
t['bS'] = np.where(t['pA'] == t['pB'],
t['sA'] + t['sB'],
np.where(t['pA'] > t['pB'],
t['sA'], t['sB']))
def func1(df):
list1 = []
for index, row in df.iterrows():
if row['pA'] == row['pB']:
list1.append(row['sA'] + row['sB'])
if row['pA'] > row['pB']:
list1.append(row['sA'])
if row['pB'] > row['pA']:
list1.append(row['sB'])
df['bS'] = list1
return df
def func2(df):
list2 = []
for r in zip(t['pA'],t['pB'],t['sA'],t['sB']):
if r[0] == r[1]:
list2.append(r[2] + r[3])
if r[0] > r[1]:
list2.append(r[2])
if r[1] > r[0]:
list2.append(r[3])
df['bS'] = list2
return df
'''
timeit.timeit("t['bS'] = np.where(t['pA'] == t['pB'], t['sA'] + t['sB'],np.where(t['pA'] > t['pB'], t['sA'], t['sB']))", setup = setup, number = 1000)
Out[0]: 0.6907481750604347
timeit.timeit("func1(t)", setup = setup, number = 1000)
Out[1]: 1.7969895842306869
timeit.timeit("func2(t)", setup = setup, number = 1000)
Out[2]: 0.40988909450607025