🚀 大语言模型注意力机制详解

深入理解Transformer的核心:从数学原理到代码实现

🔄 训练与推理详解

📚 QKV的本质含义

Query (Q):查询向量

  • 代表"我在找什么信息"
  • 每个位置的token生成一个Query,用于查询其他位置的信息

Key (K):键向量

  • 代表"我能提供什么信息"
  • 每个位置的token生成一个Key,供其他位置查询

Value (V):值向量

  • 代表"我的实际信息内容"
  • 根据Query-Key的匹配度,Value被加权聚合
注意力输出的含义:

Attention输出是所有位置Value的加权和,权重由Query与Key的相似度决定。

本质上是:"根据当前位置的需求(Q),从所有位置中按相关性(K)提取信息(V)"

🚀 训练:并行计算所有位置

训练时的并行计算(Teacher Forcing) 输入序列:"我爱北京天安门" 北京 天安 [EOS] 所有位置同时生成Q、K、V Q: [6×d_model] 并行计算 K: [6×d_model] 并行计算 V: [6×d_model] 并行计算 因果掩码(Causal Mask) 绿色=可见,白色=被掩码(防止看到未来) 为什么训练可以并行? 1. Teacher Forcing:训练时知道完整答案 2. 因果掩码:确保每个位置只看到之前的内容 3. 矩阵运算:所有位置的注意力一次计算 4. GPU并行:充分利用GPU的并行计算能力 💡 一次前向传播处理整个序列, 而不是逐个token计算 训练目标(右移一位) 北京 天安 [EOS] - Loss = CrossEntropy(预测输出, 目标输出)

🔮 推理:逐个生成Token

推理时的自回归生成(Auto-regressive) Step 1: 输入"我" 生成Q₁,K₁,V₁ Attn → FFN 输出: "爱" KV Cache: [K₁,V₁] Step 2: 输入"我爱" 复用K₁,V₁ 只算Q₂,K₂,V₂ 北京 输出: "北京" KV Cache: [K₁,V₁,K₂,V₂] Step 3: 输入"我爱北京" 北京 复用K₁,V₁,K₂,V₂ 只算Q₃,K₃,V₃ 天安 输出: "天安" KV Cache持续增长... KV Cache机制 • 每步只需计算新位置的Q、K、V • 之前位置的K、V保存在缓存中复用 • 注意力计算: Q_new × [K_cache; K_new]ᵀ • 内存需求: 2 × n_layers × seq_len × n_heads × d_head 计算示例(单个token): 1. 输入embedding: [1 × d_model] 2. 生成Q_new: [1 × n_heads × d_head] 3. Attention: Q_new × K_cache^T → [1 × seq_len] 推理瓶颈 ❌ 顺序依赖:必须等前一个token生成完 ❌ 内存瓶颈:KV Cache随序列长度线性增长 ❌ 低GPU利用率:每步只处理1个token 这就是为什么需要GQA/MQA优化! 优化方案 ✓ 批处理:同时处理多个序列 ✓ GQA/MQA:减少KV Cache内存 ✓ 投机解码:小模型预测,大模型验证 Flash Attention也能加速推理!

📏 主流模型规模详解

🎯 典型模型参数配置

模型 参数量 层数 隐藏维度 注意力头数 头维度 词表大小 最大序列长度
GPT-3 175B 96 12,288 96 128 50,257 2,048
Llama 2-7B 7B 32 4,096 32 (Q) / 8 (KV) 128 32,000 4,096
Llama 2-70B 70B 80 8,192 64 (Q) / 8 (KV) 128 32,000 4,096
Mistral-7B 7B 32 4,096 32 (Q) / 8 (KV) 128 32,000 32,768
PaLM 540B 118 18,432 48 (MQA) 256 256,000 2,048
Claude 3 ~100B ~80 ~12,288 ~96 128 ~100,000 200,000

