xLSTM-改进长短期记忆网络

Github: https://github.com/AI-Guru/xlstm-resources

LSTM(长短期记忆网络)已经存在很长时间了。它们已被应用于相当多与序列相关的任务,例如文本生成和翻译,甚至生成图像字幕。

它们的缺点是无法并行化以利用强大的现代 GPU。这一限制为利用 GPU 进行大规模并行训练和推理的 Transformer 的出现铺平了道路。

如果我们现在尝试改进和并行化 LSTM,它们能成为构建下一代LLM的工具吗?

这正是论文“ XLSM——扩展长短期记忆网络”所回答的问题, XLSM 代表“扩展”长短期记忆。他们通过在架构中提出两个新模块,即 sLSTM 和 mLSTM 来实现这一点。

xLSTM Figure

一、LSTM 回顾

1、一个生动的例子

原始 LSTM 主要是为了解决 RNN 时序反向传播中的梯度消失和爆炸问题而提出的。为了方便大家看清楚,我们来看一个生动的例子。

在这样一个时序模型中,输入为x,隐层变量为s,输出为y,LSTM 相比 RNN 增加了条时间链条c,用来保存长期记忆。

LSTM 的核心原理就在于设计了多个门控机制协调短期记忆和长期记忆。其中f1为遗忘门,如同橡皮擦,根据昨天的记忆st-1和今天输入xt决定删除哪些旧记忆sigmoid 函数取值为0时相当于制除操作;f2 使用 tanh 函数取值在(-1,1)之间,作用不是遗忘,而是把这两天发生的事情进行梳理和归纳,然后像铅笔一样增加记忆,因此称为输入门。右边显示了它们的各自计算公式。同时保持长短期记忆链,并相互更新,这就是 LSTM 成功的秘密了。

2、记忆的原理和公式

其实静下心来看,st改用 ht表示。ft就是遗忘门,只是进行了展开;it就是输入门,zt对应输入门中 tanh 函数部分。三者分别都加了非线性激活函数,共同作用生成新的 cell 迭代,也就是公式(2)。这和刚才我们的介绍都是一致的。另外还加了输出门,长短期记忆链条之间的第三种连接,o对应前面例子中的y。公式(3)是h链条的更新公式。

所谓的门控机制,其实就是一种时序上的注意力机制,相当于把不同时间信息进行“掺和”,是对时序信息的一种选择性控制。从这个视角看,与transformer和 Mamba 都异曲同工之妙。核心思想都是选择性控制信息流动,更好地处理时序数据或序列信息。门控机制通过固定的结构和参数来控制信息流,而注意力机制通过动态计算权重来控制信息流。因此,门控机制可以看作是一种特定形式的时序注意力机制,对不同时间步的信息进行选择性控制和“掺和。可以认为是一种约束版或者简化版的注意力机制。

3、为啥歇菜了?

尽管曾经取得了巨大的成功,LSTM 有三个主要局限性:
1 在处理长序列时效率低:
2 记忆容量有限;
3 不能并行处理数据。
这也是为什么能让 transformer后来者居上的原因,因为借助网络模块堆叠、参数规模扩充和 GPU 并行处理拼算力,有针对性的借鉴了上述问题。但显然不止transformer这一条路。原有的门控机制还有很大的潜力可挖,本文就是有针对性的一条条进行了创新和优化。先来看初级改造版本。

二、初级版:sLSTM 改进注意力机制

针对上述问题的第一个改进版本叫 sLSTM,目的是改善决策能力。改动不大,主要有三点:

1.输入门和遗忘门的激活函数从 sigmoid 改成了指数函数(红色部分)。

2.引入了归一化状态 nt(公式9),相应的隐层 h_t的计算方式变了,改成了c_t/n_t也就是公式(10)

3.还引入了一个额外状态 mt来进一步稳定门控,这个稍后讲。

你肯定好奇,这么做的原因是什么啊?原文没有细讲,而是直接给出了选择。我猜也是试出来的,但是不是瞎试。首先,如下图所示,指数函数相比于sigmoid 函数,具有更大的输出范围和更大的梯度(右图黄色,左图红色),可以减轻梯度消失问题使得梯度在反向传播过程中不会迅速减小,从而使得模型在训练时能够更有效地更新权重。其次,指数函数的增长速度比 sigmoid 函数快,对输入变化更加敏感。因此,可以更迅速地强烈的调整输入和遗忘门的输出,使得模型能够更快地捕捉到输入信息的变化,更加选择性地记住或忘记信息,从而提高模型的记忆和遗忘能力。第三,这种强烈的选择性,让模型能够更准确地保留重要信息和丢弃不重要的信息。在特定任务(如长序列的最近邻搜索或稀有事件预测)中表现得尤为显著,能够显著提升模型性能。

