DPO为什么会让大语言模型输出变长

摘自:https://zhuanlan.zhihu.com/p/5830338806

总的来说,DPO让模型输出变长主要可以分为以下几个原因:

  1. RM和模型评测的长度偏好。不管是Reward Model还是当前用与评测的模型(即便是GPT4)都会存在比较明显的长度偏好,即倾向于给更长的回答一个更高的分数。这一点已经有非常多工作给出过分析了。
  2. 训练数据本身长度分布不均衡。实战过程中往往就是用RM进行排序构造训练数据,RM的长度偏好就是会导致训练数据中容易出现chosen比rejected更长的情况。训练数据的长度差异(chosen比rejected长)就会导致训练后模型输出变长。
  3. 数据长度差异导致的reward被高估或低估。《Eliminating Biased Length Reliance of Direct Preference Optimization via Down-Sampled KL Divergence》中发现,DPO的算法本身也存在对response长度的依赖,chosen和rejected之间的长度差异可能会导致reward被高估/低估(overestimated or underestimated rewards)。即,当chosen过短时,reward会被低估,而当chosen过长时,reward会被高估
  4. DPO算法本身的长度敏感性。《Length Desensitization in Direct Preference Optimization》中提到,response长度会影响到似然概率的大小,并且进一步影响到训练优化方向:当chosen更长时,DPO会往chosen的方向进行优化(增大chosen概率),从而使输出变长;而rejected更长时,DPO会往远离rejected的方向优化(降低rejected概率),但却未必会让输出变短。

如何解决:

  1. RM的优化:前面讲的都是对DPO进行长度控制的工作,但对RM本身的长度偏好进行优化的工作没有看到太多,如果大家有看到相关的也可以在评论区提供一下。如果将RM本身的长度偏好问题解决的话,那就可以极大程度上解决训练数据的长度分布均衡问题了。
  2. 数据的优化:有些工作会在数据构造时对长度进行综合考虑,如对RM打分进行长度归一后再排序、采样多个答案进行排序时根据均值方差限制chosen的长度等,通过这些方式可以减少长度差距过大的情况。如果数据本身的长度分布均衡了,也能一定程度上减缓这种问题。
  3. 训练算法上的优化:如果从LD-DPO的分析上看,即便数据分布比较均衡,只要存在长度差异,DPO本身的长度敏感性就是会导致模型输出变长,因此可能还是需要一些算法层面的优化,比如在DPO阶段加入SFTloss就是一种简单有效的方法,在很多公开的大模型技术报告中也都有用到该方法。另外R-DPO、SamPO和LD-DPO的长度控制效果都算是比较好的方法。

DPO面临的一个问题(准确来讲是一种现象)就是会让大模型的输出变长,且多轮DPO的话会让模型输出越来越长。本篇文章我们将结合搜集到的一些相关工作,探讨一下业界对该现象的一些分析,探究这一现象产生的根本原因,以及如何有效地解决。

首先我们需要思考一个问题,模型输出变长到底是不是一件坏事?一般来说,输出变长可能会使内容更加详细,信息量更丰富,回复质量更高,用户体验更好。但如果过度长,输出了很多冗余信息,回复质量没有明显改善,反而带来了推理成本的增加,回复变得啰嗦,用户体验反而变差了。

因此,无论是从用户体验的角度还是多轮DPO能否run下去的角度,做好长度控制都是有必要的。

相关工作

先简要介绍一些相关工作,然后后面详细总结。

1.《Disentangling Length from Quality in Direct Preference Optimization》(简称R-DPO)

在这之前的一些RL的工作也有分析过长度爆炸问题,但该文章可能是第一个提出DPO的长度爆炸问题的。

文章中发现,无论是RL训练中使用的Reward Model还是用来评测模型效果的打分模型(如GPT-4)都表现出明显的长度偏好,即会给更长的答案一个更高的分数(如下图)。且在一些公开的DPO训练数据集中,chosen的长度往往会比rejected更长,而这可能就是DPO后的模型输出长度明显比SFT模型更长的原因

为了解决这个问题,该文章提出了一种长度正则化的策略,即在计算loss的时候添加一个长度正则项,从而避免模型对长度的过度拟合,公式如下:

其中 |yw| 表示chosen的长度, |yl| 表示rejected的长度,从公式中可以看出,当chosen与rejected的长度差距越大,正则项的值越大,从而实现对长度的“惩罚”效果。

从文章中的实验结果可以看出,该方法确实可以在尽可能减少性能损失的前提下有效解决长度增长问题。(有时还是会损失一定的性能。)

2.《SimPO: Simple Preference Optimization with a Reference-Free Reward》(简称SimPO)

陈丹琦团队的工作,直接去掉了reference model,用长度归一的方式实现长度控制。其loss如下:

文章中提到了很多输出长度相关的内容,但核心贡献并不是做长度控制,而是用一种更简单高效的方法实现偏好训练。从公式上看,和原始DPOloss相比主要有两处不同,一个是分母从reference model的logp替换成了长度,另外就是增加了一个 γ ,类似一个offset的作用。不过其中对chosen和rejected的reward做长度归一的部分,直觉上看起来应该是能起到一定的长度控制效果的。

不过从论文中的实验结果看,该方法的效果还是比较好的(当时声称训出最强8B大模型),但与标准DPO相比似乎并没有实现长度控制的效果。