📊 矩阵规模计算示例(以Llama 2-7B为例)

单层注意力矩阵规模:

参数设置:

  • d_model = 4,096(隐藏维度)
  • n_heads = 32(查询头数)
  • n_kv_heads = 8(KV头数,GQA)
  • d_head = 128(每个头的维度)
  • seq_len = 4,096(序列长度)

权重矩阵大小:

  • W_q: [4096 × 4096] = 16.8M参数
  • W_k: [4096 × 1024] = 4.2M参数(GQA优化)
  • W_v: [4096 × 1024] = 4.2M参数(GQA优化)
  • W_o: [4096 × 4096] = 16.8M参数
  • 单层注意力总计:42M参数

运行时矩阵大小(批次=1):

  • Q: [4096 × 32 × 128] = 16MB(FP32)
  • K: [4096 × 8 × 128] = 4MB(GQA优化)
  • V: [4096 × 8 × 128] = 4MB(GQA优化)
  • 注意力分数: [32 × 4096 × 4096] = 2.1GB(如不用Flash)

💾 内存需求分析

不同模型的显存需求对比 训练时显存需求(混合精度) 7B模型 ~56GB(参数14GB + 梯度14GB + 优化器28GB) 13B模型 ~104GB(参数26GB + 梯度26GB + 优化器52GB) 70B模型 ~560GB(需要多卡并行) 规则:训练显存 ≈ 模型参数 × 8(混合精度) 推理时显存需求(FP16) 模型权重 7B: 14GB KV Cache(seq=4096) 标准: 16GB GQA 4GB(节省75%) MQA 2GB(节省87.5%) GPU配置建议 推理部署: • 7B模型:RTX 4090 (24GB) • 13B模型:A100 40GB 或 2×RTX 4090 • 70B模型:4×A100 80GB 训练微调: • 7B模型:A100 80GB • 13B模型:2×A100 80GB • 70B模型:8×A100 80GB 优化技巧: • 使用GQA/MQA减少KV Cache • 量化(INT8/INT4)减少权重内存 • Flash Attention减少激活内存

📝 内存计算公式

# 训练时内存需求(Adam优化器)
training_memory = (
    model_params * 2 +     # FP16模型权重
    model_params * 2 +     # FP16梯度  
    model_params * 4 * 2   # FP32 Adam状态(m和v)
)

# 推理时内存需求
inference_memory = (
    model_params * 2 +     # FP16模型权重
    kv_cache_memory        # KV缓存
)

# KV Cache计算
# 标准MHA:
kv_cache = 2 * n_layers * seq_len * n_heads * d_head * batch_size * 2  # FP16

# GQA (n_kv_heads < n_heads):
kv_cache_gqa = 2 * n_layers * seq_len * n_kv_heads * d_head * batch_size * 2

# MQA (n_kv_heads = 1):
kv_cache_mqa = 2 * n_layers * seq_len * 1 * d_head * batch_size * 2

# 示例:Llama 2-7B,seq_len=4096, batch=1
# 标准:2 * 32 * 4096 * 32 * 128 * 1 * 2 = 2GB
# GQA: 2 * 32 * 4096 * 8 * 128 * 1 * 2 = 512MB(节省75%)

1 全注意力 (Full/Dense Attention)

📐 数学原理

缩放点积注意力 (Scaled Dot-Product Attention):

$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$

计算步骤分解:
  1. 步骤1: 计算注意力分数 $S = QK^T$ (Query与Key的相似度)
  2. 步骤2: 缩放 $S_{scaled} = \frac{S}{\sqrt{d_k}}$ (防止梯度消失)
  3. 步骤3: 归一化 $A = \text{softmax}(S_{scaled})$ (得到注意力权重)
  4. 步骤4: 加权求和 $O = AV$ (根据权重聚合Value)

