【发布时间】:2020-10-26 04:48:18
【问题描述】:
我将一个包含 5 个类别(例如汽车、公共汽车、...)的数据框传递给 nn.Embedding。
当我执行embedding.parameters() 时,我可以看到有 5 个张量,但我怎么知道哪个索引对应于原始输入(例如汽车、公共汽车……)?
【问题讨论】:
我将一个包含 5 个类别(例如汽车、公共汽车、...)的数据框传递给 nn.Embedding。
当我执行embedding.parameters() 时,我可以看到有 5 个张量,但我怎么知道哪个索引对应于原始输入(例如汽车、公共汽车……)?
【问题讨论】:
您不能因为张量未命名(只有维度可以命名,请参阅PyTorch's Named Tensors)。
您必须将名称保存在单独的数据容器中,例如(4 类别在这里):
import pandas as pd
import torch
df = pd.DataFrame(
{
"bus": [1.0, 2, 3, 4, 5],
"car": [6.0, 7, 8, 9, 10],
"bike": [11.0, 12, 13, 14, 15],
"train": [16.0, 17, 18, 19, 20],
}
)
df_data = df.to_numpy().T
df_names = list(df)
embedding = torch.nn.Embedding(df_data.shape[0], df_data.shape[1])
embedding.weight.data = torch.from_numpy(df_data)
现在您可以简单地将它与您想要的任何索引一起使用:
index = 1
embedding(torch.tensor(index)), df_names[index]
这会给你(tensor[6, 7, 8, 9, 10], "car"),所以数据和相应的列名。
【讨论】: