pytorch分布式 训练参数设置

# 自己的数据获取
dataset = MyDataset(input_size, data_size)
 
# 使用 DistributedSampler
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
 
trainloader = DataLoader(dataset=dataset,
                         pin_memory=true,
                         shuffle=(train_sampler is None),   # 使用分布式训练 shuffle 应该设置为 False
                         batch_size=args.batch_size,
                         num_workers=args.workers,
                         sampler=train_sampler)

需要注意的几个参数:batch_size、num_workers、shuffle、pin_memory在进行多机多卡以及单机多卡的设置。

1、 Batch_size设置:

Dataparallel : 设置 batch_size 是指总多卡的Batch size,数据被直接划分到多个 GPU 上

DistributedDataParallel batch size 设置成单卡一样即可,因为各个GPU对应的进程独立从磁盘中加载数据这里的 Batch_size指的是单卡的。

2、shuffle设置:

shuffle:

Dataparallel  :设置 ‘shuffle’: True

DistributedDataParallel  :为了能够按顺序划分数据子集,拿到不同部分数据,所以数据集不能够进行随机打散,所以用了参数 ‘shuffle’: False

3、 pin_memory 设置:

是否提前申请CUDA内存(默认为False,但有说法除非数据集很小,否则在N卡上推荐总是打开)。

如果开了pin memory:
每个worker都需要缓存一个batch的数据.
batch size和num_workers都大, 显存会炸

为什么 设置 pip_memory=true, 看解释:
多GPU训练的时候注意机器的内存是否足够(一般内存为显卡显存x2),如果不够,建议关闭pin_memory(锁页内存)选项。
采用DistributedDataParallel多GPUs训练的方式比DataParallel更快一些,如果你的Pytorch编译时有nccl的支持,那么最好使用DistributedDataParallel方式。
关于什么是锁页内存:
pin_memory就是锁页内存,创建DataLoader时,设置pin_memory=True,则意味着生成的Tensor数据最开始是属于内存中的锁页内存,这样将内存的Tensor转义到GPU的显存就会更快一些。
主机中的内存,有两种存在方式,一是锁页,二是不锁页,锁页内存存放的内容在任何情况下都不会与主机的虚拟内存进行交换(注:虚拟内存就是硬盘),而不锁页内存在主机内存不足时,数据会存放在虚拟内存中。显卡中的显存全部是锁页内存,当计算机的内存充足的时候,可以设置pin_memory=True。当系统卡住,或者交换内存使用过多的时候,设置pin_memory=False。因为pin_memory与电脑硬件性能有关,pytorch开发者不能确保每一个炼丹玩家都有高端设备,因此pin_memory默认为False。

当计算机的内存充足的时候,可以设置pin_memory=True。当系统卡住,或者交换内存使用过多的时候,设置pin_memory=False。pin_memory默认为False。

4、 num_workers 设置:num_worker的设置值一般是所运行机子上的CPU核心数

可以设置set num_workers =4 x number of available GPUs 

um_worker大: 下一轮迭代的batch可能在上一轮/上上一轮…迭代时已经加载好了。 坏处是GPU memory开销大 (这是开了pin memory的情况吧) ,也加重了CPU负担。

CPU的物理个数:grep ‘physical id’ /proc/cpuinfo | sort | uniq | wc -l 结果为2,说明CPU有两个。 每个CPU的核数:cat /proc/cpuinfo |grep “cores”|uniq 10,说明每个10核。 cpu核数 = 2×10

1、cpu个数

grep ‘physical id’ /proc/cpuinfo | sort -u

2、核心数【当数据集较大时建议采用,num_works一般设置为(CPU 核心数 +-1)为最佳】

grep ‘core id’ /proc/cpuinfo | sort -u | wc -l

3、线程数

grep ‘processor’ /proc/cpuinfo | sort -u | wc -l

一般建议 num_workers 的值接近 CPU 核心数,但不要超过,以免导致过多的上下文切换。

如果数据集较大且预处理复杂,较高的 num_workers 值可能会更有效。反之,如果数据集较小或者预处理简单,则可能不需要太多的工作线程。

Num workers:只要你的 GPU 计算占用没有用满,说明 GPU 要等数据准备。可以试着增加进程数目,同时观察是否是硬盘 IO 瓶颈,如果是多机训练,还要注意网络瓶颈。不过,最大也不能超过核心数,一般还要减一点,因为主进程,多卡多进程训练,都会占用核心。

num_worker通过影响数据加载速度,从而影响训练速度。 每轮dataloader加载数据时:dataloader一次性创建num_worker个worker,worker就是普通的工作进程。并用batch_sampler将指定batch分配给指定的worker,worker将它负责的batch加载进RAM。然后,dataloader从RAM中找本轮迭代要用的batch,如果找到了,就使用;如果没找到,就用num_worker个worker继续加载batch到RAM,直到dataloader在RAM中找到目标batch。

pytorch单机多卡训练【分布式数据并行 和 数据并行方案】

https://github.com/KaiiZhang/DDP-Tutorial/blob/main/DDP-Tutorial.md

数据并行和分布式数据并行方案:

第一: 数据并行 , 开一个进程(process),该进程下每个线程(threading)负责一部分数据,分别跑在不同卡上,前向传播,devices各玩各的,计算loss时候需要所有devices的输出输送到主GPU【默认device0】上计算梯度均值,并更新device0上的参数,然后将参数广播到其他device上。总结:单机-多线程,通过torch.nn.DataParallel 实现。
第二: 分布式数据并行,开多个进程,一个进程运行在一张卡上,每个进程负责一部分数据。在各进程梯度计算完成之后,各进程需要将梯度进行汇总平均,然后再由 rank=0 的进程,将其 broadcast 到所有进程。各进程用该梯度来更新参数。由于各进程中的模型,初始参数一致 (初始时刻进行一次 broadcast),而每次用于更新参数的梯度也一致,因此,各进程的模型参数始终保持一致。

总结:单机/多机-多进程,通过torch.nn.parallel.DistributedDataParallel 实现。

毫无疑问,第一种简单,第二种复杂,毕竟 进程间 通信比较复杂。

torch.nn.DataParallel 和 torch.nn.parallel.DistributedDataParallel,下面简称为DPDDP

总结: 两个函数主要用于在多张显卡上训练模型,也就是所谓的分布式训练

数据并行 torch.nn.DataParallel  :

原理:

  • 网络前向传播前,输入数据被分成几份送到不同显卡上,网络模型每个显卡上拷贝一份。
  • 前向传播时,devices各玩各的。
  • 前向传播完成后,每张显卡上的网络输出会送到主device上(默认第一张卡),在主device上计算loss。然后,loss送给每个device,每个device计算得到梯度,再把梯度送到主device上,主device对汇总得到的梯度求均值后,更新主device上的网络参数。最后,将更新后的网络权重广播(broadcast)到其它device上,实现所有device网络权重同步。
  • torch.nn.DataParallel是把每张卡的输出聚合到GPU0上,然后在GPU0上与label计算loss,根据计算图反向传播,让每张卡上获得自己的梯度。优化器则对梯度进行聚合,在主GPU更新模型参数,再把新的参数分发到每个GPU。

从上面介绍可知,DataParallel 对主device依赖较高,会造成负载不均衡,限制模型训练速度。

DP使用教程:

主程序DP_main.py中,下面这行代码实现数据并行化分布式训练。

相比单卡单机代码:只需要修改以下代码:

model_train = torch.nn.DataParallel(model)	

通过终端运行命令,

CUDA_VISIBLE_DEVICES=0,1 python3 DP_main.py

DP_main.py代码:

import torch
import torchvision
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
from net import ToyModel
import torch.optim as optim


#---------------------------#
#   获得学习率
#---------------------------#
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

#---------------------------#
#   获得数据集
#---------------------------#
def get_dataset():
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    CIFAR10_trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 
        download=True, transform=transform_train)
    
    # ----------------------------------------------------------#
    #   num_workers:加载数据集使用的线程数
    #   pin_memory=True:锁页内存, 可以加速数据读取. (可能会导致Bug)
    # ----------------------------------------------------------#
    trainloader = torch.utils.data.DataLoader(CIFAR10_trainset, 
        batch_size=16, num_workers=2, pin_memory=True)
    return trainloader

#---------------------------#
#   训练
#---------------------------#
def train(model, device, trainloader, optimizer, loss_func, print_frequence, epoch):
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_func(outputs, targets)
        loss.backward()
        optimizer.step()

        # loss.item()把其中的梯度信息去掉,没.item()可能会导致程序所占内存一直增长,然后被计算机killed
        train_loss += loss.item()       
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        if batch_idx % print_frequence == print_frequence - 1 or print_frequence == trainloader.__len__() - 1:
            print('epoch: %d | Loss: %.3f | Acc: %.3f%% (%d/%d)' % (
                epoch, train_loss / (batch_idx + 1), 100. * correct / total, correct, total))
    torch.save(model.state_dict(), "%d.ckpt" % epoch)	
    # torch.save(model.module.state_dict(), "%d.ckpt" % epoch)	用双卡训练保存权重,重新加载时,也需要这样保存,否则,权重前面会多module
    
    # -------------------------------------#
    #   只是想看看lr有没有衰减
    # -------------------------------------#
    lr = get_lr(optimizer)
    print("lr:", lr)
    lr_scheduler.step()


if __name__ == '__main__':
    trainloader = get_dataset()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = ToyModel()
    print(model)

    model_train = model.train()
    if torch.cuda.is_available():   
        model_train = torch.nn.DataParallel(model)  # 单GPU跑套DP的话,指标可能会降
        cudnn.benchmark = True
        model_train = model_train.cuda()            # 等效于model_train = model_train.to(device)

    loss_func = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model_train.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    # -------------------------------------#
    #   step_size控制多少个epoch衰减一次学习率
    # -------------------------------------#
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)   
    
    print_frequence = 500
    epochs = 100
    for epoch in range(0, epochs):
        train(model_train, device, trainloader, optimizer, loss_func, print_frequence, epoch)

分布式并行DistributedDataParallel 

  • 更快的训练速度
  • 多进程的运行方式
  • 支持单机多卡和多机多卡
  • 平衡的GPU使用

DDP原理:

先说分布式几个名词:
一个world里进程个数为world_size,全局看,每个进程都有一个序号rank;分开看,一个进程在每台机器里面也有序号local_rank。

  • group:进程组,默认一个组,即一个world
  • world_size:全局进程个数
  • rank:进程序号,用于进程间通信。rank=0为GPU主卡,主要用于多机多卡。本文中仅涉及到一台机器内多张卡。
  • locak_rank:进程(一台机器)内的GPU编号,通过指令torch.distributed.run自动指定,不需要用户输入该参数。

DDP 在每次迭代中,操作系统会为每个GPU创建一个进程,每个进程具有自己的 optimizer ,并独立完成所有的优化步骤,进程内与一般的训练无异。在各进程梯度计算完成之后,各进程需要将梯度进行汇总平均,然后再由 rank=0 的进程,将其 broadcast 到所有进程。各进程用该梯度来更新参数。由于各进程中的模型,初始参数一致 (初始时刻进行一次 broadcast),而每次用于更新参数的梯度也一致,因此,各进程的模型参数始终保持一致。

而在 DataParallel 中,全程维护一个 optimizer,对各 GPU 上梯度进行求和,在主 GPU 进行参数更新,之后再将模型参数 broadcast 到其他 GPU。相较于 DP,DDP传输的数据量更少,速度更快,效率更高。

DDP的流程示意图如上图所示,DDP需要额外的建立进程组阶段(Construction)。在Construction阶段需要首先明确通信协议和总进程数。通信协议是实现DDP的底层基础,我们在之后单独介绍。总进程数就是指有多少个独立的并行进程,被称为worldsize。根据需求每个进程可以占用一个或多个GPU,但并不推荐多个进程共享一个GPU,这会造成潜在的性能损失。为了便于理解,在本文的所有示例中我们假定每个进程只占用1个GPU,占用多个GPU的情况只需要简单的调整GPU映射关系就好。

并行组建立之后,每个GPU上会独立的构建模型,然后GPU-1中模型的状态会被广播到其它所有进程中以保证所有模型都具有相同的初始状态。值得注意的是Construction只在训练开始前执行,在训练中只会不断迭代前向和后向过程,因此不会带来额外的延迟。

相比于DataParallel,DDP的前向后向过程更加简洁。推理、损失函数计算,梯度计算都是并行独立完成的。DDP实现并行训练的核心在于梯度同步。梯度在模型间的同步使用的是allreduce通信操作,每个GPU会得到完全相同的梯度。如图中后向过程的步骤2,GPU间的通信在梯度计算完成后被触发(hook函数)。图中没有画出的是,通常每个GPU也会建立独立的优化器。由于模型具有同样的初始状态和后续相同的梯度,因此每轮迭代后不同进程间的模型是完全相同的,这保证了DDP的数理一致性。

为了优化性能,DDP中针对allreduce操作进行了更深入的设计。梯度的计算过程和进程间的通信过程分别需要消耗一定量的时间。等待模型所有的参数都计算完梯度再进行通信显然不是最优的。如下图所示,DDP中的设计是通过将全部模型参数划分为无数个小的bucket,在bucket级别建立allreduce。当所有进程中bucket0的梯度计算完成后就立刻开始通信,此时bucket1中梯度还在计算。这样可以实现计算和通信过程的时间重叠。这种设计能够使得DDP的训练更高效。

在最后我们对DDP的通信部分进行介绍。DDP后端的通信由多种CPP编写的协议支持,不同协议具有不同的通信算子的支持,在开发中可以根据需求选择。

对于CV和NLP常用GPU训练的任务而言,选择Gloo或NCCL协议即可。一个决定因素是你使用的计算机集群的网络环境:

  • 当使用的是Ethernet(以太网,大部分机器都是这个环境):那么优先选择NCCL,具有更好的性能;如果在使用中遇到了NCCL通信的问题,那么就选择Gloo作为备用。(经验:单机多卡直接NCCL;多机多卡先尝试NCCL,如果通信有问题,而且自己解决不了,那就Gloo。
  • 当使用的是InfiniBand:只支持NCCL。

另一个决定性因素是二者支持的算子范围不同,因此在使用时还需要结合代码里的功能来确定。下图记录了每种通信协议能够支持的算子,Gloo能够实现GPU中最基本的DDP训练,而NCCL能够支持更加多样的算子.

不同Backend的算子支持情况

DDP使用:

  • 设备间通信
    为了保证不同卡上的模型参数同步,设备间需要通讯。
    设备间通讯通过后端backend实现,GPU上用nccl,CPU上用gloo
torch.distributed.init_process_group('nccl')
  • 指定GPU
    指定使用哪些GPU,作用相当于CUDA_VISIBLE_DEVICES命令。
torch.cuda.set_device(args.local_rank)   
  • 构造模型
    构造DDP model,[args.local_rank]是一个list
model = DistributedDataParallel(model, device_ids=[args.local_rank], 
   										output_device=args.local_rank)
  • 构建数据集
    构建数据集中需要用到train_sampler来shuffle数据,继而实现把trainset中的样本随机分配到不同的GPU上,
train_sampler = torch.utils.data.distributed.DistributedSampler(trainset)
# ---------------------------------------------------------------#
#   sampler参数和shuffle参数是互斥的,两个传一个就好,都用于数据打乱。
# ----------------------------------------------------------------#
trainloader = torch.utils.data.DataLoader(trainset, 
        batch_size=16, num_workers=2, sampler=train_sampler)
  • 数据放到多卡上
    模型、损失函数、输入数据要放到多卡上,代码例如:
data = data.to(args.local_rank)		# 等效于data.cuda(args.local_rank)

通过终端运行命令,

# CUDA_VISIBLE_DEVICES="gpu_0, gpu1,..." python -m torch.distributed.launch --nproc_per_node n_gpus DDP_main.py
CUDA_VISIBLE_DEVICES="0,1" python -m torch.distributed.launch --nproc_per_node=2 DDP_main.py # 因为是单机多卡,所以只需要指定nproc_per_node【GPU数量】即可。local_rank不需要设置。
大概内容就是,这个命令行参数“–loacl_rank”是必须声明的,但它不是由用户填写的,而是由pytorch为用户填写,也就是说这个值是会被自动赋值为当前进程在本机上的rank

DDP_main.py中内容如下:

import argparse         # 从命令行接受参数
from tqdm import tqdm   # 用于进度条
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from net import ToyModel
import torchvision.transforms as transforms
# ---------------------------#
#   下面两个包用于分布式训练
# ---------------------------#
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

# ---------------------------#
#   获得数据集
# ---------------------------#
def get_dataset():
    transform = torchvision.transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 
        download=True, transform=transform)
    # -----------------------------------------------#
    #   train_sampler主要用于DataLoader中shuffle数据
    #       把trainset中的样本随机分配到不同的GPU上
    # -----------------------------------------------#
    train_sampler = torch.utils.data.distributed.DistributedSampler(trainset)
    # ---------------------------------------------------------------#
    #   batch_size:每个进程(GPU/卡)下的batch_size。
    #       总batch_size = 这里的batch_size * 进程并行数
    #       全局进程个数world_size = 节点数量 * 每个节点上process数量
    #       总卡数                =  电脑数  * 每台电脑上有多少张卡
    #   sampler参数和shuffle参数是互斥的,两个传一个就好,都用于数据打乱。
    #   在DDP中,用sampler参数
    # ----------------------------------------------------------------#
    trainloader = torch.utils.data.DataLoader(trainset, 
        batch_size=16, num_workers=2, sampler=train_sampler)
    return trainloader

#---------------------------#
#   训练
#---------------------------#
def train(model, trainloader, optimizer, loss_func, lr_scheduler, epoch):
    model.train()
    iterator = tqdm(range(epoch))       # 为了进度条显示而已
    for epoch in iterator:
        # ------------------------------------------------------------------#
        #   设置sampler的epoch,DistributedSampler需要这个来指定shuffle方式,
        #   通过维持各个进程之间的相同随机数种子使不同进程能获得同样的shuffle效果。
        #   这一步是必须的,让数据充分打乱,训练效果更好
        # ------------------------------------------------------------------#
        trainloader.sampler.set_epoch(epoch)

        for data, label in trainloader:
            data, label = data.to(args.local_rank), label.to(args.local_rank)
            optimizer.zero_grad()
            prediction = model(data)
            loss = loss_func(prediction, label)
            loss.backward()
            iterator.desc = "loss = %0.3f" % loss
            optimizer.step()
        # ------------------------------------------------------------------#
        #   save模型的时候:保存的是model.module而不是model,
        #       因为model其实是DDP model,参数是被`model=DDP(model)`包起来的。
        #   只需要在进程0(local_rank=0)上保存一次就行了,避免多次重复保存。
        # ------------------------------------------------------------------#
        if dist.get_rank() == 0:        # 等效于 if local_rank == 0:
            torch.save(model.module.state_dict(), "%d.ckpt" % epoch)
        
        lr_scheduler.step()

# -----------------------------------------------#
# 初始化配置local_rank配置
# -----------------------------------------------#
parser = argparse.ArgumentParser()
# local_rank:当前这个节点上的第几张卡,从外部传入
#   该步骤必须有,launch会自动传入这个参数
parser.add_argument("--local_rank",help="local device id on current node", type=int)
args = parser.parse_args()
local_rank = args.local_rank        # 纯属想写代码时用local_rank还是args.local_rank都行
print('local_rank:', args.local_rank)
"""
local_rank: 0
local_rank: 1
"""


if __name__ == "__main__":
    # DDP 初始化
    torch.cuda.set_device(args.local_rank)   # 作用相当于CUDA_VISIBLE_DEVICES命令,修改环境变量
    dist.init_process_group(backend='nccl')  # 设备间通讯通过后端backend实现,GPU上用nccl,CPU上用gloo

    # 准备数据,要在DDP初始化之后进行
    trainloader = get_dataset()

    # 初始化model
    model = ToyModel().to(args.local_rank)    # 等效于model = ToyModel().cuda(args.local_rank)

    # Load模型参数要在构造DDP model之前,且只需要在 master卡 上加载即可
    ckpt_path = None
    if dist.get_rank() == 0 and ckpt_path is not None:
        model.load_state_dict(torch.load(ckpt_path))

    # 构造DDP model
    model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank)

    # 初始化optimizer,要在构造DDP model之后
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # 学习率衰减方式
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)   

    # 初始化loss
    loss_func = nn.CrossEntropyLoss().to(args.local_rank)

    # 模型训练
    train(model, trainloader, optimizer, loss_func, lr_scheduler, epoch=100)
# ----------------------------------------------------------------------------------#
#   CUDA_VISIBLE_DEVICES:来决定使用哪些GPU,个数和后面n_gpus相同
#   torch.distributed.launch:启动DDP模式,构建多个进程,也会向代码中传入local_rank参数,
#       没有CUDA_VISIBLE_DEVICES限制的话,传入为从 0 到 n_gpus-1 的索引
#   --nproc_per_node=n_gpus:单机多卡,用几个gpu
# -----------------------------------------------------------------------------------#
# 用 2 张卡跑
CUDA_VISIBLE_DEVICES="0,1" python -m torch.distributed.launch --nproc_per_node 2 DDP_main.py
# 用 3 张卡跑     
CUDA_VISIBLE_DEVICES="1,2,3" python -m torch.distributed.launch --nproc_per_node 3 DDP_main.py  

pytorch多机多卡训练【DistributedDataParallel】

https://github.com/KaiiZhang/DDP-Tutorial/blob/main/DDP-Tutorial.md#distributeddataparallel

原理

DDP的流程示意图如上图所示,DDP需要额外的建立进程组阶段(Construction)。在Construction阶段需要首先明确通信协议和总进程数。通信协议是实现DDP的底层基础,我们在之后单独介绍。总进程数就是指有多少个独立的并行进程,被称为worldsize。根据需求每个进程可以占用一个或多个GPU,但并不推荐多个进程共享一个GPU,这会造成潜在的性能损失。为了便于理解,在本文的所有示例中我们假定每个进程只占用1个GPU,占用多个GPU的情况只需要简单的调整GPU映射关系就好。

并行组建立之后,每个GPU上会独立的构建模型,然后GPU-1中模型的状态会被广播到其它所有进程中以保证所有模型都具有相同的初始状态。值得注意的是Construction只在训练开始前执行,在训练中只会不断迭代前向和后向过程,因此不会带来额外的延迟。

相比于DataParallel,DDP的前向后向过程更加简洁。推理、损失函数计算,梯度计算都是并行独立完成的。DDP实现并行训练的核心在于梯度同步。梯度在模型间的同步使用的是allreduce通信操作,每个GPU会得到完全相同的梯度。如图中后向过程的步骤2,GPU间的通信在梯度计算完成后被触发(hook函数)。图中没有画出的是,通常每个GPU也会建立独立的优化器。由于模型具有同样的初始状态和后续相同的梯度,因此每轮迭代后不同进程间的模型是完全相同的,这保证了DDP的数理一致性。

为了优化性能,DDP中针对allreduce操作进行了更深入的设计。梯度的计算过程和进程间的通信过程分别需要消耗一定量的时间。等待模型所有的参数都计算完梯度再进行通信显然不是最优的。如下图所示,DDP中的设计是通过将全部模型参数划分为无数个小的bucket,在bucket级别建立allreduce。当所有进程中bucket0的梯度计算完成后就立刻开始通信,此时bucket1中梯度还在计算。这样可以实现计算和通信过程的时间重叠。这种设计能够使得DDP的训练更高效。

