深入理解Transformer的核心:从数学原理到代码实现
Query (Q):查询向量
Key (K):键向量
Value (V):值向量
Attention输出是所有位置Value的加权和,权重由Query与Key的相似度决定。
本质上是:"根据当前位置的需求(Q),从所有位置中按相关性(K)提取信息(V)"
模型 | 参数量 | 层数 | 隐藏维度 | 注意力头数 | 头维度 | 词表大小 | 最大序列长度 |
---|---|---|---|---|---|---|---|
GPT-3 | 175B | 96 | 12,288 | 96 | 128 | 50,257 | 2,048 |
Llama 2-7B | 7B | 32 | 4,096 | 32 (Q) / 8 (KV) | 128 | 32,000 | 4,096 |
Llama 2-70B | 70B | 80 | 8,192 | 64 (Q) / 8 (KV) | 128 | 32,000 | 4,096 |
Mistral-7B | 7B | 32 | 4,096 | 32 (Q) / 8 (KV) | 128 | 32,000 | 32,768 |
PaLM | 540B | 118 | 18,432 | 48 (MQA) | 256 | 256,000 | 2,048 |
Claude 3 | ~100B | ~80 | ~12,288 | ~96 | 128 | ~100,000 | 200,000 |
单层注意力矩阵规模:
参数设置:
权重矩阵大小:
运行时矩阵大小(批次=1):
# 训练时内存需求(Adam优化器) training_memory = ( model_params * 2 + # FP16模型权重 model_params * 2 + # FP16梯度 model_params * 4 * 2 # FP32 Adam状态(m和v) ) # 推理时内存需求 inference_memory = ( model_params * 2 + # FP16模型权重 kv_cache_memory # KV缓存 ) # KV Cache计算 # 标准MHA: kv_cache = 2 * n_layers * seq_len * n_heads * d_head * batch_size * 2 # FP16 # GQA (n_kv_heads < n_heads): kv_cache_gqa = 2 * n_layers * seq_len * n_kv_heads * d_head * batch_size * 2 # MQA (n_kv_heads = 1): kv_cache_mqa = 2 * n_layers * seq_len * 1 * d_head * batch_size * 2 # 示例:Llama 2-7B,seq_len=4096, batch=1 # 标准:2 * 32 * 4096 * 32 * 128 * 1 * 2 = 2GB # GQA: 2 * 32 * 4096 * 8 * 128 * 1 * 2 = 512MB(节省75%)
缩放点积注意力 (Scaled Dot-Product Attention):
$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$
其中:
# 步骤1: 计算注意力分数(Query与Key的相似度) scores = torch.matmul(Q, K.transpose(-2, -1)) # [batch, heads, n, d_k] × [batch, heads, d_k, n] = [batch, heads, n, n] # 步骤2: 缩放(防止softmax的梯度消失) scores = scores / math.sqrt(d_k) # 除以√d_k,使方差稳定在1附近 # 步骤3: 应用softmax得到注意力权重(每行归一化) attn_weights = F.softmax(scores, dim=-1) # [batch, heads, n, n],每行和为1 # 步骤4: 使用注意力权重对Value进行加权求和 output = torch.matmul(attn_weights, V) # [batch, heads, n, n] × [batch, heads, n, d_v] = [batch, heads, n, d_v] # 关键点解释: # - Q的每一行代表一个位置的查询向量 # - K的每一行代表一个位置的键向量 # - scores[i,j]表示位置i对位置j的注意力分数 # - softmax使得每个位置的注意力权重和为1 # - 最终输出是V的加权组合,权重由Q和K的相似度决定
完整的上下文建模能力
O(n²)复杂度,长序列计算昂贵
GPT-3, BERT, 标准Transformer
局部窗口注意力 (Local/Sliding Window):
$\text{LocalAttn}(Q, K, V)_i = \text{softmax}\left(\frac{Q_i K_{[i-w:i+w]}^T}{\sqrt{d_k}}\right)V_{[i-w:i+w]}$
其中 $w$ 是窗口半径,每个位置只关注邻近的 $2w+1$ 个位置
块稀疏注意力 (Block Sparse):
将序列分成大小为 $b$ 的块:
$\text{BlockAttn}(Q^{(i)}, K^{(i)}, V^{(i)}) = \text{softmax}\left(\frac{Q^{(i)}(K^{(i)})^T}{\sqrt{d_k}}\right)V^{(i)}$
# 局部窗口注意力:每个位置只关注窗口内的位置 def create_local_mask(seq_len, window_size): mask = torch.zeros(seq_len, seq_len) for i in range(seq_len): start = max(0, i - window_size // 2) end = min(seq_len, i + window_size // 2 + 1) mask[i, start:end] = 1 # 位置i只关注[start:end]范围 return mask # 块稀疏注意力:将序列分块,块内全连接 Q = Q.view(batch, heads, n_blocks, block_size, d_k) # 重塑为块 K = K.view(batch, heads, n_blocks, block_size, d_k) scores = torch.matmul(Q, K.transpose(-2, -1)) # 块内计算注意力 # 复杂度分析: # - 局部注意力: O(n × window_size × d) 而非 O(n² × d) # - 块稀疏: O(n × block_size × d) 而非 O(n² × d)
线性复杂度,长序列高效
可能丢失长距离依赖
Longformer, BigBird, Sparse Transformer
核技巧 (Kernel Trick):
标准注意力:$O = \text{softmax}(QK^T)V = \frac{\exp(QK^T)}{\text{rowsum}(\exp(QK^T))}V$
线性注意力(改变计算顺序):$O = \frac{\phi(Q)(\phi(K)^TV)}{\phi(Q)\text{sum}(\phi(K))}$
# 线性注意力核心:改变计算顺序 # 标准注意力(低效) attn = torch.matmul(Q, K.transpose(-2, -1)) # [n, d] × [d, n] = [n, n] 大矩阵! output = torch.matmul(attn, V) # [n, n] × [n, d] = [n, d] # 线性注意力(高效) # 步骤1: 先计算K^T V,得到小矩阵 KV = torch.matmul(K.transpose(-2, -1), V) # [d, n] × [n, d] = [d, d] 小矩阵! # 步骤2: 计算归一化因子 Z = 1 / (torch.einsum('nd,d->n', Q, K.sum(dim=0)) + eps) # [n] # 步骤3: Q与KV相乘并归一化 output = torch.matmul(Q, KV) * Z.unsqueeze(-1) # [n, d] × [d, d] = [n, d] # 关键点: # - 避免了n×n的注意力矩阵 # - 当d << n时,效率提升巨大 # - 需要特征映射φ来保证非负性
O(n)复杂度,适合超长序列(>32K)
可能损失表达能力,精度略低
Performer, Linformer, Linear Transformer, Nyström
核心思想:将查询头分成G组,每组共享键值对
$$\text{GQA}(Q^{(g)}, K^{(g/G)}, V^{(g/G)}) = \text{softmax}\left(\frac{Q^{(g)}(K^{(g/G)})^T}{\sqrt{d_k}}\right)V^{(g/G)}$$
# GQA核心实现 class GroupedQueryAttention(nn.Module): def __init__(self, d_model, n_heads=8, n_kv_heads=2): # n_heads: Query头数量(如8) # n_kv_heads: KV头数量(如2) # n_groups: 每个KV头服务的Q头数(8/2=4) self.n_groups = n_heads // n_kv_heads # Q使用全部维度,KV使用更少维度 self.W_q = nn.Linear(d_model, d_model) # 8个头 self.W_k = nn.Linear(d_model, n_kv_heads * d_k) # 2个头 self.W_v = nn.Linear(d_model, n_kv_heads * d_k) # 2个头 def forward(self, x): # 计算Q, K, V Q = self.W_q(x).view(batch, seq, n_heads, d_k) # [B, N, 8, d] K = self.W_k(x).view(batch, seq, n_kv_heads, d_k) # [B, N, 2, d] V = self.W_v(x).view(batch, seq, n_kv_heads, d_k) # [B, N, 2, d] # 关键步骤:重复KV以匹配Q的头数 K = K.repeat_interleave(self.n_groups, dim=2) # [B, N, 8, d] V = V.repeat_interleave(self.n_groups, dim=2) # [B, N, 8, d] # 标准注意力计算 attn = torch.matmul(Q, K.transpose(-2, -1)) / sqrt(d_k) attn = F.softmax(attn, dim=-1) output = torch.matmul(attn, V)
减少KV缓存50-75%,保持模型质量
需要调整KV头数量的超参数
Llama 2, Llama 3, Mistral, Mixtral
核心思想:所有查询头共享单一的键值对
$$\text{MQA}(Q^{(h)}, K, V) = \text{softmax}\left(\frac{Q^{(h)}K^T}{\sqrt{d_k}}\right)V$$
KV缓存从 $O(n \cdot h \cdot d)$ 降到 $O(n \cdot d)$
# MQA核心实现 class MultiQueryAttention(nn.Module): def __init__(self, d_model, n_heads=8): # Q有多个头,K和V只有单头 self.W_q = nn.Linear(d_model, d_model) # 输出8个头的Q self.W_k = nn.Linear(d_model, d_k) # 输出1个头的K self.W_v = nn.Linear(d_model, d_k) # 输出1个头的V def forward(self, x): # 计算Q(多头)和KV(单头) Q = self.W_q(x).view(batch, seq, n_heads, d_k) # [B, N, 8, d] K = self.W_k(x).view(batch, seq, 1, d_k) # [B, N, 1, d] V = self.W_v(x).view(batch, seq, 1, d_k) # [B, N, 1, d] # 关键:扩展KV到所有头(广播) K = K.expand(-1, -1, n_heads, -1) # [B, N, 8, d] V = V.expand(-1, -1, n_heads, -1) # [B, N, 8, d] # 标准注意力计算 attn = torch.matmul(Q, K.transpose(-2, -1)) / sqrt(d_k) # 内存分析(seq_len=2048, d=64, 8头): # MHA KV缓存: 2 × 8 × 2048 × 64 × 4B = 8MB # MQA KV缓存: 2 × 1 × 2048 × 64 × 4B = 1MB # 节省: 87.5%!
最小KV缓存(节省87.5%),推理最快
可能降低模型性能,训练不稳定
PaLM, Falcon, StarCoder, SantaCoder
在线Softmax (Online Softmax):
数值稳定的增量计算:
$$m^{new} = \max(m^{old}, \text{rowmax}(S))$$
$$l^{new} = e^{m^{old}-m^{new}}l^{old} + \text{rowsum}(e^{S-m^{new}})$$
$$O^{new} = \frac{e^{m^{old}-m^{new}}l^{old}O^{old} + e^{S-m^{new}}V}{l^{new}}$$
# Flash Attention核心循环 def flash_attention_forward(Q, K, V): # Q: [batch, heads, N, d] # 分块大小(适配SRAM大小) Br = 64 # Q块大小 Bc = 64 # KV块大小 N = Q.shape[2] O = torch.zeros_like(Q) # 输出 L = torch.zeros(batch, heads, N) # 归一化因子 # 外循环:遍历Q的块 for i in range(0, N, Br): Qi = Q[:, :, i:i+Br, :] # 加载Q块到SRAM Oi = torch.zeros_like(Qi) Li = torch.zeros(batch, heads, Br) - float('inf') Mi = torch.full_like(Li, -float('inf')) # 最大值 # 内循环:遍历KV的块 for j in range(0, N, Bc): Kj = K[:, :, j:j+Bc, :] # 加载KV块到SRAM Vj = V[:, :, j:j+Bc, :] # 在SRAM中计算注意力 Sij = torch.matmul(Qi, Kj.transpose(-2, -1)) / sqrt(d) # 在线softmax(数值稳定) Mi_new = torch.max(Mi, Sij.max(dim=-1)[0]) Pij = torch.exp(Sij - Mi_new.unsqueeze(-1)) # 增量更新输出 Li = torch.exp(Mi - Mi_new) * Li + Pij.sum(dim=-1) Oi = torch.exp(Mi - Mi_new).unsqueeze(-1) * Oi + torch.matmul(Pij, Vj) Mi = Mi_new # 归一化并写回HBM O[:, :, i:i+Br, :] = Oi / Li.unsqueeze(-1) return O # 关键优化: # 1. 避免存储n×n注意力矩阵 # 2. 分块计算,每块都在SRAM中 # 3. IO复杂度从O(N²) → O(N²/M),M是SRAM大小
2-4倍速度提升,O(n)内存,数值稳定
需要CUDA kernel优化,实现复杂
GPT-4, Claude, Gemini, 现代大模型标配
注意力机制 | 时间复杂度 | 内存复杂度 | KV缓存 | 最适合场景 |
---|---|---|---|---|
Full Attention | O(n²d) | O(n²) | O(nhd) | 短序列,完整建模 |
Local Attention | O(nwd) | O(nw) | O(nhd) | 长序列,局部依赖 |
Linear Attention | O(nd²) | O(nd) | O(nhd) | 超长序列 |
GQA | O(n²d) | O(n²) | O(n(h/G)d) | 推理优化 |
MQA | O(n²d) | O(n²) | O(nd) | 极限内存优化 |
Flash Attention | O(n²d) | O(n) | O(nhd) | 训练加速 |
特性 | Full | Sparse | Linear | GQA | MQA | Flash |
---|---|---|---|---|---|---|
模型质量 | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ |
推理速度 | ⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐ |
内存效率 | ⭐ | ⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ |
长序列支持 | ⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐ | ⭐⭐ | ⭐⭐⭐ |
实现难度 | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐ |
主流采用度 | ⭐⭐⭐⭐⭐ | ⭐⭐⭐ | ⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ |