【发布时间】:2020-05-07 10:41:56
【问题描述】:
我正在尝试基于此 repo 在 PyTorch 中编写异步 Actor Critic:https://github.com/seungeunrho/minimalRL/blob/master/a3c.py 但我正在更改 ActorCritic 类以使用我自己编写的类。
基本上我有一个 A3C 类,它的一个实例 global_model,具有共享内存,我使用 torch.multiprocessing 打开一些进程以并行训练模型。在开始的每个过程中,我必须创建一个名为 local_model 的模型的新实例,以便继续训练,但是即使全局模型中的一个在每个过程中都有效,该过程也会陷入局部模型的初始化时间。
尝试调试它,我可以看到它进入了 A3C.init 函数和 SharedActorCritic.init ,但是在我放置检查点打印后它就停止了.但是,如果我神奇地打印任何包含 list(critic_param_gen) 的表达式,一切都会正常工作。我还注意到只打印critic_param_gen 是不行的。
知道这是为什么吗?
如果我使用 local_model = copy.deepcopy(global_model) 作为函数 create_local_model,也会发生类似的事情,即仅在存在该打印时才有效。
在伪代码中:
import torch.multiprocessiA3Cng as mp
import torch.nn as nn
import itertools as it
debug = True
A3C(nn.Module):
def __init__(self, model, n_features):
...
self.AC_architecture = SharedActorCritic(model, n_features)
class SharedActorCritic(nn.Module):
def __init__(self, model, n_features):
super(SharedActorCritic, self).__init__()
self.shared_architecture = model(n_features) # inherits from nn.Module
self.actor = SharedActor(n_features) # inherits from nn.Module
self.critic = SharedCritic(n_features) # inherits from nn.Module
self.critic_target = BaseCritic(model, n_features) # inherits from nn.Module
critic_param_gen = it.chain(self.shared_architecture.parameters(), self.critic.parameters())
print("checkpoint")
if debug: print(list(critic_param_gen)) # this makes the whole thing work
for trg_params, params in zip(self.critic_target.parameters(), critic_param_gen ):
trg_params.data.copy_(params.data)
def create_local_model(model, n_features):
local_model = A3C(model, n_features)
print("Process ended")
# in the main
global_model = Model() # works
global_model.share_memory() # doesn't really matter
p = mp.Process(target=create_local_model, args=(model, n_features, ))
p.start()
print("Process started")
p.join()
----
# output if debug is True
Process started
checkpoint
[ ...actual list of critic_param_gen ... ]
Process ended
# output if debug is False
Process started
checkpoint
# and then runs forever
编辑:多亏了snakecharmerb,解决了打印语句的谜团。我创建了一个最小的可重现示例。似乎如果网络足够大,如果在进程中执行复制操作就会中断,而不是在进程之外(因为可以实例化全局模型)。
import torch.nn as nn
import torch.multiprocessing as mp
import copy
class Net(nn.Module):
def __init__(self, n_features=256, n_layers=8):
super(Net, self).__init__()
self.net1 = nn.Sequential(*nn.ModuleList([nn.Linear(n_features, n_features) for _ in range(n_layers)]))
self.net2 = nn.Sequential(*nn.ModuleList([nn.Linear(n_features, n_features) for _ in range(n_layers)]))
for p1, p2 in zip(self.net1.parameters(), self.net2.parameters()):
p1.data.copy_(p2.data)
def forward(self, x):
return self.net(x)
def create_local_model_v1(global_model):
local_model = copy.deepcopy(global_model)
print("Process ended")
%%time
global_model = Net(16,2)
print("Global model created")
p = mp.Process(target=create_local_model_v1, args=(global_model,))
p.start()
print("Process started")
p.join()
# Output
Global model created
Process ended
Process started
CPU times: user 3 ms, sys: 11.9 ms, total: 14.9 ms
Wall time: 45.1 ms
%%time
global_model = Net(256,8)
print("Global model created")
p = mp.Process(target=create_local_model_v1, args=(global_model,))
p.start()
print("Process started")
p.join()
# Output - Gets stuck
Global model created
Process started
【问题讨论】:
-
print(list(critic_param_gen))耗尽了迭代器,因此后面的循环将循环零次。因此 - 以及您对copy.deepcopy(global_model)的评论 - 表明复制存在问题。从问题中的信息中无法说出这个问题可能是什么。 -
谢谢,我没想到。有什么我可以提供的信息来帮助调试它吗?
-
我猜是复制代码,以及被复制对象的代码/结构?理想情况下应该可以生成minimal reproducible example,
-
@snakecharmerb 我正在尝试这样做。我注意到结果取决于架构有多大,例如如果我减少参数的数量,它开始在这两种方法中起作用。因此,当我在进程上运行该函数时,似乎可以复制的内存量是有限的。我将尝试生成一个具有更简单架构的最小可重现示例。
标签: python multiprocessing pytorch itertools