【问题标题】:How to optimize Mean Square Displacement for several particles in two dimensions in python?如何在python中优化二维中几个粒子的均方位移?
【发布时间】:2021-12-12 17:28:45
【问题描述】:

我想计算几个粒子的均方位移,定义为:

其中i是粒子的索引,Dt是时间间隔,t是时间,vec(x)是粒子在二维中的位置。我们对所有可能的时间进行平均t

我已经设法用 numpy.请注意,pos 是具有三个轴的 np.array(particles, time, coordinate)

import numpy as np
import matplotlib.pyplot as plt
import time

#Initialize data
np.random.seed(1)
nTime = 10**4
nParticles = 3
pos = np.zeros((nParticles, nTime, 2)) #Axis: particles, times, coordinates
for t in range(1, nTime):
    pos[:, t, :] = pos[:, t-1, :] + ( np.random.random((nParticles, 2)) - 0.5)

#MSD calculation
def MSD_direct(pos):
    Dt_r = np.arange(1, pos.shape[1]-1)
    MSD = np.empty((nParticles, len(Dt_r)))
    dMSD = np.empty((nParticles,len(Dt_r)))
    for k, Dt in enumerate(Dt_r):
        SD = np.sum((pos[:, Dt:,:] - pos[:, 0:-Dt,:])**2, axis = -1)
        MSD[:,k] = np.mean( SD , axis = 1)
        dMSD[:,k] = np.std( SD, axis = 1 ) / np.sqrt(SD.shape[1])

    return Dt_r, MSD, dMSD

start_time = time.time()
Dt_r, MSD_d, dMSD_d = MSD_direct(pos)
print("MSD_direct -- Time: %s s\n" % (time.time() - start_time))

#Plots
plt.figure()
for i in range(nParticles):
    plt.plot(pos[i,:,0])    
plt.xlabel('t')
plt.ylabel('x')
plt.savefig('pos_x.png', dpi = 300)

plt.figure()
for i in range(nParticles):
    plt.plot(pos[i,:,1])    
plt.xlabel('t')
plt.ylabel('y')
plt.savefig('pos_y.png', dpi = 300)

plt.figure()
for i in range(nParticles):
    plt.fill_between(Dt_r, MSD_d[i,:]+dMSD_d[i,:], MSD_d[i,:] - dMSD_d[i,:], alpha = 0.5)
    plt.plot(Dt_r, MSD_d[i,:])
plt.xlabel('Dt')
plt.ylabel('MSD')
plt.savefig('MSD.png', dpi = 300)

输出

MSD_direct -- Time: 7.793087720870972 s

但是,我想尽可能优化此代码还有Dt 的循环,我不知道如何删除它并使用 numpy 完全矢量化程序。


我还使用 numba 重写了计算,与之前的代码相比,改进了 因素二。我想知道是否还有可能进一步改进它。

import numba as nb
@nb.jit(fastmath=True,parallel=True)
def MSD_numba(pos):
    Dt_r = np.arange(1, pos.shape[1]-1)
    MSD = np.empty((nParticles, len(Dt_r)))
    dMSD = np.empty((nParticles,len(Dt_r)))
    for i in nb.prange(nParticles):  
        for Dt in Dt_r:
            SD = (pos[i, Dt:, 0] - pos[i, 0:-Dt, 0])**2 + (pos[i, Dt:, 1] - pos[i, 0:-Dt, 1])**2
            MSD[i, Dt-1] = np.mean( SD )
            dMSD[i, Dt-1] = np.std( SD ) / np.sqrt(len(SD)) 
    return Dt_r, MSD, dMSD

start_time = time.time()
Dt_r, MSD_n, dMSD_n = MSD_numba(pos)
print("MSD_numba -- Time: %s s" % (time.time() - start_time))
print("MSD_numba -- All close to MSD_direct: %r\n" %(np.allclose(MSD_n, MSD_d) )  )

输出:

MSD_numba -- Time: 4.520232915878296 s
MSD_numba -- All close to MSD_direct: True

注意:此类问题已经在多个帖子中提出,但他们使用不同的定义(Mean square displacement pythonMean squared displacementMean square displacement for n-dim matrix python),他们没有答案(Mean square displacement in Python),他们只是使用一个粒子(Computing mean square displacement using python and FFTMean square displacement of a 1d random walk in python),他们使用熊猫(Vectorized calculation of Mean Square Displacement in PythonSpeedup MSD calculation in Python)等

