【问题标题】:Calculate Mahalanobis distance using NumPy only仅使用 NumPy 计算马氏距离
【发布时间】:2015-02-25 11:37:12
【问题描述】:

我正在寻找计算两个 numpy 数组(x 和 y)之间马氏距离的 NumPy 方法。 以下代码可以使用 Scipy 的 cdist 函数正确计算相同的值。由于这个函数在我的例子中计算了不必要的矩阵,我想要更直接的方法来计算它只使用 NumPy。

import numpy as np
from scipy.spatial.distance import cdist

x = np.array([[[1,2,3,4,5],
               [5,6,7,8,5],
               [5,6,7,8,5]],
              [[11,22,23,24,5],
               [25,26,27,28,5],
               [5,6,7,8,5]]])
i,j,k = x.shape

xx = x.reshape(i,j*k).T


y = np.array([[[31,32,33,34,5],
               [35,36,37,38,5],
               [5,6,7,8,5]],
              [[41,42,43,44,5],
               [45,46,47,48,5],
               [5,6,7,8,5]]])


yy = y.reshape(i,j*k).T

results =  cdist(xx,yy,'mahalanobis')
results = np.diag(results)
print results



[ 2.28765854  2.75165028  2.75165028  2.75165028  0.          2.75165028
  2.75165028  2.75165028  2.75165028  0.          0.          0.          0.
  0.          0.        ]

我的审判:

VI = np.linalg.inv(np.cov(xx,yy))

print np.sqrt(np.dot(np.dot((xx-yy),VI),(xx-yy).T))

有人能纠正这个方法吗?

这是它的公式:

http://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.spatial.distance.mahalanobis.html#scipy.spatial.distance.mahalanobis

【问题讨论】:

  • 我想计算 [1,11] 和 [31,41] 之间的马氏距离; [2,22] 和 [32,42],...等等。
  • scipy中的实现是纯python代码。您可以将您的方法与他们的方法进行比较。有关两个向量之间马氏距离的计算,请参见此处:github.com/scipy/scipy/blob/… 要计算观察矩阵的距离,您可能必须遍历每个观察向量。
  • 是的,我尝试从那个来源进行计算,但是由于我对 Python 的了解很少,它还没有完成。你能看看我的审判吗?
  • 您的方法仅与您的辣味不同,包括转置的增量,而辣味源代码不会在第二次出现时转置增量...
  • @jkalden 认为这是一个错误,我在这里提交了报告mail.scipy.org/pipermail/scipy-dev/2014-December/020301.html

标签: python numpy


【解决方案1】:

另一个与 einsum 一样快的简单解决方案

e = xx-yy
X = np.vstack([xx,yy])
V = np.cov(X.T) 
p = np.linalg.inv(V)
D = np.sqrt(np.sum(np.dot(e,p) * e, axis = 1))

【讨论】:

    【解决方案2】:

    我认为您的问题在于协方差矩阵的构造。试试:

    X = np.vstack([xx,yy])
    V = np.cov(X.T)
    VI = np.linalg.inv(V)
    print np.diag(np.sqrt(np.dot(np.dot((xx-yy),VI),(xx-yy).T)))
    

    输出:

    [ 2.28765854  2.75165028  2.75165028  2.75165028  0.          2.75165028
      2.75165028  2.75165028  2.75165028  0.          0.          0.          0.
      0.          0.        ]
    

    要在此处不隐式创建中间数组的情况下执行此操作,您可能必须为 Python 循环牺牲一个 C 循环:

    A = np.dot((xx-yy),VI)
    B = (xx-yy).T
    n = A.shape[0]
    D = np.empty(n)
    for i in range(n):
        D[i] = np.sqrt(np.sum(A[i] * B[:,i]))
    

    编辑:实际上,使用np.einsum voodoo,您可以删除 Python 循环并大大加快它的速度(在我的系统上,从 84.3 µs 到 2.9 µs):

    D = np.sqrt(np.einsum('ij,ji->i', A, B))
    

    编辑:正如@Warren Weckesser 指出的那样,einsum 也可以用来取消中间的 AB 数组:

    delta = xx - yy
    D = np.sqrt(np.einsum('nj,jk,nk->n', delta, VI, delta))
    

    【讨论】:

    • 感谢您的尝试,点赞。实际上,我想避免使用 np.diag 来减少不必要的计算以加快速度。
    • 非常感谢。我可以知道你所说的 C 循环是什么意思吗?
    • NumPy 例程在“底层”使用编译的 C 代码,因此它们比字节码编译的 Python 循环快得多(但您必须预先定义数据类型和数组大小)。这就是为什么我最近编辑中的einsum 解决方案比D 的Python 循环要快得多。
    • 您可以使用einsum获取完整产品:delta = xx - yy; D = np.sqrt(np.einsum('nj,jk,nk->n', delta, VI, delta))
    • 将其分解为两个步骤:p1 = einsum('nj,jk->nk', delta, VI)delta.dot(VI) 相同。 p2 = einsum('nk,nk->n', p1, delta)p1delta 的行的成对点积。查看einsum 文档字符串以获取更多示例。
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2019-08-01
    • 2018-06-29
    • 1970-01-01
    • 2015-06-25
    • 1970-01-01
    • 2019-07-28
    相关资源
    最近更新 更多