商汤最新论文登上ICLR 2022:给注意力机制Softmax找个“平替”

基于2大理念改进

萧箫 转载自 商汤AI
量子位 | 公众号 QbitAI

Transformer很受欢迎,但它架构上的不少问题依旧令人头疼。

典例之一就是其中的Softmax Attention模块,虽然能捕捉长距离依赖关系,但由于Softmax算子关于序列长度的二次空间和时间复杂性,导致难以扩展。

虽然也有用核方法、稀疏注意力机制等来近似Softmax算子,以降低时间空间复杂度,但近似操作本身存在的误差使得其效果很难超越Softmax Attention。

为此,商汤多模态研究组想到,与其近似Softmax,为何不重新设计一种方式“平替”Softmax?

他们提出了一种叫做cosFormer的新方法,论文目前已经登上ICLR 2022。

商汤最新论文登上ICLR 2022:给注意力机制Softmax找个“平替”

一方面,cosFormer在时间空间复杂度关于序列长度为线性复杂度的同时,其性能接近或者超越Softmax Attention;

另一方面,它也在LRA benchmark上取得了SOTA,其中y轴表示性能,x轴表示速度,圆圈大小表示内存。

商汤最新论文登上ICLR 2022:给注意力机制Softmax找个“平替”

一起来看看。

此前的Softmax有什么问题?

Softmax Attention

Softmax Attention的计算方式是这样的:

商汤最新论文登上ICLR 2022:给注意力机制Softmax找个“平替”

线性Attention

通过分析我们发现,性能瓶颈的主要原因是exp操作,如果相似度函数可以表示为

商汤最新论文登上ICLR 2022:给注意力机制Softmax找个“平替”

那么

商汤最新论文登上ICLR 2022:给注意力机制Softmax找个“平替”

根据矩阵运算的结合律:

商汤最新论文登上ICLR 2022:给注意力机制Softmax找个“平替”

上式可以变换为:

商汤最新论文登上ICLR 2022:给注意力机制Softmax找个“平替”

经过计算后可以得到该方法的时间复杂度为,即关于序列长度是一次的。

Softmax Attention和线性Attention的计算方式可以用下图概括:

商汤最新论文登上ICLR 2022:给注意力机制Softmax找个“平替”

所以接下来的会介绍的选择,以及核心的reweighting操作。

Softmax的两大性质

我们经过分析以及实验,归纳出Softmax Attention中比较重要的性质,这两个性质可以指导我们的模型设计:

  1. 注意力矩阵的非负性
  2. 局部注意力的放大(非极大值抑制)

对于第一点,我们有如下实验进行验证(模型结构为RoBERTa):

商汤最新论文登上ICLR 2022:给注意力机制Softmax找个“平替”

这里Loss表示验证集损失(越低越好),其余指标均为准确率(越高越好),可以看到,当保证了注意力矩阵的非负性之后,可以达到较好的效果。基于该实验,我们选择为ReLU函数。

对于第二点,我们的方式是在注意力矩阵中引入先验locality信息,观察Softmax注意力矩阵,如下图所示,我们发现其注意力矩阵的权重在对角线附近很集中:

商汤最新论文登上ICLR 2022:给注意力机制Softmax找个“平替”

所以我们的方法需要在加了reweighting操作后也更加集中在对角线附近。注意并非所有的有类似权重的函数均适用,这个reweighting的函数需要跟前面的QK一样可以拆分成两个矩阵的乘法的形式。

至此,就可以引入我们的cosFormer了。

cosFormer如何超越Softmax?

我们的方法基于线性Attention,首先给出符号定义:

商汤最新论文登上ICLR 2022:给注意力机制Softmax找个“平替”

根据之前的分析,我们选择了:

商汤最新论文登上ICLR 2022:给注意力机制Softmax找个“平替”

可得:

商汤最新论文登上ICLR 2022:给注意力机制Softmax找个“平替”

为了进行reweighting操作,并且同时保证线性Attention的计算方式依然成立,我们选择了cos函数:

商汤最新论文登上ICLR 2022:给注意力机制Softmax找个“平替”

展开可得:

商汤最新论文登上ICLR 2022:给注意力机制Softmax找个“平替”

为了便于展示,我们把它记作:

商汤最新论文登上ICLR 2022:给注意力机制Softmax找个“平替”

最终得到:

商汤最新论文登上ICLR 2022:给注意力机制Softmax找个“平替”

上式和线性Attention的计算方式一致,经过分析不难得出时间复杂度依然是O(N)。

具体性能究竟有多好?

我们在单向模型,双向模型以及LRA benchmark上测试了我们的方法,均取得了非常不错的效果。

单向语言模型,指标表示困惑度(越低越好):

商汤最新论文登上ICLR 2022:给注意力机制Softmax找个“平替”

双向语言模型,指标表示准确率(越高越好):

商汤最新论文登上ICLR 2022:给注意力机制Softmax找个“平替”

LRA benchmark:

性能实验,指标表示准确率(越高越好):

商汤最新论文登上ICLR 2022:给注意力机制Softmax找个“平替”

内存速度实验,指标表示速度(越高越好,如果内存溢出,则标记为叉):

商汤最新论文登上ICLR 2022:给注意力机制Softmax找个“平替”

目前代码已开源,感兴趣的小伙伴们可以戳下方地址了解了~

论文地址:
https://arxiv.org/abs/2202.08791

部分开源代码:
https://github.com/OpenNLPLab/cosFormer

版权所有,未经授权不得以任何形式转载及使用,违者必究。