2.7 花哨的索引
import numpy as np
rand = np.random.RandomState(42)
x = rand.randint(100, size=10)
print(x)
[51 92 14 71 60 20 82 86 74 74]
[x[3], x[7], x[2]]
[71, 86, 14]
ind = [3, 7, 2]
x[ind]
array([71, 86, 14])
利用花哨的索引,结果的形状与索引数组一致,而不是与被索引数组的形状一致。
ind = np.array([[3, 7], [4, 5]])
x[ind]
array([[71, 86],
[60, 20]])
X = np.arange(12).reshape((3, 4))
X
array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
二维索引,对应的是行和列的索引,如果索引的维度不同,会广播后再索引。
row = np.array([0, 1, 2])
col = np.array([2, 1, 3])
X[row, col]
array([ 2, 5, 11])
X[row[:, np.newaxis], col] # 索引是3X1和1x3,先广播再索引
array([[ 2, 1, 3],
[ 6, 5, 7],
[10, 9, 11]])
row[:, np.newaxis], col
(array([[0],
[1],
[2]]), array([2, 1, 3]))
X[2, [2, 0, 1]] # 组合使用,与简单索引
array([10, 8, 9])
X[1:, [2, 0, 1]] # 组合使用,与切片
array([[ 6, 4, 5],
[10, 8, 9]])
mask = np.array([1, 0, 1, 0], dtype=bool)
X[row[:, np.newaxis], mask] # 组合使用,与掩码
array([[ 0, 2],
[ 4, 6],
[ 8, 10]])
示例:选择随机点
花哨的索引的常见用途是从一个矩阵中选择行的子集,如有一个 N×D
的矩阵,表示在 D 个维度中的 N
个点。以下是一个二维正态分布的点组成的数组:
mean = [0, 0]
cov = [[1, 2], [2, 5]]
X = rand.multivariate_normal(mean, cov, 100)
X.shape
(100, 2)
该数组为100行2列的二维数组,画出散点:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn; seaborn.set()
plt.scatter(X[:, 0], X[:, 1]);
X # 二维数组的内容
array([[-0.644508 , -0.46220608],
[ 0.7376352 , 1.21236921],
[ 0.88151763, 1.12795177],
[ 2.04998983, 5.97778598],
[-0.1711348 , -2.06258746],
[ 0.67956979, 0.83705124],
[ 1.46860232, 1.22961093],
[ 0.35282131, 1.49875397],
[-2.51552505, -5.64629995],
[ 0.0843329 , -0.3543059 ],
[ 0.19199272, 1.48901291],
[-0.02566217, -0.74987887],
[ 1.00569227, 2.25287315],
[ 0.49514263, 1.18939673],
[ 0.0629872 , 0.57349278],
[ 0.75093031, 2.99487004],
[-3.0236127 , -6.00766046],
[-0.53943081, -0.3478899 ],
[ 1.53817376, 1.99973464],
[-0.50886808, -1.81099656],
[ 1.58115602, 2.86410319],
[ 0.99305043, 2.54294059],
[-0.87753796, -1.15767204],
[-1.11518048, -1.87508012],
[ 0.4299908 , 0.36324254],
[ 0.97253528, 3.53815717],
[ 0.32124996, 0.33137032],
[-0.74618649, -2.77366681],
[-0.88473953, -1.81495444],
[ 0.98783862, 2.30280401],
[-1.2033623 , -2.04402725],
[-1.51101746, -3.2818741 ],
[-2.76337717, -7.66760648],
[ 0.39158553, 0.87949228],
[ 0.91181024, 3.32968944],
[-0.84202629, -2.01226547],
[ 1.06586877, 0.95500019],
[ 0.44457363, 1.87828298],
[ 0.35936721, 0.40554974],
[-0.90649669, -0.93486441],
[-0.35790389, -0.52363012],
[-1.33461668, -3.03203218],
[ 0.02815138, 0.79654924],
[ 0.37785618, 0.51409383],
[-1.06505097, -2.88726779],
[ 2.32083881, 5.97698647],
[ 0.47605744, 0.83634485],
[-0.35490984, -1.03657119],
[ 0.57532883, -0.79997124],
[ 0.33399913, 2.32597923],
[ 0.6575612 , -0.22389518],
[ 1.3707365 , 2.2348831 ],
[ 0.07099548, -0.29685467],
[ 0.6074983 , 1.47089233],
[-0.34226126, -1.10666237],
[ 0.69226246, 1.21504303],
[-0.31112937, -0.75912097],
[-0.26888327, -1.89366817],
[ 0.42044896, 1.85189522],
[ 0.21115245, 2.00781492],
[-1.83106042, -2.91352836],
[ 0.7841796 , 1.97640753],
[ 0.10259314, 1.24690575],
[-1.91100558, -3.66800923],
[ 0.13143756, -0.07833855],
[-0.1317045 , -1.64159158],
[-0.14547282, -1.34125678],
[-0.51172373, -1.40960773],
[ 0.69758045, 0.72563649],
[ 0.11677083, 0.88385162],
[-1.16586444, -2.24482237],
[-2.23176235, -2.63958101],
[ 0.37857234, 0.69112594],
[ 0.87475323, 3.400675 ],
[-0.86864365, -3.03568353],
[-1.03637857, -1.18469125],
[-0.53334959, -0.37039911],
[ 0.30414557, -0.5828419 ],
[-1.47656656, -2.13046298],
[-0.31332021, -1.7895623 ],
[ 1.12659538, 1.49627535],
[-1.19675798, -1.51633442],
[-0.75210154, -0.79770535],
[ 0.74577693, 1.95834451],
[ 1.56094354, 2.9330816 ],
[-0.72009966, -1.99780959],
[-1.32319163, -2.61218347],
[-2.56215914, -6.08410838],
[ 1.31256297, 3.13143269],
[ 0.51575983, 2.30284639],
[ 0.01374713, -0.11539344],
[-0.16863279, 0.39422355],
[ 0.12065651, 1.13236323],
[-0.83504984, -2.38632016],
[ 1.05185885, 1.98418223],
[-0.69144553, -1.56919875],
[-1.2567603 , -1.125898 ],
[ 0.09619333, -0.64335574],
[-0.99658689, -2.35038099],
[-1.21405259, -1.77693724]])
X[0] # 二维数组中第0个元素
array([-0.644508 , -0.46220608])
X[0, 0] # 二维数组中第0个元素的横坐标
-0.6445079962363565
X[:, 0] # 二维数组中元素的横坐标组成的数组
array([-0.644508 , 0.7376352 , 0.88151763, 2.04998983, -0.1711348 ,
0.67956979, 1.46860232, 0.35282131, -2.51552505, 0.0843329 ,
0.19199272, -0.02566217, 1.00569227, 0.49514263, 0.0629872 ,
0.75093031, -3.0236127 , -0.53943081, 1.53817376, -0.50886808,
1.58115602, 0.99305043, -0.87753796, -1.11518048, 0.4299908 ,
0.97253528, 0.32124996, -0.74618649, -0.88473953, 0.98783862,
-1.2033623 , -1.51101746, -2.76337717, 0.39158553, 0.91181024,
-0.84202629, 1.06586877, 0.44457363, 0.35936721, -0.90649669,
-0.35790389, -1.33461668, 0.02815138, 0.37785618, -1.06505097,
2.32083881, 0.47605744, -0.35490984, 0.57532883, 0.33399913,
0.6575612 , 1.3707365 , 0.07099548, 0.6074983 , -0.34226126,
0.69226246, -0.31112937, -0.26888327, 0.42044896, 0.21115245,
-1.83106042, 0.7841796 , 0.10259314, -1.91100558, 0.13143756,
-0.1317045 , -0.14547282, -0.51172373, 0.69758045, 0.11677083,
-1.16586444, -2.23176235, 0.37857234, 0.87475323, -0.86864365,
-1.03637857, -0.53334959, 0.30414557, -1.47656656, -0.31332021,
1.12659538, -1.19675798, -0.75210154, 0.74577693, 1.56094354,
-0.72009966, -1.32319163, -2.56215914, 1.31256297, 0.51575983,
0.01374713, -0.16863279, 0.12065651, -0.83504984, 1.05185885,
-0.69144553, -1.2567603 , 0.09619333, -0.99658689, -1.21405259])
用花哨的索引选择随机而不重复的20个索引值,并用这些索引值选择原始数组对应的值:
indices = np.random.choice(X.shape[0], 20, replace=False)
indices
array([94, 76, 22, 0, 77, 36, 32, 58, 54, 70, 50, 92, 44, 38, 65, 46, 79,
68, 67, 71])
selection = X[indices] # 花哨的索引
selection.shape
(20, 2)
plt.scatter(X[:, 0], X[:, 1], alpha=0.3)
plt.scatter(selection[:, 0], selection[:, 1], facecolor='none', edgecolor='b', s=200);
用花哨的索引修改值
x = np.arange(10)
i = np.array([2, 1, 8, 4])
x[i] = 99
x
array([ 0, 99, 99, 3, 99, 5, 6, 7, 99, 9])
x[i] -= 10 # 赋值语句
x
array([ 0, 89, 89, 3, 89, 5, 6, 7, 89, 9])
x[[0, 0]] # 索引是个数组,依次索引0和0,相当于索引第0个值两次
array([0, 0])
x[[0, 0]] = [4, 6] # 重复索引,赋值的4会被6覆盖
x
array([ 6, 89, 89, 3, 89, 5, 6, 7, 89, 9])