引入归一化和状态 mt都是为了稳定,因为指数激活函数可能导致数值过大而溢出前者相当于搞了个大分母。后者通过下面的公式进行:

第一个式子使用了log,指数函数的逆运算,相当于降一级运算,然后取最大值,意思就是输入门和遗忘门都别太猛。类比生活中,无论是新鲜记忆,还是想遗忘的事情,情绪太激动了都不好,一定要心平气和,基本上就是这个意思。

然后根据 m_t再调整输入门和遗忘门,相当于设置了一个缓冲区。隐射到生活中,正应了那句话:忍一时风平浪静,退一步海阔天空。很多事,稍微放放,别那么激动,从容淡定,反倒更理智,处理起来更有效。落实到公式上,甭管f还是i,先找到log最大值,然后在指数上剪掉,相当于避免了溢出。
附录 A中数学进一步证明在前向传播中用f’_t和i’_t替换 f_t和 i_t不会改变整个网络的输出,也不会改变参数损失的导数。这部分推导不是人看的,猫一眼知道就行了,当然这也是人家这个团队牛逼之处,数学玩的贼溜,或者说LSTM 比起transformer 更高级的地方,理论基础扎实,而不只是拼工程拼资源。

增加了这么些公式相当于增加了新的记忆单元,它们之间通过连接从长短期记忆状态,借助门控(阀门)i,f,o进行记忆混合。门控就是选择,也是一种时序注意力机制的体现。
讲完了初级版改进,咱们来看看中级版。

三、中级版:mLSTM 改进内存处理

解决了敏感度,某种程度上也是长序列处理效率问题,为了增强LSTM 的存储能力文章将 LSTM 的记忆单元从一个标量 c增加到短阵C。而且在这里引入了 transformer键值对的概念,更新规则如下:
Ct=Ct-1+vtktT
这就有点意思了哈,“千古文章一大抄,你抄我也超”,互相借鉴形成你中有我我中有你的态势。在将输入投影到键和值之前,mLSTM 进行层归一化,使得均值为零。同时,将协方差更新规则,也就是优化器(比如adam)整合到LSTM 框架中,遗忘门对应于衰减率,输入门对应于学习率,而输出门则缩放检索到的向量。最终形成了下面的选代公式:

与前面 SLSTM 对比,最大的区别之一就是状态和权重参数都变成了矩阵形式,对应的运算变成了向量矩阵乘法和哈达玛积,公式(21)。区别之二是增加了q_t,k_t,v_t这种键值对的计算公式(22-24),优化了自注意力机制,多了好几个权重矩阵增强了模型表达能力。其他的公式基本没变,也就是说记忆单元没变,只是每个单元相当于扩容了记忆的容量。
此外,需要注意的是,这种框架可以使用多头模式,头与头之间没有记忆混合,因此可以充分并行,无形中提升了并行能力。到此,针对传统LSTM 三大弱点的改进都已经实现。

小结一下,似乎影影绰绰能看到两个思路:一是固本守住传统,在原有框架下优化提升挖掘潜力,强化门控机制的有效性,无论是修改激活函数、稳定状态,还是记忆单元矩阵化提升容量;二是开源拿来主义,引入自注意力机制中 QKV的计算模式,增强模型的记忆和检索能力。
不过这还没完,咱们来看看高级版有什么重要的发现和设计。

四、高级版:xLSTM 大模型

既然 transformer 牛通,通过简单堆叠形成的大模型效果好,为什么不把这种思想贯彻到 LSTM 中,形成 LSTM 结构的模块堆叠呢,是不是效果也会不错?这就涉及传说中的 Cover 定理啦。它及其行生的高维空间中非线性映射理论确实是现代大模型设计的重要理论依据之一。尤其是在深度学习和大规模神经网络的设计中,这些理论起到了关键作用。

1、cover定理-大模型设计理论基础

Cover定理可以定性的描述为:当空间的维数D越大时,在该空间的N个数据点间的线性可分的概率就越大

在大模型中,激活函数(如 ReLU、Sigmoid、Tanh等)通过非线性变换将数据映射到高维空间,使得模型可以捕捉复杂的模式和特征,增强模型的表达能力。深度网络的权重矩阵和激活函数共同作用,将输入数据逐步映射到越来越高的维度。这使得在低维空间中难以分离的模式在高维空间中变得线性可分。Transformer模型就是通过多头注意力机制在高维空间中进行并行处理,使得不同位置的特征可以相互影响和结合从而提高了模型的性能。
Cover 定理为这些设计提供了理论支持,解释了为什么通过高维空间中的非线性映射可以提高模型的性能。现代大模型的设计,如 BERT、GPT等,都在不同程度上利用了这些理论基础。

2、核心模块和工作原理