在最后我们对DDP的通信部分进行介绍。DDP后端的通信由多种CPP编写的协议支持,不同协议具有不同的通信算子的支持,在开发中可以根据需求选择。

对于CV和NLP常用GPU训练的任务而言,选择Gloo或NCCL协议即可。一个决定因素是你使用的计算机集群的网络环境:

  • 当使用的是Ethernet(以太网,大部分机器都是这个环境):那么优先选择NCCL,具有更好的性能;如果在使用中遇到了NCCL通信的问题,那么就选择Gloo作为备用。(经验:单机多卡直接NCCL;多机多卡先尝试NCCL,如果通信有问题,而且自己解决不了,那就Gloo。)
  • 当使用的是InfiniBand:只支持NCCL。

另一个决定性因素是二者支持的算子范围不同,因此在使用时还需要结合代码里的功能来确定。下图记录了每种通信协议能够支持的算子,Gloo能够实现GPU中最基本的DDP训练,而NCCL能够支持更加多样的算子

综上,得益于DDP的分布式并行设计,DDP并不受PythonGIL争用的影响,是以多进程的方式运行的。这也使得DDP可以支持多机多卡的训练。我们将DDP的优缺点概括如下:

不同Backend的算子支持情况

优点

  • 更快的训练速度
  • 多进程的运行方式
  • 支持单机多卡和多机多卡
  • 平衡的GPU使用

缺点

  • 需要更多的代码书写和设计

代码实现和参数讲解:

本文首先会基于MNIST图像分类建立一个最小原型,然后逐步改进它以实现多机多卡的训练和混合精度的支持。在讲述的思路上本文借鉴了Kevin Kaichuang Yang的教程,但在实现细节上有较大的差异。特别的是本文增加了对DDP启动方式的探讨,并且介绍了多进程通信操作的使用样例。

名词解释:一个world里进程个数为world_size【对于2卡2GPU, world_size =4】,全局看,每个进程都有一个序号rank【0为主机GPU主卡】;分开看,一个进程在每台机器里面也有序号local_rank。

  • group:进程组,默认一个组,即一个world
  • world_size:全局进程个数【对于2卡2GPU, world_size =4】
  • rank:进程序号,用于进程间通信。rank=0为GPU主卡,主要用于多机多卡。本文中仅涉及到一台机器内多张卡。
  • locak_rank:进程内的GPU编号,通过指令torch.distributed.run自动指定,不需要认为设置。

非多进程示例

首先引入了所有用到的库。

from datetime import datetime
import argparse
import torchvision
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import torch.distributed as dist
from tqdm import tqdm

定义一个简单的卷积神经网络模型。

class ConvNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc = nn.Linear(7*7*32, num_classes)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out

定义主函数,添加一些启动脚本的可选参数。

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-g', '--gpuid', default=0, type=int,
                        help="which gpu to use")
    parser.add_argument('-e', '--epochs', default=2, type=int, 
                        metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('-b', '--batch_size', default=4, type=int, 
                        metavar='N',
                        help='number of batchsize')         

    args = parser.parse_args()
    train(args.gpuid, args)

然后给出训练函数的详细内容。

def train(gpu, args):
    model = ConvNet()
    model.cuda(gpu)
    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().to(gpu)
    optimizer = torch.optim.SGD(model.parameters(), 1e-4)

    # Data loading code
    train_dataset = torchvision.datasets.MNIST(root='./data',
                                               train=True,
                                               transform=transforms.ToTensor(),
                                               download=True)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=0,
                                               pin_memory=True,
                                               sampler=None)

    start = datetime.now()
    total_step = len(train_loader)
    for epoch in range(args.epochs):
        model.train()
        for i, (images, labels) in enumerate(tqdm(train_loader)):
            images = images.to(gpu)
            labels = labels.to(gpu)
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (i + 1) % 100 == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, args.epochs, i + 1, total_step,
                                                                   loss.item()))
    print("Training complete in: " + str(datetime.now() - start))

最后确保主函数被启动。

if __name__ == '__main__':
    main()

以上是我们的MNIST图像分类最小原型,可以通过如下命令启动在指定单个GPU上的训练:

python train.py -g 0

多进程示例

在开始对最小原型的改造之前,我们还需要交代一些事情。在DDP的代码实现中,最重要的步骤之一就是初始化。所谓初始化对应于上文介绍的Construction阶段,每个进程中需要指明几个关键的参数:

  • backend:明确后端通信方式,NCCL还是Gloo
  • init_method:初始化方式,TCP还是Environment variable(Env),可以简单理解为进程获取关键参数的地址和方式
  • world_size:总的进程数有多少
  • rank:当前进程是总进程中的第几个

初始化方式不同会影响代码的启动部分。本文会分别给出TCP和ENV模式的样例。TCP模式

让我们先从TCP开始,注意那些标记被更改的代码部分:

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-g', '--gpuid', default=0, type=int,
                        help="which gpu to use")
    parser.add_argument('-e', '--epochs', default=1, type=int, 
                        metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('-b', '--batch_size', default=4, type=int, 
                        metavar='N',
                        help='number of batchsize')   
    ##################################################################################
    parser.add_argument('--init_method', default='tcp://localhost:18888',            #
                        help="init-method")                                          #
    parser.add_argument('-r', '--rank', default=0, type=int,                         #
                    help='rank of current process')                                  #
    parser.add_argument('--world_size', default=2, type=int,                         #
                        help="world size")                                           #
    parser.add_argument('--use_mix_precision', default=False,                        #
                        action='store_true', help="whether to use mix precision")    #
    ##################################################################################                  
    args = parser.parse_args()
    train(args.gpuid, args)

在main函数中需要增加了以下参数:

  • args.init_method:url地址,用来指明的初始化方法。在tcp初始化方法中,其格式应为:tcp:[ IP ]:[ Port ] 。IP为rank=0进程所在的机器IP地址,Port为任意一个空闲的端口号。当采用的是单机多卡模式时,IP可以默认为//localhost
  • args.rank:当前进程在所有进程中的序号
  • args.world_size:进程总数【一共几块GPU】
  • args.use_mix_precision:布尔变量,控制是否使用混合精度
def train(gpu, args):
    ########################################    N1    ####################################################################
    dist.init_process_group(backend='nccl', init_method=args.init_method, rank=args.rank, world_size=args.world_size)    #
    ######################################################################################################################
    model = ConvNet()
    model.cuda(gpu)
    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().to(gpu)
    optimizer = torch.optim.SGD(model.parameters(), 1e-4)
    # Wrap the model
    #######################################    N2    ########################
    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)                  #
    model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu])    #
    scaler = GradScaler(enabled=args.use_mix_precision)                   #
    #########################################################################
    # Data loading code
    train_dataset = torchvision.datasets.MNIST(root='./data',
                                               train=True,
                                               transform=transforms.ToTensor(),
                                               download=True)
    ####################################    N3    #######################################
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)      #
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,                   #
                                               batch_size=args.batch_size,              #
                                               shuffle=False,                           #
                                               num_workers=0,                           #
                                               pin_memory=True,                         #
                                               sampler=train_sampler)                   #
    #####################################################################################
    start = datetime.now()
    total_step = len(train_loader) # The number changes to orignal_length // args.world_size
    for epoch in range(args.epochs):
        ################    N4    ################
        train_loader.sampler.set_epoch(epoch)    #
        ##########################################
        model.train()
        for i, (images, labels) in enumerate(tqdm(train_loader)):
            images = images.to(gpu)
            labels = labels.to(gpu)
            # Forward pass
            ########################    N5    ################################
            with torch.cuda.amp.autocast(enabled=args.use_mix_precision):    #
                outputs = model(images)                                      #
                loss = criterion(outputs, labels)                            #
            ##################################################################  
            # Backward and optimize
            optimizer.zero_grad()
            ##############    N6    ##########
            scaler.scale(loss).backward()    #
            scaler.step(optimizer)           #
            scaler.update()                  #
            ##################################
            ################    N7    ####################
            if (i + 1) % 100 == 0 and args.rank == 0:    #
            ##############################################   
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, args.epochs, i + 1, total_step,
                                                                   loss.item()))            
    ############    N8    ###########
    dist.destroy_process_group()    #                                       
    if args.rank == 0:              #
    #################################
        print("Training complete in: " + str(datetime.now() - start))

在训练函数中增加/修改了以下内容:

  • N1:增加了DDP初始化的代码,需要指明backend、init_method、rank和world_size。其含义在前文都有介绍。
  • N2:在并行环境下,对于用到BN层的模型需要转换为同步BN层;其次,用DistributedDataParallel将模型封装为一个DDP模型,并复制到指定的GPU上。封装时不需要更改模型内部的代码;设置混合精度中的scaler,通过设置enabled参数控制是否生效。
  • N3:DDP要求定义distributed.DistributedSampler,通过封装train_dataset实现;在建立DataLoader时指定sampler。此外还要注意:shuffle=False。DDP的数据打乱需要通过设置sampler,参考N4。
  • N4:在每个epoch开始前打乱数据顺序。(注意total_step已经变为orignal_length // args.world_size。)
  • N5:利用torch.cuda.amp.autocast控制前向过程中是否使用半精度计算。
  • N6: 当使用混合精度时,scaler会缩放loss来避免由于精度变化导致梯度为0的情况。
  • N7:为了避免log信息的重复打印,可以只允许rank0号进程打印。
  • N8: 清理进程;然后,同上。

假设服务器环境为2台服务器(也称为2个node),每台服务器两块GPU。启动方式为:

# Node 0 : ip 192.168.1.201  port : 12345
# terminal-0
python mnist-tcp.py --init_method tcp://192.168.1.201:12345 -g 0 --rank 0 --world_size 4 --use_mix_precision
# terminal-1
python mnist-tcp.py --init_method tcp://192.168.1.201:12345 -g 1 --rank 1 --world_size 4 --use_mix_precision

# Node 1 : 
# terminal-0
python tcp_init.py --init_method tcp://192.168.1.201:12345 -g 0 --rank 2 --world_size 4 --use_mix_precision
# terminal-1
python tcp_init.py --init_method tcp://192.168.1.201:12345 -g 1 --rank 3 --world_size 4 --use_mix_precision

TCP模式启动很好理解,需要在bash中独立的启动每一个进程,并为每个进程分配好其rank序号。缺点是当进程数多的时候启动比较麻烦。完整的脚本文件见这里


ENV模式

ENV模式启动会更简洁,对于每个进程并不需要在dist.init_process_group中手动的指定其rank、world_size和url。程序会在环境变量中去寻找这些值。代码如下:

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-g', '--gpuid', default=0, type=int,
                        help="which gpu to use")
    parser.add_argument('-e', '--epochs', default=1, type=int, 
                        metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('-b', '--batch_size', default=4, type=int, 
                        metavar='N',
                        help='number of batchsize')   
    ##################################################################################
    parser.add_argument("--local_rank", type=int,                                    #
                        help='rank in current node')                                 #
    parser.add_argument('--use_mix_precision', default=False,                        #
                        action='store_true', help="whether to use mix precision")    #
    ##################################################################################                  
    args = parser.parse_args()
    #################################
    train(args.local_rank, args)    #
    #################################
  • args.local_rank:这里指的是当前进程在当前机器中的序号,注意和在全部进程中序号的区别。在ENV模式中,这个参数是必须的,由启动脚本自动划分,不需要手动指定。要善用local_rank来分配GPU_ID。
  • train(args.local_rank, args):一般情况下保持local_rank与进程所用GPU_ID一致。
def train(gpu, args):
    ##################################################################
    dist.init_process_group(backend='nccl', init_method='env://')    #
    args.rank = dist.get_rank()                                      #
    ##################################################################
    model = ConvNet()
    ...
  • 训练函数中仅需要更改初始化方式即可。在ENV中只需要指定init_method='env://'。TCP所需的关键参数模型会从环境变量中自动获取,环境变量可以在程序外部启动时设定,参考启动方式。
  • 当前进程的rank值可以通过dist.get_rank()得到
  • 之后的代码与TCP完全相同

假设服务器环境为2台服务器(也称为2个node),每台服务器两块GPU。ENV模式的启动方式为:

# Node 0 : ip 192.168.1.201  port : 12345
# terminal-0
python -m torch.distributed.launch --nproc_per_node=2 --nnodes=2 --node_rank=0 --master_addr="192.168.1.201" --master_port=12345 mnist-env.py --use_mix_precision

# Node 1 : 
# terminal-0
python -m torch.distributed.launch --nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr="192.168.1.201" --master_port=12345 mnist-env.py --use_mix_precision

ENV模式可以使用pytorch中的启动脚本torch.distributed.launch启动。在启动命令中需要指明多个参数:

  • nproc_per_node: 每台机器中运行几个进程【每台机器几个GPU】
  • nnodes:一共使用多少台机器
  • node_rank:当前机器的序号【非GPU序号】
  • master_addr:0号机器的IP
  • master_port:0号机器的可用端口

可以看到无论一台机器中的进程数为多少,只需要一行命令就可以启动,相比于TCP模式启动方式更加简洁。

训练中对模型在验证集上进行验证也是必不可少的步骤之一,那么如何在上述demo中增加模型验证的代码呢?如何实现模型的并行验证?

####################################    N11    ##################################
def evaluate(model, gpu, test_loader, rank):
    model.eval()
    size = torch.tensor(0.).to(gpu)
    correct = torch.tensor(0.).to(gpu)
    with torch.no_grad():
        for i, (images, labels) in enumerate(tqdm(test_loader)):
            images = images.to(gpu)
            labels = labels.to(gpu)
            outputs = model(images)
            size += images.shape[0]
            correct += (outputs.argmax(1) == labels).type(torch.float).sum() 
    dist.reduce(size, 0, op=dist.ReduceOp.SUM) # 群体通信 reduce 操作 change to allreduce if Gloo
    dist.reduce(correct, 0, op=dist.ReduceOp.SUM) # 群体通信 reduce 操作 change to allreduce if Gloo
    if rank==0:
        print('Evaluate accuracy is {:.2f}'.format(correct / size))
 #################################################################################

def train(gpu, args):
    ...
    ####################################    N9    ###################################
    test_dataset = torchvision.datasets.MNIST(root='./data',                        #
                                               train=False,                         #
                                               transform=transforms.ToTensor(),     #
                                               download=True)                       #
    test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset)    #
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,                 #
                                               batch_size=args.batch_size,               #
                                               shuffle=False,                       #
                                               num_workers=0,                       #
                                               pin_memory=True,                     #
                                               sampler=test_sampler)                #
    #################################################################################
    start = datetime.now()
    total_step = len(train_loader) # The number changes to orignal_length // args.world_size
    for epoch in range(args.epochs):
        ...
        #####################    N10    #################
        evaluate(model, gpu, test_loader, args.rank)    #
        #################################################
    ...        

省略了代码不变的部分,完整的程序见脚本

  • N9:增加验证集的DataLoader,设置sampler实现数据的并行切分
  • N10:在每个epoch结束前验证模型
  • N11: 利用群体通信Reduce操作,将计算准确率所需的正确预测数和全局样本数收集到rank0进程中

只需要利用群体通信将验证集样本数和预测正确的样本数汇集在rank0中即可实现并行的模型验证,对于其它任务也可以参考这个思路实现。例如图像语义分割中计算mIoU只需要将每个进程的混淆矩阵汇总相加到rank0即可。

一些可能遇到的问题

网络防火墙有可能在首次多机多卡训练时造成计算节点间的通信失败。单机多卡成功运行的代码在扩展至多机多卡遇到问题后可以首先尝试将init_method切换为Gloo,能够回避掉一些潜在的问题。记录一下本人在实践中遇到的问题和解决方法。

address family mismatch 错误

解决方案是手动设置通信的网络端口。机器的网络端口通过ifconfig命令查询,有多个网口时可以都尝试一下。

当backend==NCCL

# Node 0 
# terminal-0
export NCCL_SOCKET_IFNAME=eth0
python ...

# Node 1 : 
# terminal-0
export NCCL_SOCKET_IFNAME=eth0
python ...

当backend==Gloo

# Node 0 
# terminal-0
export GLOO_SOCKET_IFNAME=eth0
python ...

# Node 1 : 
# terminal-0
export GLOO_SOCKET_IFNAME=eth0
python ...

参考

  1. https://pytorch.org/docs/stable/distributed.html#choosing-the-network-interface-to-use
  2. https://pytorch.org/tutorials/beginner/dist_overview.html
  3. Li, S., Zhao, Y., Varma, R., Salpekar, O., Noordhuis, P., Li, T., … & Chintala, S. (2020). Pytorch distributed: Experiences on accelerating data parallel training. arXiv preprint arXiv:2006.15704.
  4. https://zhuanlan.zhihu.com/p/76638962
  5. https://yangkky.github.io/2019/07/08/distributed-pytorch-tutorial.html
  6. https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255

大模型系列教程

https://github.com/liguodongiot/llm-action?tab=readme-ov-file

目录

LLM训练

LLM训练实战

下面汇总了我在大模型实践中训练相关的所有教程。从6B到65B,从全量微调到高效微调(LoRA,QLoRA,P-Tuning v2),再到RLHF(基于人工反馈的强化学习)。

LLM预训练/SFT/RLHF…参数教程代码
Alpacafull fine-turning7B从0到1复现斯坦福羊驼(Stanford Alpaca 7B)配套代码
Alpaca(LLaMA)LoRA7B~65B1.足够惊艳,使用Alpaca-Lora基于LLaMA(7B)二十分钟完成微调,效果比肩斯坦福羊驼
2. 使用 LoRA 技术对 LLaMA 65B 大模型进行微调及推理
配套代码
BELLE(LLaMA/Bloom)full fine-turning7B1.基于LLaMA-7B/Bloomz-7B1-mt复现开源中文对话大模型BELLE及GPTQ量化
2. BELLE(LLaMA-7B/Bloomz-7B1-mt)大模型使用GPTQ量化后推理性能测试
N/A
ChatGLMLoRA6B从0到1基于ChatGLM-6B使用LoRA进行参数高效微调配套代码
ChatGLMfull fine-turning/P-Tuning v26B使用DeepSpeed/P-Tuning v2对ChatGLM-6B进行微调配套代码
Vicuna(LLaMA)full fine-turning7B大模型也内卷,Vicuna训练及推理指南,效果碾压斯坦福羊驼N/A
OPTRLHF0.1B~66B1.一键式 RLHF 训练 DeepSpeed Chat(一):理论篇 
2. 一键式 RLHF 训练 DeepSpeed Chat(二):实践篇
配套代码
MiniGPT-4(LLaMA)full fine-turning7B大杀器,多模态大模型MiniGPT-4入坑指南N/A
Chinese-LLaMA-Alpaca(LLaMA)LoRA(预训练+微调)7B中文LLaMA&Alpaca大语言模型词表扩充+预训练+指令精调配套代码
LLaMAQLoRA7B/65B高效微调技术QLoRA实战,基于LLaMA-65B微调仅需48G显存,真香配套代码
LLaMAGaLore60M/7B突破内存瓶颈,使用 GaLore 一张4090消费级显卡也能预训练LLaMA-7B配套代码

⬆ 一键返回目录

LLM微调技术原理

对于普通大众来说,进行大模型的预训练或者全量微调遥不可及。由此,催生了各种参数高效微调技术,让科研人员或者普通开发者有机会尝试微调大模型。

因此,该技术值得我们进行深入分析其背后的机理,本系列大体分七篇文章进行讲解。

peft方法

LLM微调实战

下面给大家分享大模型参数高效微调技术实战,该系列主要针对 HuggingFace PEFT 框架支持的一些高效微调技术进行讲解。

教程代码框架
大模型参数高效微调技术实战(一)-PEFT概述及环境搭建N/AHuggingFace PEFT
大模型参数高效微调技术实战(二)-Prompt Tuning配套代码HuggingFace PEFT
大模型参数高效微调技术实战(三)-P-Tuning配套代码HuggingFace PEFT
大模型参数高效微调技术实战(四)-Prefix Tuning / P-Tuning v2配套代码HuggingFace PEFT
大模型参数高效微调技术实战(五)-LoRA配套代码HuggingFace PEFT
大模型参数高效微调技术实战(六)-IA3配套代码HuggingFace PEFT
大模型微调实战(七)-基于LoRA微调多模态大模型配套代码HuggingFace PEFT
大模型微调实战(八)-使用INT8/FP4/NF4微调大模型配套代码PEFT、bitsandbytes

⬆ 一键返回目录

LLM分布式训练并行技术

近年来,随着Transformer、MOE架构的提出,使得深度学习模型轻松突破上万亿规模参数,传统的单机单卡模式已经无法满足超大模型进行训练的要求。因此,我们需要基于单机多卡、甚至是多机多卡进行分布式大模型的训练。

而利用AI集群,使深度学习算法更好地从大量数据中高效地训练出性能优良的大模型是分布式机器学习的首要目标。为了实现该目标,一般需要根据硬件资源与数据/模型规模的匹配情况,考虑对计算任务、训练数据和模型进行划分,从而进行分布式训练。因此,分布式训练相关技术值得我们进行深入分析其背后的机理。

下面主要对大模型进行分布式训练的并行技术进行讲解,本系列大体分九篇文章进行讲解。

⬆ 一键返回目录

分布式AI框架

分布式训练网络通信

待更新…

LLM训练优化技术

  • FlashAttention V1、V2
  • 混合精度训练
  • 重计算
  • MQA / GQA
  • 梯度累积

LLM对齐技术

  • PPO(近端策略优化)
  • DPO
  • ORPO

⬆ 一键返回目录

LLM推理

LLM推理框架

LLM推理优化技术

LLM压缩

近年来,随着Transformer、MOE架构的提出,使得深度学习模型轻松突破上万亿规模参数,从而导致模型变得越来越大,因此,我们需要一些大模型压缩技术来降低模型部署的成本,并提升模型的推理性能。 模型压缩主要分为如下几类:

  • 剪枝(Pruning)
  • 知识蒸馏(Knowledge Distillation)
  • 量化

LLM量化

本系列将针对一些常见大模型量化方案(GPTQ、LLM.int8()、SmoothQuant、AWQ等)进行讲述。

LLM剪枝

结构化剪枝

  • LLM-Pruner(LLM-Pruner: On the Structural Pruning of Large Language Models)
  • LLM-Shearing(Sheared LLaMA: Accelerating Language Model Pre-training via Structured Pruning)

非结构化剪枝

  • SparseGPT(SparseGPT: Massive Language Models Can be Accurately Pruned in One-Shot)
  • LoRAPrune(LoRAPrune: Pruning Meets Low-Rank Parameter-Efficient Fine-Tuning)
  • Wanda(A Simple and Effective Pruning Approach for Large Language Models)
  • Flash-LLM(Flash-LLM: Enabling Cost-Effective and Highly-Efficient Large Generative Model Inference with Unstructured Sparsity)

LLM知识蒸馏

Standard KD:

使学生模型学习教师模型(LLM)所拥有的常见知识,如输出分布和特征信息,这种方法类似于传统的KD。

  • MINILLM
  • GKD

