transformers 的 generate() 方法实现多样化文本生成:参数含义和算法原理解读

这个类对外提供的方法是 generate(),通过调参能完成以下事情:

  • greedy decoding:当 num_beams=1 而且 do_sample=False 时,调用 greedy_search()方法,每个step生成条件概率最高的词,因此生成单条文本。
  • multinomial sampling:当 num_beams=1 且 do_sample=True 时,调用 sample() 方法,对词表做一个采样,而不是选条件概率最高的词,增加多样性。
  • beam-search decoding:当 num_beams>1 且 do_sample=False 时,调用 beam_search() 方法,做一个 num_beams 的柱搜索,每次都是贪婪选择top N个柱。
  • beam-search multinomial sampling:当 num_beams>1 且 do_sample=True 时,调用 beam_sample() 方法,相当于每次不再是贪婪选择top N个柱,而是加了一些采样。
  • diverse beam-search decoding:当 num_beams>1 且 num_beam_groups>1 时,调用 group_beam_search() 方法。
  • constrained beam-search decoding:当 constraints!=None 或者 force_words_ids!=None,实现可控文本生成。

参数列表

核心代码详见:generate()入口函数定义, GenerationConfig类

1.控制生成长度的参数

参数类型缺省值含义
max_lengthint20表示 prompt + max_new_tokens 累加的最大长度,如果max_new_tokens也设置了,会覆盖这个参数
max_new_tokensint生成部分的tokens的最大长度 (忽略prompt部分的长度)
min_length0表示 prompt + min_new_tokens 累加的最小长度,如果min_new_tokens也设置了,会覆盖这个参数
min_new_tokensint生成部分的tokens的最小长度 (忽略prompt部分的长度)
early_stoppingbool, strFalse对于beam search方法的控制终止的配置。
False: 当有’num_beams’个候选生成,则终止
True: 应用一些启发式规则判断不能找到更好的生成候选,来提前终止生成
“never”: 当判断没有更好的可生成的candidate, beam search 过程终止
max_timefloat执行生成的最大时间(s秒数)
stop_stringsstr, array[str]配置模型生成的终止字符串,当模型生成参数配置的字符串,则终止生成。

2. 控制生成策略的参数

参数类型缺省值含义
do_sampleboolFalseTrue: 生成过程使用采样逻辑
False: 使用greedy做生成
num_beamsint1设置beam search 束的数量。如果是1不做beam search 搜索
num_beam_groupsint1为了保证生成的多样性,将num_beams 设置成多组。参考文献: https://arxiv.org/pdf/1610.02424.pdf
penalty_alphafloatcontrastive search decoding的配置项,用于平衡生成置信度和衰减的惩罚
dola_layersstr, List[int]str :
“None”: 不使用dola
“low” : 较低的一半layers, 最多20层使用dola
“high”: 较高的一半layers, 最多20层使用dola
List[int] : 通过指定一个index数组,指定dola 层
“low”: 提升长答案的task,
“high”:提升短答案的task

3.cache配置参数

参数类型缺省值含义
use_cacheboolTrue是否使用KV cache 加速推理速度
cache_implementationstr指定cache实现的name,在调用generate()时,实例化cache。
”static”: [StaticCache]
“offloaded_static”: [OffloadedStaticCache]
”sliding_window”: [SlidingWindowCache]
“hybrid”: [HybridCache]
“mamba”: [MambaCache]
”quantized”:[QuantizedCache]
cache_configCacheConfig , dictNonecache类使用的参数
return_legacy_cacheboolTrue当DynamicCache 被使用时,是否返回历史的和新格式的cache

4.操作模型输出logit的配置参数

