【问题标题】:Speed up array query in Numpy/Python加速 Numpy/Python 中的数组查询
【发布时间】:2012-03-08 06:52:49
【问题描述】:

我有一个点数组(称为 points),由约 30000 个 x、y 和 z 值组成。我还有一个单独的点数组(称为 vertices),大约 40000 个 x、y 和 z 值。后一个数组索引了一些边长为 size 的立方体的左下角。 我想找出哪些点位于哪些立方体中,以及每个立方体中有多少个点。我写了一个循环来做到这一点,它的工作原理是这样的:

for i in xrange(len(vertices)):        
    cube=((vertices[i,0]<= points[:,0]) & 
    (points[:,0]<(vertices[i,0]+size)) & 
    (vertices[i,1]<= points[:,1]) & 
    (points[:,1] < (vertices[i,1]+size)) &
    (vertices[i,2]<= points[:,2]) & 
    (points[:,2] < (vertices[i,2]+size))
    )
    numpoints[i]=len(points[cube])

(循环对单个立方体进行排序,“立方体”创建一个布尔索引数组。)然后我将 points[cube] 存储在某处,但这并不是让我慢下来的原因;这是“cube=”的创建。

我想加快这个循环(在 macbook pro 上完成需要几十秒)。我尝试重写C中的“cube=”部分,如下:

for i in xrange(len(vertices)):
    cube=zeros(pp, dtype=bool)
    code="""
            for (int j=0; j<pp; ++j){

                if (vertices(i,0)<= points(j,0))
                 if (points(j,0) < (vertices(i,0)+size))
                  if (vertices(i,1)<= points(j,1))
                   if (points(j,1) < (vertices(i,1)+size))
                    if (vertices(i,2)<= points(j,2))
                     if (points(j,2) < (vertices(i,2)+size))
                      cube(j)=1;
            }
        return_val = 1;"""

    weave.inline(code,
    ['vertices', 'points','size','pp','cube', 'i']) 
    numpoints[i]=len(points[cube])

这将其加快了两倍多一点。在 C 中重写 both 循环实际上只比原始的 numpy-only 版本快一点,因为频繁引用数组对象是跟踪哪些点在哪些立方体中所必需的。我怀疑有可能更快地做到这一点,而且我错过了一些东西。谁能建议如何加快速度?我是 numpy/python 新手,在此先感谢。

【问题讨论】:

    标签: python arrays performance numpy vectorization


    【解决方案1】:

    您可以使用 scipy.spatial.cKDTree 来加速这种计算。

    代码如下:

    import time
    import numpy as np
    
    #### create some sample data ####
    np.random.seed(1)
    
    V_NUM = 6000
    P_NUM = 8000
    
    size = 0.1
    
    vertices = np.random.rand(V_NUM, 3)
    points = np.random.rand(P_NUM, 3)
    
    numpoints = np.zeros(V_NUM, np.int32)
    
    #### brute force ####
    start = time.clock()
    for i in xrange(len(vertices)):        
        cube=((vertices[i,0]<= points[:,0]) & 
        (points[:,0]<(vertices[i,0]+size)) & 
        (vertices[i,1]<= points[:,1]) & 
        (points[:,1] < (vertices[i,1]+size)) &
        (vertices[i,2]<= points[:,2]) & 
        (points[:,2] < (vertices[i,2]+size))
        )
        numpoints[i]=len(points[cube])
    
    print time.clock() - start
    
    #### KDTree ####
    from scipy.spatial import cKDTree
    center_vertices = vertices + [[size/2, size/2, size/2]]
    start = time.clock()
    tree_points = cKDTree(points)
    _, result = tree_points.query(center_vertices, k=100, p = np.inf, distance_upper_bound=size/2)
    numpoints2 = np.zeros(V_NUM, np.int32)
    for i, neighbors in enumerate(result):
        numpoints2[i] = np.sum(neighbors!=P_NUM)
    
    print time.clock() - start
    print np.all(numpoints == numpoints2)
    
    • 首先将立体角位置更改为中心位置。

    center_vertices = vertices + [[size/2, size/2, size/2]]

    • 从点创建 cKDTree

    tree_points = cKDTree(points)

    • 做查询,k是要返回的最近邻的个数,p=np.inf表示最大坐标差距离,distance_upper_bound是最大距离。

    _, result = tree_points.query(center_vertices, k=100, p = np.inf, distance_upper_bound=size/2)

    输出是:

    2.04113164434
    0.11087783696
    True
    

    如果一个立方体中有超过 100 个点,您可以在 for 循环中通过 neighbors[-1] == P_NUM 进行检查,并对这些顶点执行 k=1000 查询。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2018-04-21
      • 2011-10-03
      • 1970-01-01
      • 2017-08-21
      • 1970-01-01
      • 1970-01-01
      • 2017-12-23
      • 2016-04-28
      相关资源
      最近更新 更多