首页 > 编程笔记 > Python笔记 阅读:7

PyTorch torch.max():求最大值(附带实例)

使用 torch.max() 函数求最大值,参数与 torch.sum() 函数类似,但是参数 dim 须为整数,也有以下两种格式:
torch.max(input, dtype=None)
torch.max(input, dim, keepdim=False, dtype=None)
我们来看以下示例。

1) 设置参数 input 和 dim,代码如下:
# 导入torch库
import torch
# 创建一个张量a
a = torch.tensor([[1, 2], [3, 4]])
# 计算张量a所有元素的最大值,并将结果赋值给a15
a15 = torch.max(a)
# 沿着第0维计算张量a的元素最大值,并将结果赋值给a16
a16 = torch.max(a, dim=0)
# 沿着第1维计算张量a的元素最大值,并将结果赋值给a17
a17 = torch.max(a, dim=1)
# 打印a15、a16和a17的值
print(a15)
print(a16)
print(a17)
输出结果如下:

tensor(4)
torch.return_types.max(
values=tensor([3, 4]),
indices=tensor([1, 1]))
torch.return_types.max(
values=tensor([2, 4]),
indices=tensor([1, 1]))


2) 设置参数 keepdim,代码如下:
# 导入torch库
import torch
# 创建一个张量a
a = torch.tensor([[1, 2], [3, 4]])
# 沿着第0维计算张量a的元素最大值,并保持原始维度,将结果赋值给a18
a18 = torch.max(a, 0, keepdim=True)
# 沿着第1维计算张量a的元素最大值,并保持原始维度,将结果赋值给a19
a19 = torch.max(a, 1, keepdim=True)
# 打印a18和a19的值
print(a18)
print(a19)
输出结果如下:

torch.return_types.max(
values=tensor([[3, 4]]),
indices=tensor([[1, 1]]))
torch.return_types.max(
values=tensor([[2],[4]]),
indices=tensor([[1],[1]]))

相关文章