【发布时间】:2017-07-17 17:36:14
【问题描述】:
如何像 Keras 中的 model.summary() 方法一样在 PyTorch 中打印模型的摘要:
Model Summary:
____________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
====================================================================================================
input_1 (InputLayer) (None, 1, 15, 27) 0
____________________________________________________________________________________________________
convolution2d_1 (Convolution2D) (None, 8, 15, 27) 872 input_1[0][0]
____________________________________________________________________________________________________
maxpooling2d_1 (MaxPooling2D) (None, 8, 7, 27) 0 convolution2d_1[0][0]
____________________________________________________________________________________________________
flatten_1 (Flatten) (None, 1512) 0 maxpooling2d_1[0][0]
____________________________________________________________________________________________________
dense_1 (Dense) (None, 1) 1513 flatten_1[0][0]
====================================================================================================
Total params: 2,385
Trainable params: 2,385
Non-trainable params: 0
【问题讨论】:
-
你见过模块上的 state_dict() 方法吗?它为您提供模型的不同参数。没有直接的汇总方法,但可以使用 state_dict() 方法形成一个
-
所选答案现已过期,
torchsummary是更好的解决方案。 -
torchsummary已死。请使用来自 TylerYep 的torchinfo(又名torch-summary,带破折号)github.com/TylerYep/torchinfo