PyTorch torch.median():求中位数(附带实例)
使用函数 torch.median() 求中位数,该函数返回中位数,参数与 torch.max() 函数类似,有以下两种格式:
1) 设置参数 input 和 dim,代码如下:
2) 设置参数 keepdim,代码如下:
torch.median(input, dtype=None) torch.median(input, dim, keepdim=False, dtype=None)我们来看以下示例。
1) 设置参数 input 和 dim,代码如下:
# 导入torch库 import torch # 创建一个张量a a = torch.tensor([[1, 2], [3, 4]]) # 计算张量a所有元素的中位数,并将结果打印出来 print(torch.median(a)) # 沿着第1维计算张量a的元素中位数,并将结果打印出来 print(torch.median(a, 1))输出结果如下:
tensor(2)
torch.return_types.median(
values=tensor([1, 3]),
indices=tensor([0, 0]))
2) 设置参数 keepdim,代码如下:
# 导入torch库 import torch # 创建一个张量a a = torch.tensor([[1, 2], [3, 4]]) # 沿着第1维计算张量a的元素中位数,并保持原始维度,将结果赋值给变量median_result median_result = torch.median(a, 1, keepdim=True) # 打印median_result的值 print(median_result)输出结果如下:
torch.return_types.median(
values=tensor([[1],[3]]),
indices=tensor([[0],[0]]))