【问题标题】:How to clear CUDA memory in PyTorch如何在 PyTorch 中清除 Cuda 内存
【发布时间】:2019-08-14 18:53:41
【问题描述】:

我正在尝试获取我已经训练过的神经网络的输出。输入是大小为 300x300 的图像。我使用的批量大小为 1,但在成功获得 25 张图像的输出后,我仍然收到 CUDA error: out of memory 错误。

我在网上搜索了一些解决方案,发现了torch.cuda.empty_cache()。但这似乎仍然不能解决问题。

这是我正在使用的代码。

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

train_x = torch.tensor(train_x, dtype=torch.float32).view(-1, 1, 300, 300)
train_x = train_x.to(device)
dataloader = torch.utils.data.DataLoader(train_x, batch_size=1, shuffle=False)

right = []
for i, left in enumerate(dataloader):
    print(i)
    temp = model(left).view(-1, 1, 300, 300)
    right.append(temp.to('cpu'))
    del temp
    torch.cuda.empty_cache()

这个for loop每次运行25次才给出内存错误。

每次,我都会在网络中发送一个新图像进行计算。因此,在循环中的每次迭代之后,我真的不需要将先前的计算结果存储在 GPU 中。有什么方法可以实现吗?

任何帮助将不胜感激。谢谢。

【问题讨论】:

    标签: python pytorch


    【解决方案1】:

    我知道我哪里出错了。我将发布解决方案作为其他可能遇到相同问题的人的答案。

    基本上,PyTorch 所做的是每当我通过网络传递数据并将计算结果存储在 GPU 内存上时,它都会创建一个计算图,以防我想在反向传播期间计算梯度。但由于我只想执行前向传播,我只需为我的模型指定torch.no_grad()

    因此,我的代码中的 for 循环可以重写为:

    for i, left in enumerate(dataloader):
        print(i)
        with torch.no_grad():
            temp = model(left).view(-1, 1, 300, 300)
        right.append(temp.to('cpu'))
        del temp
        torch.cuda.empty_cache()
    

    为我的模型指定 no_grad() 告诉 PyTorch 我不想存储任何以前的计算,从而释放我的 GPU 空间。

    【讨论】:

    • 这很有趣。改变模型的模式(从训练到评估)有帮助吗?我想知道是否有一个内部机制可以自动告诉 pytorch 模式已更改为 eval 所以不需要保存计算?这意味着如果 net.eval() 没有明确告诉 pytorch 在前向传递期间不保存计算,我可以使用“with torch.no_grad()”进行验证和推理?
    • 为了进行推理(只是前向传递),您只需要指定 net.eval() 它将禁用您的 dropout 和 batchnorm 层,将模型置于评估模式。但是,强烈建议也将它与 torch.no_grad() 一起使用,因为它会禁用 autograd 引擎(在推理过程中您可能不希望使用它),这将节省您的时间和内存。只做 net.eval() 仍然会计算梯度,使其变慢并消耗你的内存。
    • 如果我通过 .numpy().cpu() 将数据张量(比如说预测和 groundtruth)发送到 cpu(),我还需要提及“with torch.no_grad()”吗?
    • 如果你的变量有requires_grad=True,那么你不能直接调用.numpy()。您首先必须执行 .detach() 来告诉 pytorch 您不想计算该变量的梯度。接下来,如果您的变量在 GPU 上,您首先需要将其发送到 CPU 以便使用 .cpu() 转换为 numpy。因此,它将类似于var.detach().cpu().numpy()
    • 但是使用 torch.no_grad(),你不需要提及 .detach(),因为无论如何都不会计算梯度。
    猜你喜欢
    • 2022-01-07
    • 1970-01-01
    • 2021-12-16
    • 2020-03-26
    • 1970-01-01
    • 2019-01-10
    • 2012-01-27
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多