这个类对外提供的方法是 generate()
,通过调参能完成以下事情:
- greedy decoding:当
num_beams=1
而且do_sample=False
时,调用greedy_search()
方法,每个step生成条件概率最高的词,因此生成单条文本。 - multinomial sampling:当
num_beams=1
且do_sample=True
时,调用sample()
方法,对词表做一个采样,而不是选条件概率最高的词,增加多样性。 - beam-search decoding:当
num_beams>1
且do_sample=False
时,调用beam_search()
方法,做一个 num_beams 的柱搜索,每次都是贪婪选择top N个柱。 - beam-search multinomial sampling:当
num_beams>1
且do_sample=True
时,调用beam_sample()
方法,相当于每次不再是贪婪选择top N个柱,而是加了一些采样。 - diverse beam-search decoding:当
num_beams>1
且num_beam_groups>1
时,调用group_beam_search()
方法。 - constrained beam-search decoding:当
constraints!=None
或者force_words_ids!=None
,实现可控文本生成。
参数列表
核心代码详见:generate()入口函数定义, GenerationConfig类
1.控制生成长度的参数
参数 | 类型 | 缺省值 | 含义 |
max_length | int | 20 | 表示 prompt + max_new_tokens 累加的最大长度,如果max_new_tokens也设置了,会覆盖这个参数 |
max_new_tokens | int | 生成部分的tokens的最大长度 (忽略prompt部分的长度) | |
min_length | 0 | 表示 prompt + min_new_tokens 累加的最小长度,如果min_new_tokens也设置了,会覆盖这个参数 | |
min_new_tokens | int | 生成部分的tokens的最小长度 (忽略prompt部分的长度) | |
early_stopping | bool, str | False | 对于beam search方法的控制终止的配置。 False: 当有’num_beams’个候选生成,则终止 True: 应用一些启发式规则判断不能找到更好的生成候选,来提前终止生成 “never”: 当判断没有更好的可生成的candidate, beam search 过程终止 |
max_time | float | 执行生成的最大时间(s秒数) | |
stop_strings | str, array[str] | 配置模型生成的终止字符串,当模型生成参数配置的字符串,则终止生成。 |
2. 控制生成策略的参数
参数 | 类型 | 缺省值 | 含义 |
do_sample | bool | False | True: 生成过程使用采样逻辑 False: 使用greedy做生成 |
num_beams | int | 1 | 设置beam search 束的数量。如果是1不做beam search 搜索 |
num_beam_groups | int | 1 | 为了保证生成的多样性,将num_beams 设置成多组。参考文献: https://arxiv.org/pdf/1610.02424.pdf |
penalty_alpha | float | contrastive search decoding的配置项,用于平衡生成置信度和衰减的惩罚 | |
dola_layers | str, List[int] | str : “None”: 不使用dola “low” : 较低的一半layers, 最多20层使用dola “high”: 较高的一半layers, 最多20层使用dola List[int] : 通过指定一个index数组,指定dola 层 “low”: 提升长答案的task, “high”:提升短答案的task |
3.cache配置参数
参数 | 类型 | 缺省值 | 含义 |
use_cache | bool | True | 是否使用KV cache 加速推理速度 |
cache_implementation | str | 指定cache实现的name,在调用generate()时,实例化cache。 ”static”: [StaticCache] “offloaded_static”: [OffloadedStaticCache] ”sliding_window”: [SlidingWindowCache] “hybrid”: [HybridCache] “mamba”: [MambaCache] ”quantized”:[QuantizedCache] | |
cache_config | CacheConfig , dict | None | cache类使用的参数 |
return_legacy_cache | bool | True | 当DynamicCache 被使用时,是否返回历史的和新格式的cache |
4.操作模型输出logit的配置参数
参数 | 类型 | 缺省值 | 含义 |
temperature | float | 1.0 | 这个值用于建模下一个token的概率, 这个值被设置在generation_config.json文件中 |
top_k | int | 50 | 筛选最高概率的top k个词, 这个值被设置在generation_config.json文件中 |
top_p | float | 1.0 | 当设置<1时,筛选概率最高的token,累加概率不超过top_p的token |
min_p | float | 配置筛选概率最低的一批token, 累加概率不超过min_p,裁剪掉,该配置相当于top_p的反向操作 | |
typical_p | float | 1.0 | 测量两个分布的相似性: 预测下一个目标token的概率 and 预测下一个随机Token的条件概率期望。如果设置<1,则筛选最典型的token。 |
epsilon_cutoff | float | 0.0 | 按设置的值,卡掉低概率值的token,一般设置为:3e-4 to 9e-4 |
eta_cutoff | float | 0.0 | 混合局部典型性采样和epsilon采样方法 |
diversity_penalty | float | 0.0 | 只对group beam search方法生效,如果在某个特定时间生成的token与任何beam 组生成的token一致,则beam的score减去这个值 |
repetition_penalty | float | 1.0 | 1.0 默认不惩罚 |
encoder_repetition_penalty | float | 1.0 | 对于不在原始输入的token,指数级的惩罚 |
length_penalty | float | 1.0 | 对于beam 类的生成方法的长度惩罚,由于序列score是 log likelihood , > 0 倾向于更长的 <0 倾向于更短的 |
no_repeat_ngram_size | int | 0 | 如果大于0, 则对应的size的ngram只能出现1次 |
bad_words_ids | List[List[int]] | 列出不允许生成的tokens_id | |
force_words_ids | List[List[int]] or List[List[List[int]]] | 必须被生成的words_ids。 如果配置List[List[List[int]]] 设置对于每个token的约束 | |
renormalize_logits | bool | False | 对于所有的logits做后处理后,是否要再做下normalize |
constraints | List[Constraint] | 通过定义一个List[Constraint] 对象数组,来确保输出是在某些限制的场景下。一般用于安全的场景 | |
forced_bos_token_id | int | model.config.forced_bos_token_id | 强制跟在decoder_start_token_id之后的第一个token,对多语言模型是有用的 |
forced_eos_token_id | int or List[int] | model.config.forced_eos_token_id | 当生成的token达到max_length上限时,最后一位输出的token |
remove_invalid_values | bool | model.config.remove_invalid_values | 是否移出可能生成的nan and inf 值,配置这个会减慢生成速度 |
exponential_decay_length_penalty | tuple(int, float) | 指数级增加长度的惩罚,tuple(start_index, decay_factor) start index 指示惩罚的开始i,decay_factor 指数衰减的惩罚因子 | |
suppress_tokens | List[int] | 通过设置禁止的token的logit为-inf,来禁止token被sample | |
begin_suppress_tokens | List[int] | 通过设置首位禁止的token的logit为-inf,来禁止首位这部分token被采样到,进而导致被生成 | |
forced_decoder_ids | List[List[int]] | 一个整数pair的数组,格式[生成index, token_index]指示固定位置强制生成某个token,例如[[1, 123]] 第二个位置总是生成token 123 | |
sequence_bias | Dict[Tuple[int], float] | token list -> bias的映射,正的bias提升几率,负的bias降低几率 | |
token_healing | bool | False | 对prompt尾部的token做相似替换,以提升生成质量 |
guidance_scale | float | 是一个缩放因子,当>1时,这个因子越高,越鼓励模型生成与prompt接近的samples 。 | |
watermarking_config | BaseWatermarkingConfig or dict | 对输出结果增加水印 |
5.输出结果配置参数
参数 | 类型 | 缺省值 | 含义 |
num_return_sequences | int | 1 | 对于batch中的每个元素,设置独立计算的返回的sequence的数量 |
output_attentions | bool | False | 是否返回所有的attention的向量 |
output_hidden_states | bool | False | 是否返回所有网络层的隐层状态 |
output_scores | bool | False | 是否返回prediction scores |
output_logits | bool | 是否返回未处理过的的logit score | |
return_dict_in_generate | bool | False | 除了返回生成序列,是否还返回a [`~utils.ModelOutput`] |
6.生成时使用的特殊token的配置参数
参数 | 类型 | 缺省值 | 含义 |
pad_token_id | int | padding token ID | |
bos_token_id | int | beginning -of – sequence token ID | |
eos_token_id | Union[int, List[int]] | end-of-sequence token ID |
6.辅助生成的配置参数(投机采样)
参数 | 类型 | 缺省值 | 含义 |
is_assistant | bool | False | 指定是否模型是一个assistant(draft) model |
num_assistant_tokens | int | 20 | 投机采样过程,每次迭代 assistant model 要输出多少个token,给到目标模型做check。配置更高的值,如果assistant model 效果好 能带来更好的加速比 |
num_assistant_tokens_schedule | str | constant | “heuristic” : 当所有投机采样的token都正确时,将num_assistant_tokens增加2,否则减少1。 “constant”: num_assistant_tokens 保持固定不变 “heuristic_transient”: 类似于启发式方法,每次生成调用,都置成初始化的num_assistant_tokens值 |
assistant_confidence_threshold | float | 0.4 | 当assistant model预估当前token的置信度 小于 阈值时,提前终止assistant model的生成 |
prompt_lookup_num_tokens | int | 作为候选token 要输出的token的数量 | |
max_matching_ngram_size | int | 2 | match prompt的最大ngram的数量 |
assistant_early_exit | int | ||
assistant_lookbehind | int | 10 | 如果设置为正整数,则重新编码过程将额外考虑最后的assistant_lookbehind个辅助标记,以正确对齐标记。此设置仅可在推测解码中使用不同的分词器时使用。 |
target_lookbehind | int | 10 | 如果设置为正整数,则重新编码过程将额外考虑最后的target_lookbehind个辅助标记,以正确对齐标记。此设置仅可在推测解码中使用不同的分词器时使用。 |
如有整理错误,欢迎指正~