【发布时间】:2017-07-19 19:30:57
【问题描述】:
我有一个 NumPy 标签数组:
labels = np.ndarray(10000, dtype=np.float32)
数组中的元素如下所示:
print(labels[1:5])
Output: [ 9. 9. 4. 1.]
我想将它们转换成一个热编码标签,我使用了以下代码:
one_hot_labels = np.eye(10)[labels]
我收到以下错误:
IndexError Traceback (most recent call last)
<ipython-input-21-dccf85afc031> in <module>()
1
----> 2 s=np.eye(10)[labels]
IndexError: arrays used as indices must be of integer (or boolean) type
我该如何解决这个问题?
【问题讨论】:
-
你确定标签和火车标签是一样的吗?
-
您需要使用整数值作为索引:
one_hot_labels=np.eye(10)[labels.astype(int)] -
@JohanL 谢谢。它有效
标签: python numpy one-hot-encoding