首页 > 编程笔记 > Python笔记 阅读:1

Python Pytorchviz安装和使用(附带实例)

Pytorchviz 是一个程序包,用于创建 PyTorch 执行图和轨迹的可视化。本节通过案例介绍该模型可视化工具。

Pytorchviz安装

在可视化之前,首先安装 graphviz 和 torchviz 第三方库,代码如下:
pip install graphviz
pip install tochviz
Graphviz(Graph Visualization Software)是由 AT&T 实验室启动的开源工具包,它是用来处理 DOT 语言的工具,DOT 是一种图形描述语言,非常简单,只需要简单了解一下 DOT 语言,就可以用 Graphviz 绘图了,它对程序员特别有用。

Graphviz 在 Windows 中的安装需要下载 Release 包,并配置环境变量,否则会报以下错误:

graphviz.backend.ExecutableNotFound: failed to execute ['dot', '-Tpng', '-O', 'tmp'],
make sure the Graphviz executables are on your systems’ PATH

Pytorchviz建模可视化

接下来使用 Pytorchviz 工具可视化 PyTorch 模型。实现步骤如下:
1) 导入第三方库,代码如下:
# 导入相关库
import torch
from torch import nn
from torchviz import make_dot, make_dot_from_trace

2) 随机生成数据集,代码如下:
# 生成一个形状为(1, 8)的随机张量 x
x = torch.randn(1,8)

3) 设置网络模型,代码如下:
# 创建一个顺序模型
model = nn.Sequential()
# 在模型中添加一个线性层,输入维度为 8,输出维度为 16
model.add_module('W0', nn.Linear(8, 16))
# 添加 tanh 激活函数
model.add_module('tanh', nn.Tanh())
# 添加一个线性层,输入维度为 16,输出维度为 1
model.add_module('W1', nn.Linear(16, 1))

4) 可视化网络结构,代码如下:
# 使用 make_dot 函数将模型对 x 的计算过程生成图
vis_graph = make_dot(model(x), params=dict(model.named_parameters()))
# 查看生成的图
vis_graph.view()
# 选择 onnx 导出模式
with torch.onnx.select_model_mode_for_export(model, False):
   # 使用 jit.trace 记录模型的计算过程
   trace= torch.jit.trace(model, (x,))
# 查看记录的计算过程
torch.trace

上述 4 个步骤的代码主要涉及深度学习模型的构建、可视化和导出:
这样的操作可以帮助读者理解模型的结构和计算过程,以及进行模型的可视化和导出等相关操作。具体的应用场景可能包括模型的调试、分析和转换为其他格式等。

运行上述代码,会在当前目录下保存一个 Digraph.gv.pdf 文件,并在浏览器中默认打开,如下图所示:


图 1 模型可视化

相关文章