【问题标题】:What is the difference between an Embedding Layer with a bias immediately afterwards and a Linear Layer in PyTorchPyTorch 中带有偏差的嵌入层和线性层有什么区别
【发布时间】:2020-12-25 04:02:03
【问题描述】:

我正在阅读“使用 fastai 和 PyTorch 为编码人员提供深度学习”一书。对于 Embedding 模块的作用,我仍然有些困惑。它似乎是一个简短而简单的网络,但我似乎无法完全理解嵌入与没有偏见的线性的不同之处。我知道它做了一些更快的点积计算版本,其中一个矩阵是单热编码矩阵,另一个是嵌入矩阵。这样做实际上是为了选择一条数据吗?请指出我错在哪里。这是书中展示的简单网络之一。

class DotProduct(Module):
    def __init__(self, n_users, n_movies, n_factors):
        self.user_factors = Embedding(n_users, n_factors)
        self.movie_factors = Embedding(n_movies, n_factors)
        
    def forward(self, x):
        users = self.user_factors(x[:,0])
        movies = self.movie_factors(x[:,1])
        return (users * movies).sum(dim=1)

【问题讨论】:

    标签: python oop deep-learning pytorch fast-ai


    【解决方案1】:

    嵌入

    [...] 嵌入与没有偏差的线性的不同之处。

    基本上一切。 torch.nn.Embedding 是一个查找表;本质上与torch.Tensor 的工作方式相同,但有一些不同之处(例如在指定索引处使用稀疏嵌入或默认值的可能性)。

    例如:

    import torch
    
    embedding = torch.nn.Embedding(3, 4)
    
    print(embedding.weight)
    
    print(embedding(torch.tensor([1])))
    

    会输出:

    Parameter containing:
    tensor([[ 0.1420, -0.1886,  0.6524,  0.3079],
            [ 0.2620,  0.4661,  0.7936, -1.6946],
            [ 0.0931,  0.3512,  0.3210, -0.5828]], requires_grad=True)
    tensor([[ 0.2620,  0.4661,  0.7936, -1.6946]], grad_fn=<EmbeddingBackward>)
    

    所以我们基本上取了嵌入的第一行。仅此而已。

    用在什么地方?

    通常当我们想为每一行编码一些含义(如 word2vec)时(例如,语义上接近的单词在欧几里得空间中接近)并可能训练它们

    线性

    torch.nn.Linear(无偏差)也是torch.Tensor(重量)它对它(和输入)进行操作,本质上是:

    output = input.matmul(weight.t())
    

    每次调用层时(参见source codefunctional definition of this layer)。

    代码 sn-p

    您的代码 sn-p 中的层基本上是这样做的:

    • __init__ 中创建两个查找表
    • 使用形状(batch_size, 2) 的输入调用层:
      • 第一列包含用户嵌入的索引
      • 第二列包含电影嵌入的索引
    • 这些嵌入被相乘和相加,返回 (batch_size,)(因此它不同于 nn.Linear,后者将返回 (batch_size, out_features) 并进行点积,而不是按元素乘法,然后像这里一样求和)

    这可能用于训练一些类似推荐系统的表示(用户和电影)。

    其他东西

    我知道它对点积进行了一些更快的计算版本 其中一个矩阵是 one-hot 编码矩阵,另一个是 嵌入矩阵。

    不,它没有。 torch.nn.Embedding 可以是一种热编码,也可能是稀疏的,但取决于算法(以及这些算法是否支持稀疏性),速度会加快,也可能不会。

    【讨论】:

      猜你喜欢
      • 2010-09-12
      • 1970-01-01
      • 1970-01-01
      • 2022-01-24
      • 2010-10-16
      • 1970-01-01
      • 2019-07-09
      • 2020-04-23
      • 2018-01-21
      相关资源
      最近更新 更多