3.《Eliminating Biased Length Reliance of Direct Preference Optimization via Down-Sampled KL Divergence》(简称SamPO

这篇论文对DPO后长度变长的问题进行了一定的分析,提出的一个核心观点是:DPO的算法本身也存在对response长度的依赖,chosen和rejected之间的长度差异可能会导致reward被高估/低估(overestimated or underestimated rewards)。即,当chosen过短时,reward会被低估,而当chosen过长时,reward会被高估。

这篇工作中提出的一种方式就是在token级别下采样的概率特征,以计算正则化的KL散度,从而减少因pair长度不同而导致的奖励偏差。其loss的计算如下:

从公式可以看出,该方法的核心就是在计算reward的时候不再是全部token的条件概率的累乘(取log后就是累加),而是随机采样公共数量的token进行累乘。这样即便chosen和rejected长度不同,参与reward计算的token数是一样的。也就是说,在SamPO训练过程中,魔都看到的chosen和rejected相当于是完全等长的。

从文章中的实验结果看,该方法确实能有效控制模型输出长度的增长,甚至在多轮DPO依然能有效控制长度。但是在性能上看依然做不到碾压标准DPO的效果。

但该方法有两个风险便是:

  1. 本身DPO就存在一定的波动,随机下采样可能会导致训练的稳定性不强;
  2. 随机采样必然会导致一些信息缺失,如果采样时舍弃掉了一些非常重要的token可能会影响到训练效果。

4.《Length Desensitization in Direct Preference Optimization》(简称LD-DPO)

该论文可能是第一个从理论层面分析DPO后模型输出变长的原因的,其核心分析主要包括两方面:

  1. DPO的梯度优化方向和chosen/rejected的似然概率成反比。
  2. Response长度对似然概率的影响极大,因此长度会直接影响reward的计算,并影响到DPO的优化方向。

上图是一个对训练数据的统计热力图,图中,横坐标为chosen的长度,纵坐标为rejected的长度,颜色深度表示 log⁡πθ(yl|x)−log⁡πθ(yw|x) 值的大小。第一张图(a)是标准DPO,可以看出长度差距越大时,颜色越深,也就说明长度差距可能会导致reward计算产生bias,且长度差距越大这种bias越大。而这种bias会进一步影响到DPO的优化方向,使其往输出更长的方向进行优化。

该文章提出的解决方案是在计算似然概率时对长度进行解耦,将更长的答案拆成“公共长度部分”和“额外部分”,并进一步将后者拆分为真实偏好和冗余偏好,并对其中的冗余部分进行降权操作,通过一系列推导后将 πθ(y|x) 转化为如下的形式(可近似理解为完整似然概率部分与公共长度部分似然概率的加权和):

从公式上看,这种方式可以让长度更长的那个response(不管是chosen还是rejected)实现一定的缩放(),减少长度带来的似然概率的断崖式下滑,使其与另一个短response(不受影响)之间更具可比性,同时又不会像SamPO那样完全舍弃掉额外部分的信息。

从论文中的实验结果看,这种方法能够实现比较好的长度控制,且模型性能还能有一定提升,并且可以通过调整参数 α 可以实现不同程度的控制效果。另外文章还提出一个比较有意思的发现,就是过度冗余的回答可能反而会损害模型的推理能力,他们通过这种方法控制长度后,模型的推理能力也有明显提升。

其他工作

除此之外,还有一些工作直接在数据上做文章,通过控制chosen和rejected的长度差距来实现长度控制,如《Following Length Constraints in Instructions》(简称LIFT-DPO)。以及在一些开源模型的技术报告中我们也能看到一些相关的长度控制方法,如在利用RM打分排序时就综合考虑长度问题等,这些数据工作就不再详细展开了。

如何实现有效的长度控制?

  1. RM的优化:前面讲的都是对DPO进行长度控制的工作,但对RM本身的长度偏好进行优化的工作没有看到太多,如果大家有看到相关的也可以在评论区提供一下。如果将RM本身的长度偏好问题解决的话,那就可以极大程度上解决训练数据的长度分布均衡问题了。
  2. 数据的优化:有些工作会在数据构造时对长度进行综合考虑,如对RM打分进行长度归一后再排序、采样多个答案进行排序时根据均值方差限制chosen的长度等,通过这些方式可以减少长度差距过大的情况。如果数据本身的长度分布均衡了,也能一定程度上减缓这种问题。
  3. 训练算法上的优化:如果从LD-DPO的分析上看,即便数据分布比较均衡,只要存在长度差异,DPO本身的长度敏感性就是会导致模型输出变长,因此可能还是需要一些算法层面的优化,比如在DPO阶段加入SFTloss就是一种简单有效的方法,在很多公开的大模型技术报告中也都有用到该方法。另外R-DPO、SamPO和LD-DPO的长度控制效果都算是比较好的方法。

最后结合我自己的一些尝试来直接对比一下上面的四种方法:

  1. R-DPO是通过加正则项的方式实现长度控制,说是正则项,但其实只是一个常数,其原理相当于是对每条数据加上一个权重(文章中也提到了这点),即当chosen和rejected长度差距大时降低该数据的权重。也就是说,该方法其实是让模型减少对长度差距大的数据的学习权重。这种方法确实可以实现一定的长度控制效果,但必然会减少一些数据的利用率,这可能也是训练效果会有一定损失的原因。我自己尝试了一下该方案,实验下来确实可以做到长度控制效果,但大部分情况下性能都会比标准DPO差一些。
  2. SimPO是用长度归一来替换Reference Model的KL约束,理论上和长度控制其实没有太大关系,更多的是简化训练和提升性能。实验结果确实也体现了并不会比标准DPO更短。(该方法热度很高,但网络上褒贬不一,很多人表示无法复现结果。)根据我自己实验经验来看,跑出好的结果需要仔细调参,论文推荐的超参不一定适合所有情况。
  3. SamPO是直接用下采样的方式,强行将模型视角下的长答案变得和短答案一样长,该方法给人的直观感受就是长度控制效果肯定很好,但是很可能会有性能损失。但我自己实验下来,长度控制效果和R-DPO差不多,但是性能也比较不稳定,更换随机种子就会导致性能产生波动。我也尝试过将随机下采样改为top-k采样,即保留概率最大的top-k个token,但效果并不会比随机更好(这么直觉的方法可能论文作者也尝试过了)。
  4. LD-DPO的方法是只对答案过长的部分做了解耦和降权处理,通过降低过长部分的权重来实现整个条件概率的缩放,看起来是四种方法中实现最优雅的一种,既降低了长度差异带来的reward bias问题,又不会丢弃信息,相当于是用极小的代价实现了概率缩放目的。从论文中贴出的结果看,确实也是性能最强的一个,长度控制效果也是最好的。但论文代码没有开源,所以没有实验验证。但从公式上看复现难度应该不是很大,有能力的可以尝试复现一下看看效果。

transformers 的 generate() 方法实现多样化文本生成:参数含义和算法原理解读

这个类对外提供的方法是 generate(),通过调参能完成以下事情:

  • greedy decoding:当 num_beams=1 而且 do_sample=False 时,调用 greedy_search()方法,每个step生成条件概率最高的词,因此生成单条文本。
  • multinomial sampling:当 num_beams=1 且 do_sample=True 时,调用 sample() 方法,对词表做一个采样,而不是选条件概率最高的词,增加多样性。
  • beam-search decoding:当 num_beams>1 且 do_sample=False 时,调用 beam_search() 方法,做一个 num_beams 的柱搜索,每次都是贪婪选择top N个柱。
  • beam-search multinomial sampling:当 num_beams>1 且 do_sample=True 时,调用 beam_sample() 方法,相当于每次不再是贪婪选择top N个柱,而是加了一些采样。
  • diverse beam-search decoding:当 num_beams>1 且 num_beam_groups>1 时,调用 group_beam_search() 方法。
  • constrained beam-search decoding:当 constraints!=None 或者 force_words_ids!=None,实现可控文本生成。

参数列表

核心代码详见:generate()入口函数定义, GenerationConfig类

1.控制生成长度的参数

参数类型缺省值含义
max_lengthint20表示 prompt + max_new_tokens 累加的最大长度,如果max_new_tokens也设置了,会覆盖这个参数
max_new_tokensint生成部分的tokens的最大长度 (忽略prompt部分的长度)
min_length0表示 prompt + min_new_tokens 累加的最小长度,如果min_new_tokens也设置了,会覆盖这个参数
min_new_tokensint生成部分的tokens的最小长度 (忽略prompt部分的长度)
early_stoppingbool, strFalse对于beam search方法的控制终止的配置。
False: 当有’num_beams’个候选生成,则终止
True: 应用一些启发式规则判断不能找到更好的生成候选,来提前终止生成
“never”: 当判断没有更好的可生成的candidate, beam search 过程终止
max_timefloat执行生成的最大时间(s秒数)
stop_stringsstr, array[str]配置模型生成的终止字符串,当模型生成参数配置的字符串,则终止生成。

2. 控制生成策略的参数

参数类型缺省值含义
do_sampleboolFalseTrue: 生成过程使用采样逻辑
False: 使用greedy做生成
num_beamsint1设置beam search 束的数量。如果是1不做beam search 搜索
num_beam_groupsint1为了保证生成的多样性,将num_beams 设置成多组。参考文献: https://arxiv.org/pdf/1610.02424.pdf
penalty_alphafloatcontrastive search decoding的配置项,用于平衡生成置信度和衰减的惩罚
dola_layersstr, List[int]str :
“None”: 不使用dola
“low” : 较低的一半layers, 最多20层使用dola
“high”: 较高的一半layers, 最多20层使用dola
List[int] : 通过指定一个index数组,指定dola 层
“low”: 提升长答案的task,
“high”:提升短答案的task

3.cache配置参数

参数类型缺省值含义
use_cacheboolTrue是否使用KV cache 加速推理速度
cache_implementationstr指定cache实现的name,在调用generate()时,实例化cache。
”static”: [StaticCache]
“offloaded_static”: [OffloadedStaticCache]
”sliding_window”: [SlidingWindowCache]
“hybrid”: [HybridCache]
“mamba”: [MambaCache]
”quantized”:[QuantizedCache]
cache_configCacheConfig , dictNonecache类使用的参数
return_legacy_cacheboolTrue当DynamicCache 被使用时,是否返回历史的和新格式的cache

4.操作模型输出logit的配置参数

参数类型缺省值含义
temperaturefloat1.0这个值用于建模下一个token的概率, 这个值被设置在generation_config.json文件中
top_kint50筛选最高概率的top k个词, 这个值被设置在generation_config.json文件中
top_pfloat1.0当设置<1时,筛选概率最高的token,累加概率不超过top_p的token
min_pfloat配置筛选概率最低的一批token, 累加概率不超过min_p,裁剪掉,该配置相当于top_p的反向操作
typical_pfloat1.0测量两个分布的相似性: 预测下一个目标token的概率 and 预测下一个随机Token的条件概率期望。如果设置<1,则筛选最典型的token。
epsilon_cutofffloat0.0按设置的值,卡掉低概率值的token,一般设置为:3e-4 to 9e-4
eta_cutofffloat0.0混合局部典型性采样和epsilon采样方法
diversity_penaltyfloat0.0只对group beam search方法生效,如果在某个特定时间生成的token与任何beam 组生成的token一致,则beam的score减去这个值
repetition_penaltyfloat1.01.0 默认不惩罚
encoder_repetition_penaltyfloat1.0对于不在原始输入的token,指数级的惩罚
length_penaltyfloat1.0对于beam 类的生成方法的长度惩罚,由于序列score是 log likelihood , > 0 倾向于更长的 <0 倾向于更短的
no_repeat_ngram_sizeint0如果大于0, 则对应的size的ngram只能出现1次
bad_words_idsList[List[int]]列出不允许生成的tokens_id
force_words_idsList[List[int]] or List[List[List[int]]]必须被生成的words_ids。 如果配置List[List[List[int]]] 设置对于每个token的约束
renormalize_logitsboolFalse对于所有的logits做后处理后,是否要再做下normalize
constraintsList[Constraint]通过定义一个List[Constraint] 对象数组,来确保输出是在某些限制的场景下。一般用于安全的场景
forced_bos_token_idintmodel.config.forced_bos_token_id强制跟在decoder_start_token_id之后的第一个token,对多语言模型是有用的
forced_eos_token_idint or List[int]model.config.forced_eos_token_id当生成的token达到max_length上限时,最后一位输出的token
remove_invalid_valuesboolmodel.config.remove_invalid_values是否移出可能生成的nan and inf 值,配置这个会减慢生成速度
exponential_decay_length_penaltytuple(int, float)指数级增加长度的惩罚,tuple(start_index, decay_factor) start index 指示惩罚的开始i,decay_factor 指数衰减的惩罚因子
suppress_tokensList[int]通过设置禁止的token的logit为-inf,来禁止token被sample
begin_suppress_tokensList[int]通过设置首位禁止的token的logit为-inf,来禁止首位这部分token被采样到,进而导致被生成
forced_decoder_idsList[List[int]]一个整数pair的数组,格式[生成index, token_index]指示固定位置强制生成某个token,例如[[1, 123]] 第二个位置总是生成token 123
sequence_biasDict[Tuple[int], float]token list -> bias的映射,正的bias提升几率,负的bias降低几率
token_healingboolFalse对prompt尾部的token做相似替换,以提升生成质量
guidance_scalefloat是一个缩放因子,当>1时,这个因子越高,越鼓励模型生成与prompt接近的samples 。
watermarking_configBaseWatermarkingConfig or dict对输出结果增加水印

5.输出结果配置参数

参数类型缺省值含义
num_return_sequencesint1对于batch中的每个元素,设置独立计算的返回的sequence的数量
output_attentionsboolFalse是否返回所有的attention的向量
output_hidden_statesboolFalse是否返回所有网络层的隐层状态
output_scoresboolFalse是否返回prediction scores
output_logitsbool是否返回未处理过的的logit score
return_dict_in_generateboolFalse除了返回生成序列,是否还返回a [`~utils.ModelOutput`]

6.生成时使用的特殊token的配置参数

参数类型缺省值含义
pad_token_idintpadding token ID
bos_token_idintbeginning -of – sequence token ID
eos_token_idUnion[int, List[int]]end-of-sequence token ID

6.辅助生成的配置参数(投机采样)

参数类型缺省值含义
is_assistantboolFalse指定是否模型是一个assistant(draft) model
num_assistant_tokensint20投机采样过程,每次迭代 assistant model 要输出多少个token,给到目标模型做check。配置更高的值,如果assistant model 效果好 能带来更好的加速比
num_assistant_tokens_schedulestrconstant“heuristic” : 当所有投机采样的token都正确时,将num_assistant_tokens增加2,否则减少1。
“constant”: num_assistant_tokens 保持固定不变
“heuristic_transient”: 类似于启发式方法,每次生成调用,都置成初始化的num_assistant_tokens值
assistant_confidence_thresholdfloat0.4当assistant model预估当前token的置信度 小于 阈值时,提前终止assistant model的生成
prompt_lookup_num_tokensint作为候选token 要输出的token的数量
max_matching_ngram_sizeint2match prompt的最大ngram的数量
assistant_early_exitint
assistant_lookbehindint10如果设置为正整数,则重新编码过程将额外考虑最后的assistant_lookbehind个辅助标记,以正确对齐标记。此设置仅可在推测解码中使用不同的分词器时使用。
target_lookbehindint10如果设置为正整数,则重新编码过程将额外考虑最后的target_lookbehind个辅助标记,以正确对齐标记。此设置仅可在推测解码中使用不同的分词器时使用。


如有整理错误,欢迎指正~

语音理解模型—OSUM

OSUM: Advancing Open Speech Understanding Models with Limited Resources in Academia

大型语言模型(LLMs)在各种下游任务中取得了显著进展,启发了业界对语音理解语言模型(speech understanding language models, SULMs)的研发,以期实现基于语音情感、性别等副语言的高表现力交互。然而,大多数先进的SULM是由行业头部公司开发的,消耗大规模的数据和计算资源。而这些资源在学术界并不容易获得。此外,虽然训练好的模型和推理代码被开源了,但训练框架和数据处理流程依然缺乏透明度,这也为进一步研究产生了障碍。在本研究中,我们提出了OSUM,一个开放的语音理解模型,旨在探索在有限的学术资源下训练SLUM的潜力。OSUM模型将Whisper编码器与Qwen2 LLM相结合,支持广泛的语音任务,包括语音识别(ASR)、带时间戳的语音识别(SRWT)、语音事件检测(VED)、语音情感识别(SER)、说话风格识别(SSR)、说话者性别分类(SGC)、说话者年龄预测(SAP)和语音转文本聊天(STTC)。通过采用ASR+X训练策略,OSUM通过同时优化模态对齐和目标任务,实现了高效稳定的多任务训练。除了提供强大的性能,OSUM还强调透明度,提供公开可用的代码,并详细介绍了数据处理流程,以期为学术界提供有价值的参考,旨在加速先进SULM技术的研究和创新。

方案设计 

OSUM模型将Whisper编码器与Qwen2 LLM相结合,支持广泛的语音任务,包括语音识别(ASR)、带时间戳的语音识别(SRWT)、语音事件检测(VED)、语音情感识别(SER)、说话风格识别(SSR)、说话者性别分类(SGC)、说话者年龄预测(SAP)和语音转文本聊天(STTC)。通过采用ASR+X训练策略,OSUM通过同时优化模态对齐和目标任务,实现了高效稳定的多任务训练。

模型结构

模型的输入包括语音和自然语言提示。不同于 Whisper 和Qwen-Audio 依靠指令标签,Osum采用描述性文本,将所有八个支持任务转换为图2所示。当前,我们的模型仅支持基于文本的响应,但是音频输出功能正在积极开发。

如图2所示,OSUM模型由一个Speech Encoder、一个Adaptor和一个LLM组成。在训练过程中,Speech Encoder和Adaptor中的所有参数都会更新,而大语言模型则使用LoRA方法进行微调。各部分具体配置如下:

  • Speech Encoder: Whisper-Medium (769M);
  • Adaptor: Conv1D * 3 + Transformer * 4,4倍下采样;
  • LLM: Qwen2-7B-Instruct带LoRA。LoRA hyperparameters-α, rank, and dropout ratio are set to 32, 8, and 0.1,

多任务监督训练

训练过程包括两个阶段:

首先,在没有LLM的情况下,对原始的Whisper模型进行多任务监督微调,多任务数据微调了 Whisper ,以确保OSUM模型的更快收敛。此外,此阶段使我们能够验证多任务数据的可靠性。具体来说,我们扩展了Whisper的指示标签,以适应更多的任务,每个前向推理仅执行一个任务。

其次,将微调后的Whisper编码器与Qwen2大语言模型相结合,构建出完整的OSUM系统,然后使用更大的数据集进行进一步的监督训练。

OSUM模型的输入包括一段语音和一个自然语言描述的prompt,而输出在现阶段仅支持文本回复,音频输出功能正在开发中。为节省计算资源,OSUM的多任务训练引入了一种“ASR+X”范式,即同时训练ASR任务和一个附加任务X。这在加速训练的同时,允许执行X任务时参考文本和声学两种特征,从而提升性能和训练稳定性。“ASR+X”范式是在LLM的自回归框架内通过调整预测标签来实现的,无需对模型架构或损失函数进行修改。执行不同的X任务是通过给LLM不同的自然语言prompt来实现的,每个任务有5个候选prompt,训练时随机选择一个。prompt的示例如表1所示。

训练数据

OSUM旨在使用多样化的语音数据集进行多任务训练,目标是构建一个能够在对话场景中全面理解输入语音的统一模型。多任务训练过程使各个任务能够从共享学习中获益,从而提升模型的整体性能。有关用于训练的数据集的详细信息见表2所示,本版本模型的训练数据规模大约为5万小时。

技术性能

总览

如图2所示,OSUM 模型和Qwen2-Audio 相比,在大多数任务中,尽管 OSUM 使用的计算资源和训练数据明显更少,但它的表现优于Qwen2-Audio。

图2 OSUM与Qwen2-Audio各项任务性能对比的雷达图。雷达图中每个模型各项任务的值是基于公开测试集和内部测试集的平均结果得出的

各项指标与性能演示

ASR(语音识别):如表4所示,OSUM在中文ASR上表现优越,具体地,在WenetSpeech test meeting、3个AISHELL-2子测试集以及4个内部使用的SpeechIO测试集上优于其他模型。OSUM在英语测试集上性能也可与SenseVoice-S相媲美。值得注意的是,这些结果是在使用少得多的训练数据的情况下取得的。此外,我们发现,即使在训练过程中未纳入中英混语料数据集,OSUM在识别中英混语音方面也展现出了令人惊讶的出色能力。

表4公开测试集和内部测试集上ASR任务的评估结果。加粗字体表示同一测试集中的最佳结果。所有内部测试结果均由我们自行推理得出

表45公开测试集和内部测试集上多任务的评估结果。每个测试集的最佳结果都用粗体突出显示。蓝色字体显示的结果以及内部测试集的结果,均是我们使用原始发布的模型自行推理得出的

SRWT(带时间戳的语音识别):如表5所示,OSUM模型在SRWT任务上的性能显著优于Whisper-Large-v3,相对优势达到了36.70%,并且也超过了Qwen-Audio。此外,OSUM的表现甚至略微超过了GMM-HMM模型,而后者在时间戳预测任务被广泛使用。另外,此功能不仅使得OSUM能够以端到端的方式预测时间戳,更重要的是,它引导OSUM模型理解了“时间”这一概念。在将来,我们将会利用这一能力继续开发更灵活的应用,例如判断音频中何时出现了语音事件,何时出现了说话人转换等。

VED(语音事件检测):我们首先在公开测试集ESC-50和VocalSound上评估OSUM的性能。ESC-50包含大量的非人声音频事件,我们将它们归类为“其他”。表45示的实验结果表明,OSUM可以成功地将这些非人声音频事件归类为“其他”。此外,在VocalSound数据集上的结果显示,OSUM与Qwen2-audio相比虽然存在一定差距,但也取得了超过80%的准确率。值得注意的是,为更加符合真实使用场景,我们的训练数据是语音和音频事件拼接而成,但公开测试集只有孤立的音频事件而没有说话语音。即便存在这一不匹配的情况,OSUM模型的在公开测试集上的结果也证明了其有效性和泛化性。与公开测试集不同,我们人工录制了同时包含语音和声学事件的内部测试集。表45结果表明,PANNs由于其仅为孤立音频事件检测而设计,在我们内部测试集中基本处于不可用状态。Qwen2-audio的表现相对较好,但也出现了性能下降。相比之下,OSUM模型在公开测试集和内部测试集上都取得了较为均衡的结果,展现出了更强的泛化能力。

SER(语音情感识别):如表45示,对于SER任务,使用公开数据集的实验中,OSUM在MER2023测试集上展现出了卓越的性能,超过了一些近期的公开基准模型。在MELD数据集上,OSUM的性能略低于SenseVoice-L模型,这很可能是因为后者在更大规模的语音情感数据集上进行了训练。此外,OSUM在内部测试集上的结果与EmoBox模型相当,显著优于其他对比方法。但是,我们也观察到,厌恶和恐惧这两种情感尤其难以识别,其归因于这两种情感的训练数据更加稀缺,也容易和其他情感混淆。

SSR(说话风格识别):表5中实验表明,OSUM所采用的声学-文本双模态风格分类方法的表现显著优于GLM-4-9B-Chat所采用的单文本模态方法,这充分证明了“ASR+X”策略的价值。现阶段OSUM能够区分八种风格:“新闻科普”,“恐怖故事”,“童话故事”,“客服”,“诗歌散文”,“有声书”,“日常口语”以及“其他”。我们详细分析了测试集上各类别的准确率,发现OSUM在对“新闻科普”、“有声书”、“童话故事”以及“客服”风格类别上表现出色;然而,在“诗歌散文”、“恐怖故事”类别上仍有提升空间。有趣的是,我们发现从实际测试的主观体验上来说,OSUM风格分类正确率是超过测试集的,总体来说可以让人满意。

SGC(说话者性别分类):在SGC公开测试集上的结果表明,OSUM在AISHELL-1测试集上达到了100%的准确率。这一结果在一定程度上表明该任务上存在说话人过拟合现象。此外,在Kaggle测试集上,我们的方法略优于Qwen2-Audio。但在我们的内部测试集上,OSUM的性能略低于Qwen2-Audio,但依然超过了95%。总之,OSUM在SGC任务上展现出了不错的性能,而且实测效果很少出现性别判断错误的情况。

SAP(说话者年龄预测):在SAP任务上,由于我们发现青少年和成年人的声学相似度非常高,这使得有效区分他们变得很复杂。因此,我们将年龄分为三类:儿童、成年人和老年人。尽管我们努力调试了prompt,但Qwen2-Audio在Kaggle测试集和我们的内部测试集上,年龄分类准确率都较低。这可能是因为这些模型对年龄的分类过于细致,从而影响了Qwen2-Audio模型的最终效果。表4中结果显示,OSUM在Kaggle测试集上显著优于Qwen2-Audio,达到了76.52%的准确率。在我们的内部测试集上OSUM分类准确率虽然略有下降,但仍然超过了Qwen2-Audio。这表明OSUM在不同的数据上表现出了很强的泛化能力。

STTC(语音转文本聊天):如表5所示,在STTC任务中,我们在所有测试集上都遵循了AirBench的评估协议。这包括提供音频查询的文本以及两个不同答案的文本,让基于文本的大语言模型(LLM)给出1到10的主观评分。这两个答案一个是真实回复,另一个是语音大语言模型(SULM)生成的答案。测试结果表明,在AirBench的官方speech子测试集上,OSUM的得分虽然低于Qwen2-Audio,但也处于一个合理范围。这主要是因为我们没有使用英语对话数据进行训练,目前的得分完全依赖于大语言模型自身的表现。反之,在我们内部的中文对话测试集上,OSUM的表现优于Qwen2-Audio,这充分证明了OSUM在中文对话任务上性能是不错的。总体而言,我们的OSUM模型在对话能力方面与Qwen2-Audio相当。

更多功能

OSUM理解大模型在将来会提供更多的功能,可作为通用语音打标工具使用。此外,我们正在开发的功能包括:

  1. 同时支持ASR+X和单X任务模式,在执行单X任务打标时推理速度更快。
  2. 同时输出ASR+X1+X2+..Xn的多任务打标模式,一次性提供几乎全部所需标签。
  3. 增加更多的理解任务。

Step-Audio:产品级开源实时语音对话模型

阶跃星辰:Step-Audio 是业界首个集语音理解与生成控制一体化的产品级开源实时语音对话系统,支持多语言对话(如 中文,英文,日语),语音情感(如 开心,悲伤),方言(如 粤语,四川话),可控制语速及韵律风格,支持RAP和哼唱等。其核心技术突破体现在以下四大技术亮点:

  • 1300亿多模态模型: 单模型能实现理解生成一体化完成语音识别、语义理解、对话、语音克隆、语音生成等功能,开源千亿参数多模态模型 Step-Audio-Chat
  • 高效数据生成链路: 基于130B 突破传统 TTS 对人工采集数据的依赖,生成高质量的合成音频数据,并同步开源首个基于大规模合成数据训练,支持 RAP 和哼唱的指令加强版语音合成模型 Step-Audio-TTS-3B ,该模型具有增强的指令遵循功能以控制语音综合的能力。
  • 精细语音控制: 支持多种情绪(如生气,高兴,悲伤)、方言(包括粤语、四川话等)和唱歌(包括 RAP、干声哼唱)的精准调控,满足用户对多样化语音生成的需求。
  • 扩展工具调用: 通过 ToolCall 机制和角色扮演增强,进一步提升其在 Agents 和复杂任务中的表现。
端到端语音相互作用的人类评估。

模型组成

图2 采用了AQTA(音频输入,文本输出) + TTS框架 进行实时语音对话

Step-Audio的体系结构。 Step-Adio主要由三个组成部分组成:语音令牌,LLM和语音解码器。语音令牌器负责将输入语音离散到令牌中。LLM模型接收文本和语音令牌,输出文本,而语音解码器生成波形输出。

传统的语音对话系统通常采用包括ASR的级联建筑,LLM和TTS模块。但是,我们提出的模型在训练阶段进行了全面的多模式培训以及对文本和音频的一致性,已经具有端到端的语音对话功能。尽管对替代设计进行了广泛的探索,但我们最终采用了AQTA(音频输入,文本输出) + TTS框架 进行实时语音对话,如图2所示,这是由以下考虑的驱动的:

  • 高质量的纯净对话数据的稀缺性:纯净对话数据的可用性有限,再加上其受限的场景,限制了端到端语音对话模型的训练效率。
  • 输出语音的可控性和自定义:通过引入TTS模块,我们可以灵活地控制语音参数,例如音色和音调,以满足用户的个性化需求,同时不断增强模型的表现力能力。

在Step-Audio系统中,音频流采用Linguistic tokenizer【语义】(码率16.7Hz,码本大小1024)与Semantice tokenizer【声学】(码率25Hz,码本大小4096)并行的双码本编码器方案,双码本在排列上使用了2:3时序交错策略。通过音频语境化持续预训练和任务定向微调强化了130B参数量的基础模型(Step-1),最终构建了强大的跨模态语音理解能力。为了实现实时音频生成,系统采用了混合语音解码器,结合流匹配(flow matching)与神经声码技术。此外,采用语音活动检测(VAD)模块提取声段。

Tokenizer

我们通过token级交错方法实现Linguistic token与Semantic token的有效整合。Linguistic tokenizer的码本大小是1024,码率16.7Hz;而Semantic tokenizer则使用4096的大容量码本来捕捉更精细的声学细节,码率25Hz。鉴于两者的码率差异,我们建立了2:3的时间对齐比例——每两个Linguistic token对应三个Linguistic token形成时序配对

语言模型

为了提升Step-Audio有效处理语音信息的能力,并实现精准的语音-文本对齐,我们在Step-1(一个拥有1300亿参数的基于文本的大型语言模型LLM)的基础上进行了音频持续预训练。

在多轮对话系统中音频令牌和文本令牌之间的长度差异需要有效的处理策略。为了解决这个问题,历史信息最初是在系统输入之前使用ASR模型转录为文本格式的,从而优化了计算效率。但是,应注意的是,模型体系结构在需要时保持处理和使用音频令牌作为历史上下文的能力。

语音解码器

Step-Audio语音解码器主要是将包含语义和声学信息的离散标记信息转换成连续的语音信号。该解码器架构结合了一个30亿参数的语言模型、流匹配模型(flow matching model)和梅尔频谱到波形的声码器(mel-to-wave vocoder)。为优化合成语音的清晰度(intelligibility)和自然度(naturalness),语音解码器采用双码交错训练方法(dual-code interleaving),确保生成过程中语义与声学特征的无缝融合

实时推理管线

为了实现实时的语音交互,我们对推理管线进行了一系列优化。其中最核心的是控制模块(Controller),该模块负责管理状态转换、协调响应生成,并确保关键子系统间的无缝协同。这些子系统包括:

  • 语音活动检测(VAD):实时检测用户语音起止
  • 流式音频分词器(Streaming Audio Tokenizer):实时音频流处理。输入音频流是通过两个平行的令牌管道处理的,每个管道都采用固定持续分段。将所得令牌无缝合并为2:3交织比的单个序列。没有流音频令牌,根据音频输入的长度,推理时间将明显较慢。
  • Step-Audio语言模型与语音解码器:多模态回复生成
  • 上下文管理器(Context Manager):动态维护对话历史与状态。我们的系统利用文本转录而不是原始的音频令牌来实现历史上下文,因为它提供了更紧凑的表示(平均文本审计代币比率为1:14),提高性能,并启用更长的对话,对质量的影响最小的影响很小。 ASR异步将用户语音转录为文本,并保持准确,最新的对话历史记录。

后训练细节

在后训练阶段,我们针对自动语音识别(ASR)与文本转语音(TTS)任务进行了专项监督微调(Supervised Fine-Tuning, SFT)。对于音频输入-文本输出(Audio Question Text Answer, AQTA)任务,我们采用多样化高质量数据集进行SFT,并采用了基于人类反馈的强化学习(RLHF)以提升响应质量,从而实现对情感表达、语速、方言及韵律的细粒度控制。

TTS模型:

解决TTS任务中高质量语音数据的稀缺性

Training Detail

与传统的语音合成(TTS)系统注重对说话人特征、情感表达、语言特征和风格元素的精细控制不同,我们的方法采用了基于聊天的范式和大型语言模型(LLMs)的训练方法。这一战略对齐显著增强了系统的灵活性,同时建立了一个可扩展的框架,以支持未来模型和数据的扩展,从而解决了语音合成系统在可扩展性方面的关键挑战。

监督的微调格式:

SFT格式包括三个基本组成部分:系统提示、人类输入和助手回复,采用两轮对话结构。在这种格式中,系统提示作为指定说话人属性和定义支持的指令标签的基础元素。人类输入和助手回复部分则专门用于处理文本内容和双词典表示。第一轮的文本和音频标记可以用来保持领域内说话人的音色和风格一致性,同时也支持领域外的零样本克隆。

指令标签

指令标签分为两种不同的类别:描述性标签和比较性标签。描述性标签用于控制语言、方言、声音和风格等方面,而比较性标签则用于情感和语速控制的层次化区分。描述性标签的数据是通过Step-Audio模型克隆生成的,支持包括日语、韩语、粤语、四川方言、可爱声音、说唱和唱歌等语言和风格。比较性标签的数据则是通过Audio Edit模型生成的,支持诸如快乐、愤怒、悲伤等情感,以及快慢等语速变化,每种变化都被分为五个层级。

我们使用第5.1.1节中概述的SFT数据,并采用一个具有30亿参数的模型,训练一个周期,初始学习率为 2×10−5。学习率采用余弦衰减策略进行调整,最低值设置为 2×10−6。

AQTA:

我们为AQTA任务应用了基于人类反馈的强化学习(RLHF),从而创建了Step-Audio-Chat模型,如图6所示。

说明:

用了AQTA(音频输入,文本输出) + TTS框架 情况下是如何实现多语言对话(如 中文,英文,日语),语音情感(如 开心,悲伤),方言(如 粤语,四川话),可控制语速及韵律风格,支持RAP和哼唱

通过TTS【cosyvoice】代码可知,LLM的文本输出中会包含 {语言}【情感】 [语速] 这样的文本输出,然后TTS用于合成对应的音频: 使用[{}]的声音,根据这些情感标签的指示,调整你的情感、语气、语调和哼唱节奏

    self.sys_prompt_dict = {
        "sys_prompt_for_rap": "请参考对话历史里的音色,用RAP方式将文本内容大声说唱出来。",
        "sys_prompt_for_vocal": "请参考对话历史里的音色,用哼唱的方式将文本内容大声唱出来。",
        "sys_prompt_wo_spk": '作为一名卓越的声优演员,你的任务是根据文本中()或()括号内标注的情感、语种或方言、音乐哼唱、语音调整等标签,以丰富细腻的情感和自然顺畅的语调来朗读文本。\n# 情感标签涵盖了多种情绪状态,包括但不限于:\n- "高兴1"\n- "高兴2"\n- "生气1"\n- "生气2"\n- "悲伤1"\n- "撒娇1"\n\n# 语种或方言标签包含多种语言或方言,包括但不限于:\n- "中文"\n- "英文"\n- "韩语"\n- "日语"\n- "四川话"\n- "粤语"\n- "广东话"\n\n# 音乐哼唱标签包含多种类型歌曲哼唱,包括但不限于:\n- "RAP"\n- "哼唱"\n\n# 语音调整标签,包括但不限于:\n- "慢速1"\n- "慢速2"\n- "快速1"\n- "快速2"\n\n请在朗读时,根据这些情感标签的指示,调整你的情感、语气、语调和哼唱节奏,以确保文本的情感和意义得到准确而生动的传达,如果没有()或()括号,则根据文本语义内容自由演绎。',
        "sys_prompt_with_spk": '作为一名卓越的声优演员,你的任务是根据文本中()或()括号内标注的情感、语种或方言、音乐哼唱、语音调整等标签,以丰富细腻的情感和自然顺畅的语调来朗读文本。\n# 情感标签涵盖了多种情绪状态,包括但不限于:\n- "高兴1"\n- "高兴2"\n- "生气1"\n- "生气2"\n- "悲伤1"\n- "撒娇1"\n\n# 语种或方言标签包含多种语言或方言,包括但不限于:\n- "中文"\n- "英文"\n- "韩语"\n- "日语"\n- "四川话"\n- "粤语"\n- "广东话"\n\n# 音乐哼唱标签包含多种类型歌曲哼唱,包括但不限于:\n- "RAP"\n- "哼唱"\n\n# 语音调整标签,包括但不限于:\n- "慢速1"\n- "慢速2"\n- "快速1"\n- "快速2"\n\n请在朗读时,使用[{}]的声音,根据这些情感标签的指示,调整你的情感、语气、语调和哼唱节奏,以确保文本的情感和意义得到准确而生动的传达,如果没有()或()括号,则根据文本语义内容自由演绎。',
    }

VITA-1.5:GPT-4o级别的实时视觉和语音交互模型

[📖 VITA-1.5 Paper] [🤖 Basic Demo] [🍎 VITA-1.0]

[📽 VITA-1.5 Demo Show! Here We Go! 🔥]

引言

近年来,多模态大语言模型(MLLMs)在视觉和文本的结合上取得了显著进展。然而,随着人机交互需求的增加,语音在多模态对话系统中的作用变得愈发重要。语音不仅是信息传递的关键媒介,还能显著提升交互的自然性和便捷性。因此,如何将视觉和语音模态高效整合,实现高性能的多模态交互,成为了当前研究的重点。

VITA-1.5的提出正是为了解决这一挑战。通过精心设计的多阶段训练方法,VITA-1.5逐步训练大语言模型(LLM)理解视觉和语音信息,最终实现了流畅的视觉和语音交互。与现有模型相比,VITA-1.5不仅保留了强大的视觉-语言能力,还实现了高效的语音对话能力,显著加速了多模态端到端的响应速度。

VITA-1.5

模型架构

图 2:VITA-1.5 的整体架构。输入端由视觉和音频编码器及其连接到 LLM 的适配器组成。输出端有一个端到端的语音生成模块,而不是像初始 VITA-1.0 版本那样直接使用外部 TTS 模型。

VITA-1.5的整体架构如图2所示。输入侧与VITA-1.0版本相同,采用“多模态编码器-适配器-LLM”的配置。它将视觉/音频Transformer和多层连接器与LLM结合进行联合训练,旨在增强对视觉、语言和音频的统一理解。在输出侧,VITA-1.5拥有自己的端到端语音模块,而不是像原始VITA-1.0版本那样使用外部TTS模型。

视觉模态

视觉编码器:VITA-1.5采用InternViT-300M作为视觉编码器,输入图像大小为448×448像素,每张图像生成256个视觉标记。对于高分辨率图像,VITA-1.5采用动态分块策略捕捉局部细节,提高图像理解的准确性。

视频处理:视频被视为一种特殊的多图像输入。如果视频长度短于4秒,则均匀采样4帧;对于4到16秒的视频,每秒采样一帧;对于超过16秒的视频,均匀采样16帧。视频帧不应用动态分块,以避免过多的视觉标记影响处理效率。

视觉适配器:使用两层MLP将视觉特征映射到适合LLM理解的视觉标记。

音频模态

语音编码器:类似于[56],我们的音频编码模块由多个下采样卷积层(4倍下采样)和24个Transformer块(隐藏大小为1024)组成。下采样层有助于降低音频特征的帧率,提高LLM的处理速度。音频编码器约有350M参数,输出帧率为12.5Hz。使用Mel滤波器组特征作为音频编码器的输入,窗口大小为25ms,偏移为10ms。

语音适配器:由多个2倍下采样的卷积层组成。

语音解码器:使用TiCodec作为我们的编解码模型,定制了一个大小为1024的单码本。这种单码本设计简化了推理阶段的解码过程。编解码模型负责将连续语音信号编码为离散语音标记,频率为40Hz,同时能够将这些标记解码回采样率为24,000Hz的语音信号。

当前的LLM只能输出文本标记,语音生成能力要求LLM能够输出语音标记。为此,我们在文本标记后添加了两个语音解码器:1)非自回归(NAR)语音解码器,全局处理文本标记并建模语义特征,旨在生成语音标记的初始分布;2)自回归(AR)语音解码器,基于NAR解码器生成的语音信息逐步生成更高质量的语音标记。最终的语音标记序列通过编解码模型的语音解码器解码为连续语音信号流(波形)。我们为NAR和AR语音解码器采用了4个LLaMA解码层,隐藏大小为896,参数大小约为120M。

