【发布时间】:2020-11-09 15:13:17
【问题描述】:
我正致力于在 PyTorch 中实现基于计算机视觉的研究论文。我通过参考论文构建了模型架构。作者已在 GitHub 上以“.pth.tar”格式上传保存的权重。我想在我的模型中加入相同的权重,这样我就可以跳过训练和优化部分,直接从神经网络获取输出。
这篇论文是学习在黑暗中看东西。
模型架构如下:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv1d(32, 12, 1)
.
.
def forward(self, x):
x = F.relu(self.conv1(x))
.
.
return x
net = Net()
然后是从谷歌驱动器/云存储导入训练的权重,并定义将训练的权重放入网络的函数。
PS:两者的模型架构完全相同
【问题讨论】:
-
您有我们的存储库链接吗?它可能是一个 state_dict,您可以使用 load_state_dict 加载它。
-
我认为是load_state_dict,仓库的链接是这样的:drive.google.com/file/d/1cY3gdAVqkNZo8GnosYGj7ciSkrwasmjP/…
标签: deep-learning neural-network computer-vision pytorch