【发布时间】:2021-09-21 10:55:25
【问题描述】:
编辑:我试过 PyTorch 1.6.0 和 1.7.1,都给我同样的错误。
我有一个模型,可以让用户在不同的架构 A 和 B 之间轻松切换。两种架构的转发功能也不同,所以我有以下模型类:
附注我这里只是用一个非常简单的例子来演示我的问题,实际模型要复杂得多。
class Net(nn.Module):
def __init__(self, condition):
super().__init__()
self.linear = nn.Linear(10, 1)
if condition == 'A':
self.forward = self.forward_A
elif condition == 'B':
self.linear2 = nn.Linear(10, 1)
self.forward = self.forward_B
def forward_A(self, x):
return self.linear(x)
def forward_B(self, x1, x2):
return self.linear(x1) + self.linear2(x2)
它在单个 GPU 情况下运行良好。然而,在多 GPU 的情况下,它会抛出一个错误。
device= 'cuda:0'
x = torch.randn(8,10).to(device)
model = Net('B')
model = model.to(device)
model = nn.DataParallel(model)
model(x, x)
RuntimeError: 期望所有张量都在同一个设备上,但发现 至少有两个设备,cuda:0 和 cuda:1! (在检查参数时 方法 wrapper_addmm 中的参数 mat1)
如何使这个模型类与nn.DataParallel一起工作?
【问题讨论】:
标签: python-3.x deep-learning pytorch multiple-gpu