【问题讨论】:

  • 我很确定这条线np.linalg.norm(pos[:, Dt:] - pos[:, 0:-Dt], axis = -1)**2 有两个错误:1)为什么是平方范数?除了进行平方的 norm 之外,定义中没有平方。 2) pos[:, Dt:] - pos[:, 0:-Dt] 应该类似于 np.diff(pos[:, Dt:])
  • @dankal444 我不明白你为什么认为我计算错了 MSD。 np.linalg.norm(pos[:, Dt:] - pos[:, 0:-Dt], axis = -1)**2 这行字面上是我在帖子开头显示的定义,它在数字上也等于 np.sum((pos[:, Dt:] - pos[:, 0:-Dt])**2, axis = -1),以防你发现它更清楚。另外,为什么要使用np.diff?我们不是对数组的连续元素做差异,而是间隔Dt的差异。
  • 也许让你感到困惑的是pos 有三个轴:(粒子、时间、坐标)。当我这样做时,规范是尊重坐标轴,因为向量是二维的。
  • 我知道我错了。尽管如此,正如你所说,np.sum((pos[:, Dt:] - pos[:, 0:-Dt])**2, axis = -1) 是等价的并且 更快,我发现这很奇怪,你取平方根(按标准)并立即平方这些数字 - 这让我认为一定有一些错误.
  • @dankal444 你是对的,使用平方和会快一点,我现在更改了问题中的代码。谢谢你们的cmets。

标签: python numpy optimization multidimensional-array vectorization


【解决方案1】:

根据Computing mean square displacement using python and FFT 使用 FFT 变换的答案,我设法将这个计算速度提高了 两个数量级。请注意,pos 是具有三个轴的 np.array(particles, time, coordinate)

def MSD_fft(pos):
    nTime=pos.shape[1]        

    S2 = np.sum ( np.fft.ifft( np.abs(np.fft.fft(pos, n=2*nTime, axis = -2))**2, axis = -2  )[:,:nTime,:].real , axis = -1 ) / (nTime-np.arange(nTime)[None,:] )

    D=np.square(pos).sum(axis=-1)
    D=np.append(D, np.zeros((pos.shape[0], 1)), axis = -1)
    S1 = ( 2 * np.sum(D, axis = -1)[:,None] - np.cumsum( np.insert(D[:,0:-1], 0, 0, axis = -1) + np.flip(D, axis = -1), axis = -1 ) )[:,:-1] / (nTime - np.arange(nTime)[None,:] )

    MSD = S1-2*S2

    Dt_r = np.arange(1, pos.shape[1]-1)
    MSD = MSD[:,Dt_r]
    return Dt_r, MSD

start_time = time.time()
Dt_r, MSD_f = MSD_fft(pos)
print("MSD_fft -- Time: %s s" % (time.time() - start_time))
print("MSD_fft -- All close to MSD_direct: %r\n" %(np.allclose(MSD_f, MSD_d) )  )

输出:

MSD_direct -- Time: 2.1434285640716553 s

MSD_numba -- Time: 1.532573938369751 s
MSD_numba -- All close to MSD_direct: True

MSD_fft -- Time: 0.007384061813354492 s
MSD_fft -- All close to MSD_direct: True

虽然我无法使用这种方法计算误差。但是,如果我们有足够的统计数据,误差应该保持很小。其实在剧情里你是分不出来的。


任何n维数组的广义函数

我将前面的函数概括为任何 n 维数组给出的pos,您只需要指定时间轴和坐标:

def MSD_fft_ax(pos, axis_time, axis_coord):
    nTime=pos.shape[axis_time]        

    S2 = np.sum (  np.fft.ifft( np.abs(np.fft.fft(pos, n=2*nTime, axis = axis_time))**2, axis = axis_time ).take(range(nTime), axis = axis_time).real, axis = axis_coord )
    
    D=np.square(pos).sum(axis=axis_coord)

    if axis_coord % pos.ndim < axis_time % pos.ndim: axis_time -= 1

    shape_t = [nTime if ax==axis_time % D.ndim else 1 for ax, s in enumerate(D.shape)]
    shape_non_t = [1 if ax==axis_time % D.ndim else s for ax, s in enumerate(D.shape)]

    D=np.append(D, np.zeros( shape_non_t ), axis = axis_time)
    S1 = ( 2 * np.sum(D, axis = axis_time).reshape(shape_non_t) - np.cumsum( np.insert(D.take(np.arange(0,nTime), axis=axis_time), 0, 0, axis = axis_time) + np.flip(D, axis = axis_time), axis = axis_time ) ).take(np.arange(0,nTime), axis = axis_time) 

    MSD = ( S1-2*S2 ) / ( nTime-np.arange(nTime).reshape(shape_t) )

    Dt_r = np.arange(1, nTime-1)
    MSD = MSD.take(Dt_r, axis = axis_time)
    return Dt_r, MSD

start_time = time.time()
Dt_r, MSD_fax = MSD_fft_ax(pos, axis_time = 1, axis_coord=-1)
print("MSD_fft_ax -- Time: %s s" % (time.time() - start_time))
print("MSD_fft_ax -- All close to MSD_direct: %r\n" %(np.allclose(MSD_fax, MSD_d) )  )

输出:

MSD_fft_ax -- Time: 0.009054422378540039 s
MSD_fft_ax -- All close to MSD_direct: True

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2011-10-30
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多