其中:

  • $Q \in \mathbb{R}^{n \times d_k}$:查询矩阵 (Queries)
  • $K \in \mathbb{R}^{n \times d_k}$:键矩阵 (Keys)
  • $V \in \mathbb{R}^{n \times d_v}$:值矩阵 (Values)
  • $n$:序列长度,$d_k$:键/查询维度,$d_v$:值维度

🎨 图形化解释 - 分步骤展示

注意力计算详细步骤 步骤1: 计算相似度 Q [n×d_k] × K^T [d_k×n] = S [n×n] 每个Query与所有Key的点积 步骤2: 缩放 S ÷ √d_k = S' [n×n] 防止softmax梯度消失 步骤3: 归一化 S' softmax (按行) A [n×n] 每行和为1的权重矩阵 步骤4: 加权聚合 A [n×n] × V [n×d_v] = O [n×d_v] 根据注意力权重组合Value 维度示例(n=4, d_k=64, d_v=64) Q: [4×64] × K^T: [64×4] = S: [4×4] S: [4×4] ÷ √64 = S': [4×4] softmax(S'): [4×4] = A: [4×4] A: [4×4] × V: [4×64] = O: [4×64] 💡 输出维度与输入Q相同,但内容是V的加权组合

📝 代码实现与步骤对应

# 步骤1: 计算注意力分数(Query与Key的相似度)
scores = torch.matmul(Q, K.transpose(-2, -1))  # [batch, heads, n, d_k] × [batch, heads, d_k, n] = [batch, heads, n, n]

# 步骤2: 缩放(防止softmax的梯度消失)
scores = scores / math.sqrt(d_k)  # 除以√d_k,使方差稳定在1附近

# 步骤3: 应用softmax得到注意力权重(每行归一化)
attn_weights = F.softmax(scores, dim=-1)  # [batch, heads, n, n],每行和为1

# 步骤4: 使用注意力权重对Value进行加权求和
output = torch.matmul(attn_weights, V)  # [batch, heads, n, n] × [batch, heads, n, d_v] = [batch, heads, n, d_v]

# 关键点解释:
# - Q的每一行代表一个位置的查询向量
# - K的每一行代表一个位置的键向量
# - scores[i,j]表示位置i对位置j的注意力分数
# - softmax使得每个位置的注意力权重和为1
# - 最终输出是V的加权组合,权重由Q和K的相似度决定

✅ 优势

完整的上下文建模能力

⚠️ 劣势

O(n²)复杂度,长序列计算昂贵

💡 应用

GPT-3, BERT, 标准Transformer

🏆 主流模型采用情况

GPT系列

  • GPT-3: Full Attention
  • GPT-4: Flash Attention + Sparse

Llama系列

  • Llama 1: Full Attention
  • Llama 2/3: GQA + RoPE

Google系列

  • PaLM: MQA
  • Gemini: Flash + GQA混合

其他模型

  • Mistral/Mixtral: GQA (4头)
  • Falcon: MQA
  • Claude: Flash Attention

2 稀疏注意力 (Sparse Attention)

📐 数学原理

局部窗口注意力 (Local/Sliding Window):

$\text{LocalAttn}(Q, K, V)_i = \text{softmax}\left(\frac{Q_i K_{[i-w:i+w]}^T}{\sqrt{d_k}}\right)V_{[i-w:i+w]}$

其中 $w$ 是窗口半径,每个位置只关注邻近的 $2w+1$ 个位置

块稀疏注意力 (Block Sparse):

将序列分成大小为 $b$ 的块:

$\text{BlockAttn}(Q^{(i)}, K^{(i)}, V^{(i)}) = \text{softmax}\left(\frac{Q^{(i)}(K^{(i)})^T}{\sqrt{d_k}}\right)V^{(i)}$

🎨 图形化解释

局部窗口注意力 位置i 位置j 有注意力连接 无注意力连接 块稀疏注意力 块1 块2

📝 代码实现