EA-based KD:

不仅仅是将LLM的常见知识转移到学生模型中,还涵盖了蒸馏它们独特的涌现能力。具体来说,EA-based KD又分为了上下文学习(ICL)、思维链(CoT)和指令跟随(IF)。

In-Context Learning:

  • In-Context Learning distillation

Chain-of-Thought:

  • MT-COT
  • Fine-tune-CoT
  • DISCO
  • SCOTT
  • SOCRATIC CoT

Instruction Following:

  • Lion

低秩分解

低秩分解旨在通过将给定的权重矩阵分解成两个或多个较小维度的矩阵,从而对其进行近似。低秩分解背后的核心思想是找到一个大的权重矩阵W的分解,得到两个矩阵U和V,使得W≈U V,其中U是一个m×k矩阵,V是一个k×n矩阵,其中k远小于m和n。U和V的乘积近似于原始的权重矩阵,从而大幅减少了参数数量和计算开销。

在LLM研究的模型压缩领域,研究人员通常将多种技术与低秩分解相结合,包括修剪、量化等。

  • ZeroQuant-FP(低秩分解+量化)
  • LoRAPrune(低秩分解+剪枝)

LLM数据工程

LLM Data Engineering

预训练语料处理技术

llm-pretrain-pipeline
  • 数据收集
  • 数据处理
    • 去重
    • 过滤
    • 选择
    • 组合

LLM微调高效数据筛选技术

提示工程

  • Zero-Shot Prompting
  • Few-Shot Prompting
  • Chain-of-Thought (CoT) Prompting
  • Automatic Chain-of-Thought (Auto-CoT) Prompting
  • Tree-of-Thoughts (ToT) Prompting

LLM算法架构

llm-famliy
llm-famliy

LLM应用开发

大模型是基座,要想让其变成一款产品,我们还需要一些其他相关的技术,比如:向量数据库(Pinecone、Milvus、Vespa、Weaviate),LangChain等。

LLM国产化适配

随着 ChatGPT 的现象级走红,引领了AI大模型时代的变革,从而导致 AI 算力日益紧缺。与此同时,中美贸易战以及美国对华进行AI芯片相关的制裁导致 AI 算力的国产化适配势在必行。本系列将对一些国产化 AI 加速卡进行讲解。

⬆ 一键返回目录

AI编译器

AI编译器是指将机器学习算法从开发阶段,通过变换和优化算法,使其变成部署状态。

框架:

  • MLIR
  • XLA
  • TVM

AI基础设施

AI加速卡

AI集群

待更新…

AI集群网络通信

待更新…

  • 分布式训练网络通讯原语
  • AI 集群通信软硬件

LLMOps

LLM生态相关技术

LLM面试题

正在收集中…

⬆ 一键返回目录

服务器基础环境安装及常用工具

基础环境安装:

常用工具:

多模态视觉-语言大模型的架构演进

https://zhuanlan.zhihu.com/p/693885420

A Survey on Multimodal Large Language Models

https://github.com/BradyFU/Awesome-Multimodal-Large-Language-Models

多模态视觉-语言大模型的架构演进

本文回顾了多模态LLM (视觉-语言模型) 近一年来的模型架构演进,对其中有代表性的工作进行了精炼总结.这篇综述一张图总结了多模态LLM的典型架构:

BLIP

【2022.01发布】https://arxiv.org/abs/2201.12086

统一视觉-语言理解和生成,使用captioner+filter高效利用互联网有噪数据

Refer to caption
我们使用Captioner(Cap)为Web图像生成合成标题,并使用Filter(Filt)删除嘈杂的标题。

模型架构:

  • Image/text encoder: ITC loss对齐视觉和语言表征,基于ALBEF提出的momentum distillation
  • Image-grounded text encoder: ITM loss建模视觉-语言交互,区分positive/negative图文对,使用hard negative mining挖掘更高相似度的负例优化模型
  • Image-grounded text decoder: LM loss实现基于图像的文本解码,将双向self-attention替换为causal self-attention
Refer to caption

BLIP-2

【2023.01发布】https://arxiv.org/abs/2301.12597

使用相对轻量的Q-Former连接视觉-语言模态,通过两阶段训练:第1阶段基于冻住的视觉编码器,第2阶段基于冻住的LLM

Refer to caption
BLIP-2的框架概述。我们按照两阶段策略预训练轻量级Querying Transformer,以弥补模态差距。第一阶段从冻结图像编码器引导视觉语言表示学习。第二阶段从冻结的LLM引导视觉到语言的生成学习,这使得零拍摄指令的图像到文本生成成为可能

第1阶段:同样优化ITC/ITM/LM loss,使用不同的self-attention mask,query和text端共享self-attention参数,使得可学习的query embedding提取与text语义最相关的视觉表征;使用BERT-base初始化,32个768维的query作为信息瓶颈

  • ITC:计算每个query与text的相似度,取最大的;使用batch内negatives,不再使用momentum queue
  • ITM:对每个query与text的分类logits取平均,使用hard negatives mining挖掘难负例
  • LM:text token和frozen image encoder不能直接交互,要求query能提取有益的视觉特征
Refer to caption

第2阶段:可基于decoder-only/encoder-decoder LLM进行适配,FC层对齐维度

Refer to caption

LLaVA

【2023.04发布】https://arxiv.org/abs/2304.08485

  • 使用仅文本模态的GPT-4生成视觉-语言指令遵循数据,用于微调多模态LLM
    • 使用图片的dense captions和bounding boxes作为prompt,可以生成对话、细节描述、复杂推理等指令
  • CLIP ViT-L/14 + Vicuna,使用简单的线性层进行映射
    • 更复杂的:Flamingo中gated cross-attention,BLIP-2中的Q-former

Qwen-VL

【2023.08发布】https://arxiv.org/abs/2308.12966

支持中英双语、多图像输入

Qwen-7B + OpenCLIP ViT-bigG,输入图像直接resize到视觉编码器输入

位置感知的VL adapter:使用基于Q-former的单层的cross-attention,将图像特征维度压缩到256,在query-key pairs中引入2D绝对位置编码增强位置信息

图像输入:<img>256-dim图像特征</img>

bounding box输入输出:<box>(X_topleft, Y_topleft), (X_bottomright, Y_bottomright)</box>, <ref>…</ref>标记box所指内容

三阶段训练:

stage1. 预训练:基于大规模、弱标注、网络爬取的图像-文本对,输入分辨率224×224,冻住LLM,训练ViT和Q-former,主要目的是模态对齐

stage2. 多任务预训练:基于7种下游视觉-语言理解任务的高质量、细粒度标注数据训练,输入分辨率448×448,图像/文本数据交错,训练整个模型

stage3. 指令微调:提升指令遵循和多轮对话能力,冻住ViT,训练LLM和Q-former

Qwen-VL-Plus和Qwen-VL-Max提升了视觉推理能力、图像细节的识别/提取/分析能力(尤其是文本导向的任务)、支持高分辨率和极端纵横比的输入图像;在部分中文场景超过了GPT-4V和Gemini

InternLM-XComposer

【2023.09发布】https://arxiv.org/abs/2309.15112

交错图文构成:自动在输出文本中插入合适的图片

EVA-CLIP ViT + InternLM-7B + Q-former (将图像特征压缩到64个embedding)

两阶段训练:

stage1. 预训练:冻住ViT,训练LLM和Q-former

stage2. 监督微调:包括多任务训练和指令微调,冻住ViT和LLM,训练Q-former,对LLM进行LoRA微调,增强指令遵循和图文混排能力

Fuyu-8B

【2023.10发布】https://huggingface.co/adept/fuyu-8b

模型架构和训练过程简单,易于scaling;支持任意图像分辨率;推理速度快

decoder-only的transformer,没有专门的图像编码器;image patch直接线性映射到transformer第一层

LLaVA-1.5

【2023.10发布】https://arxiv.org/abs/2310.03744

仍使用MLP作为模态连接,突出了训练的数据高效性

CogVLM

【2023.11发布】https://arxiv.org/abs/2311.03079

深度视觉-语言模态融合,而不影响LLM原有的语言能力:冻住LLM和ViT,在attention和FFN层训练一份视觉专家模块

CogAgent

【2023.12发布】https://arxiv.org/abs/2312.08914

针对GUI场景的多模态理解和导引,使用高分辨率-低分辨率双编码器,支持1120×1120的屏幕输入

高分辨率分支使用更轻量的ViT,基于cross-attention将高分辨率图像特征与LLM每层进行融合

VILA

【2023.12发布】https://arxiv.org/abs/2312.07533

探索了视觉-语言模型训练的设计选择:

  1. 预训练阶段冻住LLM虽然能取得较好的zero-shot性能,但上下文学习能力依赖对LLM的微调
  2. 图文交错的预训练数据是有益的,只用图文数据对效果不够好
  3. 将纯文本的指令微调数据加入SFT阶段有助于缓解纯文本任务的能力退化,同时也能够增强视觉-语言任务的准确性

LLaVA-Next

【2024.01发布】https://llava-vl.github.io/blog/2024-01-30-llava-next/

相对于LLaVA-1.5,保持了极简的设计和数据高效性:

  1. 提高了输入图像的分辨率 (4x),支持3种纵横比:672×672, 336×1344, 1344×336
  2. 更好的视觉推理和OCR能力:更好的指令微调数据配比
  3. 更好的多场景视觉对话:更好的世界知识和逻辑推理
  4. 更高效的部署和推理:SGLang

动态高分辨率:视觉编码器支持336×336的图像输入,对于672×672的图像,按照{2,2}的grid split成4个图像patch过encoder,downsample到336×336也过encoder,特征拼接作为visual tokens输入到LLM中

收集高质量用户数据,包括真实场景中反映用户更广泛意图的指令数据,利用GPT-4V进行数据构造

多模态文档/图表数据,增强文档OCR和图表理解能力

InternLM-XComposer2

【2024.01发布】https://arxiv.org/abs/2401.16420

提出了新的模态对齐方法partial LoRA:只在image token上添加LoRA参数,保证预训练语言知识的完整性,这样一个更轻量的视觉编码器同样有效

OpenAI CLIP ViT-L/14 + InternLM2-7B + partial LoRA (rank=256)

两阶段训练:

stage1. 预训练:冻住LLM,微调ViT和partial LoRA模块,包括通用语义对齐(理解图像基本内容)、世界知识对齐(进行复杂的知识推理)、视觉能力增强(OCR、物体定位、图表理解)

stage2. 监督微调:微调整个模型,包括多任务训练、自由形式图文排布

InternLM-XComposer2-4KHD

2024.04发布了4KHD版本:https://arxiv.org/abs/2404.06512

支持动态分辨率(336px → 4K (3840×1600)):改进了patch division范式,保持训练图像原有的纵横比,自动变化patch数目,基于336×336的ViT配置layout

动态图像划分:将输入图像resize and pad到336的整数倍宽高

结合图像的global和local视角:global视角由输入直接resize到336×336,使用sep token分隔两种视角的token

图像2D结构的换行符:可学习的\n token分隔图像token行

Mini-Gemini

【2024.03发布】https://arxiv.org/abs/2403.18814

使用双视觉编码器提取低分辨率embedding作为query,高分辨率特征区域作为key/value,两者之间做cross-attention,输出挖掘的tokens作为prompt前缀,输入到LLM做推理,外接图像解码器生成图像(SDXL)

LLaVA-NeXT系列

LLaVA-1.5

23年10月,LLaVA-1.5发布,通过在视觉和语言模态间添加简单的MLP层实现了训练样本高效性,为多模态大模型在低数据业务场景的落地提供了可能。

[2310.03744] Improved Baselines with Visual Instruction Tuning

LLaVA-NeXT

24年1月,LLaVA-NeXT(1.6)发布,在1.5的基础上保持了精简的设计和数据高效性,支持更高的分辨率、更强的视觉推理和OCR能力、更广泛场景的视觉对话。模型分为两阶段训练:阶段1预训练只训练连接层,阶段2指令微调训练整个模型。

LLaVA-NeXT: Improved reasoning, OCR, and world knowledge

  • 动态高分辨率AnyRes:如上图,为了让模型能感知高分辨率图像的复杂细节,对图像进行网格划分。比如,对于672×672的图像,一方面按2×2的网格切分为4张336px的输入图像送给ViT编码成特征,另一方面将图像直接resize到336px进行编码,最后将两部分特征合并输入到LLM中,这样模型具备了全局和局部的视觉推理能力。
  • 指令数据混合:一方面保证指令数据具有高质量、多样性,反映真实场景的广泛用户意图;另一方面,补充文档和表格数据,提升模型的OCR和图表理解能力。
  • 扩大LLM尺寸:考虑了7B、13B、34B的LLM。

24年5月,团队发布基于更强LLM的LLaVA-NeXT版本,支持LLaMA3(8B)和Qwen1.5(72B/110B)。更大的LLM提供更好的视觉世界知识和逻辑推理能力,最大的模型接近GPT-4V的性能,同时保证了训练高效性。

LLaVA-NeXT: Stronger LLMs Supercharge Multimodal Capabilities in the Wild

LLaVA-NeXT-Video

24年4月,LLaVA-NeXT-Video发布,展现出强大的zero-shot视频理解能力。LLaVA-NeXT中的高分辨率图像动态划分可以很自然地迁移到视频模态用来表示视频的多帧,使得只在图文模态上训练的LLaVA-NeXT能在视频任务上泛化。此外,推理时的长度泛化用于有效处理超出LLM最大长度的长视频输入。基于LLaVA-NeXT-Image模型,作者发布了在视频数据上监督微调的LLaVA-NeXT-Video,以及在AI反馈的监督下使用DPO偏好对齐的LLaVA-NeXT-Video-DPO。使用SGLang部署和推理,支持可扩展的大规模视频推理。可以想到,这有助于海量视频的高效文本标注,催生了未来更强大视频生成模型。

LLaVA-NeXT: A Strong Zero-shot Video Understanding Model

  • AnyRes:可以将N帧视频看作{1xN}的网格,而LLM的最大长度限制了可以处理的帧数,很自然地会考虑对图像进行下采样减少每帧token数,但作者发现为保证效果仍只能处理16帧。
  • 长度泛化:基于LLM的长度外推技术(RoPE的线性扩展),推理时扩展2倍,从之前的16帧扩展到56帧,大大提升了模型分析长视频序列的能力。
  • 基于LLM反馈的DPO偏好优化:偏好数据由LLM生成,视频表示为详细的说明文字,带来了很大的性能增益。
  • 对于视频数据的微调,作者进行了ablation study:(1) 在LLaVA-NeXT图像级指令微调后,继续在视频级指令上增量微调;(2) 在LLaVA-NeXT图像级预训练后,在图像级和视频级数据联合微调,每个batch数据包含一种类型或者混合两种类型,实验表明混合图像和视频模态数据效果最佳。

指令微调Ablation Study


团队还分享了视觉指令微调过程中除数据之外的因素的ablation study,从模型架构、视觉表征、训练策略角度进行分析。

LLaVA-NeXT: What Else Influences Visual Instruction Tuning Beyond Data?

  • 模型架构:扩展LLM比扩展视觉编码器更有效,视觉输入配置(分辨率、token数)比视觉编码器大小更关键。
    • 学习率:为了训练更稳定,视觉编码器的学习率通常应该比LLM学习率小10倍~5倍,更大的LLM需要更小的学习率,尽量避免loss跑飞。
    • 视觉编码器:相较于模型大小,基于分辨率、token数的视觉特征支持编码更多的视觉细节,预训练数据支持编码更多的视觉知识,作用更重要。
  • 视觉表征:分辨率、特征空间视觉token数都重要,相对来说扩展分辨率更有效,建议使用AnyRes时下采样。
    • 对于更高分辨率图像或者更长的视频,AnyRes需要更多的格子。比如,对于超过768×768的图像,以前的方案首先resize到768×768会导致细节丢失。这里考虑划分成更多的格子,然后对编码的特征进行双线性插值(下采样)到更小的特征,以防止视觉token数过多。
  • 训练策略:在互联网级低质数据上大规模预训练后,指令微调前,增加一个阶段,使用一些高质量合成数据增强知识。

LLaVA-NeXT-Interleave

24年6月,LLaVA-NeXT-Interleave发布,提出图文交错格式可以作为通用模版统一不同的视觉模态,比如单图像(multi-patch)、多图像(multi-image)、视频(multi-frame)、3D(multi-view)。在保证LLaVA-NeXT单图像输入的性能下,可以提高其它模态任务的性能,而且在不同模态任务上具有初步的迁移能力。这种大一统的模型支持更广泛真实场景的应用,比如多页PPT的总结和问答、生成图像编辑的提示词、多文档的汇总和比较。

LLaVA-NeXT: Tackling Multi-image, Video, and 3D in Large Multimodal Models

作者在训练策略上进行了ablation study:

  • 从LLaVA-NeXT单图像模型继续训练,从stage2单图像指令微调后的模型开始训练效果更好,可以继承单图像任务的指令遵循能力。
  • 两种组织格式:将所有图像token放在最前面,在文本中使用特殊token指代图像 (in-the-front),将图像token放在其原来的位置,与文本交错 (interleaved)。实验表明,在训练阶段混合两种格式有助于在推理阶段这两种格式都取得更好的性能。

InternVL系列

InternVL-1.0

23年12月,上海AI Lab @OpenGVLab发布InternVL。该工作在模态对齐中视觉编码器和LLM之间在参数规模和特征表征能力上存在较大的差距,自然地提出扩大视觉端的参数量到6B (InternViT-6B),然后使用不同质量的图文数据逐渐与LLM对齐。此外,连接层的参数量也扩大了,类似Q-Former,这里设计了一个8B的语言中间件QLLaMA,使用Chinese-LLaMA的参数初始化增强其跨语言理解能力,新增96个可学习query token和cross-attention层 (1B),实现视觉和语言模态进一步对齐。

[2312.14238] InternVL: Scaling up Vision Foundation Models and Aligning for Generic Visual-Linguistic Tasks

下图是InternVL的三阶段渐进式训练策略,训练数据质量逐渐提高,最开始使用大规模有噪的图文对进行对比预训练 (类似CLIP),接着加入冻结参数的QLLaMA连接件,只学习cross-attention,使用图文匹配/对比/生成loss (类似BLIP),最后引入LLM进行监督微调,赋予多模态对话和问答能力。

InternVL训练的多阶段性赋予其内在的多功能性,通过灵活组合不同模块,可以支持各种视觉-语言任务,如下图。

这里值得讨论的一个点在于,InternVL为了让视觉端和语言端参数量平衡,对视觉端和连接层都进行了scale up。一个很自然的问题是,视觉端真的需要这么heavy的参数量吗?因为当前最新的LLaVA-NeXT仍然使用约300M的ViT和轻量的MLP连接层,仅通过扩展LLM提升多模态任务性能。我的个人拙见是,视觉理解包括感知和推理,感知部分可能并不需要那么大的参数量,而推理部分作用于high-level的视觉特征,通过微调LLM赋予其理解推理视觉模态的能力,所以为了性能、效率和稳定性的平衡,似乎这里scale up必要性不是很强,当然这里值得深入实验的验证和讨论。看到这篇论文中的图,让我想到了22年Google的Coca论文,作者把文本解码器按层对半划开,浅层一半用于文本单模态,深层一半用于图文多模态,可以看到下图视觉端参数量占比也相当高。

[2205.01917] CoCa: Contrastive Captioners are Image-Text Foundation Models

InternVL-1.5

24年4月,InternVL-1.5发布,综合性能更强,且支持推理时高达4K的分辨率。

[2404.16821] How Far Are We to GPT-4V? Closing the Gap to Commercial Multimodal Models with Open-Source Suites

上图为模型整体架构,采用了类LLaVA的ViT+MLP+LLM范式,结合了增强的InternViT-6B-448px-V1.5和中英双语InternLM2-Chat-20B,总体参数约26B。相比于InternVL-1.0,在输入端支持了动态高分辨率,连接层改为轻量的MLP,使用pixel shuffle操作将输出的视觉token数减为1/4。训练分为两阶段,预训练阶段训练InternViT和MLP映射,随后微调整个模型。

  • 这里不再使用Q-Former作为连接层的原因,可以参考作者 @Weiyun 大佬的回答:多模态大语言模型(MLLM)为什么最近的工作中用BLIP2中Q-Former结构的变少了? – Weiyun的回答 – 知乎,大致意思是说相比于MLP,Q-Former参数量大收敛更慢,数据量小的场景无法达到LLaVA-1.5这样的性能,而且提高数据量和计算量,Q-Former也没有明显的性能优势。
  • 这里的pixel shuffle操作来源于16年的一篇论文,本质是对特征元素进行重排列,将 (𝐶×𝑟2,𝐻,𝑊) 的特征变换为 (𝐶,𝐻×𝑟,𝑊×𝑟) ,对特征进行了空间维度的上采样,但通道维度缩小为原来的 1/𝑟2 。这里输出的视觉token数可以理解为通道数,主要目的是通过提升特征维度换取更少的token数,从而可以支持更高的图像分辨率。这样,448×448的输入图像,patch size=14,总共有32×32=1024个token,设置上采样系数r=2,则该图像可以表示为256个token。

接着我们来看InternVL-1.5的三个重要改进:

  • InternViT增强:V1.2版本去掉了模型的最后3层,将分辨率扩展为固定448×448,而V1.5进一步扩展为动态448×448,即每张训练图像可分块,每块大小为448×448,支持1~12个块。此外,还增强了数据规模、质量和多样性,提高了OCR和高分辨率处理能力。
  • 动态高分辨率:基于图像的分辨率和纵横比,将图像切分为448×448的分块,训练阶段最多12块,测试阶段可以外推到40块,即4K分辨率,这样模型训练和推理能适应多种分辨率和纵横比,避免了强行resize带来的失真和细节丢失。如下图,具体来说,对于一张800×1300的图像,从预定义的纵横比中匹配一个最接近的纵横比2:3,然后将图像resize到896×1344,并切分为多个448×448的图像块,再添加一个缩略视图 (直接resize到448×448) 用于图像全局理解。
  • 高质量中英双语数据集:包含自然场景、图表、文档、对话等多样化的数据,借助LLM实现数据集英文到中文的转换。

此外,翻译的prompt值得我们学习:

System:
You are a translator proficient in English and {language}. Your task is to translate the following English text into {language}, focusing on a natural and fluent result that avoids “translationese.” Please consider these points:
1. Keep proper nouns, brands, and geographical names in English.
2. Retain technical terms or jargon in English, but feel free to explain in {language} if necessary.
3. Use {language} idiomatic expressions for English idioms or proverbs to ensure cultural relevance.
4. Ensure quotes or direct speech sound natural in {language}, maintaining the original’s tone.
5. For acronyms, provide the full form in {language} with the English acronym in parentheses.
User:
Text for translation: {text}
Assistant:
{translation results}

作者在ablation study部分研究了更大的LLM是否需要更大的视觉编码器,实际上是针对我们上面对InternVL-1.0视觉端参数量的问题的实验。实验对比了LLaVA-NeXT和InternVL-1.2,两者都使用34B的LLM,在尽量保证对比公平的条件下,实验证明更大的视觉模型能提供模型解决多模态任务的整体性能(不过原论文好像没有给具体数据?)。团队后续也发布了蒸馏版的视觉模型InternViT-300M-448px,与LLaVA-NeXT的视觉端保持了同等规模。

