首页 > 编程笔记 > Python笔记 阅读:1

PyTorch torch.prod():求所有元素的积(附带实例)

PyTorch 与其他统计软件一样,也内置了丰富的统计函数。使用函数 torch.prod() 求所有元素的积,语法如下:
torch.prod(input, dim=None, keepdim=False)
torch.prod() 是一个 PyTorch 函数,用于计算张量(Tensor)中所有元素的乘积。它接受一个张量作为输入,并返回一个标量值,表示输入张量中所有元素的乘积。

代码说明:
示例代码如下:
import torch
# 创建一个2×2的张量a
a = torch.tensor([[1, 2], [3, 4]])
# 计算张量a中所有元素的乘积,并将结果赋值给result
result = torch.prod(a)
print(result)  # 输出:24
# 沿着第0维(行)计算张量a中元素的乘积,并将结果赋值给result_dim
result_dim = torch.prod(a, dim=0)
print(result_dim)  # 输出:tensor([3, 8])
# 沿着第0维(行)计算张量a中元素的乘积,并保持原始维度,将结果赋值给result_keepdim
result_keepdim = torch.prod(a, dim=0, keepdim=True)
print(result_keepdim)  # 输出:tensor([[3, 8]])

相关文章