【发布时间】:2022-01-13 02:08:09
【问题描述】:
我制作了一个我正在尝试实现的缩小版本的示例图:
所以顶部的两个输入节点只与顶部的三个输出节点完全连接,同样的设计适用于底部的两个节点。到目前为止,我已经提出了两种在 PyTorch 中实现这一点的方法,但都不是最优的。
首先是创建一个包含许多较小线性层的 nn.ModuleList,并在前向传递期间通过它们迭代输入。对于图表的示例,它看起来像这样:
class Module(nn.Module):
def __init__(self):
self.layers = nn.Module([nn.Linear(2, 3) for i in range(2)])
def forward(self, input):
output = torch.zeros(2, 3)
for i in range(2):
output[i, :] = self.layers[i](input.view(2, 2)[i, :])
return output.flatten()
所以这完成了图中的网络,主要问题是它非常慢。我认为这是因为 PyTorch 必须按顺序处理 for 循环,而不能并行处理输入张量。
要“矢量化”模块以便 PyTorch 可以更快地运行它,我有这个实现:
class Module(nn.Module):
def __init__(self):
self.layer = nn.Linear(4, 6)
self.mask = # create mask of ones and zeros to "block" certain layer connections
def forward(self, input):
prune.custom_from_mask(self.layer, name='weight', mask=self.mask)
return self.layer(input)
这也完成了图的网络,通过使用权重修剪来确保全连接层中的某些权重始终为零(例如,连接顶部输入节点和底部输出节点的权重将始终为零,因此它有效地“断开连接”)。这个模块比前一个模块快得多,因为没有 for 循环。现在的问题是这个模块占用了更多的内存。这可能是因为即使大多数层的权重为零,PyTorch 仍然将网络视为存在。这种实现基本上保留了比它需要的更多的权重。
以前有没有人遇到过这个问题并提出过有效的解决方案?
【问题讨论】:
标签: python machine-learning pytorch sparse-matrix