MiniCPM-V系列

MiniCPM-V是 @面壁智能 发布的一系列支持高效端侧部署的多模态LLM。

MiniCPM-V 2.0

24年4月,MiniCPM-V 2.0发布,仅有2.8B参数,整体性能超过了Yi-VL 34B、CogVLM-Chat 17B、Qwen-VL-Chat 10B等更大的开源模型,OCR能力突出,支持中英双语对话,部分指标接近Gemini Pro。
视觉编码器使用SigLIP SO400M/14-384px,LLM使用MiniCPM-2.4B,连接层使用Flamingo中的Perceiver Resampler (类似Q-Former使用可学习query提取显著视觉信息,但不以输入文本为条件)。基于自研的RLHF-V实现可信行为对齐,在缓解多模态幻觉问题上接近GPT-4V。基于自研的LLaVA-UHD支持高达1344×1344的分辨率和任意纵横比输入。基于自研的VisCPM实现跨语言的多模态能力泛化,进而有良好的中英双语能力。此外,该模型在端侧部署内存开销较小、速度较快,即便是处理高分辨率的图像。官方还提供了安卓端部署的mlc-MiniCPM示例。

MiniCPM-Llama3-V 2.5

24年5月,MiniCPM-Llama3-V 2.5发布,总共8B参数,整体性能超过了GPT-4V-1106、Gemini Pro、Qwen-VL-Max、Claude 3等闭源模型,OCR和指令遵循能力进一步增强 (增强了全文本OCR提取、表格到Markdown转换等功能),支持超过30种语言对话,在量化、编译优化、高效推理等加持下,同样可以在端侧高效部署。
在MiniCPM-V 2.0基础上,LLM替换为Llama3-8B-Instruct,基于更新的RLAIF-V进一步降低幻觉率。当前,官方支持了llama.cpp和ollama的高效CPU推理、GGUF 16-bit量化、LoRA微调等实用功能。

VILA1.5

24年5月,NVIDIA发布VILA1.5,提供视频理解能力,开源了3B/8B/13B/40B的模型,位于当前开源榜单MMMU和Video-MME前列。VILA详见我的上篇文章,这里简单回顾一下:VILA在大规模交错图文数据上预训练,从而具有多图理解能力,作者通过实验发现:(1) 图文交错排布比较关键;(2) 交错图文预训练过程中微调LLM能赋予其上下文学习的能力;(3) 混合只有文本的指令数据有助于提升性能;(4) 压缩视觉token可以扩展视频帧数。

CogVLM2

24年5月,智谱 @GLM大模型 发布CogVLM2,随后发布了GLM-4V。CogVLM2基于Llama3-8B-Instruct,支持8K上下文、1344×1344分辨率、中英双语对话。GLM-4V-9B替换为GLM-4-9B语言模型,采取同样的数据和训练策略,去除CogVLM原有的视觉专家,将模型大小减为13B。CogVLM和CogAgent详见我的上篇文章。

Cambrian-1

24年6月,LeCun&谢赛宁团队发布Cambrian-1,关注以视觉为中心的多模态LLM,开源了8B/13B/34B的模型。当前多模态LLM仍存在较大的视觉缺陷,需要增强视觉表征以更好地和语言模态交互,赋予模型在真实场景更强的感知定位能力。这项研究的一大意义在于影响多模态LLM的工作开始重视视觉表征质量的提升,而非一直scale up LLM。

[2406.16860] Cambrian-1: A Fully Open, Vision-Centric Exploration of Multimodal LLMs

如上图,该工作围绕多模态LLM的5个核心设计要素展开研究,分别是:视觉表征、连接器设计、指令微调数据、指令微调策略、评估基准。

  1. 视觉表征

作者评估了多种视觉编码器及其组合,下图表明以语言监督的CLIP模型优势较强,但自监督方法在提供充足数据和适当微调的情况下性能也能接近。而且,结合多种类型的视觉编码器有助于提升多模态LLM的性能,尤其是以视觉为中心的任务。注意到,高分辨率的编码器大大增强了图表和以视觉为中心任务的性能,而基于ConvNet的架构适合处理这类任务。

2. 连接器设计

提出Spatial Vision Aggregator (SVA),一个动态的、具备空间感知的连接器,以将 (来自多个视觉编码器的) 视觉特征与LLM深度融合。如下图,该方法设置一些可学习的latent query tokens,通过cross-attention与多个视觉特征交互 (视觉特征作为key/value)。SVA的设计有两点要素:(1) 通过显式定义每个query token对应的视觉特征图子区域,引入空间inductive bias,便于模型在处理视觉信息时保留对空间结构的理解,更准确地定位和整合局部特征;(2) 在LLM的多层聚合视觉特征,让模型在不同层级特征上反复利用视觉信息,增强模型对视觉内容的深入推理能力。该方法可以有效减少需要的视觉token数,例如相比于Mini-Gemini和LLaVA-NeXT,Cambrian-1的视觉token数是其20%。

3. 指令微调数据

作者发布了指令微调数据集Cambrian-10M,综合了OCR、通用VQA、纯语言等指令数据,还筛选了质量更高的7M版本。不同类型的视觉指令数据能赋予模型不同的能力,因此数据配比的平衡性也很关键,实验结果表明,平衡OCR、通用数据和语言数据的比例很重要。此外,在实验中作者发现,训练好的多模态LLM可能在基准测试上指标表现好,但实际对话能力弱,回复简短。因此,作者在训练期间引入了额外的系统提示,鼓励模型输出更长的回答和思维链推理,增强数学推理等任务的表现。

4. 指令微调策略

作者遵循LLaVA的两阶段训练策略,先使用适配数据只微调中间的MLP连接层,再打开LLM和连接器微调。结果表明,第一阶段对连接器的预训练可以提高性能,而使用更多的适配数据可以进一步增强。此外,作者对比了是否微调视觉编码器带来的性能影响,表明微调视觉编码器能增强性能,尤其对自监督预训练的视觉编码器 (如DINO v2、MoCo v3、MAE等),在以视觉为中心的测试上提升明显。

5. 以视觉为中心的基准CV-Bench

现有多数benchmark无法正确评估模型的视觉感知定位能力,而且相应的样本数量有限。CV-Bench重新利用现有视觉benchmark中的样本,包含2638个以视觉为中心的VQA问题,涉及2D的空间位置关系和物体计数、3D的深度次序和相对距离。


最后,让我们共同期待我国的AGI基础模型不断取得新的突破,引领世界潮流!

KAN网络-MLP网络的替代

论文: https://arxiv.org/abs/2404.19756

一、MLP 本质回顾

MLP 本质上是用一个线性模型外面包了一层非线性激活函数来实现非线性空间变换。线性模型的好处在于简单,每条边就是两个参数w和b,合到一起用向量矩阵表示 W。

比如下面的图,通过结合两个线性决策边界并在第二层应用激活函数,形成了一个非线性决策边界。两个线性决策边界,每一个由一条直线表示,分别是2x+3y+6=0 和5x+3y=0。这些直线分别在二维空间中划分出了不同的区域。第一层的输出通常会作为第二层的输入。然后再次通过激活函数进行处理。形成了一个复杂的、非线性的曲线形状的决策边界。这是因为每个线性模型的输出都受到了非线性激活函数的影响,允许模型捕捉更复杂的数据模式。多层结构和非线性激活函数使得网络能够将简单的线性决策边界通过复合和变换,转化为能够解决更复杂分类问题的非线性决策边界。

从理论上讲,一个包含足够多神经元的单隐藏层网络可以逼近任何连续函数(这是由通用近似定理保证的)。比如下图所示分类界面变成了曲面。

在 MLP 中,每层都进行线性变换后跟非线性操作。这种层级结构允许模型学习数据的多层次特征表示。随着层数的增加,模型的表示能力也随之增强。比如下图,这就是深度学习管用的根本原因,越深越牛逼。

表达成线性代数的矩阵形式,就是下面的公式:

在这个表达式中,圆圈(·)表示的是函数组合运算,也就是函数复合的意思。用图表示就是我们常见的形式:

注意:这里所有的激活函数都尽量使用相同的进行简化。当然有的时候用两个

你可能会问,为什么神经网络一定要用 MLP的形式?没有什么为什么,因为它管用,它简单,奥卡姆剃刀原理。适合用来理解基本的前向传播和反向传播算法,也易于用各种编程语言和框架实现。因此,历史上就选它了,整个深度学习的基础。

二、MLP 的硬伤注定了深度学习大厦的脆弱?

MLP 也就是全连接网络可以说是整个深度学习的基础,后面所有的网络无论CNN/RNN/transformer 都是在它基础上的修改,即便是现在吹牛逼不上税的大模型们。但谁能想到它居然是个有天然硬伤的豆腐渣模块呢?

1.梯度消失和梯度爆炸:当使用传统的激活函数(如Sigmoid或Tanh)时,MLP 在进行反向传播计算梯度时确实容易遇到梯度消失或梯度爆炸的问题,会出现激活函数的导数连乘积(画图)。当它非常小或非常大,网络又很深,连续乘积会使得梯度趋向于0(梯度消失)或变得异常大(梯度爆炸),从而阻碍学习过程。

2.参数效率低:MLP 通常使用全连接层,这意味着每层的每个神经元都与前一层的所有神经元相连接,导致参数数量迅速增加,尤其是对于输入维度很高的数据(比如图像数据)。这不仅增加了计算负担,也增加了模型过拟合的风险。这就是大模型的困境,拼参数量没出路,大部分学习都是浪费掉的,效率巨低下无比,好比人海战术。

3.处理高维数据的能力有限:MLP没有利用数据的内在结构(如图像中的局部空间相关性或文本数据的序列信息)。例如,在图像处理中,MLP 无法有效地利用像素之间的局部空间关联,这使得其在图像识别等任务上的性能不如卷积神经网络(CNN)。

4、长期依赖问题:虽然 MLP 理论上可以逼近任何函数,但在实际应用中,它们很难捕捉到输入序列中的长期依赖关系(长时间跨度的相关信息)。这一点在处理时间序列或自然语言处理任务时尤为明显,而循环神经网络(RNN)和transformer在这些任务中通常表现得更好。

但无论 CNN/RNN/transformer 怎么改进,都躲不掉 MLP 这个基础模型根上的硬伤就是这个线性组合+激活函数的模式。进而决定了整个深度学习大厦的脆弱。就好比板砖出现了问题。那能不能替换掉这种板砖呢?谈何容易,既要解决函数拟合准确度的问题,又要保证神经网络的效率,这不亚于重新发明深度学习这个学科。因此虽然理论上任何结构都可以,但并没有出现一种更好的基础模型组件。这恰恰是 KAN 网络带给大家的惊喜,也给发展了十几年沉闷的深度学习世界带来了一丝变革的曙光。我们来看它具体干了什么?

三、KAN 网络为什么牛逼?

Kolmogorov-Arnold Networks 顾名思义基于柯尔莫果洛夫-阿诺尔德表示定理。是由这两个俄罗斯数学家 1957 年提出的如何用一组较简单的函数来表示任何一个多变量的连续函数。

想象一下,你有一个非常复杂的配方,需要各种各样的原料和步骤来制作一道菜。柯尔莫果洛夫·阿诺尔德表示定理告诉我们,无论这个配方多么复杂,我们总能找到一种方法,通过一些简单的基本步骤(这里是一些基本的函数)来重现这道菜的味道。在上面的式子中,输入是x,φq.p(xp)是基本的一元函数,就像是青椒西红柿基本原料的处理。内层求和就是放到一起。中q是外层的函数,各自接受内层求和的结果作为输入。外层的求和 ∑表示整个函数 fx) 是子函数中q 的和。用图来表示就相当于一个两层的神经网络,区别在于一没了线性组合,而是直接对输入进行激活;二来这些激活函数不是固定的,而是可以学习的。

和 MLP 每层统一进行非线性空间变换相比,这相当于对每个坐标轴单独进行非线性变换,然后再组合形成多维度空间。(画个简图,先组合再变形和先单个变形再简单组合的区别)

公式写成向量的形式就是:

对比 MLP,没有了激活函数和参数矩阵的嵌套关系,而直接是非线性函数中的嵌套。对于多层网络,这相当于下面的结构:

注意这里所有的非线性函数中都采用同样的函数结构,只是用不同的参数来控制其形状。具体来说,文章选择了数值分析中的样条函数 spline。这个英语单词 spline 来源于可变形的样条工具,那是一种在造船和工程制图时用来画出光滑形状的工具。

对比 MLP 和 KAN,最大的区别就是变固定的非线性激活+线性参数学习为直接对参数化的非线性激活函数的学习。因为参数本身的复杂度,显然单个spline 函数的学习难度要比线性函数难,但 KANS通常允许比 MLPs 更小的计算图,也就是实现同样效果,需要的网络规模更小。例如,文章展示了在解偏微分方程(PDE)的过程中,一个2层宽度为10的 KAN 比一个4层宽度为 100 的 MLP 具有更高的准确度(均方误差 10^-7 对比10^-5)并且具有更高的参数效率(参数数量100 对比 10000)。
到这里为止,你一定好奇,这IDEA不复杂啊,难道以前没人想到,有,但是卡壳在都坚持使用原始的二层宽度为(2n+1)的表示方法,并没有机会利用更现代的技术(例如,反向传播)来训练网络。KAN 模型的贡献就在于通过进一步简化推广到任意宽度和深度,同时通过广泛的实证实验论证了在 AI + 科学方面的效果,而且具备很好的准确性和可解释性。这就牛通了啊,深度学习最大的问题就是个黑盒子,训练网络像是炼丹。大模型越弄越大,很可能一条道走到黑就进死胡同了。好比芯片的摩尔定律。现在出现了量子芯片,原理上就不同,从而有可能实现根本性的变革。当然,原来的各种网络结构还能平替重做一遍,有没有感觉一片 AI 新大陆向你招手了。我一直劝大家别太短视,成天只盯着transformer,大模型兜兜转转,撑死了也是井中之蛙。

四、KAN 的架构细节

4.1 详细解释

整个网络架构原理看图一目了然。很多个这种类似四分之三个周期的正弦函数组合起来就能拟合任意形状的函数。换句话说,用 B-spline 这一种激活函数两次求和就够了。

图中展示的结构中,使用了两种尺度或分辨率的组合:粗粒度和细粒度网格,在保持计算效率的同时,更加精确地捕捉和适应函数的变化。这种基础结构其实并不是很难想到,以前就有了,但难点是怎么把它变深,否则单靠这么点玩意儿是不能逼近复杂函数的。这就是本文的主要贡献了。

这里I是层编号,右边为输入,左边为输出。看上面左图就大致明白对应关系,输入为2 个,因此第二层是 2*2+1=5个。中j 就是每条边上的激活函数,也就是非线性变换。相当于每个x都有5个分身,然后再分别组合。其中i用来标记当前层的节点而j用来标记下一层的节点。每个节点 x_i 的输出通过激活函数 φ_l,ij处理后,贡献到所有下一层的 x_l+1,的计算中。对应上面左图,输入层2个节点,第二层5个节点,因此矩阵为 5*2。矩阵的第一列表示x_0,1对应的5个激活函数,第二列对应x_0,2的,然后两两组合。

因此,这里需要强调的是 KAN 网络层节点数不是随便搞的,由输入节点个数确定2n+1个,然后所需要的参数或者连接数为(2n+1)*n,明显比全连接少了不少,看图就知道。

SiLU(Sigmoid Linear Unit)是一种神经网络激活函数,也被称为 Swish 函数。这个函数由一篇 Google Brain 的论文首次提出,并因其在某些任务上表现出的优异性能而受到关注。你可以认为它就是 sigmoid 函数的一种变体。
2.假设层宽相等,L层,每层 N 个节点。
2.每个样条函数的阶数通常为k=3,在G个区间上G +1个网格点。”G 个区间”指的是样条函数的分段定义的区间数。
那么总共大约有 O(N’L(G +k))或O(N2LG)个参数。相比之下,具有深度工 和宽度 N 的多层感知机(MLP)只需要O(N2L)个参数,这看起来比KAN更有效率也就是说单看计算复杂度好像 KAN 还不如 MLP 简单,但是幸运的是,KANS 通常需要比 MLPS 小得多的 NN,这不仅节省了参数,而且还提高了泛化能力,并且有助于解释性。

换句话理解,就是借助 spline 样条函数的表达能力,无需很多节点就能实现比较强的表达能力,因此总的来说,可以比 MLP 节省不少参数量。

4.2 逼近能力和缩放定律的讨论

文章花了一页的篇幅推导证明了定理

这部分讲的不是人能听懂的话,看不懂很正常。简单说,就是从数学上证明可以通过构建多层的 B样条函数网络来有效逼近复杂函数。尽管增加网络的深度和复杂度,KANS能够通过细致的网格划分来逼近高维函数,而不会受到维数灾难的影响,也就是在高维空间中,数据的稀疏性和处理复杂度急剧增加的问题。而残差率不依赖于维度,因此战胜了维数灾难!

再来看看所谓的缩放定律。注意这里的缩放定律与大模型领域的不同。后者是说模型大小(如参数数量)的增加,模型的性能(例如在语言任务中的准确性)通常会提高,并且有时这种提升的速度可以用某些数学关系(如幂律关系)来描述C=6ND。这里更偏重于理论和数学上的分析,当然背景相似,都是讨论随着参数数量的增加,模型表现的提升。这部分内容基本上也可以暂时略过,主要就是简要对比了几种理论关注于如何通过理论来指导实际的神经网络设计,以实现更有效的学习和泛化能力。后面还有讨论,这里暂时可以忽略。

好,我们接下来重点看看 KAN 准确性和可解释性的改进

4.3 如何提升准确性?

MLPs通过增加模型的宽度和深度可以提高性能,但这种方法效率低下,因为需要独立地训练不同大小的模型。KANS:开始可以用较少的参数训练,然后通过简单地细化其样条网格来增加参数,无需重新训练整个模型。
基本原理就是通过将样条函数(splines)旧的粗网格转换为更细的网格,并对应地调整参数,无需从头开始训练就能扩展现有的 KAN 模型。这种技术称为“网格扩展”(grid extension)

文章用了一个小例子来证明这一点。用 KAN 网络逼近一个函数。上图中横轴的每个grid-x“标签代表了在特定训练步骤时进行网格细化的时点。每次这样的标记出现,都意味着网格点数量在这个步骤有所增加,从而使模型能够更细致地逼近目标函数,这通常会导致误差的下降。表明网格点的增加直接影响了模型的学习效果,提高了逼近目标函数的精度。左右图表示了两种不同结构的网络。

下面两个图分别展示了测试误差随网格大小变化(左下图)和训练时间随网格大小的变化(右下图)。结论就是误差loss 随网格大小 grid size G在不同的规模上显示出不同的缩放关系;训练时间随网格大小增加而增长,特别是在网格非常大时(接近1000),训练时间急剧上升。

这些观测结果支持了文章中关于 KANS 利用网格扩展可以有效提高精度而无需重新训练整个模型的说法,同时也提示了在选择网格大小时可能需要在模型精度和训练效率之间做出权衡。简单说,网格太密了也不好,太费时。

4.4 如何提升可解释性?

尽管上面介绍了 KAN 的不少好处,但遇到实际问题时该怎么设计网络结构依然是个玄学。因此需要有种方法能自动发现这种结构。本文提出的方法是使用稀疏正则化和剪枝技术从较大的 KAN 开始训练,剪枝后的 KAN 比未剪枝的 KAN 更易解释。为了使 KAN 达到最大的可解释性,本文提出了几种简化技术,并提供了一个示例,说明用户如何与KAN 进行交互以增强可解释性。

1.稀疏化:使用数据集训练一个 KAN 模型,使其能够尽可能地拟合目标函数。MLP通常使用 L1 正则化来促进权重的稀疏性,L1正则化倾向于推动权重值向零收缩特别是那些对模型输出影响不大的权重。权重短阵的“稀疏化“可以降低模型的复杂性,减少模型的存储需求和计算负担,因为只需要处理非零权重;还能提高模型的泛化能力,减少过拟合的风险。

2.剪枝:在稀疏化后,进一步通过剪枝技术移除那些不重要的连接和神经元。设定特定激活函数:根据剪枝后各神经元的特性,手动设置或调整特定神经元的激活函数

3.训练仿射参数:在调整了激活函数后,对模型中的剩余参数进行再次训练,优化这
些参数以最好地拟合数据。
4.符号化:最终,模型将输出一个符号公式,这个公式是对原始目标函数的一个近似表示,但通常会更简洁、更易于理解和分析。

五、实验验证

5.1 KAN准确性

比较了 KAN 与 MLP 在逼近5个典型函数上的性能,横轴是参数量,纵轴为均方根误差(RMSE)。总的来说,KAN和MLP随着参数数量的增加,RMSE都在下降。
在大多数情况下,KAN(蓝色线)比相同深度的MLP具有更低的 RMSE,尤其是在参数数量较少时。这表明 KANS 在参数利用效率上可能更高。MLP 在参数数量增加后性能提升逐渐放缓并迅速达到平台期,这可能是因为 MLP 对于这些类型的函数拟合存在固有的限制。KAN 在多个测试案例中都接近或跟随理论曲线。
这表明 KANS 在处理复杂函数和高维数据时可能是更优的选择,具有更好的扩展性和效率。这种性能优势特别重要,当我们需要从有限的数据中学习复杂的模式时,如在物理建模、声音处理或图像处理等任务中。当然目前还都是比较理论化的实验数据。

接着对比了 KAN 和 MLP 在高难度的特殊函数拟合任务上的性能,结论类似。随参数量增多KAN(蓝色)表现稳定,越来越好,而MLP(黄色)出现平台期。KANS在维持低误差的同时,表现出更好的参数效率和泛化能力。这一点对于设计高效目精确的机器学习模型来说是极其重要的,特别是在资源受限或对精度要求极高的应用中。

5.2 KAN 的可解释性

借助前面提升模型可解释性的小技巧,包括稀疏化、剪枝等,KAN 网络最终形成的网络结构不仅能够实现数学函数的拟合,而且其形式本身能反映出被拟合函数的内在结构。

以第一个图为例
函数:f(T,y)= xy
解释:图中的结构利用了恒等式 2xy =( X+ y)?-x2- y2 来计算乘法。这说明KAN通过结合基本运算(加法、平方)来实现复杂的乘法操作,展示了KAN如何通过基本的数学操作构造更复杂的函数。
x和y各自经过线性函数求和,然后平方,同时再减去x和y的平方。
因此可以看出,KAN 模型的牛逼之处在于两点:首先,不仅仅在于自身的型结构,MLP是先组合再非线性激活,KAN 是先非线性激活再组合;其次,KAN的训练能实现自身结构上的优化,有点自组织的味道了。

八、小结

