安装

conda install graphviz
conda install tensorwatch


载入库

import sys
import torch
import tensorwatch as tw
import torchvision.models


网络结构可视化

alexnet_model = torchvision.models.alexnet()
tw.draw_model(alexnet_model, [1, 3, 224, 224])


载入alexnet,draw_model函数需要传入三个参数,第一个为model,第二个参数为input_shape,第三个参数为orientation,可以选择'LR'或者'TB',分别代表左右布局与上下布局。
在notebook中,执行完上面的代码会显示如下的图,将网络的结构及各个层的name和shape进行了可视化。

Pytorch 网络结构可视化

统计网络参数

通过model_stats方法统计各层的参数情况。
tw.model_stats(alexnet_model, [1, 3, 224, 224])
alexnet_model.features
alexnet_model.classifier
来源:https://zhuanlan.zhihu.com/p/66320870

 

相关文章:

  • 2021-08-20
  • 2021-04-17
  • 2021-08-05
  • 2021-09-18
  • 2021-07-13
  • 2021-07-16
  • 2022-01-22
  • 2021-08-22
猜你喜欢
  • 2021-11-26
  • 2022-01-28
  • 2022-12-23
  • 2021-06-22
  • 2022-01-18
  • 2022-12-23
  • 2021-09-21
相关资源
相似解决方案