这里发生了一些事情:numpy 的向量操作、添加单例轴和广播。
首先,您应该能够看到== 是如何发挥作用的。
假设我们从一个简单的标签数组开始。 == 以向量化方式运行,这意味着我们可以将整个数组与一个标量进行比较,并获得一个由每个元素比较的值组成的数组。例如:
>>> labels = np.array([1,2,0,0,2])
>>> labels == 0
array([False, False, True, True, False], dtype=bool)
>>> (labels == 0).astype(np.float32)
array([ 0., 0., 1., 1., 0.], dtype=float32)
首先我们得到一个布尔数组,然后我们强制转换为浮点数:在 Python 中为 False==0,而 True==1。所以我们最终得到一个数组,它是 0,labels 不等于 0,而 1 是它所在的位置。
但是与 0 比较并没有什么特别之处,我们可以与 1 或 2 或 3 进行比较以获得相似的结果:
>>> (labels == 2).astype(np.float32)
array([ 0., 1., 0., 0., 1.], dtype=float32)
事实上,我们可以遍历所有可能的标签并生成这个数组。我们可以使用 listcomp:
>>> np.array([(labels == i).astype(np.float32) for i in np.arange(3)])
array([[ 0., 0., 1., 1., 0.],
[ 1., 0., 0., 0., 0.],
[ 0., 1., 0., 0., 1.]], dtype=float32)
但这并没有真正利用 numpy.我们要做的是将每个可能的标签与每个元素进行比较,IOW 进行比较
>>> np.arange(3)
array([0, 1, 2])
与
>>> labels
array([1, 2, 0, 0, 2])
这就是 numpy 广播的魔力所在。现在,labels 是形状 (5,) 的一维对象。如果我们将其设为形状为 (5,1) 的二维对象,则操作将在最后一个轴上“广播”,我们将获得形状为 (5,3) 的输出以及比较每个条目的结果标签的每个元素的范围。
首先我们可以使用None(或np.newaxis)为labels添加一个“额外”轴,改变它的形状:
>>> labels[:,None]
array([[1],
[2],
[0],
[0],
[2]])
>>> labels[:,None].shape
(5, 1)
然后我们可以进行比较(这是我们之前看到的安排的转置,但这并不重要)。
>>> np.arange(3) == labels[:,None]
array([[False, True, False],
[False, False, True],
[ True, False, False],
[ True, False, False],
[False, False, True]], dtype=bool)
>>> (np.arange(3) == labels[:,None]).astype(np.float32)
array([[ 0., 1., 0.],
[ 0., 0., 1.],
[ 1., 0., 0.],
[ 1., 0., 0.],
[ 0., 0., 1.]], dtype=float32)
numpy 中的广播功能非常强大,值得一读。