1.MLP 的硬伤:我们回归了 MLP的核心原理,线性组合+非线性激活。深层次化网络后,反向传播求导数时单一激活函数的连乘积会产生很多问题,而且全连接网络导致参数效率低下。
2.KAN 的原理:用单一架构的参数化可学习非线性激活函数直接组合,实现非线性空间变换。模型表征能力大大提升。
3.KAN 训练算法:通过 grid extension,也就是激活函数分辨率提升,以及稀疏化、剪枝等结构自优化技巧,实现了准确性和可解释性的提升。能够在参数量大大减少的情况下实现相同甚至更有的拟合效果。
4.实验验证:仿真实验提供了有效的量化的效果证明,展示了非常有前景的方向,但目前显然还比较初级。不过,提供了一条新的道路。

Mamba-2 模型解读

Github:https://github.com/state-spaces/mamba

论文:https://arxiv.org/abs/2405.21060

Mamba2 模型再次回归,引发 AI界新的雀跃。它重要性在于很可能开启了一个新的时代,注意力机制2.0时代,由单一注意力机制变成混合注意力机制。人们苦Transformer 这个大语言模型核心模块的硬伤久矣(传统自注意力机制运算效率低),初代的 Mamba 理论上仍然不够完善,所以才会被人诟病。不过,这种争论充分说明了 LLM核心架构的变革已经势在必行。

标题剑指“Transformers 就是 SSM 状态空间模型”,俨然是要做个大一统的工作,气吞山河Generalized models and efficient algorithms 叫法感人,可以想象后续工作会一片一片。

一、重要结论

这次的摘要非常简短,但字少事更大,人狠话不多。三句话三层意思:
1.尽管 Transformers 多年来一直是深度学习在语言建模领域成功的主要架构,但近年来,状态空间模型(SSMs)如 Mamba 已被证明在小到中等规模上能够匹敌甚至超越。人话就是SSM 是新的政治正确!
2.本文展示这些横型家族实际上是密切相关的,可以在一个称为“状态空间对偶“理论框架下连接 SSM 和注意力变体。选择性 SSM 就是一种新的注意力机制,看,这么快就应验了,当然人家牛的是数学上证明了。别死盯着传统的 Transformer 架构了,它的内核其实还是注意力。就像雷达一样,要搞新体制雷达。新型注意力才是正途!但方法并非从 Transformer 侧搞。
3.据此设计的 Mamba-2速度比 Transformers快 2-8倍,而准确率更优。算力昂贵的今天效率将是新模型竞争的重点!

理解整篇文章的核心其实就是这个对偶 duality,我们将围绕着它展开,明白了它也就是彻底搞懂了 Mamba2 的精髓。

二、统一 SSM 和注意力机制(两性话题)


2.1 什么是对偶关系

在数学、物理学乃至哲学中,“对偶性”是指两种看似不同的理论或模型之间存在的一种深层次的等价关系。通过这种对偶关系,可以将一个复杂的问题转化为另一个相对简单的问题来解决,或者在一种表示形式下无法轻易看到的性质在另一种表示形式下变得显而易见。比如太极图就是典型的对偶关系,阴阳对应。

本文提出的状态空间对偶一边是结构化状态空间模型(SSMS),一边是注意力变体,关联方式也就是分界线是具有次平方参数和乘法复杂性的结构化矩阵

这个理论的基础其实是线性注意力(LA)框架,它也是一种对偶:一边是线性递归神经网络(RNN),一边是传统的自回归注意力机制。也就是刚说的Transformeris RNN.

核心思路就是传统的自回归注意力机制在处理长序列时复杂度较高,而通过对偶关系,可以将其转化为线性 RNN 来处理,从而显著降低计算复杂度。这就好比一个忙碌的咖啡馆,顾客们不断地排队下单。自回归注意力机制传统方法中,咖啡师逐个处理每个订单,计算复杂度随着订单数量增加而成平方增长。通过对偶关系,我们可以转换成一种更高效的处理方法:先收集一批订单根据订单类型(拿铁、卡布等)分类,然后批量制作再分发,这就是线性 RNN 处理长序列数据的思路。

简单总结一句话:阳谋不行咱们来阴谋啊,哪个管用来哪个,效果类似能达到目的就行

既然是在 SSM 和注意力之间进行对偶,我们分别回归一下它们都是啥。

2.2 时空模型 SSM 的本质

2.1洋洋酒酒公式不少,看起来费劲,其实就是在讲下面这幅图,上上期我们分析过。本质上左边就是一个简单的线性时不变系统建模,中间是离散化后的模型就是个RNN,最右边是并行化用卷积核进行处理,也就是 CNN 化的模型。这种表示方式是用图模型来建模,强调的是序列数据之间的依赖关系和动态变化。所谓的 SSM 其实可以理解为就是 RNN,只不过更强调通过线性代数方程来描述系统状态的变化,利用状态空间模型中的状态转移矩阵和观测矩阵来进行建模。也就是下面的两组公式。

前者 ABC 都是时不变的,后者变成了随时间变化的,也就是 Mamba 模型的工作,放宽了对系数矩阵的约束。因此某种程度上说,SSM 就是线性代数版的RNN,借助线代这个数学工具来更深入的分析 RNN。差分方程进行时序动态建模,CNN能并行化处理,而 ABC 系数矩阵能实现特征空间的结构化分析。

2.3 注意力机制的本质

再来看看注意力机制。本意是给序列中每个位置的元素分配分数,使每个元素能够“关注”序列中的其他元素。前最常见和重要的注意力机制变体是 softmax 自注意力
Y = softmax(QKT).V
其中 QK 成对比较的机制引发了注意力机制的平方训练成本。你看它的本质上就是矩阵运算。这与 SSM(结构化状态空间横型)是一致的,两者在线性代数层面上有很强的关联性,统一分析和优化的视角可以帮助我们更好地理解它们的内部机制,并发现新的改进方法。
以前学深度学习,坦率说 CS 占优势,需要的数学知识不多,代码实现能力强就可以了。但现在的趋势明显不同了,对线性代数等熟悉理论知识的功力要求越来越高,自动化或者应用数学背景的同学优势更加明显。

2.3 怎么实现二者的对偶

例如,Toeplitz 矩阵是指每条对角线上的元素都相同的短阵。Cauchy 矩阵是指每个元素都由两个向量的元素之间的差的倒数来定义的矩阵。Vandermonde 矩阵是由一个向量的幂组成的矩阵。低秩矩阵是指其秩远小于其行或列数的矩阵。类似的特殊结构矩阵,通过压缩表示可以用更少参数和更快算法计算,减少存储需求,加快运算速度,在大规模数据处理和机器学习中尤为重要。SSM 本质上也是一种结构化矩阵。讲到这你明白了吧,这文章就是玩儿线性代数,在注意力机制和 SSM 之间建立统一的对偶关联关系。

那到底是怎么用所谓的结构化矩阵让二者勾连的呢?其实也很简单就是在 SSM 的计算中,特别是矩阵 A,引入了类似注意力机制的公式和方法。

具体来说:

1.简化A矩阵的结构,使其可以用标量乘以单位矩阵表示,从而减少计算复杂度

2.类似 Transformer 中多头注意力的概念,增加了头维度(P)以增强模型的表达能力。

3.使用类似注意力的对偶形式,去除了softmax,并引入了一个额外的掩码矩阵L,根据数据生成,控制信息在时间上的传递量。

圆圈表示元素相乘,也就是哈德马积。右边的式子有点难懂,给你写成矩阵形式好理解,假设a =[a1, a2, a3]这个 mask 其实就是个下三角矩阵:

还是看不懂是吧,行标是i,列标是j,i<j 的部分全是零,意味着只考虑时间上早于或同一时间点的元素之间的关系。换句话说,它是一种类似 GPT 模型中的单向注意力机制,只考虑过去的时间步,而不考虑未来的时间步。通过这种下三角矩阵,可以有效地控制信息在时间上的流动,确保信息只能从过去传递到现在,而不能反向传播。

很多人到这里可能会很困惑?为什么这就是对偶啊?明明就是在 A中应用了类似注意力的计算方式嘛!我们打两个比方加深你的理解。如同太极图中的阴阳互补相互转换一样,无论是注意力机制还是 SSM,它们的核心都是处理和更新信息状态,只是方式不同。SSM 像个男人强调逻辑和顺序,注意力机制像个女人本强调细节和关联,对偶就像是二者的结婚,你中有我我中有你,到底是谁干谁也说不清。站在SSM 视角,内嵌了注意力,站在注意力机制的视角,这是新型计算方式。它不是一种简单的内嵌,而是有着数学和功能上的深层联系,这就牛逼了。正如你能说生活不需要技巧,随随便便就能登峰造极吗?

到此,你已经 get了本文的核心思想,来看看它的组织结构。

这张图乍一看是一脸惜逼的,不过借用刚才打的比方让你秒懂。SSM 是个男的,Attention 是个女的,SSD 是二者的对偶就是性生活,这种结合不仅仅是表面的叠加,而是深层次的融合。那我问你 Mamba-2是啥?造人生娃啊!通过深层次结合,达到最佳的计算效果和性能。上边的 Structured Matrices 就是爱爱的结品,受孕的胚胎,叫结构化矩阵,它能带来更加高效的算法。

到目前为止,我们讲完了原文1和2部分的内容,接下来正如上图所示,分别详细讲述几条边。首先是 SSM 如何生成结构化短阵,也就是这个所谓的半可分矩阵。

三、SSM 矩阵的巧妙设计(MAN)

半可分矩阵对应原文第3部分,充斥大量线性代数定理推导,读起来基础不扎实那是相当的费劲。人话总结归纳,其实说穿了就是两层意思:一是 SSM 可以表示为 y=Mx 的形式,其中 M 是 ABC 的表达式:二是 M 具有专门设计的半可分结构,能简化运算。

3.1 SSM的表达式

原来的 RNN 也好,时序系统也好,可以卷积化,写成y=Mx的形式。这里的M 是什么就有讲
头了。

3.2 顺序半可分矩阵 SSS

对矩阵 M 专门设计,可以实现更高效的计算。看图一目了然

首先是序列化的,其次是下三角的,第三是低秩的。顺序半可分矩阵(sss,SequentiallySemiseparable Structure)的原因是:半指主要关注下三角部分,可分指的是每个蓝色小块的秩较小,不超过N,意味着可以用更少的独立成分表示,从而实现高效计算。 y=SSS(A, B,C)·x
而定理 3.5指出,人话状态空间模型 SSM,如果状态的维度为 N,等价于一个秩为 N 的SSS

绕来绕去,就是说任何 SSM 其实都可以转写成一个等价的局部下对角阵M 的形式。

SSM 是一个整体的框架,用于处理输入x并生成输出y,如果把它类比为一个男人,那么 M短阵在其中起到核心作用。提升了计算效率。
比如来个更特别的,让秩 N=1,就是 M=1SS

归纳起来,到目前为止,和 Mamba 初代的区别在于两个:

一是在 A 矩阵的计算中嵌入了注意力公式;

二是让 M 矩阵设计为顺序半可分的形式;

说到底,都是在设计更为机巧特殊的矩阵结构,也就是我们先前所说的阀门结构,从而能更好的控制记忆的流淌融合,也是一种新型注意力机制的体现。这里没有说清楚A和M之间的关系,其实是A设计成了半可分,也就是下对角阵的形式,然后M继承了过来,而BC 没变

3.3 张量收缩下的 SSM 计算

这部分写的堪称混乱,不过说穿了就一句话:SSS 的计算过程可以被看作是一系列张量收缩操作,借助顺序半可分矩阵的特殊结构能实现高效计算。理论上,所有 SSM 的计算都可以通过这种方式优化,从而在处理大规模数据时显著提高计算效率。具体来说,可以分解为三个步球:

我们知道,张虽收缩是矩阵乘法的扩展,允许处理更高维的张量,并进行复杂的维度变换和求和操作。使用诸如 NumPy 的 einsum 函数(爱因斯坦求和约定),可以方便地实现这种操作。如下图所示

三个步骤第一步将输入矩阵X 与矩阵 B 进行结合,以产生一个中间结果 Z。矩阵A没有出现,它体现在第二步因状态更新中,L的定义依赖于 A。第三步是最终输出。

讲完了 SSM 并行化时矩阵 M的形式进行了哪些机巧设计,以及张量收缩视角下的 SSM 计算。再来看看注意力机制怎么也能统一到同样的框架下。

四、注意力机制的通用实现(WOMAN)

这部分原文写的比较啰嗦,让人有点晕头转向。梗直哥把其表达的意思用人话翻译帮你理解

4.1 张量收缩视角下注意力计算

简单说,就是换一种视角,把不同注意力机制的计算一般化,用更加通用的形式来描述,具体来说,就是站到张虽收缩的视角来看,纳入一个统一的框架
比如,常见的注意力计算形式如下:
Q=input (T,N)
K=input (S,N)
V =input(S,P)
G =QKT(T, S)
M=f(G)(T,S)
Y= GV(T,P)
这里的S 和 T 代表源序列和目标序列的长度,N 代表特征维度,P 代表头部维度,最常见的 softmax 这里用f来更一般化的表示。自注意力机制比较简单,就是源序列和目标序列相同S=T,特征维度和头部维度相同 N=P。4.1.1-4.1.3:主要是回顾和总结已有的注意力机制和方法,没有啥新内容。

4.1.4 回顾了矩阵 A 中嵌入注意力机制的掩码注意力
y=(L°(QKT))·V
所有的这些计算方式都可以写成张量收缩的形式,也就是矩阵乘法的扩展,允许处理更高维的张量,并进行复杂的维度变换和求和操作。使用诸如 NumPy 的 einsum 函数(爱因斯坦求和约定),可以方便地实现这种操作。
Y = contract(TN, SN, SP, TS → TP)(Q, K, V, L)

其中还可以进一步拆解成多步收缩,可以更高效地实现注意力机制,提升计算效率。
G= contract(TN, SN → TS)(Q,K)
M = contract(Ts,TS → TS)(G, L)
Y= contract(TS, SP → TP)(M, V)
(T,S)(T,S)(T,P)
也就是先计算相似性矩阵 G=QK^T,应用掩码矩阵L后得到新的相似性矩阵

然后再计算最终的输出 Y=MV,后面对应的是相应的维度。归纳来说,用张量收缩来实现,使得注意力机制的计算过程更加清晰和高效。

4.2 线性注意力

线性注意力及其他许多高效注意力变种,通常通过改变矩阵关联的顺序来实现,例如(OKT)V=Q(KTV).
具体来说,它将标准注意力中的复杂计算简化为累加和操作,从而提高了计算效率。
Y=Q·cumsum(KTV)
同时可以证明它一样可以用张量收缩表达

4.3 结构化掩码注意力

结构化掩码注意力是结合了结构化矩阵和掩码注意力,更加高效的计算方式,用张量收缩来实现:
Z= contract(SP,SN → SPN)(V,K) (S,P,N)
H = contract(TS, SPN → TPN)(L,Z) (T,P,N)
Y = contract(TN, TPN → TP)(Q, H) (T,P)

第一步为扩展操作V和K运算,然后是线性部分计算(L,Z),然后是收缩操作。
如下图所示不同的掩码矩阵(如因果掩码、衰减掩码、Toeplitz 矩阵等)L定义了不同的序列变换矩阵 M,从而实现不同形式的结构化注意力。

如同前面类比 SSM 为男人,SSS为精子,这里不同的注意力机制就像是女人,而structuredmasked attention,SMA 就像是卵子
对比公式 15 和公式8,陡然之间发现他们一样,说明无论是从状态空间模型(SSM)侧,还是从注意力机制侧来看,都可以统一到张量收缩的视角下进行操作。还是借用人类的比方类比,张虽收缩就如同生殖过程,或者说遗传生物学,对男人女人都是类似的,都遵循着同样的DNA 遗传信息传递。

五、状态空间的对偶性(SEX)

原文第 5部分用一个简单的例子进一步验证了状态空间的对偶性。首先,从 SSM 侧来看,假设Aj是个标量,那么SSM的矩阵M为:

而这正是二次掩码核注意力定义的原始定义。换句话说标量结构下的SSM,通过明确写出M的矩阵形式,然后执行二次矩阵-向虽乘法来计算输出,本质上是在执行与二次掩码核注意力相同的计算步骤。这意味着从计算角度来看,二者是等价的,可以相互转换。在特定情况下,SSM 的计算方式可以视为一种注意力机制,特别是当我们使用标星结构和半可分矩阵时。
反之,从注意力机制侧看,当掩码矩阵 LLL 具有特定的结构时(如因果掩码或1-半可分矩阵),注意力机制的计算可以视为 SSM 的计算。

这种关系从图4看的更加清楚:

核心思想:SSM 和 SMA在很多情况下是等价的,可以通过相同的数学操作实现。这种等价性为理解和实现这些模型提供了统一的视角和方法。无论是在状态空间模型还是注意力机制的应用中,都可以利用这种统一视角来选择最合适的计算方法,从而提高模型的效率和效果。图中显示了一大类的状态空间双重模型(SSD),这些模型捕捉了许多序列模型的特性。特别的,1-半可分 SMA 和标量恒等 SSM 在这一交集中。交集部分确实可以看作是不同模型(SSM 和SMA)结合后的统一体,就像结婚后的一家人,不再分你我。

六、算法示例(造人)

6.1 算法原理

既然 SSM 和注意力机制两种对偶等价,结合起来进行计算效率显然更高,如同男女搭配干活不累,你中有我,我中有你。整体上是SSM,但是通过块分解把大矩阵拆解成小的子矩阵每个小问题特别是低秩块,再用注意力机制计算,利用矩阵乘法上的高效性和并行计算能力使得计算过程更加高效。如下图所示,一个大的M矩阵,分解成9块,其中蓝色块用矩阵乘法。

整个计算过程如下图所示。半可分矩阵 MMM 被分解成多个子矩阵块,包括对角块(Diagonal Block)和低秩块(Low-Rank Block)。前者表示输入到输出的计算,后者被进步分成三类:从输入到状态(Input-State),从状态到状态(State- State)从状态到输出(State – Output)

图的下半部分展示了通过这种块分解方法进行计算的流程。输入序列 X 被分解成多个块,每个块对应图中的一个黑色虚线框。输入块通过低秩块(绿色箭头)和对角块(色箭头)进行计算,得到中间的状态块 H。状态块之间通过低秩块(黄色箭头)进行计算,表示状态间的传递。最终,状态块通过对角块(蓝色箭头)计算得到输出块Y。

通过这种块分解方法,可以将一个大规模矩阵运算问题分解成多个小规模的块级别运算问题,每个块可以独立进行计算。这种方法利用了半可分短阵的低秩特性,减少了计算复杂度,同时提高了并行计算能力,使得计算过程更加硬件友好。

6.2代码实现

https://github.com/state-spaces/mamba

segsum(x): 是一个辅助函数,用于计算分段累加和。ssd(X,A,B,C,block len=64,initial states=None): 这是主函数,用于计算 SSD 模型A,B,C:分别表示状态短阵、扩展短阵和收缩矩阵先对输入张量 X、A、B 和 C 进行重排,将它们重排成块的形式。
1.对每个块内的对角块(Diagonal Block)进行计算,使用 torch.einsum 计算块内的矩阵来法。
2.计算每个块内的低秩块(Low-Rank Block),用于生成下一个块的输入状态。
3.生成块间的状态转移,确保在块边界处的状态正确。
4.对块内的低秩块进行计算,将状态转换为输出。最后将块内和块间的输出汇总,得到最终输出Y和最终状态final state。该代码的主要目的是通过块分解的方法,将一个大规模的状态空间模型问题分解成多个小规模的块级别运算问题。这种方法利用了半可分矩阵的特性,能够提高计算效率和并行性,适合硬件加速。

七、Mamba2 网络架构(KIDS)

7.1 多种设计

Mamba-1基于SSM(状态空间模型)设计,线性投影之后生成SSM参数 A,B,C。Mamba-2的两种块设计:顺序 Mamba块和并行 Mamba块。SSM 层从 A,X,B,C直接映射到输出Y,前者在序列变换中并行生成参数 A,B.C,后者在块开始时并行生成,适合更大规模的并行处理。这种方法类似于标准注意力架构中的并行生成 Q.K,V。这种设计减少参数数量,适合更大模型的张量并行计算。每个 Mamba 块中增加额外归一化层,改善模型稳定性,尤其是大模型。总体来说,Mamba-2模型通过并行化和增加归一化层来优化原始 Mamba 模型的计算效率和稳定性。

7.2 并行化处理方法

主要分为两种类型:张量并行(TensorParallelism)和序列/上下文并行(Sequence/ContextParallelism)。如下图所示,左边输入和输出投影矩阵分割,并在单个设备上处理。每个SSM头(即 A、B、C、X)到Y 的映射都在单个设备上进行。最终归一化层选择 GroupNorm,以避免额外的通信。右图将序列维度上的计算分配到多个设备上,每个设备负责一部分序列的计算,然后将结果传递给下一个设备。

八、实验验证

8.1合成记忆任务

上图展示了不同模型在多查询关联记忆(MQAR)任务中的表现。三张子图对应不同的序列长度(256、512、1024)。横轴代表模型的维度(32、64、128、256)纵轴代表准确率(从0到 1)Mamba-2 系列模型在较大的模型维度下表现优异,特别是当维度达到 128 和 256 时,其准确率接近 1.0。Mamba-2 模型明显优于 Mamba-1 和普通注意力模型,尤其在更大的状态规模(N=256)下表现尤为显著。

8.2语言模型预训练与评估

(缩放定律)在 The Pile 上进行训练的模型,Mamba-2的性能匹配或超过了 Mamba 和强大的“Transformer++”方案。与我们的 Transformer基线相比,Mamba-2 在性能(困惑度)、理论 FLOPs 和实际壁钟时间上都是帕累托占优的。
零样本评估:在每种模型规模中,Mamba-2 模型的表现普遍优于其他模型。特别是Mamba-2 在较大的模型规模下(2.78 参数)表现尤为突出,证明其在不同任务上的泛化能力更强。

不同数量的注意力层下的困惑度。大约10%的注意力层比例表现最佳。

适量的注意力层可以显著提高模型性能,超过了完全不使用注意力层或完全使用注意力层的情况。

(零样本评估)比较了 SSD、MLP 和注意力层的不同组合方式,在2.78规模上进行评估困惑度(ppl)和准确率(acc)

Mamba-2与注意力层的结合(后4个)在多个任务上的表现优于其他模型组合,显示出更强的泛化能力和任务适应性。

8.3 速度性能

左图不同方法在处理序列长度(从512到512k)时所需的时间(以毫秒为单位)。右图展示了在处理固定序列长度(4K)时,不同状态维度(从16到256)下所需的时间(以毫秒为单位)。SSD 方法在处理大状态扩展时表现优异,比Mamba的融合扫描快2到8倍(比如64k时紫色线1毫秒,Mamba为 10毫秒),并且在序列长度超过2k时也比FlashAttention-2更快。

总体结论:
Mamba-2 模块在结合并行处理和额外归一化后,显著提升了模型性能,表现优于传统的Mamba-1 模块。

多头结构中,复杂的头组合和状态扩展通常可以提高性能,特别是当型规模增大时。

对于核近似,Swish和LayerNorm 方法通常效果较好,且适用于不同规模的模型·增加复杂度和头的数虽一般有助于提高模型性能,但需要权衡参数数呈的增加。

九、小结

