【发布时间】:2019-04-18 06:34:40
【问题描述】:
正如下面的代码,我试图获得模型的 47 个重复输出的平均值。但它总是内存不足。如果我删除z_proto_class_list.append(z_proto_class),那就没问题了。我想这是因为如果我不附加张量,内存就会被释放。我总是试图一次生成 47 输出,但它显然比我当前的选择更消耗内存。有没有办法解决我当前的问题?谢谢。
z_proto_class_list = []
for support_input_ids, support_input_mask, support_segment_ids in dataloader:
s_z, s_pooled_output = model(support_input_ids, support_input_mask, support_segment_ids, output_all_encoded_layers=False)
sz_dim = s_z.size(-1)
index = torch.LongTensor(support_idx_list).unsqueeze(1).unsqueeze(2).expand(len(support_idx_list),1,sz_dim).cuda()
z_proto_raw = torch.gather(s_z,1,index)
z_proto_class = z_proto_raw.view(1,n_support, sz_dim).mean(1)
z_proto_class_list.append(z_proto_class)
torch.cuda.empty_cache()
z_proto = torch.cat(z_proto_class_list, 0)
【问题讨论】:
-
而不是
cat然后mean- 只需保存 sum 并将其除以元素的数量。
标签: memory-management out-of-memory pytorch