# 局部窗口注意力:每个位置只关注窗口内的位置
def create_local_mask(seq_len, window_size):
    mask = torch.zeros(seq_len, seq_len)
    for i in range(seq_len):
        start = max(0, i - window_size // 2)
        end = min(seq_len, i + window_size // 2 + 1)
        mask[i, start:end] = 1  # 位置i只关注[start:end]范围
    return mask

# 块稀疏注意力:将序列分块,块内全连接
Q = Q.view(batch, heads, n_blocks, block_size, d_k)  # 重塑为块
K = K.view(batch, heads, n_blocks, block_size, d_k)
scores = torch.matmul(Q, K.transpose(-2, -1))  # 块内计算注意力

# 复杂度分析:
# - 局部注意力: O(n × window_size × d) 而非 O(n² × d)
# - 块稀疏: O(n × block_size × d) 而非 O(n² × d)

✅ 优势

线性复杂度,长序列高效

⚠️ 劣势

可能丢失长距离依赖

💡 应用

Longformer, BigBird, Sparse Transformer

3 线性注意力 (Linear Attention)

📐 数学原理

核技巧 (Kernel Trick):

标准注意力:$O = \text{softmax}(QK^T)V = \frac{\exp(QK^T)}{\text{rowsum}(\exp(QK^T))}V$

线性注意力(改变计算顺序):$O = \frac{\phi(Q)(\phi(K)^TV)}{\phi(Q)\text{sum}(\phi(K))}$

关键优化:改变矩阵乘法顺序
  • 标准方式:$(QK^T)V$ - 先算$QK^T$得到$n×n$矩阵,复杂度$O(n^2d)$
  • 线性方式:$Q(K^TV)$ - 先算$K^TV$得到$d×d$矩阵,复杂度$O(nd^2)$
  • 当$d \ll n$时,显著降低计算复杂度

🎨 图形化解释

标准注意力 O(n²d) 第1步: Q [n×d] × K^T [d×n] = Attn [n×n] 大矩阵! 第2步: Attn × V [n×d] = O [n×d] 总复杂度: O(n²d) - 需要存储n×n矩阵 线性注意力 O(nd²) 第1步: K^T [d×n] × V [n×d] = KV [d×d] 小矩阵! 第2步: Q [n×d] × KV [d×d] = O [n×d] 总复杂度: O(nd²) - 只需d×d矩阵 效率对比(假设n=2048, d=64) 标准注意力: 2048² × 64 = 268M 操作 线性注意力: 2048 × 64² = 8.4M 操作 速度提升 32倍!

📝 代码实现

# 线性注意力核心:改变计算顺序

# 标准注意力(低效)
attn = torch.matmul(Q, K.transpose(-2, -1))  # [n, d] × [d, n] = [n, n] 大矩阵!
output = torch.matmul(attn, V)               # [n, n] × [n, d] = [n, d]

# 线性注意力(高效)
# 步骤1: 先计算K^T V,得到小矩阵
KV = torch.matmul(K.transpose(-2, -1), V)    # [d, n] × [n, d] = [d, d] 小矩阵!

# 步骤2: 计算归一化因子
Z = 1 / (torch.einsum('nd,d->n', Q, K.sum(dim=0)) + eps)  # [n]

# 步骤3: Q与KV相乘并归一化
output = torch.matmul(Q, KV) * Z.unsqueeze(-1)  # [n, d] × [d, d] = [n, d]

# 关键点:
# - 避免了n×n的注意力矩阵
# - 当d << n时,效率提升巨大
# - 需要特征映射φ来保证非负性

✅ 优势

O(n)复杂度,适合超长序列(>32K)

⚠️ 劣势

可能损失表达能力,精度略低

💡 应用

Performer, Linformer, Linear Transformer, Nyström

4 分组查询注意力 (GQA)

📐 数学原理

核心思想:将查询头分成G组,每组共享键值对

$$\text{GQA}(Q^{(g)}, K^{(g/G)}, V^{(g/G)}) = \text{softmax}\left(\frac{Q^{(g)}(K^{(g/G)})^T}{\sqrt{d_k}}\right)V^{(g/G)}$$

内存节省计算:
  • MHA KV缓存: $2 × n_{heads} × seq\_len × d_k$
  • GQA KV缓存: $2 × n_{kv\_heads} × seq\_len × d_k$
  • 节省比例: $1 - \frac{n_{kv\_heads}}{n_{heads}}$

🎨 图形化解释

GQA: 8个Q头,2个KV头(4:1分组) Query Heads (8个) Q₁ Q₂ Q₃ Q₄ Q₅ Q₆ Q₇ Q₈ KV Pairs (2个) K₁,V₁ K₂,V₂ KV缓存内存对比 MHA: 8个KV头 (100%) GQA: 2个KV头 (25%) 节省75%内存! 示例(seq_len=2048, d_k=64): MHA: 2 × 8 × 2048 × 64 × 4B = 8MB GQA: 2 × 2 × 2048 × 64 × 4B = 2MB 每层节省6MB,12层模型节省72MB!

📝 代码实现

# GQA核心实现
class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model, n_heads=8, n_kv_heads=2):
        # n_heads: Query头数量(如8)
        # n_kv_heads: KV头数量(如2)
        # n_groups: 每个KV头服务的Q头数(8/2=4)
        self.n_groups = n_heads // n_kv_heads
        
        # Q使用全部维度,KV使用更少维度
        self.W_q = nn.Linear(d_model, d_model)              # 8个头
        self.W_k = nn.Linear(d_model, n_kv_heads * d_k)     # 2个头
        self.W_v = nn.Linear(d_model, n_kv_heads * d_k)     # 2个头
    
    def forward(self, x):
        # 计算Q, K, V
        Q = self.W_q(x).view(batch, seq, n_heads, d_k)      # [B, N, 8, d]
        K = self.W_k(x).view(batch, seq, n_kv_heads, d_k)   # [B, N, 2, d]
        V = self.W_v(x).view(batch, seq, n_kv_heads, d_k)   # [B, N, 2, d]
        
        # 关键步骤:重复KV以匹配Q的头数
        K = K.repeat_interleave(self.n_groups, dim=2)       # [B, N, 8, d]
        V = V.repeat_interleave(self.n_groups, dim=2)       # [B, N, 8, d]
        
        # 标准注意力计算
        attn = torch.matmul(Q, K.transpose(-2, -1)) / sqrt(d_k)
        attn = F.softmax(attn, dim=-1)
        output = torch.matmul(attn, V)

✅ 优势

减少KV缓存50-75%,保持模型质量

⚠️ 劣势

需要调整KV头数量的超参数

💡 应用

Llama 2, Llama 3, Mistral, Mixtral

5 多查询注意力 (MQA)

📐 数学原理

核心思想:所有查询头共享单一的键值对

$$\text{MQA}(Q^{(h)}, K, V) = \text{softmax}\left(\frac{Q^{(h)}K^T}{\sqrt{d_k}}\right)V$$

KV缓存从 $O(n \cdot h \cdot d)$ 降到 $O(n \cdot d)$

🎨 图形化解释

MQA: 所有查询头共享单个KV对 8个Query Heads Q₁ Q₂ Q₃ Q₄ Q₅ Q₆ Q₇ Q₈ 单个KV对 K, V 共享给所有Q头 KV缓存对比 MHA: 8×n×d MQA: 1×n×d 节省87.5%内存! 💡 极限内存优化: • 推理时KV缓存最小化 • 适合边缘设备部署

📝 代码实现

# MQA核心实现
class MultiQueryAttention(nn.Module):
    def __init__(self, d_model, n_heads=8):
        # Q有多个头,K和V只有单头
        self.W_q = nn.Linear(d_model, d_model)     # 输出8个头的Q
        self.W_k = nn.Linear(d_model, d_k)         # 输出1个头的K
        self.W_v = nn.Linear(d_model, d_k)         # 输出1个头的V
    
    def forward(self, x):
        # 计算Q(多头)和KV(单头)
        Q = self.W_q(x).view(batch, seq, n_heads, d_k)  # [B, N, 8, d]
        K = self.W_k(x).view(batch, seq, 1, d_k)        # [B, N, 1, d]
        V = self.W_v(x).view(batch, seq, 1, d_k)        # [B, N, 1, d]
        
        # 关键:扩展KV到所有头(广播)
        K = K.expand(-1, -1, n_heads, -1)               # [B, N, 8, d]
        V = V.expand(-1, -1, n_heads, -1)               # [B, N, 8, d]
        
        # 标准注意力计算
        attn = torch.matmul(Q, K.transpose(-2, -1)) / sqrt(d_k)
        
# 内存分析(seq_len=2048, d=64, 8头):
# MHA KV缓存: 2 × 8 × 2048 × 64 × 4B = 8MB
# MQA KV缓存: 2 × 1 × 2048 × 64 × 4B = 1MB
# 节省: 87.5%!

✅ 优势

最小KV缓存(节省87.5%),推理最快

⚠️ 劣势

可能降低模型性能,训练不稳定

💡 应用

PaLM, Falcon, StarCoder, SantaCoder

6 Flash Attention

📐 数学原理

在线Softmax (Online Softmax):

数值稳定的增量计算:

$$m^{new} = \max(m^{old}, \text{rowmax}(S))$$

$$l^{new} = e^{m^{old}-m^{new}}l^{old} + \text{rowsum}(e^{S-m^{new}})$$

$$O^{new} = \frac{e^{m^{old}-m^{new}}l^{old}O^{old} + e^{S-m^{new}}V}{l^{new}}$$

分块策略:
  • 将Q分成大小为B_r的块
  • 将K,V分成大小为B_c的块
  • 逐块在SRAM中计算,减少HBM访问

🎨 图形化解释

Flash Attention: 分块计算与内存优化 GPU内存层次 SRAM (20KB, 19TB/s) HBM (40GB, 1.5TB/s) 写回 加载 ❌ 标准注意力问题: n×n矩阵太大,无法放入SRAM 分块计算策略 Q块 Q₁ K,V块 K₁V₁ 输出 O₁ ✓ 当前块完全在SRAM中计算 Flash Attention算法流程 for i in range(0, N, Br): # 遍历Q块 for j in range(0, N, Bc): # 遍历KV块 在SRAM中计算 Q[i]×K[j]^T×V[j],增量更新O[i]

📝 代码实现

# Flash Attention核心循环
def flash_attention_forward(Q, K, V):
    # Q: [batch, heads, N, d]
    # 分块大小(适配SRAM大小)
    Br = 64  # Q块大小
    Bc = 64  # KV块大小
    
    N = Q.shape[2]
    O = torch.zeros_like(Q)  # 输出
    L = torch.zeros(batch, heads, N)  # 归一化因子
    
    # 外循环:遍历Q的块
    for i in range(0, N, Br):
        Qi = Q[:, :, i:i+Br, :]  # 加载Q块到SRAM
        Oi = torch.zeros_like(Qi)
        Li = torch.zeros(batch, heads, Br) - float('inf')
        Mi = torch.full_like(Li, -float('inf'))  # 最大值
        
        # 内循环:遍历KV的块
        for j in range(0, N, Bc):
            Kj = K[:, :, j:j+Bc, :]  # 加载KV块到SRAM
            Vj = V[:, :, j:j+Bc, :]
            
            # 在SRAM中计算注意力
            Sij = torch.matmul(Qi, Kj.transpose(-2, -1)) / sqrt(d)
            
            # 在线softmax(数值稳定)
            Mi_new = torch.max(Mi, Sij.max(dim=-1)[0])
            Pij = torch.exp(Sij - Mi_new.unsqueeze(-1))
            
            # 增量更新输出
            Li = torch.exp(Mi - Mi_new) * Li + Pij.sum(dim=-1)
            Oi = torch.exp(Mi - Mi_new).unsqueeze(-1) * Oi + torch.matmul(Pij, Vj)
            Mi = Mi_new
        
        # 归一化并写回HBM
        O[:, :, i:i+Br, :] = Oi / Li.unsqueeze(-1)
    
    return O

# 关键优化:
# 1. 避免存储n×n注意力矩阵
# 2. 分块计算,每块都在SRAM中
# 3. IO复杂度从O(N²) → O(N²/M),M是SRAM大小

✅ 优势

2-4倍速度提升,O(n)内存,数值稳定

⚠️ 劣势

需要CUDA kernel优化,实现复杂

💡 应用

GPT-4, Claude, Gemini, 现代大模型标配

📊 性能对比总结

注意力机制 时间复杂度 内存复杂度 KV缓存 最适合场景
Full Attention O(n²d) O(n²) O(nhd) 短序列,完整建模
Local Attention O(nwd) O(nw) O(nhd) 长序列,局部依赖
Linear Attention O(nd²) O(nd) O(nhd) 超长序列
GQA O(n²d) O(n²) O(n(h/G)d) 推理优化
MQA O(n²d) O(n²) O(nd) 极限内存优化
Flash Attention O(n²d) O(n) O(nhd) 训练加速

🔍 特性对比

特性 Full Sparse Linear GQA MQA Flash
模型质量 ⭐⭐⭐⭐⭐ ⭐⭐⭐⭐ ⭐⭐⭐ ⭐⭐⭐⭐⭐ ⭐⭐⭐⭐ ⭐⭐⭐⭐⭐
推理速度 ⭐⭐ ⭐⭐⭐⭐ ⭐⭐⭐⭐⭐ ⭐⭐⭐⭐ ⭐⭐⭐⭐⭐ ⭐⭐⭐⭐
内存效率 ⭐⭐⭐ ⭐⭐⭐⭐ ⭐⭐⭐⭐ ⭐⭐⭐⭐⭐ ⭐⭐⭐⭐⭐
长序列支持 ⭐⭐⭐⭐ ⭐⭐⭐⭐⭐ ⭐⭐ ⭐⭐ ⭐⭐⭐
实现难度 ⭐⭐⭐⭐⭐ ⭐⭐⭐⭐ ⭐⭐⭐ ⭐⭐⭐⭐ ⭐⭐⭐⭐⭐
主流采用度 ⭐⭐⭐⭐⭐ ⭐⭐⭐ ⭐⭐ ⭐⭐⭐⭐⭐ ⭐⭐⭐⭐ ⭐⭐⭐⭐⭐

💡 选择建议

  • 训练阶段:Flash Attention + Full Attention(速度与精度并重)
  • 推理优化:GQA(Llama系列)或MQA(PaLM系列)减少KV缓存
  • 中等文本(1K-8K):Local Attention(Longformer)或稀疏模式
  • 长文本(8K-32K):块稀疏(BigBird)或混合注意力
  • 超长文本(>32K):Linear Attention(Performer)或Linformer
  • 边缘设备:MQA + 量化 + Flash Attention组合
  • 实时服务:GQA(4个KV头)+ Flash Attention v2

🎯 实践经验

  • 不同层可以使用不同注意力:底层Local,中层Sparse,顶层Full
  • Flash Attention已成为训练标配,几乎无需权衡
  • GQA是目前最平衡的选择(Llama 2/3采用)
  • 线性注意力虽然快但精度损失明显,谨慎使用
  • 根据硬件选择:A100/H100用Flash,消费级GPU用GQA/MQA