【问题标题】:PyTorch copy parameter gets stuck in multiprocessing if parameters too big如果参数太大,PyTorch 复制参数会卡在多处理中
【发布时间】:2020-05-07 10:41:56
【问题描述】:

我正在尝试基于此 repo 在 PyTorch 中编写异步 Actor Critic:https://github.com/seungeunrho/minimalRL/blob/master/a3c.py 但我正在更改 ActorCritic 类以使用我自己编写的类。

基本上我有一个 A3C 类,它的一个实例 global_model,具有共享内存,我使用 torch.multiprocessing 打开一些进程以并行训练模型。在开始的每个过程中,我必须创建一个名为 local_model 的模型的新实例,以便继续训练,但是即使全局模型中的一个在每个过程中都有效,该过程也会陷入局部模型的初始化时间。

尝试调试它,我可以看到它进入了 A3C.init 函数和 SharedActorCritic.init ,但是在我放置检查点打印后它就停止了.但是,如果我神奇地打印任何包含 list(critic_param_gen) 的表达式,一切都会正常工作。我还注意到只打印critic_param_gen 是不行的。

知道这是为什么吗?

如果我使用 local_model = copy.deepcopy(global_model) 作为函数 create_local_model,也会发生类似的事情,即仅在存在该打印时才有效。

在伪代码中:

import torch.multiprocessiA3Cng as mp
import torch.nn as nn
import itertools as it

debug = True

A3C(nn.Module):
  def __init__(self, model, n_features):
     ... 
     self.AC_architecture = SharedActorCritic(model, n_features)

class SharedActorCritic(nn.Module):
    def __init__(self, model, n_features):
        super(SharedActorCritic, self).__init__()

        self.shared_architecture = model(n_features) # inherits from nn.Module
        self.actor = SharedActor(n_features) # inherits from nn.Module
        self.critic = SharedCritic(n_features) # inherits from nn.Module

        self.critic_target = BaseCritic(model, n_features) # inherits from nn.Module

        critic_param_gen = it.chain(self.shared_architecture.parameters(), self.critic.parameters())
        print("checkpoint")
        if debug: print(list(critic_param_gen)) # this makes the whole thing work
        for trg_params, params in zip(self.critic_target.parameters(), critic_param_gen ):
            trg_params.data.copy_(params.data)

def create_local_model(model, n_features):
    local_model = A3C(model, n_features)
    print("Process ended")

# in the main
global_model = Model() # works
global_model.share_memory() # doesn't really matter

p = mp.Process(target=create_local_model, args=(model, n_features, ))
p.start()
print("Process started")
p.join()

----
# output if debug is True
Process started
checkpoint
[ ...actual list of critic_param_gen ... ]
Process ended

# output if debug is False
Process started
checkpoint
# and then runs forever

编辑:多亏了snakecharmerb,解决了打印语句的谜团。我创建了一个最小的可重现示例。似乎如果网络足够大,如果在进程中执行复制操作就会中断,而不是在进程之外(因为可以实例化全局模型)。

import torch.nn as nn
import torch.multiprocessing as mp
import copy

class Net(nn.Module):
    def __init__(self, n_features=256, n_layers=8):
        super(Net, self).__init__()
        self.net1 = nn.Sequential(*nn.ModuleList([nn.Linear(n_features, n_features) for _ in range(n_layers)]))
        self.net2 = nn.Sequential(*nn.ModuleList([nn.Linear(n_features, n_features) for _ in range(n_layers)]))

        for p1, p2 in zip(self.net1.parameters(), self.net2.parameters()):
            p1.data.copy_(p2.data)

    def forward(self, x):
        return self.net(x)

def create_local_model_v1(global_model):
    local_model = copy.deepcopy(global_model)
    print("Process ended")

%%time
global_model = Net(16,2)
print("Global model created")
p = mp.Process(target=create_local_model_v1, args=(global_model,))
p.start()
print("Process started")
p.join()

# Output
Global model created
Process ended
Process started
CPU times: user 3 ms, sys: 11.9 ms, total: 14.9 ms
Wall time: 45.1 ms

%%time
global_model = Net(256,8)
print("Global model created")
p = mp.Process(target=create_local_model_v1, args=(global_model,))
p.start()
print("Process started")
p.join()

# Output - Gets stuck
Global model created
Process started


【问题讨论】:

  • print(list(critic_param_gen)) 耗尽了迭代器,因此后面的循环将循环零次。因此 - 以及您对 copy.deepcopy(global_model) 的评论 - 表明复制存在问题。从问题中的信息中无法说出这个问题可能是什么。
  • 谢谢,我没想到。有什么我可以提供的信息来帮助调试它吗?
  • 我猜是复制代码,以及被复制对象的代码/结构?理想情况下应该可以生成minimal reproducible example,
  • @snakecharmerb 我正在尝试这样做。我注意到结果取决于架构有多大,例如如果我减少参数的数量,它开始在这两种方法中起作用。因此,当我在进程上运行该函数时,似乎可以复制的内存量是有限的。我将尝试生成一个具有更简单架构的最小可重现示例。

标签: python multiprocessing pytorch itertools


【解决方案1】:

TLDR:使用 torch.multiprocessing.spawn

我不够熟练,无法确定此错误的确切原因和解决方案,但问题出现在torch/nn/parameter.py

result = type(self)(self.data.clone(memory_format=torch.preserve_format), self.requires_grad)

这在深拷贝过程中被调用。为了进行更多调查,我进行了一个更详细的实验来测试导致挂起的参数和环境。结果的要点是模型的大小不是问题,而是有多少特征/问题会导致问题。对我来说,不管有多少层,256 个功能都会导致挂起。另一个更奇怪的问题是,当我删除初始化部分,将 net1 中的参数复制到 net2 时,挂起消失了,但是如果我不向另一个进程发送任何内容,那么一切正常。最后,当使用spawn 函数时,一切正常,直到层数超过 256。

我需要警告有关挂起的一切,据我所知这是一个僵局,但这可能只是一些非常缓慢的过程。这是极不可能的,因为似乎所有活动都停止了,但是我无法确认这是一个死锁,因为当我在挂起期间回溯 C 代码时,我得到的只是内存地址(真正确认一切我想我需要用一些调试选项重建torch ...)。无论如何,我大约有 99% 的把握这是一个死锁,可能是由某处的多处理中的某些东西引起的。我的信心如此之高的原因是代码甚至不会对信号做出反应。如果一切都按预期工作,我希望程序至少允许我从信号处理程序中打印出回溯,但什么也没有。

我发现以下博客文章有些不错: The tragic tale of the deadlocking Python queue

除此之外,我在这一点上的意见是 f*** 结合了火炬和多处理。

如果有人想查看我运行的实验的代码或结果,请告诉我。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2011-02-14
    • 2018-07-05
    • 1970-01-01
    • 1970-01-01
    • 2018-03-12
    • 1970-01-01
    • 2014-12-19
    • 1970-01-01
    相关资源
    最近更新 更多