您在此语句中隐式创建了几个嵌套循环:
agg = np.array([[x, np.sum(buy[buy[:, 0]>=x][:, 1]), np.sum(sell[sell[:, 0]<=x][:, 1])] for x in price_grid])
尽管它们主要以矢量化格式执行,但当价格网格很大或订单数量很大时,这会让你吃不消。
您可以通过合并订单以线性方式执行此操作。我在下面使用字典。然后,再线性遍历buys 字典以创建您需要的总需求数,这仍然是 O(n) 而不是 O(n^2)。
对于更大的n,这是一个巨大的变化。我对 orig 和 mod(如下)进行了计时,对于 5000 个订单和 10K 的价格网格(您的价值),这个 mod 在没有 numpy 操作的情况下快 100 倍。
注意:如果供应 == 需求恰好在任何价格步长,则 mod 会提前停止,这在逻辑上存在细微差别。 (不清楚这是错误还是功能...... :),但可以很容易地调整逻辑)。我已经展示了它们与时差匹配的(有些罕见的)事件的捕获。
带时间的编辑代码
import numpy as np
import time
MAX_QTY = 10
MIN_QTY = 0
MIN_PX = 1
MAX_PX = 10_000
TICK_SIZE = 1
price_grid = np.arange(MIN_PX, MAX_PX, TICK_SIZE)
def gen_orders(num, price_grid):
qty = np.random.randint(MIN_QTY, MAX_QTY, num)
px = np.random.choice(price_grid, num)
return np.array((px, qty)).T
buy = gen_orders(5000, price_grid)
sell = gen_orders(5000, price_grid)
tic = time.time()
agg = np.array([[x, np.sum(buy[buy[:, 0]>=x][:, 1]), np.sum(sell[sell[:, 0]<=x][:, 1])] for x in price_grid])
matched = agg[agg[:, 1]<agg[:, 2]][0, :] # price_grid is sorted
cleared_px = matched[0]
cleared_qty = np.min(matched[1:])
toc = time.time()
print(f'ORIG: computed clear px: {cleared_px} and qty: {cleared_qty} in {toc-tic:0.6f} sec')
### ALTERNATE ###
# Start the clock again for the mod method...
tic = time.time()
buys = {}
sells = {}
# "bin" the buys by price
for b in buy:
buys[b[0]] = buys.get(b[0], 0) + b[1]
# need to aggregate the demand...
agg_demand = {MAX_PX: buys.get(MAX_PX,0)} # starting point
for px in range(MAX_PX-1, MIN_PX-1, -1): # backfill down to min px
agg_demand[px] = agg_demand[px+1] + buys.get(px, 0)
# "bin" the sells similarly
for s in sell:
sells[s[0]] = sells.get(s[0], 0) + s[1]
# set up the loop
selling_px = MIN_PX
supply = sells.get(selling_px, 0)
demand = agg_demand.get(selling_px, 0)
while demand > supply:
# updates
selling_px += 1
demand = agg_demand.get(selling_px) # update with the pre-computed aggregate demand
supply += sells.get(selling_px, 0) # keep running aggregation of supply
new_cleared_px = selling_px
new_cleared_qty = min(demand, supply)
toc = time.time()
print(f'MOD: computed clear px: {new_cleared_px} and qty: {new_cleared_qty} in {toc-tic:0.6f} sec')
if cleared_px != new_cleared_px or cleared_qty != new_cleared_qty: # somethign wrong...??
print(agg[cleared_px-5:cleared_px+5,:])
输出:
ORIG: computed clear px: 4902 and qty: 11390 in 1.183204 sec
MOD: computed clear px: 4899 and qty: 11398 in 0.020830 sec
[[ 4898 11411 11398]
[ 4899 11398 11398]
[ 4900 11398 11398]
[ 4901 11398 11398]
[ 4902 11390 11398]
[ 4903 11385 11398]
[ 4904 11385 11398]
[ 4905 11385 11398]
[ 4906 11385 11398]
[ 4907 11384 11398]]
[Finished in 1.3s]