0. 最小心智模型: Attention = 两次矩阵乘

  1. 打分(谁看谁)
    S=QKTS = QK^T

  2. 加权求和(读出内容)
    O=softmax(S)VO = \text{softmax}(S)\,V


1. Anchor 1: Q / K / V 的形状(只记这一组)

单个 batch, 单个 head:

  • QRN×dQ \in \mathbb{R}^{N \times d}
  • KRN×dK \in \mathbb{R}^{N \times d}
  • VRN×dV \in \mathbb{R}^{N \times d}

含义:

  • NN: 序列长度(token 数)
  • dd: 每个 head 的维度(head_dim)

口诀: Q/K/V 都是"每个 token 一行, 每行一个 d 维向量";


2. Anchor 2: 为什么注意力矩阵是 N×NN \times N

  • KTRd×NK^T \in \mathbb{R}^{d \times N}
  • S=QKTRN×NS = QK^T \in \mathbb{R}^{N \times N}

直觉:

  • NN 个 query ×\times NN 个 key \rightarrow 一张 N×NN \times N 的"匹配分数表";

口诀: N 个 query 看 N 个 key \rightarrow 一张 N×\timesN 表;


3. Anchor 3: 为什么输出还是 N×dN \times d

P=softmax(S)P = \text{softmax}(S), 则:

  • PRN×NP \in \mathbb{R}^{N \times N}(按行 softmax)
  • O=PVRN×dO = PV \in \mathbb{R}^{N \times d}

直觉:

  • 每个 query 输出一个 d 维向量(不改变向量维度, 只是混合 token);

口诀: 注意力"混合 token", 不改变每个 token 的向量维度;


4. 防忘公式: 写成下标版, 永远不会错

4.1 打分(标量)

Sij=Qi, KjS_{ij} = \langle Q_i,\ K_j \rangle

  • QiQ_i: 第 i 个 token 的 query(长度 d)
  • KjK_j: 第 j 个 token 的 key(长度 d)
  • 点积是标量 \Rightarrow SijS_{ij} 是标量
    i/j 各跑 NN \Rightarrow SSN×NN \times N

4.2 输出(向量)

Oi=j=1Nsoftmax(Si:)jVjO_i = \sum_{j=1}^{N} \text{softmax}(S_{i:})_j \cdot V_j

  • VjV_j 是 d 维向量
  • 加权和仍是 d 维向量 \Rightarrow OiO_i 是 d 维
    所有 i 组成 ORN×dO \in \mathbb{R}^{N \times d}

5. Multi-Head Attention(最不容易乱的记法)

一句话: 多头 = 多套 Q/K/V 并行算, 最后把 head 的 d 拼回去;

常见形状(单 batch):

  • 输入 hidden: XRN×DmodelX \in \mathbb{R}^{N \times D_{\text{model}}}
  • 三个投影:
    • Q=XWQQ = XW_Q
    • K=XWKK = XW_K
    • V=XWVV = XW_V

其中:

  • WQ,WK,WVRDmodel×(Hd)W_Q, W_K, W_V \in \mathbb{R}^{D_{\text{model}} \times (H\cdot d)}

所以:

  • Q,K,VRN×(Hd)Q, K, V \in \mathbb{R}^{N \times (H\cdot d)}

reshape 后(两种写法等价, 只是布局不同):

  • QRH×N×dQ \in \mathbb{R}^{H \times N \times d}QRN×H×dQ \in \mathbb{R}^{N \times H \times d}
  • 同理 K,VK, V

每个 head hh 独立:

  • ShRN×NS_h \in \mathbb{R}^{N \times N}
  • OhRN×dO_h \in \mathbb{R}^{N \times d}

拼接回去:

  • ORN×(Hd)O \in \mathbb{R}^{N \times (H\cdot d)}

最后输出投影回模型维度(常见):

  • Ofinal=OWOO_{\text{final}} = OW_O, WOR(Hd)×DmodelW_O \in \mathbb{R}^{(H\cdot d) \times D_{\text{model}}}

关键等式:
Dmodel=HdD_{\text{model}} = H \cdot d


6. 30 秒自检法(不靠记忆, 靠推理)

当你写出 attention 公式后, 检查两条就够:

  1. QKTQK^T 的"内维"必须一致(都是 d)
  2. 输出形状必须是 N×dN \times d(或多头拼接后的 N×DmodelN \times D_{\text{model}})

7. 类比记忆: Attention = 数据库检索

  • QQ: 查询(query)
  • KK: 索引(key)
  • VV: 内容(value)

流程:

  • QKTQK^T: 算"查询对每个 key 的匹配分数"
  • softmax: 把分数变成概率/权重
  • VV: 按权重把内容加权读出来

8. 一行速记卡片(随手贴)

  • Q/K/V: N×dN \times d
  • Scores: S=QKTN×NS = QK^T \Rightarrow N \times N
  • Weights: P=softmax(S)N×NP=\text{softmax}(S) \Rightarrow N \times N(按行)
  • Output: O=PVN×dO = PV \Rightarrow N \times d
  • Multi-head: Dmodel=HdD_{\text{model}} = H\cdot d, 每个 head 都是同一套