电脑基础 · 2023年3月31日

torch.randn的用法

torch.randn 是一个 PyTorch 中的函数,用于生成指定大小的张量,其中每个元素都是从标准正态分布(均值为0,标准差为1)中随机抽取的。

torch.randn 的语法如下:

torch.randn(*size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor

其中 *size 表示张量的大小,可以是一个整数,一个元组或多个整数。例如,要创建一个大小为 3x2 的张量,可以使用以下代码:

import torch
x = torch.randn(3, 2)
print(x)

输出结果:

tensor([[ 0.4438, -0.0241],
[-0.4326, -0.8158],
[-0.3517, -1.3522]])

在上面的代码中,我们创建了一个大小为 3x2 的张量 x,其中每个元素都是从标准正态分布中随机抽取的。

out 参数可以指定一个输出张量,将生成的随机数填充到这个张量中。

dtype 参数可以指定生成的随机数的数据类型。默认情况下,它是 torch.float32

device 参数可以指定生成的张量的计算设备,例如 CPU 或 GPU。

requires_grad 参数可以指定是否需要计算梯度,默认值为 False