PyTorch torch.max():求最大值(附带实例)
使用 torch.max() 函数求最大值,参数与 torch.sum() 函数类似,但是参数 dim 须为整数,也有以下两种格式:
1) 设置参数 input 和 dim,代码如下:
2) 设置参数 keepdim,代码如下:
torch.max(input, dtype=None) torch.max(input, dim, keepdim=False, dtype=None)我们来看以下示例。
1) 设置参数 input 和 dim,代码如下:
# 导入torch库 import torch # 创建一个张量a a = torch.tensor([[1, 2], [3, 4]]) # 计算张量a所有元素的最大值,并将结果赋值给a15 a15 = torch.max(a) # 沿着第0维计算张量a的元素最大值,并将结果赋值给a16 a16 = torch.max(a, dim=0) # 沿着第1维计算张量a的元素最大值,并将结果赋值给a17 a17 = torch.max(a, dim=1) # 打印a15、a16和a17的值 print(a15) print(a16) print(a17)输出结果如下:
tensor(4)
torch.return_types.max(
values=tensor([3, 4]),
indices=tensor([1, 1]))
torch.return_types.max(
values=tensor([2, 4]),
indices=tensor([1, 1]))
2) 设置参数 keepdim,代码如下:
# 导入torch库 import torch # 创建一个张量a a = torch.tensor([[1, 2], [3, 4]]) # 沿着第0维计算张量a的元素最大值,并保持原始维度,将结果赋值给a18 a18 = torch.max(a, 0, keepdim=True) # 沿着第1维计算张量a的元素最大值,并保持原始维度,将结果赋值给a19 a19 = torch.max(a, 1, keepdim=True) # 打印a18和a19的值 print(a18) print(a19)输出结果如下:
torch.return_types.max(
values=tensor([[3, 4]]),
indices=tensor([[1, 1]]))
torch.return_types.max(
values=tensor([[2],[4]]),
indices=tensor([[1],[1]]))