【发布时间】:2020-12-28 12:09:43
【问题描述】:
在机器翻译中,我们总是需要在注释和预测中切出第一个时间步(SOS 令牌)。
当使用batch_first=False 时,切掉第一个时间步仍然保持张量连续。
import torch
batch_size = 128
seq_len = 12
embedding = 50
# Making a dummy output that is `batch_first=False`
batch_not_first = torch.randn((seq_len,batch_size,embedding))
batch_not_first = batch_first[1:].view(-1, embedding) # slicing out the first time step
但是,如果我们使用batch_first=True,切片后,张量不再是连续的。我们需要使其连续,然后才能进行不同的操作,例如view。
batch_first = torch.randn((batch_size,seq_len,embedding))
batch_first[:,1:].view(-1, embedding) # slicing out the first time step
output>>>
"""
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-8-a9bd590a1679> in <module>
----> 1 batch_first[:,1:].view(-1, embedding) # slicing out the first time step
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
"""
这是否意味着batch_first=False 至少在机器翻译方面更好?因为它使我们免于执行contiguous() 步骤。有没有比batch_first=True 效果更好的情况?
【问题讨论】: