PyTorch 常用模块
random
函数 | 功能 | 官方文档 |
---|---|---|
torch.rand() | 返回 | 地址 |
torch.randn() | 返回服从标准正态分布的随机数 | 地址 |
torch.randint() | 返回 low -> high 内的随机整数 | 地址 |
torch.randperm() | 返回 0 -> n-1 内的随机排列整数 | 地址 |
GPU 加速
Tensor 位置
python
# 查看模型在哪个设备上
print(next(model.parameters()).device)
# 查看张量在哪个设备上
print(tensor.device)
编写代码
一般来讲,需要转移数据和模型两部分到 GPU 上
python
# 在训练时把张量转移到 GPU
x, y = x.to(device), y.to(device)
# 在实例化模型时把模型参数转移到 GPU
model = CustomModel().to(device)