根据我的经验,有两种最快的方法可以在 3D 中找到邻居列表。一种是使用用 C++ 或 Cython 编写的最简单的双循环代码(在我的例子中,两者都是)。它在 N^2 中运行,但对于小型系统来说非常快。另一种方法是使用线性时间算法。 Scipy ckdtree 是一个不错的选择,但有局限性。来自分子动力学软件的邻居列表查找器功能最强大,但很难包装,而且初始化时间可能很慢。
下面我比较四种方法:
- 朴素的 cython 代码
-
OpenMM 周围的包装器(很难安装,见下文)
-
Scipy.spatial.ckdtree
scipy.spatial.distance.pdist
测试设置:n 点分散在一个矩形框中,体积密度为 0.2。系统大小从 10 到 1000000(一百万)个粒子不等。联系半径取自0.5, 1, 2, 4, 7, 10。请注意,因为密度为 0.2,所以在接触半径为 0.5 时,每个粒子平均有大约 0.1 个接触,在 1 = 0.8,在 2 = 6.4,在 10 - 大约 800!对于小型系统,接触发现重复了几次,对于 >30k 粒子的系统进行了一次。如果每次调用的时间超过 5 秒,则运行中止。
设置:双至强 2687Wv3、128GB RAM、Ubuntu 14.04、python 2.7.11、scipy 0.16.0、numpy 1.10.1。没有任何代码使用并行优化(OpenMM 除外,尽管并行部分运行得如此之快,以至于在 CPU 图上甚至都没有注意到,但大部分时间都花在了从 OpenMM 传输数据上)。
结果:请注意,下面的图是对数尺度的,分布在 6 个数量级上。即使是很小的视觉差异实际上也可能是 10 倍。
对于少于 1000 个粒子的系统,Cython 代码总是更快。但是,1000 个粒子后的结果取决于接触半径。 pdist 实现总是比 cython 慢,并且占用更多内存,因为它显式地创建了一个距离矩阵,由于 sqrt 而速度很慢。
- 在小接触半径(每个粒子ckdtree 是所有系统尺寸的理想选择。
- 在中等接触半径下,(每个粒子 5-50 个接触)naive cython 实现是最好的,最多 10000 个粒子,然后 OpenMM 开始以大约几个数量级获胜,但
ckdtree 的性能仅差 3-10 倍
- 在高接触半径(每个粒子>200 个接触)的情况下,幼稚的方法最多可以处理 100k 或 1M 个粒子,那么 OpenMM 可能会胜出。
安装 OpenMM 非常棘手;您可以在http://bitbucket.org/mirnylab/openmm-polymer 文件“contactmaps.py”或自述文件中阅读更多内容。然而,下面的结果表明,对于 N>100k 个粒子,每个粒子只有 5-50 个接触是有利的。
Cython 代码如下:
import numpy as np
cimport numpy as np
cimport cython
cdef extern from "<vector>" namespace "std":
cdef cppclass vector[T]:
cppclass iterator:
T operator*()
iterator operator++()
bint operator==(iterator)
bint operator!=(iterator)
vector()
void push_back(T&)
T& operator[](int)
T& at(int)
iterator begin()
iterator end()
np.import_array() # initialize C API to call PyArray_SimpleNewFromData
cdef public api tonumpyarray(int* data, long long size) with gil:
if not (data and size >= 0): raise ValueError
cdef np.npy_intp dims = size
#NOTE: it doesn't take ownership of `data`. You must free `data` yourself
return np.PyArray_SimpleNewFromData(1, &dims, np.NPY_INT, <void*>data)
@cython.boundscheck(False)
@cython.wraparound(False)
def contactsCython(inArray, cutoff):
inArray = np.asarray(inArray, dtype = np.float64, order = "C")
cdef int N = len(inArray)
cdef np.ndarray[np.double_t, ndim = 2] data = inArray
cdef int j,i
cdef double curdist
cdef double cutoff2 = cutoff * cutoff # IMPORTANT to avoid slow sqrt calculation
cdef vector[int] contacts1
cdef vector[int] contacts2
for i in range(N):
for j in range(i+1, N):
curdist = (data[i,0] - data[j,0]) **2 +(data[i,1] - data[j,1]) **2 + (data[i,2] - data[j,2]) **2
if curdist < cutoff2:
contacts1.push_back(i)
contacts2.push_back(j)
cdef int M = len(contacts1)
cdef np.ndarray[np.int32_t, ndim = 2] contacts = np.zeros((M,2), dtype = np.int32)
for i in range(M):
contacts[i,0] = contacts1[i]
contacts[i,1] = contacts2[i]
return contacts
Cython 代码的编译(或 makefile):
cython --cplus fastContacts.pyx
g++ -g -march=native -Ofast -fpic -c fastContacts.cpp -o fastContacts.o `python-config --includes`
g++ -g -march=native -Ofast -shared -o fastContacts.so fastContacts.o `python-config --libs`
测试代码:
from __future__ import print_function, division
import signal
import time
from contextlib import contextmanager
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from scipy.spatial import ckdtree
from scipy.spatial.distance import pdist
from contactmaps import giveContactsOpenMM # remove this unless you have OpenMM and openmm-polymer libraries installed
from fastContacts import contactsCython
class TimeoutException(Exception): pass
@contextmanager
def time_limit(seconds):
def signal_handler(signum, frame):
raise TimeoutException("Timed out!")
signal.signal(signal.SIGALRM, signal_handler)
signal.alarm(seconds)
try:
yield
finally:
signal.alarm(0)
matplotlib.rcParams.update({'font.size': 8})
def close_pairs_ckdtree(X, max_d):
tree = ckdtree.cKDTree(X)
pairs = tree.query_pairs(max_d)
return np.array(list(pairs))
def condensed_to_pair_indices(n, k):
x = n - (4. * n ** 2 - 4 * n - 8 * k + 1) ** .5 / 2 - .5
i = x.astype(int)
j = k + i * (i + 3 - 2 * n) / 2 + 1
return np.array([i, j]).T
def close_pairs_pdist(X, max_d):
d = pdist(X)
k = (d < max_d).nonzero()[0]
return condensed_to_pair_indices(X.shape[0], k)
a = np.random.random((100, 3)) * 3 # test set
methods = {"cython": contactsCython, "ckdtree": close_pairs_ckdtree, "OpenMM": giveContactsOpenMM,
"pdist": close_pairs_pdist}
# checking that each method gives the same value
allUniqueInds = []
for ind, method in methods.items():
contacts = method(a, 1)
uniqueInds = contacts[:, 0] + 100 * contacts[:, 1] # unique index of each contacts
allUniqueInds.append(np.sort(uniqueInds)) # adding sorted unique conatcts
for j in allUniqueInds:
assert np.allclose(j, allUniqueInds[0])
# now actually doing testing
repeats = [30,30,30, 30, 30, 20, 20, 10, 5, 3, 2 , 1, 1, 1]
sizes = [10,30,100, 200, 300, 500, 1000, 2000, 3000, 10000, 30000, 100000, 300000, 1000000]
systems = [[np.random.random((n, 3)) * ((n / 0.2) ** 0.333333) for k in range(repeat)] for n, repeat in
zip(sizes, repeats)]
for j, radius in enumerate([0.5, 1, 2, 4, 7, 10]):
plt.subplot(2, 3, j + 1)
plt.title("Radius = {0}; {1:.2f} cont per particle".format(radius, 0.2 * (4 / 3 * np.pi * radius ** 3)))
times = {i: [] for i in methods}
for name, method in methods.items():
for n, system, repeat in zip(sizes, systems, repeats):
if name == "pdist" and n > 30000:
break # memory issues
st = time.time()
try:
with time_limit(5 * repeat):
for ind in range(repeat):
k = len(method(system[ind], radius))
except:
print("Run aborted")
break
end = time.time()
mytime = (end - st) / repeat
times[name].append((n, mytime))
print("{0} radius={1} n={2} time={3} repeat={4} contPerParticle={5}".format(name, radius, n, mytime,repeat, 2 * k / n))
for name in sorted(times.keys()):
plt.plot(*zip(*times[name]), label=name)
plt.xscale("log")
plt.yscale("log")
plt.xlabel("System size")
plt.ylabel("Time (seconds)")
plt.legend(loc=0)
plt.show()