Python Pytorchviz安装和使用(附带实例)
Pytorchviz 是一个程序包,用于创建 PyTorch 执行图和轨迹的可视化。本节通过案例介绍该模型可视化工具。
Graphviz 在 Windows 中的安装需要下载 Release 包,并配置环境变量,否则会报以下错误:
1) 导入第三方库,代码如下:
2) 随机生成数据集,代码如下:
3) 设置网络模型,代码如下:
4) 可视化网络结构,代码如下:
上述 4 个步骤的代码主要涉及深度学习模型的构建、可视化和导出:
这样的操作可以帮助读者理解模型的结构和计算过程,以及进行模型的可视化和导出等相关操作。具体的应用场景可能包括模型的调试、分析和转换为其他格式等。
运行上述代码,会在当前目录下保存一个 Digraph.gv.pdf 文件,并在浏览器中默认打开,如下图所示:

图 1 模型可视化
Pytorchviz安装
在可视化之前,首先安装 graphviz 和 torchviz 第三方库,代码如下:pip install graphviz pip install tochvizGraphviz(Graph Visualization Software)是由 AT&T 实验室启动的开源工具包,它是用来处理 DOT 语言的工具,DOT 是一种图形描述语言,非常简单,只需要简单了解一下 DOT 语言,就可以用 Graphviz 绘图了,它对程序员特别有用。
Graphviz 在 Windows 中的安装需要下载 Release 包,并配置环境变量,否则会报以下错误:
graphviz.backend.ExecutableNotFound: failed to execute ['dot', '-Tpng', '-O', 'tmp'],
make sure the Graphviz executables are on your systems’ PATH
Pytorchviz建模可视化
接下来使用 Pytorchviz 工具可视化 PyTorch 模型。实现步骤如下:1) 导入第三方库,代码如下:
# 导入相关库 import torch from torch import nn from torchviz import make_dot, make_dot_from_trace
2) 随机生成数据集,代码如下:
# 生成一个形状为(1, 8)的随机张量 x x = torch.randn(1,8)
3) 设置网络模型,代码如下:
# 创建一个顺序模型 model = nn.Sequential() # 在模型中添加一个线性层,输入维度为 8,输出维度为 16 model.add_module('W0', nn.Linear(8, 16)) # 添加 tanh 激活函数 model.add_module('tanh', nn.Tanh()) # 添加一个线性层,输入维度为 16,输出维度为 1 model.add_module('W1', nn.Linear(16, 1))
4) 可视化网络结构,代码如下:
# 使用 make_dot 函数将模型对 x 的计算过程生成图 vis_graph = make_dot(model(x), params=dict(model.named_parameters())) # 查看生成的图 vis_graph.view() # 选择 onnx 导出模式 with torch.onnx.select_model_mode_for_export(model, False): # 使用 jit.trace 记录模型的计算过程 trace= torch.jit.trace(model, (x,)) # 查看记录的计算过程 torch.trace
上述 4 个步骤的代码主要涉及深度学习模型的构建、可视化和导出:
- 首先,通过 torch.randn(1, 8) 生成了一个形状为 (1, 8) 的随机张量 x。
- 其次,创建了一个 nn.Sequential 模型,并通过 add_module() 方法向模型中添加了三个模块:一个线性层 W0、tanh 激活函数和另一个线性层 W1。
- 然后,使用 make_dot 函数将模型对输入 x 的计算过程生成图形表示,并通过 vis_graph.view() 查看生成的图。
- 接着,通过 torch.onnx.select_model_mode_for_export 选择 onnx 导出模式,并使用 torch.jit.trace 记录模型的计算过程。
- 最后,通过 torch.trace 查看记录的计算过程。
这样的操作可以帮助读者理解模型的结构和计算过程,以及进行模型的可视化和导出等相关操作。具体的应用场景可能包括模型的调试、分析和转换为其他格式等。
运行上述代码,会在当前目录下保存一个 Digraph.gv.pdf 文件,并在浏览器中默认打开,如下图所示:

图 1 模型可视化