1.统一的理论框架:将 SSM 和注意力机制在张量收缩视角下实现统一,并建立对偶性关联无疑是本文最大的贡献。这意味着注意力机制与 RNN两种网络在底层逻辑上被关联起来了打破了男人和女人的界限,实现了灵活的性别转换。
新型注意力机制:借助 SSM 时空建模,实现创新。从男人的视角研究女人,而线性代数要在注意力机制创新中起到越来越重要的关键性作用。
3.混合模型成为新的趋势。整体 SSM,局部注意力机制,实现灵活组合,提升整体性能。
如果今天的分享对你有所帮助,欢迎三连支持。我是直哥。学好AI不迷路。只说人话,专治好奇。

Sequence Modeling With CTC

网址: https://distill.pub/2017/ctc/

在语音识别中,我们的数据集是音频文件和其对应的文本,不幸的是,音频文件和文本很难在单词的单位上对齐。除了语言识别,在OCR,机器翻译中,都存在类似的Sequence to Sequence结构,同样也需要在预处理操作时进行对齐,但是这种对齐有时候是非常困难的。如果不使用对齐而直接训练模型时,由于人的语速的不同,或者字符间距离的不同,导致模型很难收敛。

我们可以设计一个规则,比如“一个字符对应十个语音输入”。但是人们的语速是不同的,所以这种规则总是可以被打破的。另一种方法是手动将每个字符与其在音频中的位置对齐。从建模的角度来看,这工作得很好,我们知道每个输入时间步的基本事实。 然而,这对数据集的标注工作是非常耗时的。

这个问题不仅仅出现在语音识别中。我们在许多其他地方看到它。来自图像或笔画序列的手写识别就是一个例子。

CTC(Connectionist Temporal Classification 连接时序分类)是一种避开输入与输出手动对齐的一种方式,是非常适合语音识别或者OCR这种应用的。

给定输入序列 𝑋=[𝑥1,𝑥2,…,𝑥𝑇] 以及对应的标签数据 𝑌=[𝑦1,𝑦2,..,𝑦𝑈] ,例如语音识别中的音频文件和文本文件。我们的工作是找到 𝑋 到 𝑌 的一个映射,这种对时序数据进行分类的算法叫做Temporal Classification。

对比传统的分类方法,时序分类有如下难点:

  1. 𝑋 和 𝑌 的长度都是变化的;
  2. 𝑋 和 𝑌 的长度是不相等的;
  3. 对于一个端到端的模型,我们并不希望手动设计𝑋 和 𝑌 的之间的对齐。

CTC提供了解决方案,对于一个给定的输入序列 𝑋 ,CTC给出所有可能的 𝑌 的输出分布。根据这个分布,我们可以输出最可能的结果或者给出某个输出的概率。我们会要求CTC有效地完成下面这两件事。

1、损失函数:给定输入序列 𝑋 ,我们希望最大化 𝑌 的后验概率 𝑃(𝑌|𝑋) , 𝑃(𝑌|𝑋) 应该是可导的,这样我们能执行梯度下降算法;

2、测试:给定一个训练好的模型和输入序列 𝑋 ,我们希望输出概率最高的 𝑌 :

当然,在测试时,我们希望 𝑌∗ 能够尽快的被搜索到。

算法详解

给定输入 𝑋 ,CTC输出每个可能输出及其条件概率。问题的关键是CTC的输出概率是如何考虑 𝑋 和 𝑌 之间的对齐的,这种对齐也是构建损失函数的基础。所以,首先我们分析CTC的对齐方式,然后我们在分析CTC的损失函数的构造。

1.1 对齐

需要注意的是,CTC本身是不需要对齐的,但是我们需要知道 𝑋 的输出路径和最终输出结果的对应关系,因为在CTC中,多个输出路径可能对应一个输出结果,举例来理解。例如在OCR的任务中,输入 𝑋 是含有“CAT”的图片,输出 𝑌 是文本[C, A, T]。将 𝑋 分割成若干个时间片,每个时间片得到一个输出,一个最简答的解决方案是合并连续重复出现的字母,如图2.

这个问题有两个缺点:

  1. 几乎不可能将 𝑋 的每个时间片都和输出Y对应上,例如OCR中字符的间隔,语音识别中的停顿;
  2. 不能处理有连续重复字符出现的情况,例如单词“HELLO”,按照上面的算法,输出的是“HELO”而非“HELLO”。

为了解决上面的问题,CTC引入了空白字符 𝜖 ,例如OCR中的字符间距,语音识别中的停顿均表示为 𝜖 。所以,CTC的对齐涉及去除重复字母和去除 𝜖 两部分,如图3。

这种对齐方式有三个特征:

  1. 𝑋 与 𝑌 之间的时间片映射是单调的,即如果 𝑋 向前移动一个时间片, 𝑌 保持不动或者也向前移动一个时间片;
  2. 𝑋 与 𝑌 之间的映射是多对一的,一个或多个输入元素可以与单个输出元素对齐,但反之则不然,所以也有了特征3;
  3. 𝑋 的长度大于等于 𝑌 的长度。

1.2 损失函数

CTC对齐为我们提供了一种从每个时间步的概率到输出序列的概率的自然方法。

也就是说,对应标签 𝑌 ,其关于输入 𝑋 的后验概率可以表示为所有映射为 𝑌 的路径之和,我们的目标就是最大化 𝑌 关于 𝑥=𝑦 的后验概率 𝑃(𝑌|𝑋) 。假设每个时间片的输出是相互独立的,则路径的后验概率是每个时间片概率的累积,公式及其详细含义如图5。

上面的CTC算法存在性能问题,对于一个时间片长度为 𝑇 的 𝑁 分类任务,所有可能的路径数为 𝑁𝑇 ,在很多情况下,这几乎是一个宇宙级别的数字,用于计算Loss几乎是不现实的。在CTC中采用了动态规划的思想来对查找路径进行剪枝,算法的核心思想是如果路径 𝜋1 和路径 𝜋2 在时间片 𝑡 之前的输出均相等,我们就可以提前合并他们,如图6。

其中,横轴的单位是 𝑋 的时间片,纵轴的单位是 𝑌 插入 𝜖 的序列 𝑍 。例如对于单词“ZOO”,插入 𝜖 后为:

𝑍={𝜖,𝑍,𝜖,𝑂,𝜖,𝑂,𝜖}

我们用 𝛼𝑠,𝑡 表示路径中已经合并的在横轴单位为 𝑡 ,纵轴单位为 𝑠 的节点。根据CTC的对齐方式的三个特征,输入有9个时间片,标签内容是“ZOO”, 𝑃(𝑌|𝑋) 的所有可能的合法路径如下图:

图7:CTC中单词ZOO的所有合法路径

有两个有效的起始节点和两个有效的最终节点,因为序列开头和结尾的 𝜖ϵ 是可选的。完全概率是最后两个节点的和。现在我们可以有效地计算损失函数,下一步是计算梯度并训练模型。CTC损失函数相对于每个时间步的输出概率是可微的,因为它只是它们的总和和乘积。考虑到这一点,我们可以解析地计算损失函数相对于(未归一化的)输出概率的梯度,并从那里像往常一样运行反向传播。

图7

对于数据集 𝐷 ,模型的优化目标是最小化负对数似然:

1.3 预测

当我们训练好一个RNN模型时,给定一个输入序列 𝑋 ,我们需要找到最可能的输出,也就是求解

𝑌∗=arg⁡max𝑌⁡𝑝(𝑌|𝑋)

求解最可能的输出有两种方案,一种是Greedy Search,第二种是beam search

1.3.1 Greedy Search

每个时间片均取该时间片概率最高的节点作为输出:

1.3.2 Beam Search

Beam Search是寻找全局最优值和Greedy Search在查找时间和模型精度的一个折中。一个简单的beam search在每个时间片计算所有可能假设的概率,并从中选出最高的几个作为一组。然后再从这组假设的基础上产生概率最高的几个作为一组假设,依次进行,直到达到最后一个时间片,下图是beam search的宽度为3的搜索过程,红线为选中的假设。

到目前为止,我们提到了CTC的一些重要属性。在这里,我们将更深入地了解这些属性是什么以及它们提供了什么样的权衡。

CTC的性质:

  1. 条件独立:CTC的一个非常不合理的假设是其假设每个时间片都是相互独立的,这是一个非常不好的假设。在OCR或者语音识别中,各个时间片之间是含有一些语义信息的,所以如果能够在CTC中加入语言模型的话效果应该会有提升。
  2. 单调对齐:CTC的另外一个约束是输入 𝑋 与输出 𝑌 之间的单调对齐,在OCR和语音识别中,这种约束是成立的。但是在一些场景中例如机器翻译,这个约束便无效了。
  3. 多对一映射:CTC的又一个约束是输入序列 𝑋 的长度大于标签数据 𝑌 的长度,但是对于 𝑌 的长度大于 𝑋 的长度的场景,CTC便失效了。

Practitioner’s Guide

到目前为止,我们已经对 CTC 有了概念性的理解。在这里,我们将为从业者提供一些实现技巧。

软件:即使对 CTC 有深入的了解,实施也很困难。该算法有几个边缘情况,应该用较低级别的编程语言编写快速实现。开源软件工具使入门变得更加容易:

  • 百度研究已经开源了warp-ctc。该软件包是用 C++ 和 CUDA 编写的。CTC 损失函数在 CPU 或 GPU 上运行。绑定可用于 Torch、TensorFlow 和 PyTorch。
  • TensorFlow 为 CPU 内置了 CTC loss and CTC beam search 束搜索函数。
  • Nvidia 还在cuDNN 版本 7 及更高版本中提供了 CTC 的 GPU 实现。

Numerical Stability:计算 CTC 损失在数值上是不稳定的。避免这种情况的一种方法是在每个时间步长归一化 α。在实践中,这对于中等长度的序列来说已经足够好了,但对于长序列来说,它仍然会下溢。更好的解决方案是使用 log-sum-exp 技巧计算对数空间中的损失函数。 在对数空间中计算两个概率之和时,使用恒等式.还应使用 log-sum-exp 技巧在 log-space 中进行推理。

Beam Search: 

使用波束搜索解码器时的一个常见问题是要使用的波束的大小。准确性和运行时间之间存在权衡。我们可以检查光束尺寸是否在良好的范围内。为此,首先计算推断输出 ci​. 的 CTC 分数,然后计算真值输出 cg​. 的 CTC 分数 如果两个输出不同,我们应该有cg​<ci​. Ifci​<<cg​ ,那么真值输出在模型下实际上具有更高的概率,并且光束搜索未能找到它。在这种情况下,可能需要大幅增加光束尺寸。

MambaOut-视觉领域的探索

https://arxiv.org/pdf/2405.07992
GitHub – yuweihao/MambaOut: MambaOut: Do We Really Need Mamba for Vision?

作者的主要观点可以概括如下:

  1. Mamba架构的适用性:作者认为Mamba架构,具有类似于RNN的token混合器状态空间模型(SSM),最适合处理具有长序列和自回归特性的任务。
  2. Mamba在视觉任务中的表现:尽管Mamba被引入以解决注意力机制的二次复杂性问题,并应用于视觉任务,但在图像分类等任务中,其性能通常不如卷积神经网络和基于注意力的模型。
  3. Mamba在图像分类任务中的不必要性:作者提出,由于图像分类任务既不符合长序列也不符合自回归特性,因此引入Mamba是不必要的。
  4. Mamba在检测和分割任务中的潜力:尽管检测和分割任务不是自回归的,但它们符合长序列特性,因此作者认为探索Mamba在这些任务中的潜力是有价值的。
  5. MambaOut模型的构建与验证:为了验证上述假设,作者构建了一系列名为MambaOut的模型,这些模型在不使用Mamba核心token混合器SSM的情况下堆叠Mamba块。实验结果强烈支持作者的假设。
  6. MambaOut模型的性能:MambaOut模型在ImageNet图像分类任务上超越了所有视觉Mamba模型,表明Mamba对于图像分类任务确实不是必需的。然而,在检测和分割任务上,MambaOut未能达到最先进的视觉Mamba模型的性能,显示了Mamba在长序列视觉任务中的潜力。
  7. 未来研究方向:由于计算资源的限制,本文仅验证了Mamba在视觉任务上的概念。作者提出,将来可能会进一步探索Mamba和RNN概念,以及RNN和Transformers在大型语言模型(LLMs)和大型多模态模型(LMMs)中的集成。

结论:本文从概念上讨论了Mamba机制,并得出结论认为它非常适合具有长序列和自回归特性的任务。我们根据这些标准分析了常见的视觉任务,并认为在ImageNet图像分类中引入Mamba是不必要的,因为它不符合这两个特性。然而,对于与长序列特性相符合的视觉检测和分割任务,Mamba的潜力值得进一步探索。为了实证支持我们的观点,我们开发了MambaOut模型,这些模型采用了没有核心标记混合器SSM的Mamba块。MambaOut在ImageNet上超越了所有视觉Mamba模型,然而与最先进的视觉Mamba模型相比,它表现出明显的性能差距,从而验证了我们的主张。由于计算资源的限制,本文仅验证了视觉任务中的Mamba概念。将来,我们可能会进一步探索Mamba和RNN概念,以及将RNN和Transformers集成到大型语言模型(LLMs)和大型多模态模型(LMMs)中。

一、Mamba 到底 OUT 了没?


文章首页展示方式非常赞的:
首先标题开宗明义,强调讨论点其实是 Mamba对 Vision 的应用?NLP 排除了,因为那就是纯纯的序列问题,Mamba 还是非常值得继续深入研究,还有很大水文的空间哈。其次,非常显著的给出了 github 代码连接,一句对科比的致敬也让技术型文章多了点人文色彩。它来自于科比 2016年4月 14日最后一场比赛后,为了感谢全场球迷有感而发脱口说的告别词。Mamba 原意是毒蛇,象征着科比在球场上的攻击性和坚韧不拔。后来几乎成了互联网上的一个梗。从 Transformer开始,AI界写论文都兴起玩梗了,既是为了宣传,也是为了突出与众不同,在万千文章中让人记住,用心良苦着实让人感动。但是不是Mamba 真的 OUT,咱们讲完就知道了。
第三,图1一目了然的给出了与 Mamba 模块的主要区别,就是干脆做减法去掉了 SSM 结构,右图体现了性能上的差别,注意这仅在ImageNet的分类任务上。没有 NLP 数据集,没有其他视觉任务,仅仅图像分类!

整个摘要的重点,也是结论性的东西,作者其实用斜体给你标出来了
1.long-sequence and autoregressive 这方面 Mamba 依然擅长,承认优点。
2.图像分类不是 autoregressive 自回归任务,也不是long-sequence,因此用不着 Mamba所以 MambaOut。比如在ImageNet 分类任务上
3还有第三个结论也很有意思,即使是视觉领域,目标检测和实例分割任务上 Mamba还OUT不了,依然有潜力。明白了吗?

二、如今该爱谁:Transformer、Mamba和MambaOUT

虽然标题大胆而耸人听闻,但引言部分还是很旗帜鲜明的给足了Mamba Credit。简单的说就Transformer 有硬伤,面对长序列时自注意力机制计算的复杂度会出现随窗口长度二次方增加的问题。

Mamba 模型的出现引起了 A1 社区的广泛兴趣,因为具有可并行训练和高效长序列推理的能力。除了 Mamba 外,很类似的还有 RWKV 模型大家也可以关注一下。最近这半年出来了一批模型。简单的说都是“RNN+注意力机制”相结合的产物,区别在于适用任务和架构设计上的差异,有的更专注于 NLP 任务,有的尝试用在视觉上。
整篇文章的研究重点其实就是前言中几行斜体字:

Do we really need Mamba for Vision? 视觉问题真得需要Mamba 模型吗Hypothesis 1:SSM 对于图像分类没有必要,因为该任务既不具有长序列特征也不具有自回归特征。
Hypothesis 2:sSM 可能对对象检测和实例分割有潜在好处,因为这些任务具有长序列特征,但不具有自回归特征。
重要的是三个问题:怎么分析的,模型怎么实现的,以及怎么用实验证明的。

第二部分相关工作简要小结了 Transformer 典型模型 BERT和 GPT系列,以及 ViT 强调了Transformer 中的注意力模块会随序列长度增加而扩展,带来显著的计算挑战。许多研究探索了各种策略来缓解这一问题,如低秩方法、内核化、token 混合范围限制和历史记忆压缩。这都是水文章的号方向。最近,RNN-like方法(特别是 RWKV和Mamba)因其在大规模语言模型中的出色表现而受到关注,这点到目前为止还是毋庸置疑的。

对于Transformer 的改进或者说平替,现在学界的一种典型思路就是回归传统模型,从故纸堆里找灵感。这篇文章的作者显然也认同这种观点,而且直接露骨的把它们称作 RNN-like 方法,其实最新的还有 xLSTM。但这种视角还是浅了,仅仅是从模型结构视角来看,做一种时序回归而已。第二段小结了 Mamba 最新的各种变体,包括 Vision Mamba 整合了 SSM 来开发类似ViT 的等向性视觉模型;VMamba 则利用 Mamba 构建类似 AlexNet和 ResNet 的分层视觉模型;LocalMamba 通过引入局部归纳偏置来增强视觉 Mamba 模型;

PlainMamba 旨在进一步提升等向性 Mamba 模型的性能,还有 EfficientVMamba等等。你看事实上大家这半年来以及像吸血鬼一样迅速扑上去搞 Mamba 了,把它作为Transformer 的平替。而这篇文章试图把自己打扮成“半血猎人“来拯救世界。你说它是Mamba吧,它说自己不是,你说它就是CNN吧,它非要把自己和Mamba比,还起了这么个名。有意思,也很拧巴。

三、核心原理

论文第三部分正式开始分析 Mamba优缺点,适合什么任务。直接看公式容易迷迷瞪瞪,不明觉厉:

1.Mamba 的本质回顾

文章中提到 Mamba 是个 token 混合器 mixer,这和我上期给大家讲的如出一辙,咱们当时是说“掺和”,看下图,B掺和了三次,C掺和了两次。

其实用流体力学视角 看 Mamba 更透彻,本质上就是当成一个记忆流淌的管道系统而 selective SSM 就是个带着总开关 delta + 两个阀门 BC + 主管道A的系统。因为A 与时间无关,因此隐藏状态 h可以视为固定大小的记忆,存储所有历史信息。固定大小意味着记忆不可避免地丢失,但保证了与当前输入集成的计算复杂性保持不变而通过总开关 delta + 两个阀门 BC 门控机制实现了一种选择性注意力机制。这种设计更加的高效,从更抽象的数学角度理解,是用李指数隐射拟合数据,替换了原有的牛顿力学运动方程。

2.自注意力机制的类别

相比之下,Transformer 中的自注意力机制更加复杂,如下图有两种:一种叫因果模式,其实就是只能看过去,不能看未来,只有记忆没有未卜先知;另一种是全可见模式,左右都知道。Transformer 本质上两种都可以,因果模式的比如 GPT,全可见的比如 BERT,前者适合自回归用来生成和预测,以史为鉴,后者适合理解,左顾右看瞻前顾后。

按照这种分类,Mamba的选择性机制算那种呢?显然不是全可见模式,看公式就知道是因果模式,但和 Transformer 的有什么不同呢?

下图展示了 Transformer 因果注意力与 Mamba 中因果注意力的区别,前者是组合(叠加)之前所有的记忆,记忆无损但复杂度增加,越累越长,计算复杂度同样为0(L^2):后者融合之前的记忆到新的隐藏状态,记忆有损但复杂度恒定

基于 Mamba 的这种特性,显然它适用于以下特征的任务

·特征 1:任务涉及处理长序列,因为复杂度低,更高效

·特征 2:任务需要因果 token 混合模式。

但这样以来还怎么 OUT 呀!于是他们反向思考:首先,什么时候不需要长序列呢?视觉作为空间数据,那种最不需要呢?你说是鸡蛋里挑骨头也好,逆向思维也好。既然逻辑上它擅长长序列,那就说明短序列一般,那咱们就摁着短序列搞不就成了。
其次,什么时候不需要因果注意力呢?什么问题需要全局可见注意力呢?着这个方向搞,不也能证明 Mamba不行吗?这种创新的思维方式确实聪明,典型的田忌赛马思路,你打你的,我打我的,拉到我擅长的地方打,你还打得过吗?

3.视觉任务的特点分析

在视觉识别任务中,感觉上图像分类就不属于长序列任务,因为主要关注整体特征空间特征就够了,目标也只是粗犷的类别标号,因此不涉及什么序列信息,而且需要全局信息。但是目标检测和语义分割则不一定,比如要考虑边缘的连贯性,因此可能有序列问题。但是,这种假设或者感觉怎么证明呢?首先,文章针对图像分类任务,做了全可见模式和因果模式性能的分析实验。如图所示:

左图是全可见模式,横纵轴互相都能看,BERT和ViT 的自注意力机制都是这种。中间图是因果模式,GPT的自注意力机制和 Mamba的 SSM 是这种,比如y3 只能看到x1-x3,看不见 x4-x5。右图显示以 ViT 为例,将自注意力机制从全可见模式切换到因果模式后,性能有所下降,说明对于图像分类问题,用因果模式没必要。
既然注意力机制的类型明确了,在图像分类这一亩三分地上干掉 Mamba 的可能性暴增。但老问题又回来了,怎么确定它是不是长序列任务呢?整篇文章最有点数学理论含量,也是最有看点的就是32关于图像处理任务是否属于长序列问题的分析。

分析的切入点选择了 Transformer 的浮点运算次数公式,也就是一个Transformer 块的
计算量:
FLOPs = 24D2L + 4DL2
其中 L是 token 长度(即输入序列的长度),D 是通道维度(即特征维度),加号前后分别代表线性复杂度和二次复杂度。为什么要看这个,因为Transformer的硬伤不就是长序列运算量随着窗口长度暴涨吗?针对具体任务,如果我们能知道它们的计算量是不是对L敏感,那不就知道它是不是需要长序列建模了吗?
这个公式可能没几个人熟悉,也不知道怎么来的呢?我们来拆解下加深理解,因为它对理解这篇文章的核心思想非常关键。
在 Transformer 模型中,自注意力机制是计算量最大的部分。为了估算 Transformer 块的计算量(即浮点运算次数,FLOPS),需要考虑自注意力机制和前馈神经网络(FFN)的计算。

在 ViT 中图像首先被分割成多个固定大小的 patches,每个的大小通常为 16×16 像素
Token 数:假设输入图像大小为 224 x 224,则生成的 token 数为(器)=14×14 =196.
每个这样的 patch 通道被展平成一个长向量,RGB 三通道就是 16163=768,然后通过一个线性投影层(粉色的)映射到高维空间,也就是通道维度D,它是个指定的超参数。

对于 ViT-S,常见的通道维度 D为 384。对图像分类任务L=196,远小于6D=6384=2304,因此不涉及长序列建模。

目标检测和实例分割问题:在COCO 数据集上推理图像大小为800×1280,生成的token 数约为 4000,大于 6D=6384=2304,因此涉及长序列建模。

这个结论和我们先前的直觉分析是一致的:图像分类模型不需要处理非常长的序列来捕捉远距离的依赖关系,序列长度L相对较短,模型的注意力窗口不需要很大。

