首页 > 编程笔记 > Python笔记 阅读:4

TensorBoard add_graph()的用法(附带实例)

add_graph() 方法可以可视化神经网络模型,语法格式如下:
add_graph(model, input_to_model=None, verbose=False)
参数说明如下:
TensorboardX 给出了一个官方样例,大家可以尝试,代码如下:
# 导入相关库
import torch
import numpy as np
from torchvision import models, transforms
from PIL import Image
from tensorboardX import SummaryWriter

vgg16 = models.vgg16()  # 这里下载预训练好的模型
print(vgg16)            # 打印这个模型
transform_2 = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    # convert RGB to BGR
    # from <https://github.com/mrzhu-cool/pix2pix-pytorch/blob/master/util.py>
    transforms.Lambda(lambda x: torch.index_select(x, 0, torch.LongTensor([2, 1, 0]))),
    transforms.Lambda(lambda x: x*255),
    transforms.Normalize(mean = [103.939, 116.779, 123.68],
                         std = [ 1, 1, 1 ]),
])

cat_img = Image.open('./1.jpg')
# 因为PyTorch是分批次进行训练的,所以这里建立一个批次为1的数据集
vgg16_input=transform_2(cat_img)[np.newaxis]
print(vgg16_input.shape)

# 开始前向传播,打印输出值
raw_score = vgg16(vgg16_input)
raw_score_numpy = raw_score.data.numpy()
print(raw_score_numpy.shape, np.argmax(raw_score_numpy.ravel()))

# 将结构图在TensorBoard中展示
with SummaryWriter(log_dir='./runs/graph', comment='vgg16') as writer:
    writer.add_graph(vgg16, (vgg16_input,))
输出如下图所示:


图 1 可视化网络图

相关文章