参数类型缺省值含义
temperaturefloat1.0这个值用于建模下一个token的概率, 这个值被设置在generation_config.json文件中
top_kint50筛选最高概率的top k个词, 这个值被设置在generation_config.json文件中
top_pfloat1.0当设置<1时,筛选概率最高的token,累加概率不超过top_p的token
min_pfloat配置筛选概率最低的一批token, 累加概率不超过min_p,裁剪掉,该配置相当于top_p的反向操作
typical_pfloat1.0测量两个分布的相似性: 预测下一个目标token的概率 and 预测下一个随机Token的条件概率期望。如果设置<1,则筛选最典型的token。
epsilon_cutofffloat0.0按设置的值,卡掉低概率值的token,一般设置为:3e-4 to 9e-4
eta_cutofffloat0.0混合局部典型性采样和epsilon采样方法
diversity_penaltyfloat0.0只对group beam search方法生效,如果在某个特定时间生成的token与任何beam 组生成的token一致,则beam的score减去这个值
repetition_penaltyfloat1.01.0 默认不惩罚
encoder_repetition_penaltyfloat1.0对于不在原始输入的token,指数级的惩罚
length_penaltyfloat1.0对于beam 类的生成方法的长度惩罚,由于序列score是 log likelihood , > 0 倾向于更长的 <0 倾向于更短的
no_repeat_ngram_sizeint0如果大于0, 则对应的size的ngram只能出现1次
bad_words_idsList[List[int]]列出不允许生成的tokens_id
force_words_idsList[List[int]] or List[List[List[int]]]必须被生成的words_ids。 如果配置List[List[List[int]]] 设置对于每个token的约束
renormalize_logitsboolFalse对于所有的logits做后处理后,是否要再做下normalize
constraintsList[Constraint]通过定义一个List[Constraint] 对象数组,来确保输出是在某些限制的场景下。一般用于安全的场景
forced_bos_token_idintmodel.config.forced_bos_token_id强制跟在decoder_start_token_id之后的第一个token,对多语言模型是有用的
forced_eos_token_idint or List[int]model.config.forced_eos_token_id当生成的token达到max_length上限时,最后一位输出的token
remove_invalid_valuesboolmodel.config.remove_invalid_values是否移出可能生成的nan and inf 值,配置这个会减慢生成速度
exponential_decay_length_penaltytuple(int, float)指数级增加长度的惩罚,tuple(start_index, decay_factor) start index 指示惩罚的开始i,decay_factor 指数衰减的惩罚因子
suppress_tokensList[int]通过设置禁止的token的logit为-inf,来禁止token被sample
begin_suppress_tokensList[int]通过设置首位禁止的token的logit为-inf,来禁止首位这部分token被采样到,进而导致被生成
forced_decoder_idsList[List[int]]一个整数pair的数组,格式[生成index, token_index]指示固定位置强制生成某个token,例如[[1, 123]] 第二个位置总是生成token 123
sequence_biasDict[Tuple[int], float]token list -> bias的映射,正的bias提升几率,负的bias降低几率
token_healingboolFalse对prompt尾部的token做相似替换,以提升生成质量
guidance_scalefloat是一个缩放因子,当>1时,这个因子越高,越鼓励模型生成与prompt接近的samples 。
watermarking_configBaseWatermarkingConfig or dict对输出结果增加水印

5.输出结果配置参数

参数类型缺省值含义
num_return_sequencesint1对于batch中的每个元素,设置独立计算的返回的sequence的数量
output_attentionsboolFalse是否返回所有的attention的向量
output_hidden_statesboolFalse是否返回所有网络层的隐层状态
output_scoresboolFalse是否返回prediction scores
output_logitsbool是否返回未处理过的的logit score
return_dict_in_generateboolFalse除了返回生成序列,是否还返回a [`~utils.ModelOutput`]

6.生成时使用的特殊token的配置参数

参数类型缺省值含义
pad_token_idintpadding token ID
bos_token_idintbeginning -of – sequence token ID
eos_token_idUnion[int, List[int]]end-of-sequence token ID

6.辅助生成的配置参数(投机采样)

参数类型缺省值含义
is_assistantboolFalse指定是否模型是一个assistant(draft) model
num_assistant_tokensint20投机采样过程,每次迭代 assistant model 要输出多少个token,给到目标模型做check。配置更高的值,如果assistant model 效果好 能带来更好的加速比
num_assistant_tokens_schedulestrconstant“heuristic” : 当所有投机采样的token都正确时,将num_assistant_tokens增加2,否则减少1。
“constant”: num_assistant_tokens 保持固定不变
“heuristic_transient”: 类似于启发式方法,每次生成调用,都置成初始化的num_assistant_tokens值
assistant_confidence_thresholdfloat0.4当assistant model预估当前token的置信度 小于 阈值时,提前终止assistant model的生成
prompt_lookup_num_tokensint作为候选token 要输出的token的数量
max_matching_ngram_sizeint2match prompt的最大ngram的数量
assistant_early_exitint
assistant_lookbehindint10如果设置为正整数,则重新编码过程将额外考虑最后的assistant_lookbehind个辅助标记,以正确对齐标记。此设置仅可在推测解码中使用不同的分词器时使用。
target_lookbehindint10如果设置为正整数,则重新编码过程将额外考虑最后的target_lookbehind个辅助标记,以正确对齐标记。此设置仅可在推测解码中使用不同的分词器时使用。


如有整理错误,欢迎指正~

发表评论

您的电子邮箱地址不会被公开。 必填项已用*标注