在实例分割和目标检测任务中,虽然图像同样被分割成 patches,并作为序列输入到模型中,但不仅需要识别图像中的对象,还需要确定对象的位置和边界。由于输入图像通常更大,生成的 tokens 数量(序列长度L)也更多。需要捕捉远距离的依赖关系,例如物体的边缘和不同部分之间的关系。模型需要处理较长的序列,自注意力机制的窗口需要更大,以捕捉这些复杂的依赖关系。

3.3 和 3.4 进一步讨论了视觉识别任务是否需要因果注意力以及 SSM 机制的必要性。其实结论已经很明显了,既然它和长序列没有关系,那就是理解任务,当然需要全可见型注意力机制,不需要时序记忆,而需要空间全局可见的高屋建瓴。

因此本文主要 IDEA 在于验证了两个假设假设 1:在 ImageNet 上的图像分类任务中引入 SSM 没有必要,因为这个任务不需要长序列建模或自回归特性。
假设 2:尽管目标检测和分割任务不需要自回归特性,但由于这些任务涉及长序列建模,因此值得探索 SSM 在这些任务中的应用潜力。
明白了分析的过程,咱们看看模型架构。

  1. 模型架构
    下图展示了 MambaOut 模型的总体框架以及 Gated CNN 块的具体结构。整体框架类似于 ResNet,通过降采样逐步减少特征图的尺寸,同时增加特征的抽象层次。

输入图像大小为 HxWx3,表示图像的高度、宽度和 RGB 三个颜色通道。采用了分层架构,共有四个阶段,每个阶段进行特征提取和降采样。每个阶段包含若干个 GatedCNN 块,用于特征提取。每个阶段之间有降采样操作,将特征图的大小逐渐减小,从而增加特征的抽象层次。通道维度为D1,D2,D3,D4。

右侧是基本组件,Gated CNN块包含两个线性层、中间夹一个卷积层和归一化层,通过残差连接实现输入和输出的融合。和 Mamba块的区别在于前者没有 SSM(状态空间模型)。

MambaOut 的架构与 Swin Transformer 和 DenseNet 在分层结构和降采样方面有相似之处,但在特征提取和信息混合机制上有所不同。MambaOut使用 GatedCNN块而 Swin Transformer 使用窗口注意力机制的 Transformer块,DenseNet 则使用密集连接的卷积层。这些差异决定了它们在处理不同任务时的特性和优势。

这么看,MambaOut 实际上就是在 Gated CNN 基础上的优化版本,通过结构简化和实验验证。更准确的说主要就是和 Mamba 做针对性的对比而已,在视觉识别任务中进行了优化和验证。
结论:如果说 Mamba是回归RNN+新型注意力机制,那 Mamba0ut 其实是回归CNN+新型注意力机制。
既然如此,读到这里时其实产生了很大的困惑,这不就是纯纯的CNN吗?不就是 ResNet 吗?它怎么还好意思叫 MambaOUT呢?这个 Gated CNN 神奇到哪里了呢?确实,网络结构上的改进比较小,但人家起作用了啊。魔鬼在细节:

对比左边的 ResNet,两大微小的区别:

一是使用线性层进行升维操作,使得 Gated CNN 块能够在特征空间中进行更灵活的变换,这与传统的 ResNet 中主要使用卷积操作进行特征提取有所不同。

二是跳线增加了非线性激活函数可以被看作一种简单的门控机制,根据输入值调整输出信息量。增加了模型的非线性能力,使得模型能够学习更复杂的特征。

5.代码实现

第 4.1 节更侧重于解释模型设计和架构上的区别,为后续章节的实验结果分析提供背景和依据。
来看代码实现,Gated CNN块通过线性变换、卷积操作和残差连接,实现了对输入特征的扩展、局部特征提取和信息保留。结合了深度卷积网络和残差网络的优点,同时通过门控机制(如激活函数)来控制信息流。

外部结构是个四级堆叠,具体看 Github。重点看这段代码,整体比较简单,因为去掉了 SSM。需要注意的事只对部分通道进行深度卷积,看看是怎么实现的。这里的conv_channels 定义了要进行深度卷积的通道数,conv_ratio 是一个控制参与卷积的通道比例的参数。这意味着卷积操作只在部分通道上进行,而不是所有通道。

四、 实验效果


1、图像分类比较
4.2 汇报了在 ImageNet 上图像分类的比较。太细节实现的我们跳过,可以看原文。着重结果分析,比较了MambaOut,VMamba及其他基于卷积和注意力机制的模型在ImageNet 上的表现如表1所示。

这个表看着很吓人,比较了几十种模型及其变体的性能,但其实结论并不复杂:

1.SSM 有没有对图像分类意义不大,因为时序关系不重要。
2.不如最新的 CAFormer-M36 使用简单的可分离卷积和原始注意力机制,比所有同等大小的视觉 Mamba 模型高出超过1%的准确率85.2%。人家才是纯种的CNNtransformer 啊。

2.目标检测和实例分割

使用标准的 COCO 数据集,Mamba0ut 作为 Mask R-CNN 的主千网络使用,结果尽管 MambaOut 在 COCO 上的目标检测和实例分割任务中可以超越一些视觉 Mamba模型,但它仍然落后于最先进的视觉 Mamba模型,例如 VMamba 和LocalVMamba。这种性能差距强调了在长序列视觉任务中整合 Mamba 的好处。当然,与最先进的卷积-注意力混合模型 TransNeXt相比51.7%%,视觉 Mamba 仍表现出显著的性能差距49.2%。仍然需要努力!这个合理也不合理,两点:

1.Transformer优化了多少年了,Mamba 才多久

2.即使是实例分割问题,所谓的长序列建模,但序列长度并没有 NLP那么长,因此效果有限正常。

3.语义分割的比较
结论与实例分割类似,SSM 模块在这些任务中的重要性,同时也验证了Mamba0ut在某些情况下的有效性。视觉 Mamba 需要进一步展示其在长序列建模任务中的强大性能,以在语义分割任务中实现更强的性能。

五、小结与探讨:What canlsay!

本文主要的贡献在于:
1.定量分析论证了图像分类任务不是长序列建模问题,而目标检测和实例分割是。前者不需要 RNN 这种机制,因此MambaOut,后者 OUT不了。
2.借鉴 Mamba的 GatedCNN 结构微调了 ResNet,实现了一种新型全局可见注意力机制下的改进版模型。

GPT-4o背后的语音技术

5月14日凌晨,OpenAI推出了最新的生成模型GPT-4o,带来了一系列震撼的功能,用技术彻底颠覆了产品形态。产品最大的亮点在于:以近乎完美的交互方式,为每位用户带来GPT-4级别的智能体验。在语音方面,GPT-4o做到了实时低延迟,平均响应时间与人类反应速度相当,输出的语音能够理解极度贴合对话上下文,能够理解人类的情感情绪,听觉质量上佳,与真人无异。

OpenAI的博客:https://openai.com/index/hello-gpt-4o/

GPT-4o是一个any2any的多模态模型,能够接受文本、音频、图像、视频等多模态输入,也能够生成包含文本、语音、图像和视频等混合内容的多模态输出。限于篇幅,本文主要谈谈语音多模态的实现,并分享一些对于语音研究未来发展的看法。

当我们主要关注文本和语音模态时,GPT-4o其实就是一个语音语言模型(speech language model, SLM)。该SLM同时具备语音理解能力和语音合成能力,输入端和输出端均支持文本和语音的混合多模态。那么,这一SLM应该如何实现呢?在大语言模型(large language model, LLM)滥觞的今日,不难想到这样一种方法:将连续的语音数据离散化成如同单词(或者称token,词元)一样的表示,并入到LLM的词表中,再走一遍训练LLM的老路。

基于上述思想来构建SLM,需要解决以下几个问题:

  1. 语音如何离散化?
  2. 如何让LLM理解语音的token?加入语音token之后,LLM在语音数据的理解上是否具有涌现性?
  3. LLM如何合成/解码语音?

接下来,我们按图索骥,分别看看上述三个问题应该如何解决。看完现有的方案之后,也会谈谈一些关于工程实现的思考以及新兴语音技术对于游戏业务的影响。最后,我会给出一个完整的roadmap来收束全文。

语音的离散化:向LLM看齐!

在谈及语音离散化之前,我们先来看看语音和文本作为两种不同的模态,有什么区别,有什么联系。这直接关系到后文建模方法的选择以及离散化特征的关注点。

语音和文本的差别主要体现在:文本离散、序列短、信息密度高(几乎每个词都包含语义);语音则连续、序列长、信息密度低。语音序列长、信息密度低的特点,意味着语音数据有很大的压缩空间,这一点和图像非常类似。因此,一些用于图像的离散化压缩方法也可以用在语音上。

除了差异,语音和文本也有一定的联系:语音是文本的超集,既包含文本内容(说话人说了什么,也就是语义信息),也包含语音特有的音色、韵律、语速等声学信息(也叫做副语言)。既然语音包含文本,那么在NLP中预训练语言模型也可以用来建模语音中的上下文依赖关系,从而得到语音的离散化token。基于这些方法得到的token主要包含语音的语义信息。

花开两朵,各表一枝。我们先来看看语音的语义token如何获取。

语义token:  用MLM建模语音的上下文依赖

语音的语义建模方法,最常用到的就是BERT的MLM方法,比较经典的工作有三个:wav2vec 2.0[1]、HuBERT[2]和w2v-BERT[3]。

类似于BERT,wav2vec 2.0[1]在隐空间(latent space)随机mask了一定比例的语音输入,然后用基于对比学习的训练目标学习帧的表征。值得注意的一点是,对比学习中目标帧的离散化处理是一个非常巧妙的操作,它将无限的连续特征空间坍缩为有限的离散空间,让帧特征的鲁棒性更强了。这在语音领域上非常有用的trick,允许模型接受带有噪声的语音作为输入。

图1:wav2vec 2.0的模型架构

wav2vec 2.0只是借用了BERT中mask的操作,训练目标大体上是基于对比学习的范式。那么,能直接用BERT的MLM建模目标来得到高质量的语音表征吗?其后的HuBERT[2]做的就是这个事情。HuBERT[2]的核心点在于使用简单的KMeans聚类方法为语音数据抽取离散化的分类标签,也就是文中所说的hidden unit/acoustic unit。有了分类标签,然后就是用BERT的MLM loss来学习语音数据中内在的上下文依赖关系。对于KMeans聚类对初始值和K值高灵敏的特点,作者设计了ensemble和iterative refinement方法予以解决。前者就是多个聚类模型ensemble,后者就是先在基于MFCC的聚类标签上进行学习,学习到一定程度时,在模型学习到的表征重新聚类,再做一次BERT的学习。

图2:HuBERT的模型架构

既然对比学习可以学习语音的语义表征,BERT的MLM也可以,那将二者结合起来,会不会有互补的效果呢?w2v-BERT[3]做的就是这个事情。注意到:HuBERT中语音的离散token不是端到端获得的,需要用KMeans算法对特征进行离线聚类,而wav2vec 2.0又正好提供了音频帧的量化离散表征,HuBERT和wav2vec 2.0很容易就能缝合在一起。缝合的方法也是显然的:前面若干层做类似wav2vec 2.0的对比学习,学习出HuBERT要用的离散表征,然后在后面若干层做类似HuBERT的MLM训练。

图3:w2v-BERT的模型架构

声学token:压缩+离散

上一部分介绍的预训练模型做的是上下文关系的预训练,学习到的表征主要包含与上下文相关的语义信息。要想将语音的token还原成为真正具有真人表现力的信号,还需要有包含音色、韵律、语速等副语言信息的声学特征。声学特征的学习在很大程度上参考了图像领域的工作,用到的主要是类似于VQVAE[4]、VQGAN等的离散化压缩方法,并针对语音数据的特性做了优化。这一部分比较经典的工作就是SoundStream[5]和Encodec[6],二者的工作高度类似,我们放在一起来看。

说到压缩,最先想到的模型当然就是AutoEncoder(自编码器)。为提升压缩效率,有利于数字传输和存储,以及离散化建模的要求,压缩模型中还需要包含量化(quantization),将连续的音频信号转换为离散的数值。基于上述考虑,模型大体上应该是VQVAE[4]的结构。为了平衡VQ(Vector Quantization,向量量化)与音频实时高保真传输的矛盾,通常采用多个残差连接的codebook来进行量化,这个就是所谓的RVQ(具体分析过程可以参见知乎文章)。采用RVQ的好处主要有两个:其一,区分不同quantization block的分工,第一个block包含最重要的语义信息,后续的block包含还原语音的副语言信息;第二,模型训练时可随机采样前面若干个block来训练,保持一定精度,实现对比特率的动态适应。

总而言之,SoundStream[5]/Encodec[6]其实就是一个RVQ-VAE,它们所建模的语音离散化token包含了层次化的语义信息和声学信息。

图4:Encodec的模型架构

语音的统一表征?

不难发现,虽然说SoundStream[5]和Encodec[6]这样的基于RVQ-VAE的压缩建模方法包含了语音的声学特征,但其中也不可避免地带入了语义特征。二者提取的实际上更像是一种语义特征和声学特征的混合体。基于此,SpeechTokenizer[7]在二者的基础上,引入了语义引导信息来解耦语义特征和声学特征。语义特征和声学特征的解耦对于最终的语音合成有着相当的重要性。SpeechTokenizer的具体做法是:使用HuBERT[2]的特征对RVQ1的特征做语义蒸馏,其余部分保留声学信息。

图5:SpeechTokenizer的模型架构


语音的其他表征:MEL依旧有用!

上述的语音离散表征,不管是基于HuBERT[2]的语义token,还是基于Encodec[6]的声学token,它们都是直接基于原始的音频波形抽取的。除此之外,也可以基于语音的中间表征来抽取。最典型的语音中间表征就是梅尔谱(MEL spectrogram,下文简称MEL)。梅尔谱本身就对语音进行了压缩,将梅尔谱类比于图像,使用单码本的VQ也可以达到与SoundStream和Encodec那样类似的压缩程度。这种MEL+VQ的做法在各种语音合成模型中也相当常见。我们在语音合成部分会详细介绍。

让LLM理解语音token!

有了上面所说的语义token和声学token之后,其实就可以利用它们来构建语音层面的语言模型了。比较经典的工作有:谷歌的AudioLM[8]和AudioPaLM[9]、字节的SALMONN[10]、复旦的SpeechGPT[11]/SpeechGPT-Gen[12]/SpeechAlign[13]、阿里的LauraGPT[14]和新加坡国立大学的NextGPT[15]。它们的做法其实都大差不差,我们看几个就知道是怎么回事了。

AudioLM:最初的SLM

见名知义,AudioLM[8]构建的是语音层面的语言模型——给定一段语音,模型预测后续的语音。输入侧和输出侧都只有语音模态。这个任务形式和GPT-4o非常类似,不会经历ASR->LM->TTS的过程,而是直接从语音上下文中推理语义信息,再结合声学信息合成贴合上下文的高表现力语音。而上文所述的语义token和声学token正好就能满足这个任务的要求。

AudioLM的具体做法是:用SoundStream[5]提取声学token,用w2v-BERT[3]提取语义token,模型主体就是一个常规的GPT,词表包含所有的声学token和语义token。它的建模过程也相当有意思,有很大的参考意义:先做最重要的语义建模,然后先预测SoundStream的前若干层特征,建模粗糙的声学特征,在预测SoundStream的剩余层特征,建模声音的细节信息,最后基于所有的声学token还原为语音。这种层次化的建模在诸如VALL-E[16]这样的语音合成模型中也非常常见。

图6:AudioLM的tokenizer

图7:AudioLM的建模流程

当然,AudioLM[8]仅仅关注语音模态,LM也很常规,不具备如同GPT-4o一样强悍的指令遵循能力和对话能力,语音对话的连贯性和表现力都相当弱。但这一工作仍然具有相当的启发性和开拓性,证明了:即使是常规的LM,照样也能理解语音token。

AudioPaLM[9]:整合LLM

这个就是AudioLM的后续了,谷歌将常规的LM替换成已经训练好的、具有强大文本理解能力和生成能力的大语言模型——PaLM-2[17],既继承了AudioLM保留副语言的能力,又融合了PaLM-2强大的语义理解能力和推理能力。而且,该模型的词表同时包含大语言模型的token和语音token,可以同时做语音理解任务和合成生成任务,第一将这些任务整合在一个模型中进行解决。

不过,需要指出地是,文中的语音token embedding是直接输入到Transformer中的,并没有使用音频编码器做一次转换。而且,AudioPaLM的训练更加接近文本多任务的T5,并未用到复杂的、丰富多样的指令来表达任务的意图,还不能算是真正严格的instruction fine-tuning。

图8:AudioPaLM的模型架构

SALMONN[10]:让LLM理解语音

这是字节跳动和清华大学电子系(也是我们实验室)的合作成果。虽然这个工作的目的是让LLM能够理解语音,还不能生成语音,但它的训练方法和LLM比较接近,而且在诸多语音相关的任务上都显示出了涌现性,可以用作universal的特征提取器,这对于构建高质量的、包含语音-文本多模态的指令微调数据集具有相当大的意义。

图9:SALMONN的模型架构

SpeechGPT/SpeechGPT-Gen/SpeechAlign:向LLM的训练方法看齐

这算是复旦大学邱锡鹏组在这个领域一个成系列的工作,我们一个一个来看。

SpeechGPT[11]做的也是兼具语音理解能力和语音生成能力的多模态模型。在模型的训练上,SpeechGPT大幅度向LLM看齐,使用了三段式的训练方法:第一阶段先做模态适应的预训练,其实就是拿ASR的语音数据来做预训练;第二阶段和第三阶段都是指令微调,不过根据指令模态的不同,细分为了跨模态的指令微调和模态链指令微调。指令微调的数据集都是来自ASR数据集。描述任务需求的指令由GPT-4生成。

在我看来,这个工作还是相当偏学术化的作品,文中有不少点都有值得商榷的地方:第一,语音的离散化仅仅用了HuBERT[2],模型只能看到语音的语义特征,这对模型合成语音的音质和表现力有非常大的影响,demo的语音也验证了我的判断;第二,指令微调数据集的构造上有问题。他们用的是ASR数据集,其实更好的选择应该是TTS数据集,可惜高质量的TTS数据集实在是太少了。ASR数据集中的文本和语音可能并不是严格对齐的,GPT-4产生的meta-prompt和语音本身的特征也有可能是对不上的,比如prompt要求大声朗读,但语音本身可能是特定低沉的。meta-prompt本身就无法做到足够复杂丰富,不能描述到语音的一些细粒度信息。

这一部分,最好要有像诸如SALMONN[10]这样的多模态语音理解模型的介入,像DALLE3一样丰富指令的多样性。至于语音方面,可以考虑引入zero-shot的语音合成模型或者变声模型来做合成数据。第三,文中的训练方法也没有与人类偏好做对齐。

图10:SpeechGPT的模型架构

对于上面的第一个问题,作者在其后的SpeechGPT-Gen[12]中做了解决。解决思路的核心点就是:让模型不仅看到语音的语义token,也要看到语音的声学token。具体做法是:SpeechGPT的HuBERT特征替换成了SpeechTokenizer[7]中的语义特征,用SpeechGPT这一LLM来自回归地建模语义特征,有了语义特征之后,再使用Flow-Matching这样的扩散模型来建模声学特征。这里选用Flow-Matching扩散模型,可能是受了SD3和Voicebox/Audiobox的影响。为了增强两阶段建模的依赖关系,作者将语义特征的先验信息注入到第二阶段扩散模型的先验分布中。可以看到,这里语音的解码其实也是一种层次化渐进式解码。

图11:SpeechGPT-Gen的模型架构

SpeechAlign[13]做的则是SLM与人类偏好的对齐,彻底地向LLM的训练方法看齐。该工作构建了对比gold token和合成token的encodec数据集,然后进行偏好优化来进行改进。使用的偏好优化方法包括RLHF和Chain of Hindsight。

图12:SpeechAlign的流程图

简单总结一下上面这些工作中值得关注的点:

  1. 要想让LLM输出上下文连贯的高表现力语音,必须要让LLM看到语义token和声学token,只有语义token,那语音就会显得呆板机械,只有声学token,那语音就不知所云;
  2. LLM的指令微调同样可以迁移到语音-文本多模态领域中,LLM的指令微调同样可以带来如同NLP一样的涌现性;
  3. 高质量指令微调数据集的构建应该是最大的瓶颈!一下子让LLM同时做语音理解和语音生成,难度非常大。不如分步进行。
  4. 如果要分步进行的话,要先实现一个类似于SALMONN[10]那样的多模态理解模型和一个强大的Zero-shot TTS模型。前者用于给语音数据打上丰富的标签,可以是情感情绪、韵律、音高、语速,也可以是口音、意图和说话环境;后者则用于生成高质量的语音数据。毕竟,高质量的、文本和语音严格对齐的TTS数据实在是太少了,尤其是中文领域。有了这两个模型的加持,我们其实就能够构造出高质量的指令微调数据集。我不知道OpenAI是否有SALMONN这样的模型,但OpenAI的OpenVoice模型应该足够为其提供高质量的语音数据了。

既然我们在上面的篇幅中论述了语音理解多模态模型的构建,那我们在下一部分就重点关注zero-shot TTS模型,它对高质量指令微调数据集的构建同样至关重要。同时,LLM解码语音的方法也能从zero-shot TTS方案中得到不少的启发。

LLM如何合成语音:Zero-shot TTS

前面说到,SLM词表中包含了语音的语义token和声学token。语义token保证生成语音与对话上下文的连贯性,声学token保证了合成语音的质量和表现力。要想做到合成上下文连贯的高自然度语音,有两个问题必须要解决:

  1. 语音既有语义token,又有声学token,应该要如何解码成语音?
  2. SLM在合成语音的过程中是否能够遵循多轮对话中的文本指令和语音指令?这个很重要!这允许模型根据用户的即时要求来生成语音回复。比如说,OpenAI演示视频中出现的:“将语速提高两倍”、“采用更加机械化的语气”这样的要求。

对于第一个问题,以VALL-E[16]为代表的诸多zero-shot TTS模型给出了不同的解决方案,这些方案虽有不同,但也有不可忽视的共同点;对于第二个问题,以VoiceLDM[18]和ParlerTTS[19]为代表的text/prompt-guided zero-shot TTS工作给出了肯定的答案。简单解释一下text/prompt-guided zero-shot TTS是怎么回事,通常的语音合成就是将文本(transcription)转换成声音,该任务在transcription之外,又增加了description的输入,来描述合成语音的情感情绪、口音、语气、语速、音高、说话环境、氛围等等信息。我们逐个来看这些工作。

Zero-shot TTS

2023年以来,学术界和工业界出了不少具备in-context learning(zero-shot/few-shot)能力的TTS模型。这些TTS模型通常会将低信息密度、长序列的连续语音数据压缩为高信息密度的tokens或者latents(其实就是码本中具体的token embedding)。这些模型本质上做的事情就是:如何高效实现语音tokens/latents到音频波形的映射。

这些模型给出的解决方案基本上都遵循一个准则:语义token和声学token层次化解码,先语义后声学,或者先解码成MEL再后接声码器,并且非必要不做自回归(毕竟自回归上线虽高,但太吃数据了)!我们一个个来看。

基于声学token或语义token的工作