训练数据

如表1所示,多模态指令微调的训练数据涵盖了广泛的类别,如描述数据和问答数据,包括中文和英文。在不同的训练阶段,从整体数据集中选择性地采样子集以服务于不同的目标。具体来说,数据集分类如下:

  • 图像描述数据:使用ShareGPT4V、ALLaVA-Caption、SharedGPT4o-Image和合成数据等数据集训练模型生成图像的描述性语言。
  • 图像问答数据:使用LLaVA-150K、LLaVA-Mixture-sample、LVIS-Instruct、ScienceQA、ChatQA和从LLaVA-OV采样的子集(如通用图像问答和数学推理数据集)等数据集训练模型回答基于图像的问题和执行视觉推理任务。
  • OCR和图表数据:支持模型理解OCR和图表内容,使用Anyword-3M、ICDAR2019-LSVT、UReader、SynDOG、ICDAR2019-LSVT-QA和从LLaVA-OV采样的相应数据等数据集。
  • 视频数据:使用ShareGemini和合成数据等数据集训练模型处理视频输入并执行诸如描述和基于视频的问答等任务。
  • 纯文本数据:增强模型理解和生成语言的能力,促进基于文本的问答任务。

除了表1中列出的图像和视频数据外,还纳入了110,000小时的内部语音-转录配对ASR数据,涵盖中文和英文,用于训练音频编码器并将音频编码器与LLM对齐。此外,使用TTS系统生成的3,000小时文本-语音配对数据用于训练语音解码器。

