注意力机制(Attention)详解
约 1431 字大约 5 分钟
attentiontransformer
2025-08-29
注意力机制是自然语言处理和计算机视觉领域的核心技术,它允许模型在处理输入时动态地"关注"最相关的部分。本文从 Seq2Seq 的局限性出发,逐步深入到 Self-Attention 和 Multi-Head Attention。
Seq2Seq 的局限性
传统的编码器-解码器(Seq2Seq)架构将整个输入序列压缩为一个固定长度的上下文向量,这带来了两个根本问题:
- 信息瓶颈:长序列的所有信息被压缩到一个固定维度的向量中,导致信息丢失
- 长距离依赖:RNN 编码器对早期输入的记忆逐渐衰减
上下文向量 c 是信息瓶颈所在——所有序列信息都被压缩到这个单一向量中。
Bahdanau Attention(加性注意力)
Bahdanau(2015)提出让解码器在每个时间步都能访问编码器的所有隐藏状态,通过学习一个对齐函数来计算注意力权重。
对齐分数计算:
eij=vTtanh(W1si−1+W2hj)
其中 si−1 是解码器上一步的隐藏状态,hj 是编码器第 j 步的隐藏状态。
import torch
import torch.nn as nn
import torch.nn.functional as F
class BahdanauAttention(nn.Module):
def __init__(self, hidden_dim):
super().__init__()
self.W1 = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.W2 = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.v = nn.Linear(hidden_dim, 1, bias=False)
def forward(self, decoder_state, encoder_outputs):
# decoder_state: (batch, hidden)
# encoder_outputs: (batch, seq_len, hidden)
score = self.v(torch.tanh(
self.W1(decoder_state.unsqueeze(1)) + self.W2(encoder_outputs)
)) # (batch, seq_len, 1)
attn_weights = F.softmax(score, dim=1)
context = (attn_weights * encoder_outputs).sum(dim=1)
return context, attn_weights.squeeze(-1)Luong Attention(乘性注意力)
Luong(2015)提出了更简洁的注意力变体,使用点积或双线性形式计算对齐分数:
| 变体 | 公式 | 特点 |
|---|---|---|
| Dot | siThj | 最简单,要求维度相同 |
| General | siTWhj | 引入可学习矩阵 |
| Concat | vTtanh(W[si;hj]) | 类似 Bahdanau |
class LuongAttention(nn.Module):
def __init__(self, hidden_dim, method='dot'):
super().__init__()
self.method = method
if method == 'general':
self.W = nn.Linear(hidden_dim, hidden_dim, bias=False)
def forward(self, decoder_state, encoder_outputs):
if self.method == 'dot':
score = torch.bmm(encoder_outputs, decoder_state.unsqueeze(2))
elif self.method == 'general':
score = torch.bmm(self.W(encoder_outputs), decoder_state.unsqueeze(2))
attn_weights = F.softmax(score, dim=1)
context = torch.bmm(attn_weights.transpose(1, 2), encoder_outputs).squeeze(1)
return context, attn_weights.squeeze(2)Self-Attention(自注意力)
Self-Attention 是 Transformer 的核心,它让序列中的每个位置都能直接关注序列中的其他所有位置,彻底解决了长距离依赖问题。
Q/K/V 机制
每个输入 token 的嵌入向量通过三个线性变换生成 Query(查询)、Key(键)和 Value(值):
Scaled Dot-Product Attention
Attention(Q,K,V)=softmax(dkQKT)V
除以 dk 是为了防止点积值过大导致 softmax 进入梯度极小的饱和区域。
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Q, K, V: (batch, heads, seq_len, d_k)
mask: (batch, 1, 1, seq_len) 或 (batch, 1, seq_len, seq_len)
"""
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, V)
return output, attn_weightsMulti-Head Attention
多头注意力将 Q、K、V 投影到多个子空间,每个头独立计算注意力,然后拼接结果。这使模型能同时关注不同位置的不同类型信息。
MultiHead(Q,K,V)=Concat(head1,...,headh)WO
headi=Attention(QWiQ,KWiK,VWiV)
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
assert d_model % n_heads == 0
self.d_k = d_model // n_heads
self.n_heads = n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# 线性投影并拆分为多头
Q = self.W_q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_k(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_v(V).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
# Scaled Dot-Product Attention
output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
# 拼接多头结果
output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_k)
return self.W_o(output), attn_weightsCross-Attention(交叉注意力)
Cross-Attention 用于连接两个不同的序列:Query 来自一个序列(如解码器),而 Key 和 Value 来自另一个序列(如编码器)。
典型应用场景:
- 机器翻译:解码器关注源语言编码器输出
- 多模态模型:文本关注图像特征(如 CLIP、Flamingo)
- 检索增强生成(RAG):生成模型关注检索到的文档
注意力变体与优化
| 方法 | 复杂度 | 核心思想 |
|---|---|---|
| 标准 Attention | O(n²) | 全连接注意力 |
| Sparse Attention | O(n√n) | 固定稀疏模式 |
| Linear Attention | O(n) | 核近似避免 softmax |
| FlashAttention | O(n²) 但更快 | IO-aware 分块计算 |
| Grouped Query Attention (GQA) | O(n²) 但更省内存 | 多 Query 头共享 KV 头 |
| Multi-Query Attention (MQA) | O(n²) 但最省内存 | 所有 Query 头共享一组 KV |
注意力可视化
注意力权重矩阵可以可视化,帮助理解模型在生成每个 token 时关注了输入的哪些部分。
import matplotlib.pyplot as plt
import seaborn as sns
def visualize_attention(attn_weights, src_tokens, tgt_tokens):
"""
attn_weights: (tgt_len, src_len) numpy array
"""
fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(attn_weights, xticklabels=src_tokens,
yticklabels=tgt_tokens, cmap='viridis', ax=ax)
ax.set_xlabel('Source')
ax.set_ylabel('Target')
plt.tight_layout()
plt.show()总结
注意力机制的演进路线:Bahdanau/Luong Attention 解决了 Seq2Seq 的信息瓶颈问题;Self-Attention 消除了序列处理对循环结构的依赖;Multi-Head Attention 赋予模型多维度的关注能力。这些技术共同构成了 Transformer 架构的基础,推动了 GPT、BERT、ViT 等模型的发展。
贡献者
更新日志
9f6c2-feat: organize wiki content and refresh site setup于