PyTorch torch.prod():求所有元素的积(附带实例)
PyTorch 与其他统计软件一样,也内置了丰富的统计函数。使用函数 torch.prod() 求所有元素的积,语法如下:
代码说明:
示例代码如下:
torch.prod(input, dim=None, keepdim=False)torch.prod() 是一个 PyTorch 函数,用于计算张量(Tensor)中所有元素的乘积。它接受一个张量作为输入,并返回一个标量值,表示输入张量中所有元素的乘积。
代码说明:
- input:输入张量;
- dim:指定沿哪个维度进行乘积运算。默认值为None,表示计算整个张量的乘积。如果指定了维度,那么结果将是一个降低该维度的张量;
- keepdim:布尔值,表示是否保持原始张量的维度。默认值为 False,表示不保持原始维度。如果设置为True,则结果张量的维度与输入张量相同,但指定的维度大小为 1。
示例代码如下:
import torch # 创建一个2×2的张量a a = torch.tensor([[1, 2], [3, 4]]) # 计算张量a中所有元素的乘积,并将结果赋值给result result = torch.prod(a) print(result) # 输出:24 # 沿着第0维(行)计算张量a中元素的乘积,并将结果赋值给result_dim result_dim = torch.prod(a, dim=0) print(result_dim) # 输出:tensor([3, 8]) # 沿着第0维(行)计算张量a中元素的乘积,并保持原始维度,将结果赋值给result_keepdim result_keepdim = torch.prod(a, dim=0, keepdim=True) print(result_keepdim) # 输出:tensor([[3, 8]])