三阶段训练策略

为了确保VITA-1.5在涉及视觉、语言和音频的任务中表现良好,我们必须面对一个关键挑战,即不同模态之间的训练冲突。例如,添加语音数据可能会对视觉数据的理解产生负面影响,因为语音的特征与视觉的特征显著不同,导致学习过程中的干扰。为了解决这一挑战,我们设计了一个三阶段训练策略,如图3所示。核心思想是逐步将不同模态引入模型,使其在增加新模态能力的同时保持现有模态的能力。

VITA-1.5的训练管道。训练过程分为三个阶段,以逐步将视觉和音频纳入LLM同时缓解了形态冲突。第一阶段的重点是视觉训练,包括视觉对齐(阶段1.1,使用表1中的20%字幕数据),视觉理解(阶段1.2,使用100%的字幕数据)以及用于Visual QA的指令调整(阶段1.3,使用20%字幕数据和100%QA数据)。阶段2引入音频输入调整,并具有音频对齐(阶段2.1,使用11,000小时的语音转录对)和语音质量检查的指令调整(阶段2.2,采样4%字幕数据和20%的QA数据)。最后,第3阶段的重点是音频输出调整,包括对编解码器模型的训练(使用3,000个小时的文本语音数据)和语音解码器培训(阶段3.2)。图像中显示的百分比对应于表1中指定的数据采样率。

阶段1:视觉训练

阶段1.1 视觉对齐:在此阶段,我们的目标是弥合视觉和语言之间的差距。前者的特征从预训练的视觉编码器InternViT-300M中提取,后者通过LLM引入。我们使用表1中20%的描述性描述数据进行训练,其中只有视觉适配器是可训练的,而其他模块是冻结的。这种方法允许LLM初步对齐视觉模态。

阶段1.2 视觉理解:在此阶段,我们的目标是教会LLM转录图像内容。为此,我们使用表1中所有的描述性描述数据。在此过程中,视觉模块的编码器和适配器以及LLM都是可训练的。重点是使模型通过学习关于图像的描述性文本,建立视觉和语言之间的强连接,使其能够通过生成自然语言描述来理解图像内容。

阶段1.3 视觉SFT:在阶段1.2之后,模型已经获得了对图像和视频的基本理解。然而,指令跟随能力仍然有限,难以应对视觉问答任务。为了实现这一点,我们使用表1中所有的问答数据,同时保留20%的描述性描述数据以增加数据集的多样性和任务的复杂性。

在训练过程中,视觉模块的编码器和适配器以及LLM都是可训练的。此阶段的关键目标是使模型不仅能够理解视觉内容,还能够根据指令回答问题。

阶段2:音频输入微调

阶段2.1 音频对齐:在完成阶段1的训练后,模型已经建立了强大的图像和视频理解基础。在此阶段,我们的目标是基于阶段1减少音频和语言之间的差异,使LLM能够理解音频输入。训练数据包括11,000小时的语音-转录对。我们采用两步方法:(a)语音编码器训练:我们采用常见语音识别系统中使用的训练框架,使用连接时序分类(CTC)损失函数[18]训练语音编码器。目的是使编码器从语音输入中预测转录文本。此步骤确保音频编码器能够提取语音特征并将其映射到文本表示空间。(b)语音适配器训练:在训练语音编码器后,我们将其与LLM集成,使用音频适配器将音频特征引入LLM的输入层。此阶段的训练目标是使LLM能够输出语音数据的转录文本。

此外,在步骤(b)中,我们引入了特殊的可训练输入标记来指导语音理解过程。这些标记提供了额外的上下文信息,指导用于问答任务的LLM执行ASR任务。

阶段2.2 音频SFT:此阶段的重点是引入语音问题和文本答案的问答功能。为此,我们从表1中采样4%的描述数据和20%的问答数据。在数据处理方面,大约一半的基于文本的问题被随机替换为其对应的语音版本,使用TTS系统生成。

在此阶段,视觉编码器和适配器、音频编码器和适配器以及LLM都是可训练的,旨在提高模型对多模态输入的适应性。此外,我们在LLM的输出中添加了一个分类头。该头用于区分输入是来自语音还是文本。结果,模型可以更准确地解释语音输入,并高效灵活地处理不同模态。

阶段3:音频输出微调

在前两个训练阶段,VITA-1.5模型已经有效地发展了其多模态理解能力。然而,一个关键的能力,即语音输出,仍然缺失,这对于其作为交互助手的角色至关重要。为了在不影响模型基本能力的情况下引入语音输出功能,我们借鉴了[56]的策略,使用3,000小时的文本-语音数据,并采用两步训练方法(见图3)。

阶段3.1 编解码训练:此步骤的目标是使用语音数据训练具有单码本的编解码模型。编解码模型的编码器能够将语音映射到离散标记,而解码器可以将离散标记映射回语音流。在VITA-1.5的推理阶段,仅使用解码器。

阶段3.2 NAR + AR解码器训练:此阶段的训练使用文本-语音配对数据,其中文本被输入到LLM的分词器和嵌入层以获得其嵌入向量,语音被输入到编解码模型的编码器以获得其语音标记。文本嵌入向量被发送到NAR语音解码器以获得全局语义特征,然后将这些特征发送到AR语音解码器,预测相应的语音标记。请注意,在此阶段LLM是冻结的,因此多模态性能不受影响。

评估

视觉-语言评估

基线:我们比较了一系列开源MLLMs,包括VILA-1.5、LLaVA-Next、CogVLM2、InternLM-XComposer2.5、Cambrian-1、MiniCPM-V-2.6、Ovis1.5、InternVL-Chat-1.5、InternVL-2、LLaVA-OV和Video-LLaVA、SilME和LongVA,以及5个闭源MLLMs,包括GPT-4V、GPT-4o、GPT-4o-mini、Gemini 1.5 Pro和Claude 3.5 Sonnet。

评估基准:为了评估VITA-1.5的图像感知和理解能力,我们使用了多个评估基准,包括MME、MMBench、MMStar、MMMU、MathVista、HallusionBench、AI2D、OCRBench和MMVet。这些基准涵盖了广泛的方面,包括通用多模态能力(如MME、MMBench和MMMU)、数学推理(MathVista)、幻觉检测(HallusionBench)、图表(AI2D)和OCR(OCRBench)理解,提供了全面的评估结果。对于视频理解,我们使用了代表性的评估基准,包括Video-MME、MVBench和TempCompass。

视觉-语言能力:表2展示了VITA-1.5的图像理解性能比较。在三个阶段的训练后,VITA-1.5的表现与最先进的开源模型相当,甚至超过了一些闭源模型,如GPT-4V和GPT-4o-mini。这一结果突显了VITA-1.5在图像-语言任务中的强大能力。如表3所示,VITA-1.5在视频理解评估中表现出与顶级开源模型相当的性能。与专有模型的显著差距表明,VITA-1.5在视频理解方面仍有显著的改进空间和潜力。请注意,在阶段2(音频输入微调)和阶段3(音频输出微调)的训练后,VITA-1.5几乎保留了其在阶段1(视觉-语言训练)中的原始视觉-语言能力。

语音评估

基线:以下三个基线模型用于比较:Wav2vec2-base、Mini-Omini2、Freeze-Omini和VITA-1.0。

评估基准普通话评估集包括三个数据集:aishell-1、test net和test meeting。这些数据集用于评估模型在普通话语音上的表现。评估指标是字符错误率(CER)。英语评估集包括四个数据集:dev-clean、dev-other、test-clean和test-other,用于评估模型在英语语音上的表现。评估指标是词错误率(WER)。

ASR性能:表4中的评估结果表明,VITA-1.5在普通话和英语ASR任务中均取得了领先的准确性。这表明VITA-1.5已成功集成了先进的语音能力,以支持多模态交互。

结论

本文介绍了VITA-1.5,这是一个通过精心设计的三阶段训练策略整合视觉和语音的多模态LLM。通过缓解模态之间的固有冲突,VITA-1.5在视觉和语音理解方面实现了强大的能力,无需依赖单独的ASR或TTS模块即可实现高效的语音到语音交互。广泛的评估表明,VITA-1.5在多模态基准测试中表现出色。我们希望VITA-1.5能够接过VITA-1.0的旗帜,继续推动开源模型在实时多模态交互领域的进步。

ASR语音识别指标计算

#coding=utf-8
import os
import sys
import re
from typing import List, Union
import jiwer
import pdb


def cal_wer(path_ref, path_hyp, metric_type, output_detail, path_output):

    ref_text, hyp_text, ref_key = _read_file(path_ref, path_hyp, metric_type)
    
    cal_wer_from_list(ref_text, hyp_text, ref_key, metric_type, output_detail, path_output)


def cal_wer_from_list(
    reference: Union[str, List[str]], 
    hypothesis: Union[str, List[str]], 
    key: Union[str, List[str]], 
    metric_type: str, 
    output_detail: bool, 
    path_output: str
):
    if isinstance(reference, str):
        reference = [reference]
    if isinstance(hypothesis, str):
        hypothesis = [hypothesis]
    if isinstance(key, str):
        key = [key]

    # 根据ref是否为空, 先分别计算wer指标再汇总
    ref_normal, hyp_normal, key_normal = [], [], []
    ref_empty, hyp_empty, key_empty = [], [], []
    for i in range(len(reference)):
        if len(reference[i]) != 0:
            ref_normal.append(reference[i])
            hyp_normal.append(hypothesis[i])
            key_normal.append(key[i])
        else:
            ref_empty.append(reference[i])
            hyp_empty.append(hypothesis[i])
            key_empty.append(key[i])

    res_normal, out_normal = _cal_wer_normal(ref_normal, hyp_normal, metric_type)
    res_empty, out_empty = _cal_wer_empty(hyp_empty, metric_type)
    _summary(ref_normal, hyp_normal, res_normal, out_normal.alignments, key_normal, 
             hyp_empty, res_empty, out_empty, key_empty, 
             metric_type, output_detail, path_output)


def _read_file(path_ref, path_hyp, metric_type):
    ref_key, ref_text = _preprocess(path_ref, '\t', metric_type)
    hyp_key, hyp_text = _preprocess(path_hyp, '\t', metric_type)

    tmp_dict = {}
    tmp_text = []
    for i in range(len(hyp_key)):
        if hyp_key[i] not in tmp_dict.keys():
            tmp_dict[hyp_key[i]] = hyp_text[i]
        else:
            print ("repeated key")
    for i in range(len(ref_key)):
        if ref_key[i] in tmp_dict.keys():
            tmp_text.append(tmp_dict[ref_key[i]])
        else:
            tmp_text.append("")

    return ref_text, tmp_text, ref_key


def _preprocess(path_in, sep, metric_type):
    res_key, res_text = [], []

    with open(path_in, "r", encoding="utf-8") as f_in:
        lines = f_in.readlines()
        for line in lines:
            line = line.strip().split(sep, 1)
            if len(line) == 2:
                key, text = line
                text = re.sub("<s>", "", text)
                text = re.sub("</s>", "", text)
                text = re.sub("<unk>", "", text)
                text = re.sub("@@ ", "", text)
                text = re.sub("@ ", "", text)
                text = re.sub("@@", "", text)
                text = re.sub("@", "", text)
                #text = re.sub(" ", "", text)
                text = text.lower()
            else:
                key = line[0]
                text = ""

            text = [x for x in text]
            text_tmp = ""
            if metric_type == "wer":
                for ch in text:
                    if '\u4e00' <= ch <= '\u9fff':
                        text_tmp += " " + ch + " "
                    else:
                        text_tmp += ch
                text = text_tmp.strip().replace("  ", " ")
            elif metric_type == "cer":
                text_tmp = "".join(text)
                text = text_tmp.strip().replace(" ", "")
            else:
                assert False

            res_key.append(key)
            res_text.append(text)

    return res_key, res_text


def _cal_wer_normal(reference, hypothesis, metric_type):
    if metric_type == "wer":
        out = jiwer.process_words(reference=reference, hypothesis=hypothesis)
        ERR = out.wer
    elif metric_type == "cer":
        out = jiwer.process_characters(reference=reference, hypothesis=hypothesis)
        ERR = out.cer
    else:
        assert False

    H = out.hits
    S = out.substitutions
    D = out.deletions
    I = out.insertions
    N = H + S + D

    res = [ERR, N, S, D, I]

    return res, out


def _cal_wer_empty(hypothesis, metric_type):
    out = []

    I = 0
    for hyp in hypothesis:
        if hyp == "":
            i = 0
        else:
            if metric_type == "wer":
                i = len(hyp.split(" "))
            elif metric_type == "cer":
                i = len(hyp)
            else:
                assert False
        I += i
        out.append(i)

    res = [0, 0, 0, 0, I]

    return res, out


