Skip to content

PyTorch 常用模块

random

函数功能官方文档
torch.rand()返回 [0,1) 内服从均匀分布的随机数地址
torch.randn()返回服从标准正态分布的随机数地址
torch.randint()返回 low -> high 内的随机整数地址
torch.randperm()返回 0 -> n-1 内的随机排列整数地址

GPU 加速

Tensor 位置

python
# 查看模型在哪个设备上
print(next(model.parameters()).device)

# 查看张量在哪个设备上
print(tensor.device)

编写代码

一般来讲,需要转移数据和模型两部分到 GPU 上

python
# 在训练时把张量转移到 GPU
x, y = x.to(device), y.to(device)

# 在实例化模型时把模型参数转移到 GPU
model = CustomModel().to(device)