Untitled

开头的胡言乱语

大家好,趁着除夕放假攒了一篇长文来解释一下我们的新工作:StableMask: Refining Causal Masking in Decoder-only Transformer。这篇文章题目取得比较奇特(causal mask竟然也能改?),当然并不存在什么夸大的成分,我们确实把causal mask改掉了。

Preliminary

废话不多说,我们直接进入正题。可能大部分人看到这个题目都会直接疑惑,causal mask不就是一堆负无穷值吗,这也能改?改了不就泄漏信息了,还怎么自回归训练?首先先回答这个问题,causal mask当然能改,考虑传统的causal mask实现:

$$ A' = A + M \in \mathbb{R}^{n \times n} $$

其中$A$是注意力矩阵,来自于$A = \frac{QK^T}{\sqrt{d_k}}$,相信读者们都清楚transformer的具体原理,这里就不细细解释了。简单来说,注意力矩阵表示了整个长度为$n$的序列中,每条长度为$d$的向量和另一条长度为$d$的向量之间的余弦相似度关系。因为 $A$矩阵随后要承担$O = \mathrm{Softmax}(A)V \in \mathbb{R}^{n \times n}$的重任,而我们的LLM大多又采用decoder架构——这意味着每个token在预测下一个位置的时候不能看到自己后面的信息(否则就偷看答案了),因此,很有必要把那些未来信息藏起来,我们的 $A$矩阵长这个样:

$$ A = \begin{pmatrix} a_{11} & a_{12} & \cdots & a_{1n} \\ a_{21} & a_{22} & \cdots & a_{2n} \\ \vdots & \vdots & \ddots & \vdots \\ a_{n1} & a_{n2} & \cdots & a_{nn} \\ \end{pmatrix}. $$

我们需要遮盖的是$a_{ij} (i < j)$的部分,这样保证对于第$i$个位置,其不会获得自身位置以后的信息。

对于一个输出向量$o_i$,其由以下操作得到:

$$ o_i = \sum_{j=1}^{n}(a_{ij}v_j) $$

这里的“遮盖”实际上就是令$j > i$时$a_{ij}=0$,即:

$$ o_i = \sum_{j=1}^{i}(a_{ij}v_j)+\sum_{j=i+1}^{n}(0 \times v_j) $$

具体怎么遮盖呢?简单来说,是通过Softmax操作的特殊性实现的,其公式表示为:

$$ a_i' = \mathrm{Softmax}(a_i) = \frac{e^{a_i}}{\sum_{j=1}^{n}e^{a_j}} $$

我们发现,首先,如果一个 $a_i$是个负无穷,那它算完Softmax就是0了:

$$ a_i \to -\infty \\ \mathrm{Softmax}(a_i) = 0 $$

因此,让注意力矩阵变成以下形式即可实现遮盖:

$$ A = \begin{pmatrix} a_{11} & -\infty & \cdots & -\infty \\ a_{21} & a_{22} & \cdots & -\infty\\ \vdots & \vdots & \ddots & \vdots \\ a_{n1} & a_{n2} & \cdots & a_{nn} \\ \end{pmatrix}. $$

Two Transformer Issues