def _summary(ref_normal, hyp_normal, res_normal, out_normal, key_normal,
             hyp_empty, res_empty, out_empty, key_empty, 
             metric_type, output_detail, path_output):
    # wer/cer计算
    _, N, S, D, I = res_normal
    I += res_empty[-1]
    if N != 0:
        ERR = (S + D + I) / N
        SUB = S / N
        DEL = D / N
        INS = I / N
        N_WORD = N
    else:
        if I == 0:
            ERR = 0
        else:
            ERR = 1
        SUB, DEL, INS, N_WORD = 0, 0, I, 0

    # 句准计算 + 详细错误指标 + 详细错误统计
    utt_normal, alignments_normal, statistics_normal = _analyse_normal(
        ref_normal, hyp_normal, out_normal, key_normal, metric_type)
    utt_empty, alignments_empty, statistics_empty = _analyse_empty(
        hyp_empty, out_empty, key_empty, metric_type)

    utt = utt_normal + utt_empty
    alignments = alignments_normal + alignments_empty
    for key in statistics_empty['insert'].keys():
        if key not in statistics_normal['insert'].keys():
            statistics_normal['insert'][key] = statistics_empty['insert'][key]
        else:
            statistics_normal['insert'][key] += statistics_empty['insert'][key]
    N_SENT = len(out_normal) + len(out_empty)
    ACC_UTT = utt / N_SENT
    res = [ERR, SUB, DEL, INS, N_WORD, ACC_UTT, N_SENT]

    _format_output(res, alignments, statistics_normal, metric_type, output_detail, path_output)


def _analyse_normal(ref_normal, hyp_normal, out_normal, key_normal, metric_type):
    utt_normal = 0
    alignments_normal = []
    statistics_normal = {'substitute' : {}, 'delete' : {}, 'insert' : {}}

    for i, alignment in enumerate(out_normal):
        err, n_hit, n_sub, n_del, n_ins = 0, 0, 0, 0, 0
        ref_align, hyp_align = "", ""
        sub_align, del_align, ins_align = "", "", ""
        for j, chunk in enumerate(alignment):
            if (metric_type == "wer" and (ref_align != "" or hyp_align != "")):
                ref_align += " "
                hyp_align += " "
            if chunk.type == 'equal':
                n_hit += chunk.ref_end_idx - chunk.ref_start_idx
                ref_align += _extract_string(ref_normal[i], chunk.ref_start_idx, chunk.ref_end_idx, metric_type)
                hyp_align += _extract_string(hyp_normal[i], chunk.hyp_start_idx, chunk.hyp_end_idx, metric_type)

            elif chunk.type == 'substitute':
                err += 1
                n_sub += chunk.ref_end_idx - chunk.ref_start_idx

                ref_sub = _extract_string(ref_normal[i], chunk.ref_start_idx, chunk.ref_end_idx, metric_type)
                hyp_sub = _extract_string(hyp_normal[i], chunk.hyp_start_idx, chunk.hyp_end_idx, metric_type)

                ref_align += ref_sub
                hyp_align += hyp_sub

                key_sub = "(" + ref_sub + ") --> (" + hyp_sub + ")"

                sub_align += key_sub + "\t"

                if key_sub not in statistics_normal['substitute'].keys():
                    statistics_normal['substitute'][key_sub] = 1
                else:
                    statistics_normal['substitute'][key_sub] += 1

            elif chunk.type == 'delete':
                err += 1
                n_del += chunk.ref_end_idx - chunk.ref_start_idx

                ref_del = _extract_string(ref_normal[i], chunk.ref_start_idx, chunk.ref_end_idx, metric_type)
                hyp_del = "*"

                ref_align += ref_del
                hyp_align += hyp_del

                key_del = ref_del

                del_align += key_del + "\t"

                if key_del not in statistics_normal['delete'].keys():
                    statistics_normal['delete'][key_del] = 1
                else:
                    statistics_normal['delete'][key_del] += 1

            elif chunk.type == 'insert':
                err += 1
                n_ins += chunk.hyp_end_idx - chunk.hyp_start_idx

                ref_ins = "*"
                hyp_ins = _extract_string(hyp_normal[i], chunk.hyp_start_idx, chunk.hyp_end_idx, metric_type)

                ref_align += ref_ins
                hyp_align += hyp_ins

                key_ins = hyp_ins

                ins_align += key_ins + "\t"

                if key_ins not in statistics_normal['insert'].keys():
                    statistics_normal['insert'][key_ins] = 1
                else:
                    statistics_normal['insert'][key_ins] += 1

            else:
                assert False

        if err == 0:
            utt_normal += 1
        alignments_normal.append((key_normal[i], ref_align, hyp_align, 
                                  sub_align, del_align, ins_align, 
                                  n_hit, n_sub, n_del, n_ins))

    return utt_normal, alignments_normal, statistics_normal


def _analyse_empty(hyp_empty, out_empty, key_empty, metric_type):
    utt_empty = 0
    alignments_empty = []
    statistics_empty = {'insert' : {}}

    for i, ins in enumerate(out_empty):
        ref_align, hyp_align = "", ""
        sub_align, del_align, ins_align = "", "", ""

        if ins == 0:
            utt_empty += 1
        else:
            ref_ins = "*"
            hyp_ins = _extract_string(hyp_empty[i], 0, len(hyp_empty[i]), metric_type)

            ref_align += ref_ins
            hyp_align += hyp_ins

            key_ins = hyp_ins

            ins_align += key_ins + "\t"

            if key_ins not in statistics_empty['insert'].keys():
                statistics_empty['insert'][key_ins] = 1
            else:
                statistics_empty['insert'][key_ins] += 1
        alignments_empty.append((key_empty[i], ref_align, hyp_align, 
                                sub_align, del_align, ins_align, 
                                0, 0, 0, ins))

    return utt_empty, alignments_empty, statistics_empty


def _extract_string(s, begin, end, metric_type):
    res = ""
    if metric_type == 'wer':
        res = ' '.join(s.split(' ')[begin:end])
    elif metric_type == 'cer':
        res = s[begin:end]
    else:
        assert False
    return res


def _format_output(res, alignments, statistics, metric_type, output_detail, path_output):
    with open(path_output, "w", encoding="utf-8") as f_out:
        if output_detail == True:
            f_out.write("-"*100 + "\n")
            for i, sample in enumerate(alignments):
                key, ref, hyp = sample[0:3]
                sub_align, del_align, ins_align = sample[3:6]
                n_hit, n_sub, n_del, n_ins = sample[6:]

                f_out.write("KEY: " + key + "\n")
                f_out.write("REF: " + ref + "\n")
                f_out.write("HYP: " + hyp + "\n")
                f_out.write("CNT: " + "H(" + str(n_hit) + ") " + \
                                      "S(" + str(n_sub) + ") " + \
                                      "D(" + str(n_del) + ") " + \
                                      "I(" + str(n_ins) + ")\n")
                f_out.write("SUB: " + sub_align + "\n")
                f_out.write("DEL: " + del_align + "\n")
                f_out.write("INS: " + ins_align + "\n\n")
            f_out.write("-"*100 + "\n")

            f_out.write("-"*100 + "\n")
            lst_sub = list(sorted(statistics['substitute'].items(), key = lambda x : x[1], reverse=True))
            lst_del = list(sorted(statistics['delete'].items(), key = lambda x : x[1], reverse=True))
            lst_ins = list(sorted(statistics['insert'].items(), key = lambda x : x[1], reverse=True))
            f_out.write("\n替换错误统计: \n")
            for x in lst_sub:
                f_out.write("\t" + x[0] + "(" + str(x[1]) + ")" + "\n")
            f_out.write("\n删除错误统计: \n")
            for x in lst_del:
                f_out.write("\t" + x[0] + "(" + str(x[1]) + ")" + "\n")
            f_out.write("\n插入错误统计: \n")
            for x in lst_ins:
                f_out.write("\t" + x[0] + "(" + str(x[1]) + ")" + "\n")
            f_out.write("-"*100 + "\n")

        f_out.write("-"*100 + "\n")
        f_out.write(metric_type.upper() + ": " + str(round(res[0] * 100.0, 2)) + '%\n')
        f_out.write("WORDS: " + str(res[4]) + "\t")
        f_out.write("SUB: " + str(round(res[1] * 100.0, 2)) + "%\t")
        f_out.write("DEL: " + str(round(res[2] * 100.0, 2)) + "%\t")
        f_out.write("INS: " + str(round(res[3] * 100.0, 2)) + "%\n")
        f_out.write("ACC_UTT: " + str(round(res[5] * 100.0, 2)) + '%\t')
        f_out.write("SENTS: " + str(res[6]) + '\n')
        f_out.write("-"*100 + "\n")
    
    print (metric_type + " calculation done")
    print ("saved to " + path_output)


if __name__ == '__main__':

    '''
    # example of function cal_wer_from_list
    ref = ["今 天 天 气", "hello 我 ok 的", ""]
    hyp = ["今 天 天", "halo 我 ok 的 呀", "噪 声"]
    key = ["000", "001", "002"]
    path_output = "./example.wer"
    cal_wer(ref, hyp, key, "wer", True, path_output)

    ref = ["今天天气", "hello我ok的", ""]
    hyp = ["今天天", "halo我ok的呀", "噪声"]
    key = ["000", "001", "002"]
    path_output = "./example.cer"
    cal_wer_from_list(ref, hyp, key, "cer", True, path_output)
    '''

InspireMusic–阿里通义开源音乐生成框架

InspireMusic是由通义实验室开源的音乐生成技术,旨在打造一款集音乐生成、歌曲生成、音频生成能力为一体的开源AIGC工具包。

为研究者和开发者提供音乐/歌曲/音频生成模型的训练和调优工具及模型,方便优化生成效果;同时为音乐爱好者提供一个易于使用的文本生成音乐/歌曲/音频创作工具,可通过文字描述或音频提示来控制生成内容。

目前,InspireMusic已开源了音乐生成的训练和推理代码,支持通过简单的文字描述或音频提示,快速生成多种风格的音乐作品。

InspireMusic的文生音乐创作模式涵盖了多种曲风、情感表达和复杂的音乐结构控制,提供了极大的创作自由度和灵活性。未来计划进一步开放歌唱生成和音频生成的基础模型,欢迎研究者、开发者及用户积极参与体验和研发。该开源工具包为社区开发者提供了丰富的技术资源,支持从学术研究到产品开发的广泛应用。

🎶 主要特点

  • 统一的音频生成框架:基于音频大模型技术,InspireMusic支持音乐、歌曲及音频的生成,为用户提供多样化选择;
  • 灵活可控生成:基于文本提示和音乐特征描述,用户可精准控制生成音乐的风格和结构;
  • 简单易用:简便的模型微调和推理工具,为用户提供高效的训练与调优工具。

🌟代码仓库

核心模型

InspireMusic由音频tokenizer、自回归Transformer模型、基于常微分方程的扩散模型即Conditional Flow Matching (CFM)模型、Vocoder所组成,可支持文本生成音乐、音乐续写等任务。通过具有高压缩比的单码本WavTokenizer将输入的连续音频特征转换成离散音频token,然后利用基于Qwen模型初始化的自回归Transformer模型预测音频token,再由CFM扩散模型重建音频的潜层特征,最终通过Vocoder输出高质量的音频波形。两种推理模式的设计:fast模型和高音质模型,为不同需求的用户提供了灵活的选择。

工具包安装使用指南

第一步:下载代码库

git clone --recursive https://github.com/FunAudioLLM/InspireMusic.git
# If you failed to clone submodule due to network failures, please run the following command until success
cd InspireMusic
git submodule update --init --recursive

第二步:安装代码库

conda create -n inspiremusic python=3.8
conda activate inspiremusic
cd InspireMusic
# pynini is required by WeTextProcessing, use conda to install it as it can be executed on all platforms.
conda install -y -c conda-forge pynini==2.1.5
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
# install flash attention to speedup training, support version 2.6.3
pip install flash-attn --no-build-isolation

第三步:下载模型