先是微软的VALL-E[16]。这是zero-shot TTS的开山之作,首次在TTS任务上采用了上万小时的数据。它采用Encodec将语音转换为离散的token,然后用GPT在token上做语言模型的任务。但是,语音毕竟不是文本,如果直接在语音的所有特征上都做自回归的话,那训练的成本会相当高。考虑到Encodec RVQ特征的层次性,低层特征表示语义内容这样的重要特征,高层特征则表征声学细节。前者具有比较强的上下文依赖关系,适合用自回归来建模,后者诸如音色这样的特征,具有全局性,用非自回归特征也可以搞定,所以就有了VALLE中自回归+非自回归的层次建模方式。

图13:VALL-E的模型架构

尽管VALL-E[16]在用GPT建模token的上下文关系的时候,基于token的层次化特性做了分治处理,可能是限于当前语音数据集的规模(几万小时可能不够),这种GPT自回归的难度还是相当大的,解码过程存在常见的错误传播现象,鲁棒性非常差,极其不稳定。根据Ilya Sutskever此前对于自回归的论述,GPT自回归相比于BERT这种双向结构是非常data-hungry的,万小时的数据可能不够。根据本人以及一些同行的经验,VALL-E模型这一类的自回归模型,也包括tortoise-tts[20]和xtts v2,要想显出威力,至少要有十几万小时的数据才行。

既然GPT自回归的难度这么大,就有不少人想方设法地来降低GPT学习的难度了。他们的解决方案也非常类似:给GPT提供额外的条件信息不就行了。比较典型的工作就是微软的RALL-E[21]和吉利的HAM-TTS[22]。RALL-E先生成了时长信息和音高信息,作为GPT自回归的先验,之所以会补充时长和音高,这大概是受到FastSpeech2[23]这样的非自回归模型的启发,这两个指标的引入,有助于提升合成的鲁棒性;HAM-TTS则是补充了基于HuBERT的语义信息。值得注意地是,HAM-TTS将模型的训练数据扩充到了65万小时,其中有50万小时的数据是合成数据。合成数据也能大幅度提升合成语音的音质。

图14:RALL-E的模型架构,框出来的就是辅助信息

图15:HAM-TTS的模型架构

说到VALL-E的后续改进,VoiceCraft不得不提。我愿意称之为“优雅的VALL-E”。它的优雅主要体现在两个方面:casual masking和delayed stacking。所谓的causal masking,是为了用自回归GPT架构来做语音编辑任务,就是把被mask的部分移动到序列末尾去预测,一套架构同时做合成和编辑任务;所谓的delay stacking,是为了适配自回归和RVQ,通过delay错位让当前码本的token预测正好可以利用前面那些token的预测结果,比起VALL-E那样自回归和非自回归缝合在一起的结构要优雅不少。

图16:VoiceCraft的建模流程

基于声学/语义latents的工作

我们通常所说的语音token是离散的。如果使用对应码本中的embedding来表示语音的话,它也可以是连续的低维度的latent变量。既然是低维度的连续latent变量,那图像合成领域中大火的LDM(latent diffusion model,其实就是stable diffsion 1&2采用的模型)模型[]自然也可以用到语音的合成上。这方面的经典工作有很多,比如说:NaturalSpeech 2&3[25, 26]、AudioLDM 2[27]、VoiceLDM[18]。但这里面只有NaturalSpeech2用到了语音离散化部分提及的声学/语义token,NaturalSpeech3的属性分解形式的VQ更像是另一种形式的RVQ。我们先来看NaturalSpeech 2&3,其他的工作后面再来看。

首先是NaturalSpeech 2[26],它基本上就是VALL-E的连续版本。它用的latent也是来自Encodec,对其中不同层次的latent做了求和,然后将其作为扩散模型的训练目标。值得注意地是,扩散模型和FastSpeech2一样也用了时长和音高作为合成的先验条件。这一点也被后来的RALL-E采用。该工作中的扩散模型采用WaveNet实现,同时预测不加噪的latent和后验均值,和图像合成领域的扩散模型在实现方式上还是有所不同的。

图17:NaturalSpeech2的模型架构

然后是NaturalSpeech 3[26],还是非自回归的,而且非自回归的正统性味道更加浓厚,借用了不少FastSpeech2和megatts1&2(后面会讲)[27, 28]的设计思想。像megatts 1&2一样,同样采用(自)监督信号对语音token编码的内容做了限制,而不再像是VALL-E/NaturalSpeech2那样一把抓。相应地,语音token化的方法也用VQ就行。具体而言,文章将语音信号分解为时长、内容、韵律和细节四个部分,然后每个部分用离散化的扩散模型来建模。不过,原文使用GRL来促进语音属性的分解,这一点的靠谱程度存疑。我也尝试过文章的FACodec,但效果很差。三级扩散模型级联的结构,预测起来似乎也非常麻烦。

图18:NaturalSpeech3的模型架构

基于MEL谱+VQ的TOKEN的工作

当然,也有不少工作用了MEL谱作为中间特征,然后在梅尔谱的基础上,或是用VQ提供离散token,或是用CNN来提取连续latent。对于MEL+VQ的工作,有tortoise-tts[20]、xtts 1&2、megatts1&2[28, 29]、base TTS[30]。对于MEL+latents的工作,有:AudioLDM 1&2[27]、StyleTTS 1&2[31, 32]。我们来简单看看是它们是怎么做的。

Tortoise-tts[20]。该工作是著名的开源英文TTS模型。其作者目前在OpenAI就职,同时也是GPT-4o的重要Contributor(他自个儿在博客中说的)。Tortoise-tts使用MEL+VQVAE的方法得到语音的MEL token,然后对MEL token以及text token做GPT自回归建模。对于语音的解码,自然也是分为两步:先是用扩散模型将MEL token转换为MEL谱,这一步和文生图很像,用扩散模型是很自然的选择;然后用声码器将MEL谱转换为音频波形。tortoise-tts和VALL-E的主体都是自回归建模,二者的不同主要在于token的不同。

图19:tortoise-tts的模型架构

MegaTTS 1&2[28, 29]。字节跳动的MegaTTS系列对语音token编码信息做了显式的信息压缩处理,让语音token仅编码上下文依赖强的韵律信息,然后用GPT自回归来建模语音的韵律。对于其他方面的信息,模型的处理显得较为常规:音色一般具有全局性,使用单一的音色编码器从参考音频中提取就性;对于文本语义内容的处理,模型在很大程度上参考了非自回归的FastSpeech 2。

对于语音的解码,也是分为两步:先通过MEL decoder还原为MEL谱,然后通过声码器解码为音频波形。MegaTTS 2和1总体上类似,在音色编码(音素级编码、多条参考音频)、语音提示长度(扩展同speaker语音上下文长度硬train,音频prompt长度更长)和时长建模(也用GPT自回归)上做了改进,同时堆了更大规模的数据。剪映的后端TTS模型用的就是megatts2。该工作在各论文的评测中表现也都不错。

图20:megatts1的模型架构

基于MEL谱+VAE的latents的工作

AudioLDM 1&2[27]。AudioLDM 1&2使用的语音latents是一致的,均通过MEL+VAE获得。既然是连续的latents,使用扩散模型来建模也合情合理。解码过程也相当简单:VAE decoder获得梅尔谱,然后用声码器转换为音频波形。该系列工作的核心创新点是利用多模态模型统一了扩散模型条件输入侧的信息:AudioLDM 1用CLAP统一了文本模态和音频模态,用单模态的音频数据就能完成模型的训练;AudioLDM 2则包含了图像、文本、转录文本等更多模态,模型泛用性也更强,既能做语音合成,也能做音乐生成、音频事件生成。

图21:AudioLDM 1的模型架构

图22:AudioLDM2的模型架构

StyleTTS 1&2[31, 32]。StyleTTS系列的模型一众zero-shot TTS模型显得比较老派,整体结构基本上沿袭了非自回归的FastSpeech 2,不同之处在于增加了基于参考音频抽取的风格信息。说是风格,其实跟megatts的音色很像。StyleTTS 2的工作则将风格进一步拆分成声学风格和韵律风格。训练时的风格信息由音频提供,推断时的风格信息则由扩散模型提供。StyleTTS 2通过一个扩散模型桥接了文本韵律和语音风格之间的联系,摆脱推断时对参考音频的依赖。不用参考音频其实对产品的意义还挺大的,要都用现实世界中真人尤其是名人的声音作为参考音频,那这势必会引起版权纠纷。这种纠纷在国内国外都有相关的事件。最近寡姐投诉OpenAI的事件就是一例。

图23:StyleTTS 1的模型架构

图24:StyleTTS 2的模型架构

TTS对指令的遵循

SLM不仅要合成合乎上下文语义的高表现力语音,合成的语音还要符合用户的即时要求。一些text-guided zero-shot TTS的工作值得参考。这些工作一般都是在已有的zero-shot TTS模型或者text-to-audio模型上改造而来,同时吸收transcription和description两路条件。其中的重点还是在于数据集的构建。这方面的工作有:PromptTTS[33]、InstructTTS[34]、ParlerTTS[19]、VoiceLDM[18]和Audiobox[35]。我们主要谈谈ParlerTTS和VoiceLDM。

ParlerTTS[19]。VALL-E/VoiceCraft的增强版,通过T5编码器和cross-attention旁路引入了描述性文本的信息。该工作的目的是想使用自然语言prompt来指定说话风格和环境信息,摆脱对参考音频的依赖。描述性标签文本的收集过程也显得相当朴素:通过定制化的监督式模型获取语音数据的口音特征、录音质量特征、音高语速特征。然后用LLM将这些特征转换为自然语言的描述。在我看来,这个工作有这么几点局限性吧:其一,缺乏情绪标签;其二,语音描述性标签的收集并不具备通用性,较为繁琐,远不如一个强大的多模态语音理解模型来得实在。文章demo虽然达到了预期的效果,但场景似乎局限在朗读的情景中。

图25:ParlerTTS的模型架构

VoiceLDM[18]。在VoiceLDM1的基础上增加了转录文本的输入。这个工作和AudioLDM 1很像,同样使用CLAP注入语音的描述性信息。不同地是,为了做TTS任务,该工作通过cross-attention旁路增加了transcription的信息。

图26:VoiceLDM的模型架构

TTS总结

林林总总说了这么多zero-shot的TTS方法,我想说明的结论有这么几点:

  1. 在LLM大行其道、scaling law大显神威的时代,TTS模型的训练数据规模已经突破了万小时,甚至达到了数十万小时的级别。在大数据的加持下,TTS任务上也涌现出了in-context learning能力。
  2. 语音信息的解码通常都要层次化或者多步进行,不能一步到位。自回归、扩散模型和流匹配都能在TTS中发挥作用;
  3. 借鉴NLP instruction fine-tuning和文生图的经验,TTS模型同样可以遵循文本指令或者语音指令,合成符合用户即时要求的语音,摆脱对参考音频的依赖,这或许也能规避一些知识产权的困扰(比如最近有名的寡姐投诉OpenAI事件)。同时,用户也能在对话过程中随时切换语音回复的风格,这一点在OpenAI的demo中有很明确的体现。另外,不知道大家有没有注意,GPT-4o合成的语音是可以是放映所处的声学环境的:有一段语音背后似乎是有钢琴声的。
  4. text-guided zero-shot TTS在模型架构上和zero-shot TTS有非常大的相似性。但训练数据可能较为缺乏。先开发zero-shot TTS,再用类似SALMONN那样的多模态理解模型来打标签(类似DALLE3的做法),这样数据集构造方式,可能会是更好的选择。

另外,对于语音的解码方案,我倾向于是这样的:

  1. 如果要做流式推理,外接类似HIFIGAN这样的声码器的方式可能不是好的选择。HIFIGAN并不天然支持流式解码。相反地,诸如SoundStream和Encodec这样的方法,同时有流式变体和非流式变体;
  2. 先做语义token的解码,这个解码大概率是自回归解码。语义token毕竟是建模上下文依赖关系,自回归方法已经在NLP上证明了这一点;
  3. 然后做声学token的解码,扩散或者flow-matching可能是更好的选择。扩散模型或者流匹配可以很好地修补语音的细节;

当然,除了上面讲到的,zero-shot TTS还有很多值得研究的方法。限于篇幅,仅列举于此,不再详述:HierSpeech++[36]、base TTS[30]、Voicebox/Audiobox[35]、UniAudio[37]、Make-a-Voice[38]等等。

其他问题

对于GPT-4o模型,如果仅仅聚焦于语音多模态,还有下面的问题值得关注:

  1. 语音交互如何做到低延迟?大概率要求流式切片处理,主要工作在于工程优化,用C++重写算子。推理框架的话,用tensorrt、mnn这些都行。上下文所述的音频离散化方法,诸如SoundStream和Encodec,其实也支持流式处理。
  2. 语音对话中的打断如何实现?个人认为有两种可能的方案:turn-based和流式处理。所谓的turn-based方案,是比较工程化的,简答概括一下就是:检测是否有停顿,如果一段时间内没有声音,模型就开始返回语音回复。另一种流式方案,则是:模型一直在接受用户的流式语音输入,判断是否应该输出语音回复,一个充分训练的模型应该是能够准确预测出语音词表中的[START]和[END]的。

对游戏配音业务的思考

text/prompt-guided zero-shot TTS方法对游戏的AI配音意义重大。主要体现在:

  1. 用自然语言提示去合成音色稳定的语音,摆脱对参考音频的依赖,在业务中能够更加灵活,至少比克隆已有人物/角色的语音的方式更加方便,更不容易出戏。举个例子,在开放世界剧情类游戏的研发阶段,我们会设定一些profile赋予NPC,让玩家跟NPC聊天。我们曾经用克隆《原神》、《崩坏:星穹铁道》已有角色的方式赋予这些NPC角色语音,但放在那些欧美背景的NPC中,就是很有违和感,没有现实世界中的accent,不够decent。
  2. 剧情任务中的配音会更加真人化、更有沉浸感。过年期间过《崩坏:星穹铁道》花火和黑天鹅的同行任务的时候,部分NPC角色会有六公主的翻译腔,这是花火行于欢愉命途的恶趣味,空气中顿时充满了快活的味道。如果走bv2、gsv的语音克隆方案,应该是很难有这种效果的。而且,玩家在剧情任务中势必会经过不同的地势地貌,至少室内、室外的声音听起来是有不同的。室内的声音至少会有回响、混响的吧。这种感觉语音克隆方案也是无法做到的。

全文总结

总结一下本文说谈的内容,我认为GPT-4o语音多模态的实现可能是走了以下的技术路线:

  1. audio & text tokenizer的实现应该是语音离散化部分所用的技术,例如SoundStream、Encodec、SpeechTokenizer,或者是MEL+VQ最后配合声码器来解码;参考zero-shot TTS、AudioLM/AudioPaLM、SpeechGPT-Gen等工作的结果,LLM中语音token的解码应该是要走层次化或者多步的方法,先解码语义特征,再解码声学特征,或者是先解码MEL,再加一个HIFIGAN这样的声码器。另外,如果做audio/speech/music这样的通用声合成的话,可能也能通过prompt来控制。AudioLDM2虽然做了这方面的工作,但audio/music和speech的参数其实是不一样的,说到底还不是同一个模型。
  2. 对于指令微调,数据集的构造非常重要,大概率要用到合成数据。其一,网络上高质量语音数据的量级远远不及文本,直接拿ASR数据来做肯定会影响模型合成语音的音质;其二,大语言模型合成的instruction往往触及不到语音的细粒度特征,这样的instruction其实无法准确详尽地描述text和speech之间的关系。因而,需要引入强大的zero-shot TTS模型合成高质量语音,然后用多模态语音理解模型来为合成语音打标签,当然也可以评分做筛选什么的。
  3. 最后是要让大模型的输出对齐人类的偏好。这方面的方法有很多,有DPO、PPO什么的,都可以用。

图27:全文总结,可能的roadmap

参考文献

[1] Baevski A, Zhou Y, Mohamed A, et al. wav2vec 2.0: A framework for self-supervised learning of speech representations[J]. Advances in neural information processing systems, 2020, 33: 12449-12460.

[2] Hsu W N, Bolte B, Tsai Y H H, et al. Hubert: Self-supervised speech representation learning by masked prediction of hidden units[J]. IEEE/ACM Transactions on Audio, Speech, and Language Processing, 2021, 29: 3451-3460.

[3] Chung Y A, Zhang Y, Han W, et al. W2v-bert: Combining contrastive learning and masked language modeling for self-supervised speech pre-training[C]//2021 IEEE Automatic Speech Recognition and Understanding Workshop (ASRU). IEEE, 2021: 244-250.

[4] Van Den Oord A, Vinyals O. Neural discrete representation learning[J]. Advances in neural information processing systems, 2017, 30.

[5] Zeghidour N, Luebs A, Omran A, et al. Soundstream: An end-to-end neural audio codec[J]. IEEE/ACM Transactions on Audio, Speech, and Language Processing, 2021, 30: 495-507.

[6] Défossez A, Copet J, Synnaeve G, et al. High fidelity neural audio compression[J]. arXiv preprint arXiv:2210.13438, 2022.

[7] Zhang X, Zhang D, Li S, et al. Speechtokenizer: Unified speech tokenizer for speech large language models[J]. arXiv preprint arXiv:2308.16692, 2023.

[8] Borsos Z, Marinier R, Vincent D, et al. Audiolm: a language modeling approach to audio generation[J]. IEEE/ACM Transactions on Audio, Speech, and Language Processing, 2023.

[9] Rubenstein P K, Asawaroengchai C, Nguyen D D, et al. Audiopalm: A large language model that can speak and listen[J]. arXiv preprint arXiv:2306.12925, 2023.

[10] Changli Tang, Wenyi Yu, Guangzhi Sun, Xianzhao Chen, Tian Tan, Wei Li, Lu Lu, Zejun Ma, Chao Zhang. SALMONN: Towards Generic Hearing Abilities for Large Language Models

[11] Zhang D, Li S, Zhang X, et al. Speechgpt: Empowering large language models with intrinsic cross-modal conversational abilities[J]. arXiv preprint arXiv:2305.11000, 2023.

[12] Zhang D, Zhang X, Zhan J, et al. SpeechGPT-Gen: Scaling Chain-of-Information Speech Generation[J]. arXiv preprint arXiv:2401.13527, 2024.

[13] Zhang D, Li Z, Li S, et al. SpeechAlign: Aligning Speech Generation to Human Preferences[J]. arXiv preprint arXiv:2404.05600, 2024.

[14] Chen Q, Chu Y, Gao Z, et al. Lauragpt: Listen, attend, understand, and regenerate audio with gpt[J]. arXiv preprint arXiv:2310.04673, 2023.

[15] Wu S, Fei H, Qu L, et al. Next-gpt: Any-to-any multimodal llm[J]. arXiv preprint arXiv:2309.05519, 2023.

[16] Wang C, Chen S, Wu Y, et al. Neural codec language models are zero-shot text to speech synthesizers[J]. arXiv preprint arXiv:2301.02111, 2023.

[17] Anil R, Dai A M, Firat O, et al. Palm 2 technical report[J]. arXiv preprint arXiv:2305.10403, 2023.

[18] Lee Y, Yeon I, Nam J, et al. VoiceLDM: Text-to-Speech with Environmental Context[C]//ICASSP 2024-2024 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP). IEEE, 2024: 12566-12571.

[19] Lyth D, King S. Natural language guidance of high-fidelity text-to-speech with synthetic annotations[J]. arXiv preprint arXiv:2402.01912, 2024.

[20] Betker J. Better speech synthesis through scaling[J]. arXiv preprint arXiv:2305.07243, 2023.

[21] Xin D, Tan X, Shen K, et al. RALL-E: Robust Codec Language Modeling with Chain-of-Thought Prompting for Text-to-Speech Synthesis[J]. arXiv preprint arXiv:2404.03204, 2024.

[22] Wang C, Zeng C, Zhang B, et al. HAM-TTS: Hierarchical Acoustic Modeling for Token-Based Zero-Shot Text-to-Speech with Model and Data Scaling[J]. arXiv preprint arXiv:2403.05989, 2024.

[23] Ren Y, Hu C, Tan X, et al. Fastspeech 2: Fast and high-quality end-to-end text to speech[J]. arXiv preprint arXiv:2006.04558, 2020.

[24] Rombach R, Blattmann A, Lorenz D, et al. High-resolution image synthesis with latent diffusion models[C]//Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2022: 10684-10695.

[25] Shen K, Ju Z, Tan X, et al. Naturalspeech 2: Latent diffusion models are natural and zero-shot speech and singing synthesizers[J]. arXiv preprint arXiv:2304.09116, 2023.

[26] Ju Z, Wang Y, Shen K, et al. NaturalSpeech 3: Zero-shot speech synthesis with factorized codec and diffusion models[J]. arXiv preprint arXiv:2403.03100, 2024.

[27] Liu H, Tian Q, Yuan Y, et al. AudioLDM 2: Learning holistic audio generation with self-supervised pretraining[J]. arXiv preprint arXiv:2308.05734, 2023.

[28] Jiang Z, Ren Y, Ye Z, et al. Mega-tts: Zero-shot text-to-speech at scale with intrinsic inductive bias[J]. arXiv preprint arXiv:2306.03509, 2023.

[29] Jiang Z, Liu J, Ren Y, et al. Mega-tts 2: Zero-shot text-to-speech with arbitrary length speech prompts[J]. arXiv preprint arXiv:2307.07218, 2023.

[30] Łajszczak M, Cámbara G, Li Y, et al. BASE TTS: Lessons from building a billion-parameter text-to-speech model on 100K hours of data[J]. arXiv preprint arXiv:2402.08093, 2024.

[31] Li Y A, Han C, Mesgarani N. Styletts: A style-based generative model for natural and diverse text-to-speech synthesis[J]. arXiv preprint arXiv:2205.15439, 2022.

[32] Li Y A, Han C, Raghavan V, et al. Styletts 2: Towards human-level text-to-speech through style diffusion and adversarial training with large speech language models[J]. Advances in Neural Information Processing Systems, 2024, 36.

[33] Guo Z, Leng Y, Wu Y, et al. Prompttts: Controllable text-to-speech with text descriptions[C]//ICASSP 2023-2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP). IEEE, 2023: 1-5.

[34] Yang D, Liu S, Huang R, et al. Instructtts: Modelling expressive TTS in discrete latent space with natural language style prompt[J]. arXiv preprint arXiv:2301.13662, 2023.

[35] Vyas A, Shi B, Le M, et al. Audiobox: Unified audio generation with natural language prompts[J]. arXiv preprint arXiv:2312.15821, 2023.

[36] Lee S H, Choi H Y, Kim S B, et al. HierSpeech++: Bridging the Gap between Semantic and Acoustic Representation of Speech by Hierarchical Variational Inference for Zero-shot Speech Synthesis[J]. arXiv preprint arXiv:2311.12454, 2023.

[37] Yang D, Tian J, Tan X, et al. Uniaudio: An audio foundation model toward universal audio generation[J]. arXiv preprint arXiv:2310.00704, 2023.

[38] Huang R, Zhang C, Wang Y, et al. Make-a-voice: Unified voice synthesis with discrete representation[J]. arXiv preprint arXiv:2305.19269, 2023.