Mamba
Mamba 是一种新的深度学习模型,主要应用于高效的长序列建模,旨在解决现有 Transformer 和其他序列模型在处理长序列任务时的效率瓶颈。它基于改进的选择性状态空间模型(Selective State Space Model, SSM),通过对输入内容的上下文感知动态调整模型行为,在计算效率和性能上都取得了显著进展。
具体来说,Mamba 通过使关键参数(如状态转移矩阵和激活函数)成为输入的函数,使得模型能够动态调整对输入的响应。这种特性允许 Mamba 更好地处理复杂的上下文依赖任务,比如语言建模和序列记忆。
与 Transformer 的二次时间复杂度相比,Mamba 的选择性状态空间方法能达到线性时间复杂度,同时在处理百万长度的序列时表现优异。
此外,Mamba 采用了 GPU 硬件友好的特别设计,优化了并行计算流程,比如引入了并行关联扫描和块矩阵分解算法,以提高训练和推理速度。这使得它的推理吞吐量达到 Transformer 的 5 倍。
Mamba 目前包括两个版本。第一个版本首次引入选择性状态空间模型,结合硬件感知算法提升了效率。Mamba 2 通过结构化状态空间对偶(SSD)框架,将 Transformer 的注意力机制和状态空间模型结合,进一步提升了性能和硬件适配性。
Mamba 是深度学习序列建模的一个有力新方向,尤其在需要长序列建模和高效推理的应用场景中展示了巨大的潜力。目前,Mamba 在语言、音频、基因组数据等领域的任务中均表现出领先性能,并在某些任务中超越了同等规模甚至更大规模的 Transformer 模型。
结构
下面介绍 Mamba 模型的核心工作原理,尤其是它如何通过选择性状态空间和输入依赖门控机制,实现高效的序列建模。
1. 输入扩展
给定输入序列表示 \(X \in \mathbb{R}^{n \times d_m}\),其中 \(n\) 是序列长度,\(d_m\) 是隐藏层维度。Mamba 将输入 \(X\) 投影到更高维的空间 \(d_e\),生成矩阵 \(H\):
\[H = X W_{\text{in}} \in \mathbb{R}^{n \times d_e}\]其中,\(W_{\text{in}} \in \mathbb{R}^{d_m \times d_e}\) 是一个可学习的投影矩阵。这一步的作用是增强模型的表示能力,为后续步骤提供更丰富的特征。
参数选择上,\(d_e = 2d_m\),以扩展状态空间的维度。
2. 短卷积平滑操作
模型接着对扩展后的输入 \(H\) 进行短卷积(Short Convolution, SC)操作。短卷积是一种轻量、高效的卷积操作,通常用于信号处理或序列建模中。它能够对输入信号做平滑处理,减少噪声的影响,使得后续模块能更清晰地建模关键特征。公式如下:
\[U = \text{SC}(H) = \text{SiLU}(\text{DepthwiseConv}(H, W_{\text{conv}}))\]其中,\(H\) 是输入特征矩阵,Depthwise Convolution 对输入特征矩阵,在每个通道内独立地进行卷积(而不是跨通道操作),因此计算量小,非常高效。\(W_{\text{conv}}\) 是卷积核,定义卷积操作的权重。在 Mamba 中,卷积核大小被设置为 \(4\),是经过优化的选择,确保高效的 GPU/TPU 加速性能,也能在计算效率和上下文捕获能力之间取得平衡。
SiLU (Sigmoid-Weighted Linear Unit)激活函数是一种平滑的非线性激活函数,数学公式定义为 \(\text{SiLU}(x) = x / (1 + e^{-x})\)。它在卷积后引入非线性变换,增强模型的表达能力。
在实际应用中,SC 能很好地结合并行化计算方法(如 Mamba 的硬件感知算法),进一步提高性能,因此被广泛应用于序列建模,如语言建模、语音处理和基因组数据分析;也因其计算开销低,特别适合实时信号处理应用,或作为特征增强模块,用于复杂深度学习模型的初始阶段。
3. 选择性门控机制
模型然后使用低秩投影和 Softplus 激活函数计算选择性门控值 \(\Delta\):
\[\Delta = \text{Softplus}(U W_r W_q + b) \in \mathbb{R}^{n \times d_e}\]其中,\(W_r \in \mathbb{R}^{d_e \times d_r}, W_q \in \mathbb{R}^{d_r \times d_e}\) 是低秩矩阵;\(b\) 是偏置,初始化时使 \(\Delta\) 的范围限定在 [0.001, 0.1],以优化初始学习表现。
参数选择上,\(d_r = d_m / 16\),作为低秩投影的维度。
这个门控机制动态选择重要的序列元素,为后续状态更新提供更精确的输入。
Softplus 是一种常用的激活函数,定义如下:
\[\text{Softplus}(x) = \ln(1 + e^x)\]它是一个平滑的非线性函数,与 ReLU (Rectified Linear Unit) 有相似的特性,但它对小于零的输入不会完全归零,而是接近零的值,这使得它更加平滑和连续,因此能够避免 ReLU 的一些缺点,如“神经元死亡问题”。
具体来说,Softplus 的特点是:
- 平滑性:Softplus 是连续可导的,与 ReLU 不同,它在零点处没有断点。
- 非线性:它可以为神经网络引入非线性,帮助模型学习复杂的模式。
- 值域范围:Softplus 的输出为正值,且随着输入变大,函数输出逐渐趋近于线性。
它有以下好处:
- 避免 ReLU 的“神经元死亡”问题:在 ReLU 中,输入为负时输出完全为零,可能导致某些神经元在训练中完全不更新。而 Softplus 在负输入下仍有非零输出,避免了这一问题,提高了训练的鲁棒性。
- 梯度稳定性:Softplus 是连续可导的,其梯度平滑,避免了 ReLU 在零点处的梯度不连续问题。这在一些需要梯度较高精度的优化任务中(如强化学习、物理建模)尤为重要。
- 数值稳定性:由于 Softplus 输出为正值,它在某些对非负输出有需求的任务中非常适用,例如概率建模中需要正数的分布参数(如方差)。
- 更适合低范围输入的场景:对于负值输入,Softplus 的增长比 ReLU 更缓慢,这可以帮助模型在处理较小输入范围时更稳定地学习。
Softplus 与其他激活函数的对比如下表所示:
激活函数 | 定义 | 优点 | 缺点 |
---|---|---|---|
ReLU | \(\max(0, x)\) | 简单高效,适合稀疏表示 | 梯度断点,神经元死亡问题 |
Sigmoid | \(\frac{1}{1 + e^{-x}}\) | 平滑,值域有限 | 梯度消失问题,数值不稳定 |
Softplus | \(\ln(1 + e^x)\) | 平滑,无神经元死亡问题 | 计算比 ReLU 更耗时 |
4. 输入依赖的状态空间参数
通过输入依赖机制计算状态空间模型(SSM)的参数 \(B\) 和 \(C\) \(B = U W_b \in \mathbb{R}^{n \times d_s}, \quad C = U W_c \in \mathbb{R}^{n \times d_s}\)
其中,\(W_b, W_c\) 是映射矩阵。参数选择上,\(d_s = 16\),是状态维度。
通过这样的输入依赖设计,模型能自适应地调整状态空间参数,从而提高长序列建模的表达能力和效率。
下面介绍 Mamba 模型的核心递归推断和输出生成机制。
1. 扩展状态空间的递归更新
在时间步 \(t\),Mamba 的选择性状态空间模型(Selective SSM,简称 S6)通过以下递归公式计算状态 \(Z_t\): \(Z_t = \exp(-\Delta_t \odot \exp(A)) \odot Z_{t-1} + \Delta_t \odot (B_t \otimes U_t)\)
其中,
- \(\Delta_t\):选择门控值(与输入 \(U_t\) 相关)。
- \(\odot\):按位乘积(点积)。
- \(\exp\):自然指数函数,逐元素计算。
- 矩阵 \(A\):可学习
- \(\otimes\):外积,生成更高维的张量。
- \(Z_t \in \mathbb{R}^{d_e \times d_s}\):扩展的状态空间表示,其中 \(d_s\) 是状态维度。
- \(B_t, U_t\):分别为输入相关的状态参数和卷积后的输入信号。
上式的第一个项 \(\exp(-\Delta_t \odot \exp(A)) \odot Z_{t-1}\) 是对前一时间步的状态 \(Z_{t-1}\) 进行指数衰减,模拟递归过程中的记忆消散。第二项 \(\Delta_t \odot (B_t \otimes U_t)\):引入当前时间步的新信息,权重由选择性门控 \(\Delta_t\) 动态调整。
初始状态 \(Z_0 = 0\),初始矩阵 \(A_{ij} = \log(j)\)(参考 S4D-Real 方法),以确保状态更新具有合理的时间动态特性。
2. 输出生成
模型通过以下公式计算输出 \(Y_t\): \(Y_t = Z_t C_t + D \odot U_t\)
其中,\(C_t\) 是与输入相关的输出投影矩阵,\(D \in \mathbb{R}^{d_e}\) 是可学习向量,初始化为 \(D_i = 1\)。
上式中,\(Z_t C_t\) 代表输出中状态空间的贡献,\(D \odot U_t\) 代表输入信号的直接贡献。
\(D\) 的引入确保了模型能够在状态空间贡献不足时仍然对输入信号做出反应,增强模型的鲁棒性。
3. 并行化和硬件优化
传统递归模型(如 RNN、LSTM)在处理长序列时,存在两个瓶颈。首先,顺序依赖:每个时间步的计算依赖于前一时间步的结果,因此无法并行。其次,低硬件利用率:顺序计算无法充分利用 GPU/TPU 的并行能力,导致性能低下。
为了提升效率,Mamba 使用硬件感知的并行扫描算法来加速递归计算。并行扫描算法针对现代硬件环境(如 GPU 和 TPU),通过将递归结构转化为更高效的并行操作来解决传统递归神经网络(如 RNN)在长序列推断中的低效问题。
具体来说,并行扫描算法(Parallel Scan Algorithm)包括以下步骤:
(1) 递归公式的分解
递归模型的核心通常是状态更新公式: \(Z_t = f(Z_{t-1}, X_t)\) 其中,\(Z_t\) 是时间步 \(t\) 的状态,\(X_t\) 是当前的输入。该公式的顺序依赖性是并行化的主要障碍。
Mamba 通过状态空间模型(SSM)的数学分解,借助矩阵运算和并行前缀扫描,将递归公式转化为硬件友好的形式。以下是其具体的数学描述:
在 SSM 中,时间步 \(t\) 的递归公式通常写为: \(Z_t = A Z_{t-1} + B X_t,\) 其中:
- \(Z_t \in \mathbb{R}^d\) 是时间步 \(t\) 的状态,
- \(A \in \mathbb{R}^{d \times d}\) 是状态转移矩阵,
- \(B \in \mathbb{R}^{d \times m}\) 是输入投影矩阵,
- \(X_t \in \mathbb{R}^m\) 是时间步 \(t\) 的输入。
这种递归计算存在顺序依赖,难以并行。
Mamba 使用状态空间的特性,将递归公式转化为并行计算形式。它首先将递归公式展开为矩阵形式,通过累积计算得到: \(Z_t = A^t Z_0 + \sum_{k=0}^{t-1} A^k B X_{t-k}.\)
在此公式中,\(A^t\) 和 \(A^k\) 的幂次可以通过矩阵运算并行计算,累积求和可通过前缀扫描高效实现。
(2) 前缀扫描分解
前缀扫描(prefix scan)算法将上述累积操作转化为并行化计算。令 \(U_k = A^k B X_{t-k},\)
它将输入信号通过 \(A^k B\) 投影到状态空间。
我们有 \(Z_t = \sum_{k=0}^t U_k,\)
这可以在并行硬件中通过分治方法高效计算。
(3) 状态分块并行
对于长序列,Mamba 将长序列分解为多个子块(blocks),每个子块的计算在并行线程中独立进行。
\[Z^{(i)} = \sum_{j \in \text{block}_i} U_j,\]子块的结果再通过全局归约(reduction)步骤合并,完成全序列的状态计算。
这使得训练和推断时间随序列长度线性增长,而非传统 RNN 的二次增长。
Mamba 也基于 GPU 的高吞吐量矩阵计算,结合 CUDA 优化库如 FlashAttention,能够充分利用现代硬件的并行能力,大幅缩短训练和推断时间,可处理更长的序列。
4. 最终输出与门控机制
最终输出 \(O\) 通过门控机制生成: \(O = Y \odot \text{SiLU}(XW_g)W_{\text{out}} \in \mathbb{R}^{n \times d_m}\)
其中 \(W_g, W_{\text{out}}\) 为可学习参数,分别用于门控权重计算和输出投影。这一 Gated Linear Unit (GLU) 通过 \(\text{SiLU}(XW_g)\) 实现非线性门控,动态调整每个输入序列元素对最终输出的贡献。门控机制确保了输出对输入序列的语义保持敏感,同时利用 \(\text{SiLU}\) 的非线性增强了表达能力。
总之,Mamba 通过递归结构捕获时间序列的依赖性,并结合门控和扩展状态空间设计,实现了动态、高效的长序列建模。这种架构特别适合需要细粒度上下文建模的任务,例如语言理解和基因组序列分析。
混合模型
Mamba 层主要通过递归状态捕获长时依赖关系,但可能对序列中的中短期上下文建模不足,因此,微软提出的 Simple Mamba 模型设计了 Sliding Window Attention (SWA) Layer,通过滑动窗口注意力机制来补充 Mamba 层。
Sliding Window Attention 的工作原理如下。
首先,它对输入序列进行窗口滑动注意力操作。以窗口大小 \(w = 2048\) 为例,SWA 在输入序列上以固定大小的滑动窗口进行操作。在每个窗口内计算自注意力,让模型专注于当前上下文的子集。这个滑动窗口逐步覆盖整个序列,保持时间复杂度对序列长度 \(n\) 为线性(即 \(O(n)\))。
其次,它引入相对位置编码 (RoPE),使用 RoPE (Rotary Position Embedding) 提供相对位置信息,使得注意力机制能感知序列中元素之间的相对顺序。RoPE 优于绝对位置编码,特别是在滑动窗口的应用中,能更自然地处理局部关系。
微软用 FlashAttention 2 对 SWA 进行了高效实现,达到了与 Mamba 层相当的训练速度。FlashAttention 2 是一种优化的注意力计算方法,能够高效利用 GPU 硬件,通过块级并行化和减少内存访问实现极快的训练速度。
通过实验,微软最终把窗口大小设置为 2048。这是一个在效率和准确性之间的折中,能够捕获足够多的局部上下文信息。
总之,Mamba 的递归状态适合长时序建模,但对短期复杂依赖的捕获有限。SWA 能直接访问滑动窗口内的上下文,通过注意力机制提取清晰的中短期历史信号,因此和 Mamba 构成有益的补充。SWA 的实现计算效率高,其滑动窗口注意力保持了对序列长度 \(n\) 的线性复杂度,使其适合处理超长序列,而它基于 FlashAttention 的高效实现进一步加速了模型训练。
双向 Mamba
Bidirectional Mamba 是一种扩展 Mamba 模型的架构。它借鉴了双向 LSTM 的设计,包括两个 Mamba,其中一个从输入序列的开始往后,逐步积累信息,另一个从序列的结尾向前,捕获影响当前时刻的信息。
为了减少计算负担,它的部分模型参数(如卷积核和投影矩阵)在前后向流中共享,以减少计算负担。但每个方向维护独立的递归状态,避免上下文信息干扰。
最终的模型输出结合前向流和后向流的信息,通过加权或其他融合策略整合双向信息。
通过这种方法,它能够同时利用序列的前向(past context)和后向(future context)信息,提升对双向依赖的建模能力。比如,在 NLP 任务中,理解句子的前后文是高质量语义建模的关键。当前词可能同时依赖于前文和后文,因此双向机制能更全面地捕捉语言特征。此外,双向机制也可以提高模型鲁棒性:在某些任务(如噪声数据处理)中,双向机制提供了冗余信息来源,使模型对局部错误的敏感性降低。
Index | Previous | Next |