既然你们都能这么干,xLSTM 想我为什么不能啊!因此它干了下面两件事:
1.非线性总结(压缩信息)【左图】:通过残差块在高维空间中对历史信息进行非线性总结使得不同的历史或上下文信息更容易分离。
2.线性映射回原始空间【右图】:完成高维空间中的处理后,再将数据线性映射回原始空间这一过程利用了高维空间中的优势,使得模型能够更好地分离和记忆历史信息。
具体到怎么升维呢,设计了下面两种结构:

左边是先在原始空间中总结信息,然后映射到高维空间,再返回原始空间。看图从下往上输入 sLSTM,然后向上投影,也就是用一个倒着的梯形矩阵升维,处理后再降维。右边是先映射到高维空间,总结信息后再返回原始空间。也就是输入直接上投影,再用 mLSTM 处理,然后再降维。
先干后变,还是先变后干这个好理解,但你肯定好奇为啥左边适合sLSTM,右边适合 mLSTM模型呢?主要原因是在高维空间中的记忆容量更大,因此用有矩阵化记忆单元的mLSTM更合适,而在低维空间处理 sLSTM 更合适。
想了解关于这两个基础模块的更多细节,可以到附录图 9/10 中看到,我给你列到这里了,咱们一个个详细解释它们的细节和用处.

PF=3/4 和 PF=4/3:投影因子(Projection Factor),分别将输入维度缩小为原来的3/4,将输入维度扩大 4/3 倍。
GN(GroupNorm):组归一化(Group Normalization)。在每一组内进行归一化有助于加速训练和提高模型稳定性,特别是在小批量(batch)训练时。

Swish 一种平滑的非线性激活函数,可以帮助模型学习到更复杂的模式。

Conv4:卷积层,卷积核大小为 4。提取局部特征,
LN(LayerNorm):层归一化,帮助稳定和加速训练过程。SLSTM 单元中i,f,z,0:分别表示输入门(input gate)、遗忘门 (forget gate)细胞状态更新(cell update)和输出门(output gate)。NH=4:表示有 4 个头(heads)。此外,将输入分成块,使用块对角线结构进行线性变换,有助于捕捉局部相关性。这些结构与从上一个隐藏状态中得到的递归门预激活(circular arrows)一致。

整个逻辑过程为:输入先LN整理,然后一分为二。一部分卷积提取特征,激活非线性变换,另一部分直接输入sLSTM。这里所有运算都采用了4个头的多头并行进制,每个头可以专注于捕捉输入数据的不同特征或模式,从而使模型能够更全面地理解数据。
内部采用块对角线结构,在计算时可以并行处理,从而显著降低计算复杂度和内存需求;每个子矩阵(块)主要关注输入数据的一部分能够更好地捕捉局部特征;结构化的稀疏性,这有助于减少过拟合。
在sLSTM图中的箭头表示信息在不同时间步之间的流动和处理,代表的是与先前时刻状态的混合计算。这部分相当于记忆的重新组合。然后组内归一化、降维、再激活、再降维,然后与残差相加再输出。
类似的,我们看看另一种基础模块:

PF=1/2 和 PF=2:投影因子(Projection Factor)。前者将输入维度缩小一半,后者将输入维度扩大两倍。
LSkip 是个跳线,类似于残差连接,可以帮助梯度更好地传递,防止梯度消失和爆炸。这里相当于冇两种跳线残差
mLSTM 单元中的 q、k、v分别表示査询(query)、键(key)和值(value),我们刚讲过,都是从输入中生成的,用于计算注意力权重和进行信息检索。
BS=4:块大小为 4 的块对角投影矩阵。

整体逻辑上与前面刚讲的模块大差不差,咱们就不一一过了。整体上都是充分利用了残差堆叠结构,层归一化技术等稳定网络,通过升降维度实现空间变换,激活函数非线性变换,然后利用 LSTM 进行记忆混合,主或者说时序上的选择性自注意力机制计算,采用多头和块对角模式实现并行处理,当然也没少了用卷积提取特征。

3、与Transformer 的对比

有了这两种基本构建模块,通过堆叠增加模型的深度,能够逐层提取更高层次的特征。最终,整个堆叠结构作为一个端到端的模型进行训练,通过反向传播优化所有层的参数。使用这些模块的堆叠设计是现代深度学习型中常见且有效的做法。换句话说,你 transformer 能干的我xLSTM 现在都能干了,有啥嘛?!而且老子内部有清晰而明确的逻辑结构,有数学公式的严密推导,效率更高,而你transformer 内部就是个乱七八糟的自注意力机制和交叉注意力机制,黑盒子,大量的参数是浪费掉的,低效,训练难推理难,从效率和准确率上都不如我也就make sense了。与Transformer不同,xLSTM 网络在计算复杂度和内存复杂度上随着序列长度呈线性关系。由于xLSTM 的记忆压缩性,它非常适合在工业应用和边绿设备上实现。

