【问题标题】:Explanation of numpy indexing ndarray[(4, 2), (5, 3)]numpy索引ndarray[(4, 2), (5, 3)]的解释
【发布时间】:2020-12-21 08:02:13
【问题描述】:

问题

请帮助理解 Numpy 将元组 (i, j) 索引到 ndarray 的设计决策或合理性。

背景

当索引为单个元组 (4, 2) 时,则为 (i=row, j=column)。

shape = (6, 7)
X = np.zeros(shape, dtype=int)
X[(4, 2)] = 1
X[(5, 3)] = 1
print("X is :\n{}\n".format(X))
---
X is :
[[0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 1 0 0 0 0]    <--- (4, 2)
 [0 0 0 1 0 0 0]]   <--- (5, 3)

但是,当索引是多个元组 (4, 2), (5, 3) 时,则 (i=row, j=row) 为 (4, 2) 和 (i=column, j=column) (5, 3)。

shape = (6, 7)
Y = np.zeros(shape, dtype=int)
Y[(4, 2), (5, 3)] = 1
print("Y is :\n{}\n".format(Y))
---
Y is :
[[0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 1 0 0 0]    <--- (2, 3)
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 1 0]    <--- (4, 5)
 [0 0 0 0 0 0 0]]

这意味着你正在构造一个二维数组R,例如R=A[B, C]。 这意味着值 rij=abijcij.

所以这意味着位于R[0,0]的项目是A中的项目 作为行索引B[0,0] 和列索引C[0,0]。项目R[0,1]A 中的项目,具有行索引 B[0,1] 并作为列索引 C[0,1]

multi_index:整数数组的元组,每个维度一个数组

为什么不总是(i=行,j=列)?如果总是 (i=row, j=column) 会发生什么?


更新

根据 Akshay 和 @DaniMesejo 的回答,明白了:

X[
  (4),    # dimension 2 indices with only 1 element
  (2)     # dimension 1 indices with only 1 element
] = 1

Y[
  (4, 2, ...), # dimension 2 indices 
  (5, 3, ...)  # dimension 1 indices (dimension 0 is e.g. np.array(3) whose shape is (), in my understanding)
] = 1

【问题讨论】:

  • 总是i = rowj = column。你的解释是错误的。在您的第一个示例中,括号是多余的(与 X[4, 2] 相同)。其中 4 是第一个轴的索引,2 是第二个轴的索引。因此,逗号之前的每个值都用于行,而逗号之后的每个值都用于列。但问题是您可以为行和列传递多个值,例如 Y[(4, 2), (5, 3)],因为逗号之前的所有值都用于第一个轴,后面的一个逗号用于第二个轴。
  • @DaniMesejo,感谢您指出“逗号之前的值用于第一个轴,逗号后面的值用于第二个轴”。

标签: python numpy matrix-indexing


【解决方案1】:

很容易理解它的工作原理(以及此设计决策背后的动机)。

Numpy 将其 ndarray 存储为连续的内存块。每个元素在前一个元素之后每隔 n 个字节按顺序存储。

(图片引用自excellent SO post

所以如果你的 3D 数组看起来像这样 -

然后在内存中存储为 -

当检索一个元素(或一个元素块)时,NumPy 计算它需要遍历多少个strides(字节)才能获取下一个元素in that direction/axis。因此,对于上面的示例,对于axis=2,它必须遍历 8 个字节(取决于datatype),但对于axis=1,它必须遍历8*4 字节,而axis=0 它需要8*8 字节。

考虑到这一点,让我们看看您要做什么。

print(X)
print(X.strides)
[[0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 1 0 0 0 0]
 [0 0 0 1 0 0 0]]

#Strides (bytes) required to traverse in each axis.
(56, 8)

对于您的数组,要获取axis=0 中的下一个元素,我们需要遍历56 bytes,而对于axis=1 中的下一个元素,我们需要8 bytes

当您索引(4,2) 时,NumPy 将在axis=0 中使用56*4 字节,在axis=1 中使用8*2 字节来访问它。同样,如果要访问(4,2)(5,3),则必须访问axis=0 中的56*(4,5)8*(2,3) 中的8*(2,3)

这就是设计的原因,因为它与 NumPy 使用 strides 实际索引元素的方式一致。

X[(axis0_indices), (axis1_indices), ..]

X[(4, 5), (2, 3)] #(row indices), (column indices)
array([1, 1])

通过这种设计,也可以轻松扩展到更高维的张量(例如 8 维数组)! 如果您分别提及每个索引元组,则需要计算元素 * 维数才能获取它们。虽然采用这种设计,但它可以将步幅值广播到每个轴的元组并更快地获取这些值!

【讨论】:

    猜你喜欢
    • 2018-10-10
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2023-03-13
    • 1970-01-01
    相关资源
    最近更新 更多