InspireMusic-Base模型(https://www.modelscope.cn/iic/InspireMusic)
# git模型下载,请确保已安装git lfs
mkdir -p pretrained_models
git clone https://www.modelscope.cn/iic/InspireMusic.git pretrained_models/InspireMusic-Base

第四步:基本用法说明快速开始

cd InspireMusic/examples/music_generation/
bash run.sh

训练LLM和flow matching模型样例脚本。

torchrun --nnodes=1 --nproc_per_node=8 \
    --rdzv_id=1024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
    inspiremusic/bin/train.py \
    --train_engine "torch_ddp" \
    --config conf/inspiremusic.yaml \
    --train_data data/train.data.list \
    --cv_data data/dev.data.list \
    --model llm \
    --model_dir `pwd`/exp/music_generation/llm/ \
    --tensorboard_dir `pwd`/tensorboard/music_generation/llm/ \
    --ddp.dist_backend "nccl" \
    --num_workers 8 \
    --prefetch 100 \
    --pin_memory \
    --deepspeed_config ./conf/ds_stage2.json \
    --deepspeed.save_states model+optimizer \
    --fp16

torchrun --nnodes=1 --nproc_per_node=8 \
    --rdzv_id=1024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
    inspiremusic/bin/train.py \
    --train_engine "torch_ddp" \
    --config conf/inspiremusic.yaml \
    --train_data data/train.data.list \
    --cv_data data/dev.data.list \
    --model flow \
    --model_dir `pwd`/exp/music_generation/flow/ \
    --tensorboard_dir `pwd`/tensorboard/music_generation/flow/ \
    --ddp.dist_backend "nccl" \
    --num_workers 8 \
    --prefetch 100 \
    --pin_memory \
    --deepspeed_config ./conf/ds_stage2.json \
    --deepspeed.save_states model+optimizer

推理脚本

cd InspireMusic/examples/music_generation/
bash infer.sh

带有CFM的推理模式

pretrained_model_dir = "pretrained_models/InspireMusic/"
for task in 'text-to-music' 'continuation'; do
  python inspiremusic/bin/inference.py --task $task \
      --gpu 0 \
      --config conf/inspiremusic.yaml \
      --prompt_data data/test/parquet/data.list \
      --flow_model $pretrained_model_dir/flow.pt \
      --llm_model $pretrained_model_dir/llm.pt \
      --music_tokenizer $pretrained_model_dir/music_tokenizer \
      --wavtokenizer $pretrained_model_dir/wavtokenizer \
      --result_dir `pwd`/exp/inspiremusic/${task}_test \
      --chorus verse \
      --min_generate_audio_seconds 8 \
      --max_generate_audio_seconds 30 
done

不带CFM的fast推理模式

pretrained_model_dir = "pretrained_models/InspireMusic/"
for task in 'text-to-music' 'continuation'; do
  python inspiremusic/bin/inference.py --task $task \
      --gpu 0 \
      --config conf/inspiremusic.yaml \
      --prompt_data data/test/parquet/data.list \
      --flow_model $pretrained_model_dir/flow.pt \
      --llm_model $pretrained_model_dir/llm.pt \
      --music_tokenizer $pretrained_model_dir/music_tokenizer \
      --wavtokenizer $pretrained_model_dir/wavtokenizer \
      --result_dir `pwd`/exp/inspiremusic/${task}_test \
      --chorus verse \
      --fast \
      --min_generate_audio_seconds 8 \
      --max_generate_audio_seconds 30 
done

WeTextProcessing-文本[逆]正则化

Github:https://github.com/wenet-e2e/WeTextProcessing

摘自:https://mp.weixin.qq.com/s/q_11lck78qcjylHCi6wVsQ

Funasr仓库:

Motivation

文本正则化(Text Normalization,TN)和反正则化(Inverse Text Normalization,ITN)是构建一个完整的语音交互系统不可或缺的部分。前者广泛用于语音合成系统的前端处理,而后者则在语音识别系统的识别文本上屏显示时影响着字幕的观感体验。

当前学术界中被广泛研究的 TN / ITN 系统主要有三种类型:

  • 基于语法规则的 WFST [1]:这种系统由大量特定于语言的语法组成,优点是准确可控,可以快速修 bug ,缺点是对于容易产生歧义的文本不够鲁棒。
  • 基于神经网络的端到端模型 [2]:构建这种模型时,挑战从撰写更精确的语法规则变成了标注和收集覆盖范围更广的数据。端到端模型的一个主要缺点是会产生无法恢复的错误,这时经系统转换后的文字可能在语法上是合理的,但却与原始文本的语义大相径庭。此外,对于 badcase 的修复也不如规则的方式快捷。
  • 同时使用规则语法和神经网络的混合系统 [3]:在混合框架中,只有当系统没有找到匹配的语法规则才会转用神经网络。这种方式比较好地权衡了规则和 NN 的优劣,但是对计算资源提出了更高的要求。

鉴于以上三种系统的优劣,WeTextProcessing 选择实现基于语法规则的WFST 方案。在全球范围内的开源TN/ITN 项目中,目前受众最广泛的是谷歌公司推出的C++ 框架 Sparrowhawk [4] 。该框架的不足之处是它仅仅是一个规则执行引擎,谷歌公司并没有开源相关语言的语法规则。此外,Sparrowhawk 的实现依赖了许多第三方开源库(包括 OpenFst 、Thrax 、re2 、protobuf ),导致整体框架不够简便、轻量化。另一个较为成熟的项目是英伟达公司开源的 nemo_text_processing [5],该项目依旧使用Sparrowhawk 作为生产环境下的部署工具。与谷歌不同的是,该项目还开源了诸如英语、德语、俄语等多种语言的规则语法。在中文 TN / ITN 规则领域,Jiayu 等第三方个人开发者曾开源出一套定制化的中文 TN / ITN 规则库 chinese_text_normalization [6]

站在这些优秀开源项目的肩膀上,WeTextProcessing秉承 简单易用 和Production First & Production Ready 的原则,为中文专门设计和实现一款开源易用的 TN / ITN 工具,它不仅仅包含了包含一套完整的中文 TN / ITN 规则语法,同时也提供了一个可以一键 pip install 使用的 py工具包以及比Sparrowhawk 依赖项更少(生产环境下仅依赖 OpenFst )的整体更轻量化的 C++ 规则处理引擎。

快速上手

一键install,六行代码搞定文本处理!

# install
pip install WeTextProcessing

# tn usage
>>> from tn.chinese.normalizer import Normalizer
>>> normalizer = Normalizer()
>>> normalizer.normalize("2.5平方电线")

# itn usage
>>> from itn.chinese.inverse_normalizer import InverseNormalizer
>>> invnormalizer = InverseNormalizer()
>>> invnormalizer.normalize("二点五平方电线")

技术细节

TN 和 ITN 的流程都是包含三个部分:Tagger, Reorder 和 Verbalizer。Tagger 负责对输入的文本进行解析,得到结构化的信息。Reorder 负责对结构化信息进行顺序的调整。最终 Verbalizer 负责将重排序之后的结构化信息拼接起来。

TN 流程

ITN 流程

语法规则设计

WeTextProcessing 使用 pynini [7] 来编写和编译规则语法,规则语法可以将一个字符串转换为另一个字符串。规则语法通常可以表示为一个 WFST,pynini 的底层使用了 OpenFst 来实现 WFST 相关的功能。使用 pynini 编写的规则语法示例如下图所示:

  • digits = zero | digit 的 | 操作符表示 WFST 理论中的 union 操作;
  • cross(‘十’, ‘1’) 表示 WFST 理论中弧上的输入是“十”,输出是“1”,WFST 从一个状态转到另一个状态时若经过该弧则说明系统匹配到了“十”并成功将其转换为了“1”;
  • delete(‘十’) 表示弧上的输入是“十”,输出是空,即经过该弧时会删除“十”;
  • digit + delete(‘十’) 中 + 表示WFST理论中的 concat 操作,它将两个fst连起来;
  • accep(‘兆’) 表示弧的输入和输出都是“兆”,此时 WFST 相当于一个 FSA;
  • addzero**2addzero**3 分别表示将 addzero 重复两次和三次;
  • digits.ques 和 digits.plus 则分别表示将 digits 重复零到一次 和 重复一到无穷次

此外还有一些语法特性,比如下图中:

  • add_weight(Char().tagger, 100) 表示为 Char().tagger 这条路径赋予权重(路径长度)为 100。当有多条路径都可以匹配当前输入时,我们取最短路径作为终选结果。例如“一点零五分”最终会被 ITN 成 “1:05” 而不是 “1.05分”。
  • insert(‘ ‘) 表示弧上的输入和输出分别是“”和“ ”,即经过该弧时会强制插入一个空格。
  • processor @ tagger.optimize() 中 @ 表示将两个 fst 进行 compose 操作,optimize() 表示对 tagger 进行 epsilon-removal,determinization 以及 minimization [8]
  • ‘[EOS]’ 表示正则表达式中匹配到的 string 的结尾,同理这里没有列出的 ‘[BOS]’ 则表示开头 [9]

更多详尽的说明请参考pynini 的相关文档 [7]。对于本文所构建的所有WFST,我们采用 OpenFst 中默认的热带半环作为其类型,做出这个选择的原因是此类型对求网格图中的最短路径的操作有效率优势,其路径权重的计算仅需对沿路径的所有弧的权重进行简单求和。

进阶用法

如何快速修 badcase

当遇到 badcase 的时候,我们首先需要确定 badcase 属于什么类型,日期?时间?还是分数等等?是没有转换,还是转换成了其他类型。然后再去相对应的 rules 中进行修复,可能需要改代码,也可能需要改 tsv 文件。

比如若 ITN 系统将 “三心二意” 错误转成了 “3心2意” 则有两种解决方案:

  1. 在 whitelist.tsv 添加相关的映射放弃相关词汇的转换
  2. 将enable_standalone_number设置为False,此时系统对不带单位的数字不会进行转换

值得注意的是,WeTextProcessing 大多数失败案例是由于上下文歧义或特殊案例造成的长尾问题。例如,“三点五分” 可以是时间 “3:05” 也可以是量词 “3.5 分” 表示运动员得分。编写语法时若考虑更多的上下文可以一定程度上缓解这种情况,例如,如果 “三点五分” 前面有单词 “得到” ,则将其检测为运动员得分。当然,这种打补丁的方式并不能适用于所有情况。出于这个原因,如果想要设计一个能够覆盖 100% 场景的系统,语法的数量将不可避免呈指数级增长。其他常见的失败案例是由于定义不完整。例如,如果没有预定义 “千瓦时” 到 “kwh” 的度量缩写转换,系统将无法转换 “两百千瓦时” 为 “200kwh” 。这个问题相对来说容易解决,仅需在已有的量词类中添加所需的转换规则。

生产环境部署

对于想要自己对规则进行DIY的用户,可以通过以下方式获得自己的规则文件并部署到不同的环境中。

git clone https://github.com/wenet-e2e/WeTextProcessing.git
cd WeTextProcessing
# `overwrite_cache` will rebuild all rules according to
#   your modifications on tn/chinese/rules/xx.py (itn/chinese/rules/xx.py).
#   After rebuild, you can find new far files at `$PWD/tn` and `$PWD/itn`.
python normalize.py --text "2.5平方电线" --overwrite_cache
python inverse_normalize.py --text "二点五平方电线" --overwrite_cache

在已经pip安装好的工具包中使用自己的规则:

# tn usage
>>> from tn.chinese.normalizer import Normalizer
>>> normalizer = Normalizer(cache_dir="PATH_TO_GIT_CLONED_WETEXTPROCESSING/tn")
>>> normalizer.normalize("2.5平方电线")# itn usage
>>> from itn.chinese.inverse_normalizer import InverseNormalizer
>>> invnormalizer = InverseNormalizer(cache_dir="PATH_TO_GIT_CLONED_WETEXTPROCESSING/itn")
>>> invnormalizer.normalize("二点五平方电线")

在C++中使用自己的规则:

cmake -B build -S runtime -DCMAKE_BUILD_TYPE=Releasecmake --build build
# tn usage
./build/bin/processor_main --far PATH_TO_GIT_CLONED_WETEXTPROCESSING/tn/zh_tn_normalizer.far --text "2.5平方电线"
# itn usage
./build/bin/processor_main --far PATH_TO_GIT_CLONED_WETEXTPROCESSING/itn/zh_itn_normalizer.far --text "二点五平方电线"

总结和展望

未来,WeTextProcessing 的工作将聚焦在对 Corner Case 的规则修补:相比于规则撰写,设计一套合理的测试集是一件更为困难的事情,这是因为实际生产过程中总会遇到数不清的 corner case 。WeTextProcessing 中虽然提供了一个简单的单元测试和示例测试,但其覆盖场景仍未能达到 100% 。在未来,WeTextProcessing 的重点方向之一就是越来越多地投入部署到真实的线上环境中,以身试错,case by case 分析当前规则存在的可能漏洞并加以弥补。

参考资料

[1] Peter Ebden and Richard Sproat, “The kestrel TTS text normalization system,” Nat. Lang. Eng., vol. 21, no. 3, pp. 333–353, 2015.

[2] Courtney Mansfield, Ming Sun, Yuzong Liu, Ankur Gandhe, and Björn Hoffmeister, “Neural text normalization with subword units,” in Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, NAACL-HLT 2019, Minneapolis, MN, USA, June 2-7, 2019, Volume 2 (Industry Papers), Anastassia Loukina, Michelle Morales, and Rohit Kumar, Eds. 2019, pp. 190–196, Association for Computational Linguistics.

[3] Richard Sproat and Navdeep Jaitly, “An RNN model of text normalization,” in Interspeech 2017, 18th Annual Conference of the International Speech Communication Association, Stockholm, Sweden, August 20-24, 2017, Francisco Lacerda, Ed. 2017, pp. 754–758, ISCA.

[4] Peter Ebden and Richar Sproat, “Sparrowhawk,” 2022, https://github.com/google/sparrowhawk.

[5] Yang Zhang, “nemo_text_processing,” 2022, https://github.com/NVIDIA/NeMo/tree/main/nemo_text_processing.

[6] Jiayu Du, “chinese_text_normalization,” 2022, https://github.com/speechio/chinese_text_normalization.

[7] K. Gorman. 2016. Pynini: A Python library for weighted finite-state grammar compilation. In Proceedings of the ACL Workshop on Statistical NLP and Weighted Automata, pages 75-80.

[8] https://www.opengrm.org/twiki/bin/view/GRM/PyniniOptimizeDoc

[9] https://www.openfst.org/twiki/bin/view/GRM/ThraxQuickTour

FireRedASR -小红书语音识别大模型

小红书 FireRed 团队正式发布并开源了基于大模型的语音识别模型 ——FireRedASR,在语音识别领域带来新突破。在业界广泛采用的中文普通话公开测试集上,FireRedASR 凭借卓越的性能取得了新 SOTA!FireRedASR 在字错误率(CER)这一核心技术指标上,对比此前的 SOTA Seed-ASR,错误率相对降低 8.4%,充分体现了团队在语音识别技术领域的创新能力与技术突破。

FireredAsr,旨在满足各种应用程序中出色的性能和最佳效率的各种要求。 fireredasr包括两个变体:

FireRedASR-LLM

采用Encoder-Adapter-LLM,结合了文本预训练 LLM 的能力,为极致的 ASR 准确率而生,适用于对准确率要求极高的应用场景。在公共普通话基准上,fireredasr-LLM (8.3b参数)达到3.05%的平均字符错误率(CER),超过了3.33%的最新SOTA,相对CER(CERR)8.4%。它显示出优于工业级基线的卓越概括能力,在多源普通话ASR方案(例如视频,现场和智能助理)中,达到24%-40%的CERR。

FireRedASR-AED

基于经典的 Attention-based Encoder-Decoder 架构,FireRedASR-AED 通过扩展参数至 1.1B,成功平衡了 ASR 语音识别的高准确率与推理效率。适用于资源受限的应用程序。

主要贡献

  • High-Accuracy Models with Efficiency: ASR识别准确率优于Seed-ASR[字节跳动],模型在保持效率的同时达到卓越精度的能力。
  • Robust Real-World Performance: 在各种实用的场景中,包括简短的视频,直播,字幕生成,语音输入和智能助手,我们的模型表现出了出色的功能,与相比的相对减少(CERR)相比实现了24%-40%流行的开源基线和领先的商业解决方案。
  • 多功能识别能力:支持方言/中文/英文/歌曲识别。而且在歌词识别中表现出色

模型结构:

FireRedASR-AED是基于注意的编码器-解码器 ASR模型。训练数据:包括大约70,000小时的音频数据,主要是高质量的普通话语音。与Whisper中使用的弱标记数据集不同,我们的大多数数据都是由专业注释者手动转录的,从而确保了高转录精度和可靠性。该数据集还包含大约11,000小时的英语语音数据,以增强英语ASR功能。

Input Features: 输入25ms窗口的80-dimensional  log Mel filterbank (Fbank),10ms frame shifts,然后是全局均值和方差归一化。

Encoder Structure:编码器由两个主要组件组成:一个下采样模块和Conformer  blocks堆叠。

Decoder Structure:解码器遵循Transformer 体系结构。

Tokenization:BPE编码英文文本, 1,000 English BPE tokens, 6,827 Chinese characters, and 5 special tokens.

FireRedASR-LLM: Encoder-Adapter-LLM 架构。

Input Features and Encoder: 训练数据和处理、encoder跟FireredAsr-AED相同。

Adapter Structure:一个简单但有效的线性RELU线性网络组成,该网络投射了编码器的输出维度,以匹配输入LLM。在适配器的开头合并了一个额外的框架剪接操作。此操作进一步将时间分辨率从40ms降低到每个帧的80ms,从而降低了序列长度并提高了计算效率LLM。

LLM初始化和处理:LLM用QWEN2-7B-INSTRUCT的预训练的重量初始化。训练数据格式:(prompt, speech, transcript)

Training Strategy编码器和适配器是完全训练的,LLM采用lora微调,保证LLM的文本能力。此策略可确保编码器和适配器经过充分训练,以将语音特征映射到LLM的语义空间中,同时保留其预训练能力。训练目标基于交叉熵损失,损失仅在输入的转录部分上计算,忽略提示和语音嵌入。

Evaluation

缩放定律的观察

LLMs 方面的最新研究表明,模型性能通常会随着模型尺寸的增加而提高,这称为缩放定律 。如表3所示,我们研究了具有不同模型大小的模型的缩放行为。对于 FireRedASR-AED,我们将模型大小逐步从 140M、413M、732M 扩展到 1.1B 参数。随着模型尺寸的增加,性能持续提高,从 XS 扩展到 S、从 S 扩展到 M 以及从 M 扩展到 L 配置时分别实现 6.1%、5.3% 和 5.6% 的 CERR。对于 FireRedASR-LLM,专注于扩展编码器,同时保持 LLM 主干不变。编码器大小从 86M 增加到 710M 参数,适配器参数的变化很小(17M 到 22M)。这表现出相似的扩展模式并带来一致的性能改进,从 XS(3.29%)到 L(3.05%)配置的总体 CERR 为 7.3%。这些结果证明了我们的扩展策略的有效性,并表明通过更大的模型容量可以进一步改进。

下图是 FireRedASR 和其他 ASR 大模型的对比,在业界常用的中文普通话公开测试集上,FireRedASR-LLM(8.3B 参数量)取得了最优 CER 3.05%、成为新 SOTA!FireRedASR-AED (1.1B 参数量)紧随其后取得 3.18%,两者均比 Seed-ASR(12+B 参数量)的 3.33% 低、并且参数量更小。FireRedASR 也比 Qwen-Audio、SenseVoice、Whisper、Paraformer 取得了更优的 CER。

FireRedASR 不仅在公开测试集上表现优异,在多种日常场景下,也展现了卓越的语音识别效果。
如下图所示,在由短视频、直播、语音输入和智能助手等多种来源组成的 Speech 测试集上,与业内领先的 ASR 服务提供商(ProviderA)和 Paraformer-Large 相比, FireRedASR-LLM 的 CER 相对降低 23.7%~40.0%,优势十分明显。
值得一提的是,在需要歌词识别能力的场景中,FireRedASR-LLM 也表现出极强的适配能力,CER 实现了 50.2%~66.7% 的相对降低,这一成果进一步拓宽了 FireRedASR 的应用范围,使其不仅能胜任传统语音识别需求,还能在创新性的多媒体场景中大放异彩。

值得一提的是,FireRedASR 在中文方言和英语场景中同样表现不俗。在 KeSpeech(中文方言)和 LibriSpeech(英语)测试集上,FireRedASR 的 CER 显著优于此前的开源 SOTA 模型,使其在支持好普通话 ASR 的前提下,在中文方言和英语上也足够通用,进一步凸显了其鲁棒的语言适配能力。

Discussion:

FireredAsr模型优于竞争模型的原因:

高质量和多样化的训练数据:语料库主要由从现实世界情景中收集的专业转录音频组成,该音频比在受控环境中提供的传统阅读式录音相比,它提供的训练信号明显更高。该数据集包括声音条件,扬声器,重音和内容域的广泛差异,总计数万小时。这种多样性和规模使我们的模型能够学习强大的语音表征和语言模式。

实证研究表明,一千小时的高质量,人工标注的数据比一万小时的弱标记数据(例如,来自视频标题,OCR结果或其他ASR模型的输出)更好的结果,这解释了我们比Whisper的优势 。此外,在我们的语料库中包含唱歌数据为处理音乐内容时的基线模型的显着改进做出了贡献。

优化的训练策略:将FireredAsr-A的扩展为140m到1.1b参数时,我们将正则化和学习率确定为影响模型收敛的关键因素。我们制定了一种渐进式正则化训练策略:最初没有正则化技术以实现快速收敛,然后逐渐引入更强的正则化,因为出现了过度拟合的趋势。此外,较大的模型需要降低学习率,这对于调整此参数的最佳性能至关重要。

高效的ASR框架

总结:提出了fireredasr-LLM和FireredAsr-AED,两种针对普通话优化的高性能ASR模型。通过全面的评估,我们证明了他们的体系结构,培训策略和高质量的数据集可以在保持计算效率的同时达到最先进的性能。

DeepSeek-R1 技术报告

摘自:https://zhuanlan.zhihu.com/p/19744278380

Github: https://github.com/deepseek-ai/DeepSeek-R1

DeepSeek-R1:通过强化学习提升LLM的推理能力

R1训练流程:

•冷启动 •基于推理的强化学习 •Rejection Sampling •SFT •全场景强化学习

DeepSeek-R1-Zero 采用大规模强化学习(RL)进行训练,无需预先进行监督微调(SFT),表现出显著的推理能力。在强化学习过程中,DeepSeek-R1-Zero 展现出多种卓越且新颖的推理特性。但该模型仍面临可读性不足语言混杂等问题。

为解决这些问题并进一步增强推理性能,研究团队开发了 DeepSeek-R1,该模型在进行强化学习前引入了多阶段训练和冷启动数据。

DeepSeek-R1 在推理任务上实现了与 OpenAI-o1-1217 相当的性能水平

为促进学术研究发展,研究团队开源了 DeepSeek-R1-Zero、DeepSeek-R1,以及基于 Qwen 和 Llama 架构从 DeepSeek-R1 知识蒸馏获得的六个稠密模型(1.5B、7B、8B、14B、32B、70B)。

引言

近年来,LLM技术发展迅速,不断缩小与AGI的差距。后训练技术已成为完整训练流程中的关键环节,证实能够提升推理任务准确率,实现社会价值观对齐,适应用户偏好,同时相较于预训练所需计算资源较少。在推理能力方面,OpenAI的o1系列模型首次通过延长Chain-of-Thought(CoT)推理过程引入了推理时扩展机制,在数学、编程和科学推理等多个推理任务中取得显著进展。

然而,如何实现有效的测试时扩展仍是学术界面临的重要课题。前期研究探索了多种方法,包括过程型奖励模型、强化学习以及蒙特卡洛树搜索和束搜索等算法。但这些方法均未能达到与OpenAI的o1系列模型相当的通用推理水平。

本研究采用纯RL方法提升语言模型的推理能力。研究旨在探索LLM在无监督数据条件下通过纯RL过程实现自我进化的推理能力潜力。

具体而言,研究选用DeepSeek-V3-Base作为基础模型,采用群组相对策略优化(GRPO)作为RL框架提升模型推理性能。在训练过程中,DeepSeek-R1-Zero自然形成了多种高效且创新的推理特征。经过数千轮RL迭代,DeepSeek-R1-Zero在推理基准测试中展现出优异性能。例如,在AIME 2024测试中,pass@1得分从15.6%提升至71.0%,采用majority voting机制后,得分进一步提高到86.7%,达到OpenAI-o1-0912的性能水平。

然而,DeepSeek-R1-Zero仍面临可读性不足、语言混杂等挑战。

为解决这些问题并进一步提升推理性能,研究团队开发了DeepSeek-R1模型,该模型整合了初始训练数据和多阶段训练流程。具体实施步骤包括:首先收集数千条初始训练数据用于DeepSeek-V3-Base模型的微调;随后进行推理强化学习训练;在RL过程接近收敛时,通过拒绝采样(rejection sampling)方法从RL检查点生成新的SFT数据,并结合DeepSeek-V3在写作、事实QA和自我认知等领域的监督数据重新训练DeepSeek-V3-Base模型;最后,使用新数据完成微调后的检查点进行额外的RL训练,综合考虑各类场景的提示词。

经过上述步骤,最终获得的DeepSeek-R1模型达到了与OpenAI-o1-1217相当的性能水平。

研究进一步探索了从DeepSeek-R1到较小dense模型的知识蒸馏。以Qwen2.5 32B为基础模型,直接从DeepSeek-R1进行知识蒸馏的效果优于直接应用RL训练,表明大型基础模型所发现的推理模式对提升推理能力具有关键作用。研究团队已开源蒸馏后的Qwen和Llama系列模型。

值得注意的是,14B蒸馏模型的性能显著超越了当前最先进的开源模型QwQ-32B-Preview,而32B和70B蒸馏模型则在稠密模型推理基准测试中创造了新的记录

主要贡献

后训练:基础模型的大规模强化学习应用

  • 本研究直接将RL应用于基础模型,无需将SFT作为前置步骤。这种方法使模型能够通过CoT探索复杂问题的解决方案,最终开发出DeepSeek-R1-Zero模型。DeepSeek-R1-Zero具备自我验证、反思和生成长CoT等能力,为学术界提供了重要研究成果。这是首个验证LLM推理能力可纯粹通过RL提升而无需SFT的开放研究,为该领域未来发展奠定基础
  • 研究提出了DeepSeek-R1的开发流程,包含两个RL阶段用于优化推理模式和人类偏好对齐,以及两个SFT阶段用于构建模型的推理和非推理基础能力。该流程将有助于行业开发更高性能的模型。

知识蒸馏:小型模型的性能提升

  • 研究表明大型模型的推理模式可通过知识蒸馏迁移至小型模型,其效果优于直接对小型模型进行RL训练。开源的DeepSeek-R1及其API将支持学术界开发更优秀的小型模型
  • 利用DeepSeek-R1生成的推理数据,研究团队对学术界广泛使用的多个稠密模型进行了微调。评估结果显示,经过知识蒸馏的小型dense模型在基准测试中表现优异。DeepSeek-R1-Distill-Qwen-7B在AIME 2024上达到55.5%的性能,超越QwQ-32B-Preview。DeepSeek-R1-Distill-Qwen-32B在AIME 2024、MATH-500和LiveCodeBench上分别达到72.6%、94.3%和57.2%的成绩,显著优于现有开源模型,达到与o1-mini相当的水平。研究团队已向学术界开源基于Qwen2.5和Llama3系列的1.5B、7B、8B、14B、32B和70B蒸馏检查点

研究方法

概述

传统研究主要依赖大规模监督数据提升模型性能。本研究证实,即使在无需监督微调(SFT)作为初始训练的情况下,通过大规模强化学习(RL)也能显著提升推理能力。此外,引入适量初始训练数据可进一步优化性能。后续章节将介绍:(1)DeepSeek-R1-Zero:直接对基础模型应用RL,无需任何SFT数据;(2)DeepSeek-R1基于经数千个长CoT样例微调的检查点进行RL训练;(3)将DeepSeek-R1的推理能力通过知识蒸馏迁移至小型稠密模型

DeepSeek-R1-Zero:基础模型的强化学习应用

前期相关研究表明强化学习在推理任务中具有显著效果。然而,这些研究高度依赖耗时的监督数据采集。本节探索LLM在无监督数据条件下通过纯强化学习实现推理能力自我进化的潜力。研究首先概述强化学习算法,随后展示实验结果,以期为学术界提供研究参考。

强化学习算法

群组相对策略优化(GRPO): 为优化RL训练成本,研究采用GRPO算法,摒弃了通常与策略模型规模相当的评论家模型,转而通过群组评分估计基线。具体而言,对每个问题 q ,GRPO从旧策略 πθold 采样输出组{ o1,o2,…,oG },通过最大化以下目标优化策略模型 πθ :

其中 ε 和 β 是超参数, Ai 是优势函数,使用组内每个输出对应的奖励组{ r1,r2,…,rG }计算得到:

奖励建模

奖励机制作为训练信号来源,决定RL的优化方向。DeepSeek-R1-Zero采用基于规则的双重奖励系统

  • 准确性奖励:评估响应正确性。如对确定性数学问题,要求模型以特定格式(如方框内)提供最终答案,实现基于规则的可靠验证。对LeetCode问题,则通过编译器基于预设测试用例生成反馈。
  • 格式奖励:要求模型将推理过程置于指定标签对内。研究未采用结果或过程神经奖励模型,原因在于神经奖励模型可能在大规模RL过程中产生奖励欺骗问题,且重训奖励模型需额外资源,增加训练流程复杂度。

训练模板

DeepSeek-R1-Zero的训练始于简洁指令模板的设计。

如表1所示,模板要求模型首先生成推理过程,随后给出最终答案。研究刻意将约束限定于结构格式,避免引入内容偏见(如强制反思推理或特定问题解决策略),以准确观测模型在RL过程中的自然演化。

DeepSeek-R1-Zero的性能分析、演化过程及关键突破

性能分析 图2记录了DeepSeek-R1-Zero在RL训练过程中AIME 2024基准测试的性能变化轨迹。

图2 | DeepSeek-R1-Zero训练过程中的AIME准确率变化。为确保评估稳定性,对每个问题采样16个响应并计算总体平均准确率。

数据显示,随着RL训练的深入,模型性能呈现稳定上升趋势。在AIME 2024测试中,平均pass@1得分从初始的15.6%显著提升至71.0%,达到OpenAI-o1-0912的性能水平,充分证实了RL算法在模型性能优化方面的有效性。

表2 | DeepSeek-R1-Zero与OpenAI o1模型在推理相关基准测试上的性能对比。

表2详细对比了DeepSeek-R1-Zero与OpenAI o1-0912模型在各类推理基准测试上的表现。结果表明,纯RL训练使DeepSeek-R1-Zero获得了出色的推理能力,无需借助监督微调数据,这证实了模型通过单一RL机制实现有效学习和泛化的能力。通过引入majority voting机制,模型性能得到进一步提升。例如,在AIME基准测试中,采用majority voting后性能从71.0%提升至86.7%,超越OpenAI-o1-0912。这种优异表现凸显了模型的基础能力和推理潜力。

演化过程分析 DeepSeek-R1-Zero的演化过程展示了RL在推理能力自主优化方面的显著效果。通过直接对基础模型实施RL训练,研究得以在无监督微调影响下观测模型进展。

图3 | 展示DeepSeek-R1-Zero在RL训练过程中训练集的平均响应长度变化,反映模型自主习得延长推理时间的能力。

如图3所示,模型的推理时长在训练过程中持续优化,这种进展源于模型的内生发展而非外部干预。DeepSeek-R1-Zero通过扩展测试计算时间,自然形成了解决复杂推理任务的能力。其计算规模从数百到数千个推理token不等,实现了深度的思维探索和优化。随着测试计算时间的延长,模型展现出复杂的行为特征,包括反思机制(重新评估先前推理步骤)和多元问题解决策略的探索。这些行为模式并非预设,而是源于模型与RL环境的交互作用,显著增强了其处理高难度任务的效率和准确性。

关键突破与局限性 研究过程中观察到模型出现重要突破,如表3所示,体现在中期版本中。

表3:记录DeepSeek-R1-Zero中期版本的重要突破,展示模型获得自主思考复核能力的过程,体现RL在模型能力提升方面的有效性。

此阶段,DeepSeek-R1-Zero习得了重新评估初始方法并延长思考时间的能力。这一进展不仅体现了模型推理能力的提升,也展示了RL在实现复杂学习成果方面的潜力。这种现象验证了RL的核心优势:通过适当的激励机制,促使模型自主发展高级问题解决策略。

然而,DeepSeek-R1-Zero仍存在若干局限性。尽管具备强大的推理能力和创新的推理行为,但在可读性和语言一致性方面仍面临挑战。为提高推理过程的可读性并促进开放社区交流,研究团队开发了DeepSeek-R1模型,该模型结合了RL和用户友好的初始训练数据。

DeepSeek-R1:基于冷启动的强化学习方法

基于DeepSeek-R1-Zero的成功实践,研究聚焦两个核心问题:

  1. 通过引入少量高质量数据作为冷启动,是否能够进一步提升推理性能或加速收敛?
  2. 如何开发既能生成清晰连贯的CoT,又具备强大通用能力的用户友好型模型?

为解决上述问题,研究团队设计了四阶段训练流程

冷启动机制

区别于DeepSeek-R1-Zero,DeepSeek-R1采用少量长CoT数据对模型进行预微调作为初始RL策略网络,以避免基础模型RL训练早期的不稳定性。数据收集采用多种方法:

  • 基于长CoT示例的少样本提示
  • 直接提示生成包含反思验证的详细答案
  • 整理DeepSeek-R1-Zero的规范化输出
  • 人工标注后处理优化

研究收集数千条冷启动数据用于DeepSeek-V3-Base的预训练。相较于DeepSeek-R1-Zero,冷启动数据具有以下优势:

  • 可读性增强:克服了DeepSeek-R1-Zero输出内容可读性差的局限。通过设计标准化输出模式,包括响应末尾的总结性内容,并筛除不符合阅读友好性要求的输出。输出采用|special_token|<reasoning_process>|special_token|<summary>格式,包含查询的推理过程和结果摘要。
  • 性能提升:基于人类认知模式优化的冷启动数据设计,展现出优于DeepSeek-R1-Zero的性能表现,验证了迭代训练对推理模型的优越性。

推理强化学习优化

完成冷启动数据预训练后,采用与DeepSeek-R1-Zero类似的大规模RL训练流程,重点提升模型在编码、数学、科学和逻辑等明确定义问题域的推理能力。在训练过程中发现Chain-of-Thought存在语言混杂现象,尤其是多语言提示场景下。为此引入语言一致性奖励机制,基于目标语言词占比计算。尽管消融实验显示该机制略微影响模型性能,但提升了人类使用体验。最终将任务准确率和语言一致性奖励合并计算总体奖励,持续RL训练直至模型在推理任务上收敛。

拒绝采样与监督微调

推理RL收敛后,利用检查点生成后续SFT数据。不同于专注推理的冷启动阶段,此阶段整合多领域数据以增强模型的写作、角色扮演等通用能力。具体实施如下:

推理数据构建 通过对RL训练检查点执行拒绝采样生成推理轨迹。扩展了评估机制,除规则型奖励外,引入基于DeepSeek-V3判断的生成式奖励模型。优化输出质量,过滤混杂语言、冗长段落和代码块。对每个提示词进行多样本采样,保留正确结果。最终获得约60万条推理训练样本。

非推理数据整合 在写作、事实QA、自我认知和翻译等领域,采用DeepSeek-V3流程和部分SFT数据。对复杂非推理任务,通过提示DeepSeek-V3生成前置CoT;对简单查询则直接响应。累计获取约20万条非推理训练样本。使用总计约80万样本数据对DeepSeek-V3-Base执行两轮微调。

全场景强化学习

优化人类偏好对齐,实施第二阶段RL训练,着重提升模型实用性、安全性和推理能力。采用多元奖励信号和多样化提示分布:

  • 推理数据:延续DeepSeek-R1-Zero方法,在数理逻辑领域应用规则型奖励
  • 通用数据:采用奖励模型捕捉复杂场景下的人类偏好
  • 实用性评估:专注于响应摘要,确保输出的实用性和相关性
  • 安全性保障:全面评估推理过程和摘要,识别并降低潜在风险

通过奖励信号和数据分布的系统整合,实现了推理能力和用户体验的均衡发展。

知识蒸馏:增强小型模型的推理能力

本研究采用DeepSeek-R1生成的80万训练样本,对Qwen和Llama等开源模型进行直接SFT微调,旨在将DeekSeek-R1的推理能力迁移至计算效率更高的小型模型。

实验结果表明,这种直接知识蒸馏方法显著提升小型模型的推理性能

研究选用的基础模型包括:Qwen2.5-Math-1.5B、Qwen2.5-Math-7B、Qwen2.5-14B、Qwen2.5-32B、Llama-3.1-8BLlama-3.3-70B-Instruct

选择Llama-3.3的原因在于其推理能力较Llama-3.1略有优势

蒸馏过程中仅采用SFT,未纳入RL阶段,尽管引入RL可能带来显著的性能提升。研究重点在于验证知识蒸馏技术的有效性,为后续学术界对RL优化的深入研究奠定基础。

实验设计与评估

研究采用多维度基准测试体系评估模型性能:

标准评估基准 8类16个评估标准如下所示:

  • 知识理解类:MMLU、MMLU-Redux、MMLU-Pro
  • 跨语言评估:C-Eval、CMMLU
  • 格式理解:IFEval
  • 长文本处理:FRAMES
  • 专业知识:GPQA Diamond
  • 事实问答:SimpleQA、C-SimpleQA
  • 编程能力评估: SWE-Bench Verified、Aider、LiveCodeBench、Codeforces
  • 数学能力测试: CNMO 2024、AIME 2024

除标准基准测试外,研究还使用LLM作为评估器评估模型在开放式生成任务上的表现。具体而言,遵循AlpacaEval 2.0Arena-Hard的原始配置,使用GPT-4-Turbo-1106作为成对比较的评估器。评估时仅输入最终摘要以避免长度偏差。对于蒸馏模型,报告其在AIME 2024、MATH-500、GPQA Diamond、Codeforces和LiveCodeBench上的代表性结果。

评估用prompt 不同的评估标准采用不同的prompt,具体如下所示:

  • 基础评估:采用simple evals框架标准prompt评估MMLU、DROP、GPQA Diamond和SimpleQA
  • 特殊处理: MMLU-Redux采用Zero-Eval prompt格式实现零样本评估,MMLU-Pro、C-Eval、CLUE-WSC将原少样本prompt改造为零样本形式
  • 编程评估: HumanEval-Mul覆盖8种主流编程语言,LiveCodeBench采用CoT格式,Codeforces基于10个Div.2竞赛题目与专家测试用例,SWE-Bench通过无代理框架验证

值得注意的是,DeepSeek-R1的输出在每个基准测试上限制为最多32,768个token。

基准模型 研究与多个强基准模型进行全面对比,包括DeepSeek-V3、Claude-Sonnet-3.5-1022、GPT-4o-0513、OpenAI-o1-miniOpenAI-o1-1217。鉴于在中国大陆访问OpenAI-o1-1217 API的限制,其性能数据来源于官方报告。对于蒸馏模型,额外与开源模型QwQ-32B-Preview进行比较。

生成配置 所有模型的最大生成长度设置为32K token。对需要采样的基准测试,采用0.6的temperature参数、0.95的top-p值,并为每个查询生成64个响应以估算pass@1。

DeepSeek-R1评估结果

表4 | DeepSeek-R1与其他代表性模型的比较。

在面向教育的知识基准测试(如MMLU、MMLU-Pro和GPQA Diamond)中,DeepSeek-R1相较于DeepSeek-V3展现出优越性能。这一进步主要归因于STEM相关问题准确率的提升,这得益于大规模RL带来的显著进步。

此外,DeepSeek-R1在依赖长文本理解的问答任务FRAMES上表现卓越,展示了其强大的文档分析能力。这凸显了推理模型在AI驱动的搜索和数据分析任务中的潜力

在事实性基准测试SimpleQA上,DeepSeek-R1的表现优于DeepSeek-V3,证明了其处理基于事实查询的能力。类似地,在该基准测试中也观察到OpenAI-o1超越GPT-4o的趋势。

然而,DeepSeek-R1在中文SimpleQA基准测试中的表现不如DeepSeek-V3,主要是由于安全性RL后倾向于拒绝回答某些查询。若不考虑安全性RL,DeepSeek-R1可以达到超过70%的准确率。

DeepSeek-R1在IF-Eval(一个用于评估模型遵循格式指令能力的基准测试)上也取得了令人瞩目的成果。这些改进可归因于在最终阶段的SFT和RL训练中引入了指令遵循数据。

此外,在AlpacaEval 2.0和ArenaHard上的出色表现表明DeepSeek-R1在写作任务和开放域问答方面具有优势。其显著优于DeepSeek-V3的表现凸显了大规模RL的泛化效益,不仅提升了推理能力,还改善了各个领域的性能。

而且DeepSeek-R1生成的摘要长度简洁,在ArenaHard上平均为689个token,在AlpacaEval 2.0上平均为2,218个字符。这表明DeepSeek-R1在基于GPT的评估中避免了引入长度偏差,进一步证实了其在多任务场景下的稳健性。

数学任务上,DeepSeek-R1展现出与OpenAI-o1-1217相当的性能,大幅超越其他模型。在LiveCodeBench和Codeforces等编码算法任务上也观察到类似趋势,其中注重推理的模型在这些基准测试中占据主导地位。

在面向工程的编码任务上,OpenAI-o1-1217在Aider上优于DeepSeek-R1,但在SWE Verified上表现相当。考虑到目前相关RL训练数据量仍然非常有限,研究团队认为DeepSeek-R1的工程性能将在下一版本中得到改善。

蒸馏模型评估

表5 | DeepSeek-R1蒸馏模型与其他可比模型在推理相关基准测试上的比较。

如表5所示,仅通过蒸馏DeepSeek-R1的输出,高效的DeepSeek-R1-7B(即DeepSeek-R1-Distill-Qwen-7B,以下类似缩写)就能在各方面超越GPT-4o-0513等非推理模型。

DeepSeek-R1-14B在所有评估指标上超越QwQ-32B-Preview,而DeepSeek-R1-32B和DeepSeek-R1-70B在大多数基准测试中显著超越o1-mini。这些结果展示了知识蒸馏的巨大潜力

此外,研究发现对这些蒸馏模型应用RL能带来显著的进一步提升。考虑到这值得进一步探索,此处仅呈现简单SFT蒸馏模型的结果。

讨论

蒸馏与强化学习对比

通过蒸馏DeepSeek-R1,小型模型能够取得出色的结果。然而,仍有一个问题待解答:

模型是否可以通过本文讨论的大规模RL训练而不依赖蒸馏来达到相当的性能?

为回答这个问题,研究团队对Qwen-32B-Base使用数学、代码和STEM数据进行了超过10K步的大规模RL训练,得到DeepSeek-R1-Zero-Qwen-32B。

如表6所示的实验结果表明,32B基础模型经过大规模RL训练后,达到了与QwQ-32B-Preview相当的性能。然而,从DeepSeek-R1蒸馏得到的DeepSeek-R1-Distill-Qwen-32B在所有基准测试中的表现都显著优于DeepSeek-R1-Zero-Qwen-32B

因此,可以得出两个结论:

首先,将更强大的模型蒸馏到较小的模型中可以产生优异的结果,而较小的模型依靠本文提到的大规模RL需要巨大的计算力,甚至可能无法达到蒸馏的性能水平

其次,虽然蒸馏策略既经济又有效,但要突破智能的边界可能仍需要更强大的基础模型和更大规模的强化学习

未成功的尝试

在开发DeepSeek-R1的早期阶段,研究也遇到了失败和挫折。在此分享这些失败经验以提供见解,但这并不意味着这些方法无法开发出有效的推理模型。

过程奖励模型(PRM)

PRM是一种合理的方法,可以引导模型采用更好的方法解决推理任务。然而,在实践中,PRM有三个主要限制可能阻碍其最终成功。

首先,在一般推理中明确定义细粒度步骤具有挑战性。其次,确定当前中间步骤是否正确是一项具有挑战性的任务。使用模型的自动标注可能无法产生令人满意的结果,而手动标注不利于规模化。第三,一旦引入基于模型的PRM,必然导致奖励欺骗,重新训练奖励模型需要额外的训练资源,并使整个训练流程变得复杂。

总之,虽然PRM在对模型生成的前N个响应重新排序或辅助引导搜索方面表现良好,但在实验中,相比其在大规模强化学习过程中引入的额外计算开销,其优势有限

蒙特卡洛树搜索(MCTS)

AlphaGoAlphaZero的启发,研究探索使用MCTS来增强测试时计算的可扩展性。这种方法包括将答案分解为更小的部分,使模型能够系统地探索解决方案空间。为此,提示模型生成多个标签,对应搜索所需的具体推理步骤。在训练方面,首先使用收集的提示通过预训练值模型引导的MCTS寻找答案。随后,使用产生的问答对来训练actor模型和值模型,不断改进过程。

然而,这种方法在扩大训练规模时遇到几个挑战。首先,与搜索空间相对明确的象棋不同,token生成呈现指数级更大的搜索空间。为解决这个问题,为每个节点设置最大扩展限制,但这可能导致模型陷入局部最优。其次,值模型直接影响生成质量,因为它指导搜索过程的每个步骤。训练细粒度值模型本质上是困难的,这使得模型难以迭代改进。虽然AlphaGo的核心成功依赖于训练值模型来逐步提升性能,但由于token生成的复杂性,这一原则在团队的设置中难以复制。

总之,虽然MCTS在与预训练值模型配对时可以改善推理性能,但通过自搜索迭代提升模型性能仍然是一个重大挑战

结论、局限性和未来工作

本文分享了通过RL增强模型推理能力的探索历程。DeepSeek-R1-Zero代表了一种不依赖冷启动数据的纯RL方法,在各种任务中取得了出色的表现。DeepSeek-R1通过结合冷启动数据和迭代RL微调展现出更强的性能,最终在多个任务上达到与OpenAI-o1-1217相当的水平。

研究进一步探索了将推理能力蒸馏到小型稠密模型的可能性。以DeepSeek-R1作为教师模型生成80万条数据,并对多个小型稠密模型进行微调。

结果令人鼓舞:DeepSeek-R1-Distill-Qwen-1.5B在数学基准测试中超越GPT-4o和Claude-3.5-Sonnet,在AIME上达到28.9%,在MATH上达到83.9%的成绩。其他稠密模型也取得了显著成果,大幅超越基于相同基础检查点的其他指令微调模型。

未来,计划在以下方向继续推进DeepSeek-R1的研究:

  • 通用能力:目前DeepSeek-R1在函数调用、多轮对话、复杂角色扮演和json输出等任务上的能力仍不及DeepSeek-V3。后续研究将探索如何利用长CoT增强这些领域的任务表现。
  • 语言混杂:DeepSeek-R1当前针对中文和英文进行了优化,在处理其他语言的查询时可能出现语言混杂问题。例如,即使查询使用非英文或中文的语言,DeepSeek-R1可能使用英语进行推理和响应。未来更新将着力解决这一限制。
  • 提示词工程:在评估DeepSeek-R1时发现,模型对prompt较为敏感。少样本提示会持续降低其性能。因此,建议用户直接描述问题并使用零样本设置指定输出格式以获得最佳结果
  • 软件工程任务:由于评估时间较长影响RL过程效率,大规模RL尚未在软件工程任务中广泛应用。因此,DeepSeek-R1在软件工程基准测试上相比DeepSeek-V3未显示出显著改进。未来版本将通过对软件工程数据实施拒绝采样或在RL过程中引入异步评估来提高效率。