首页 > 编程笔记 > Python笔记 阅读:8

PyTorch torch.median():求中位数(附带实例)

使用函数 torch.median() 求中位数,该函数返回中位数,参数与 torch.max() 函数类似,有以下两种格式:
torch.median(input, dtype=None)
torch.median(input, dim, keepdim=False, dtype=None)
我们来看以下示例。

1) 设置参数 input 和 dim,代码如下:
# 导入torch库
import torch
# 创建一个张量a
a = torch.tensor([[1, 2], [3, 4]])
# 计算张量a所有元素的中位数,并将结果打印出来
print(torch.median(a))
# 沿着第1维计算张量a的元素中位数,并将结果打印出来
print(torch.median(a, 1))
输出结果如下:

tensor(2)
torch.return_types.median(
values=tensor([1, 3]),
indices=tensor([0, 0]))


2) 设置参数 keepdim,代码如下:
# 导入torch库
import torch
# 创建一个张量a
a = torch.tensor([[1, 2], [3, 4]])
# 沿着第1维计算张量a的元素中位数,并保持原始维度,将结果赋值给变量median_result
median_result = torch.median(a, 1, keepdim=True)
# 打印median_result的值
print(median_result)
输出结果如下:

torch.return_types.median(
values=tensor([[1],[3]]),
indices=tensor([[0],[0]]))

相关文章