我认为您可能想要做这样的事情(这将在您将数据/标签提供给网络之前作为预处理步骤完成):
# dummy array/labels
arr = [1, 3, 6, 8, 10, 4, 7, 15, 25, 19]
print('original array =', arr)
# create mapping (range(20) in your case)
mapping = dict(zip(set(arr), range(10)))
print('mapping =', mapping)
# apply the mapping
new_arr = list(map(lambda x: mapping[x], arr))
print('new array =', new_arr)
# >> output:
# original array = [1, 3, 6, 8, 10, 4, 7, 15, 25, 19]
# mapping = {1: 0, 3: 1, 4: 2, 6: 3, 7: 4, 8: 5, 10: 6, 15: 7, 19: 8, 25: 9}
# new array = [0, 1, 3, 5, 6, 2, 4, 7, 9, 8]
所以基本上你的原始标签(这是 len(set(labels)) = 20 但值 > 20,如果我理解正确的话)被映射到最小的可能值,以便它可以与你的损失函数一起使用.如果您需要将标签映射回原始值,最好保留映射以供以后使用。