附录A:PyTorch 速览
仅覆盖本课程会用到的最小子集,给”几乎不熟 PyTorch”的读者一个起步台阶。深入学习请直接看官方文档。
A.1 张量 (Tensor)
import torch
a = torch.tensor([[1, 2], [3, 4]]) # 从 list 构造
b = torch.zeros(3, 4) # 全 0
c = torch.randn(2, 3) # 标准正态
d = torch.arange(10) # 0..9
print(a.shape, a.dtype, a.device)
形状变换三件套:
x = torch.randn(2, 3, 4)
x.view(6, 4) # 等价于 reshape,要求底层连续
x.transpose(0, 1) # 交换维度
x.permute(2, 0, 1) # 任意重排
广播规则与 NumPy 一致:维度从右往左对齐,缺失维补 1,1 可以和任意大小广播。
A.2 自动求导
任何 requires_grad=True 的张量参与的计算都会被记录在动态图里。
w = torch.randn(3, requires_grad=True)
x = torch.tensor([1., 2., 3.])
y = (w * x).sum()
y.backward()
print(w.grad) # tensor([1., 2., 3.])
需要”暂停求导”时用 torch.no_grad() 上下文管理器,或者 tensor.detach()。
A.3 nn.Module
模型的最小骨架:
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
model = MyModel()
for name, p in model.named_parameters():
print(name, p.shape)
nn.Parameter 自动注册为可训练参数,nn.Module 内部嵌套的子模块也会被递归收集。
A.4 Dataset / DataLoader
from torch.utils.data import Dataset, DataLoader
class XYDataset(Dataset):
def __init__(self, n=1000):
self.x = torch.randn(n, 10)
self.y = torch.randint(0, 2, (n,))
def __len__(self): return len(self.x)
def __getitem__(self, i): return self.x[i], self.y[i]
loader = DataLoader(XYDataset(), batch_size=32, shuffle=True, num_workers=2)
for x, y in loader:
...
num_workers > 0 时数据加载会在子进程进行,能显著吃满 GPU。
A.5 训练四步曲
optimizer.zero_grad() # 1. 清梯度
loss = compute_loss(model, x, y) # 2. 前向 + 损失
loss.backward() # 3. 反向
optimizer.step() # 4. 更新参数
牢记这四步永远按这个顺序,且每个 batch 都执行一遍。
A.6 GPU 加速
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
x, y = x.to(device), y.to(device)
混合精度(A100/H100 上能省一半显存):
scaler = torch.cuda.amp.GradScaler()
for x, y in loader:
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
loss = compute_loss(model, x, y)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
A.7 保存与加载
只保存参数(推荐):
torch.save(model.state_dict(), "ckpt.pt")
model.load_state_dict(torch.load("ckpt.pt", map_location=device))
不要保存整个 model 对象——会和代码版本耦合。
← 第 7 章 · 返回目录 · 附录 D · 训练增强 →