4、适用场景

对比 mLSTM 和 SLSTM 两种模块,前者方便并行化,后者由于记忆混合(隐藏状态之间的连接),无法并行化。论文开发了一种快速的CUDA实现,通过 GPU 内存优化到寄存器级别,这种实现通常比 mLSTM 慢不到两倍。

那你肯定会问,现在有这两种基础结构,分别什么时候用呢?给你几个原则:

sLSTM:需要高精度和复杂特征提取的任务,计算资源充足且不需要并行化的应用对延迟敏感但不受并行化限制的场景,例如,实时语音识别系统,因为它有记忆混合

mLSTM:图像识别、视频处理等需要高效并行计算的任务,计算资源有限且需要高效利用内存的应用,例如,嵌入式系统、移动设备上:需要在工业环境或边缘设备上部署的任务,例如,工业自动化、物联网设备上的智能应用。因为它并行化好。

五、 实验论证

实验详实是这类大牛文章的最大特点,本文集中在 NLP 任务上,与大量模型进行了对比。主要包括四大类。我们直接看结论。

1、合成任务和长程任务

每行表示一种模型,包括Lama、Mamba等7种模型的 12 中变体,xLSTM[0:1]:主要是SLSTM 块,xLSTM[1:0]:主要是 mLSTM块,xLSTM[1:1]:均衡使用 mLSTM 和SLSTM 块。每列表示一种任务,包括上下文敏感、确定性上下文无关、正则,最后是多数任务,也是正则。
使用 SLSTM 和 mLSTM 的组合(如xLSTM[1:1])在大多数任务上表现出色,特别是在复杂和状态跟踪任务上。

再来看不同模型在多查询联想记忆任务中的性能对比。横轴模型的尺寸,纵轴验证准确率,xLSTM1:1表现最佳,越难越好,Lama等Transformer 模型在较小和中等难度任务中表现优越。Mamba 略强。真是“长江后浪推前浪,一浪更比一浪强”

2、验证集困惑度比较

这个图展示了在使用 158 个Token 训练的 SlimPajama 数据集上,下一词预测性能比较。横轴为模型参数量,纵轴为验证困惑度,总体趋势都差不多,但xLSTM 明显更好。说明其在语言建模任务中的优势。

3、大规模语言建模实验

这个图展示了在使用 300B 个 Token 训练的 SlimPajama 数据集上,不同模型在下一词预测任务中的验证困惑度(Validation Perplexity)比较,特别是对长序列的外推性能。横轴为 token 数量,也就是序列长度。

4、语言基准测试

在使用 300B 个 Token 训练的 SlimPajama数据集上,不同模型在下一词预测任务中的验证困感度(Validation Perplexity)随参数数量变化的情况。验证困感度越低,表示模型的预测性能越好。

·所有模型的验证困感度随着参数数量的增加而下降,说明更大参数的模型在下一词预测任务上表现更好。
xLSTM 的优势:xLSTM 模型(特别是xLSTM[7:1]和xLSTM[1:0])在所有参数数量下都表现出色,验证困惑度较低,说明其在语言建模任务中的性能优越。
模型对比:xLSTM 模型比 Mamba 表现好,而 Mamba比Lama 表现好。这表明xLSTM 在处理大规模语言建模任务时,具有明显的优势。

七、小结

1.LSTM 的缺陷:作为一种时序建模的思想,它通过常量sigmoid 门控机制实现了对记忆的重组,循环训练和推理。但传统架构面临长期记忆处理效率低、记忆存储量小、并行化困难三大硬伤。
2.xLSTM 的原理:借助指数门控混合记忆和新内存结构,LSTM增强为 sLSTM和mLSTM。二者的结合构成了xLSTM 模块,进一步堆叠可以实现大模型化

3.实验对比:xLSTM 在语吉建模上相比于诸如Transformers和State Space Models等最新方法表现良好。扩展法则表明,更大的xLSTM 模型将是当前基于Transformer 技术的大型语言模型的有力竞争者。

4.未来发展:俗话说“以史为鉴,可以知兴替”LSTM 辉煌的过去证明了它在时序建模领域的王者地位,借助 xLSTM 的再度起,它很有可能深度影响其他深度学习领域,如强化学习、时间序列预测或物理系统建模等领域。

当 xLSTM 也能扩展到数十亿参数时,为我们展示了大模型发展的更多可能。一种架构可能过时,可能被不断超越,但是一种思想,一种理论却能不断推陈出新,与时俱进。正如文章最后的小结和预言。LSTM 能走多远,到目前为止,我们可以清晰的回答:”至少可以与当前的 SOTA技术(如Transformers或State Space Models)一样远

发表评论

您的电子邮箱地址不会被公开。 必填项已用*标注