【问题标题】:Retrieving original data from PyTorch nn.Embedding从 PyTorch nn.Embedding 检索原始数据
【发布时间】:2020-10-26 04:48:18
【问题描述】:

我将一个包含 5 个类别(例如汽车、公共汽车、...)的数据框传递给 nn.Embedding

当我执行embedding.parameters() 时,我可以看到有 5 个张量,但我怎么知道哪个索引对应于原始输入(例如汽车、公共汽车……)?

【问题讨论】:

    标签: pytorch embedding


    【解决方案1】:

    您不能因为张量未命名(只有维度可以命名,请参阅PyTorch's Named Tensors)。 您必须将名称保存在单独的数据容器中,例如(4 类别在这里):

    import pandas as pd
    import torch
    
    df = pd.DataFrame(
        {
            "bus": [1.0, 2, 3, 4, 5],
            "car": [6.0, 7, 8, 9, 10],
            "bike": [11.0, 12, 13, 14, 15],
            "train": [16.0, 17, 18, 19, 20],
        }
    )
    
    df_data = df.to_numpy().T
    df_names = list(df)
    
    embedding = torch.nn.Embedding(df_data.shape[0], df_data.shape[1])
    embedding.weight.data = torch.from_numpy(df_data)
    

    现在您可以简单地将它与您想要的任何索引一起使用:

    index = 1
    embedding(torch.tensor(index)), df_names[index]
    

    这会给你(tensor[6, 7, 8, 9, 10], "car"),所以数据和相应的列名。

    【讨论】:

    • 我的数据列看起来更像这样:"type": ["car", "bus", "bike", "train"]。在这种情况下会发生什么?
    • @AndreaEunbeeJang 这些是列名吗?
    • 哦,实际上,nvm!我明白你在那里做了什么。谢谢!!!
    猜你喜欢
    • 1970-01-01
    • 2010-12-19
    • 2021-07-06
    • 1970-01-01
    • 2023-02-07
    • 1970-01-01
    • 1970-01-01
    • 2011-06-28
    相关资源
    最近更新 更多