之前在 一文搞懂大模型生成文本的解码策略 中简单介绍过惩罚参数(重复惩罚、频率惩罚、存在惩罚),本文来详谈一下这三者的区别以及代码实现。
重复惩罚(repetition_penalty)
工作机制: 直接针对当前上下文中(包括输入+已生成的token)已经出现过的 token(词元)进行惩罚。它的逻辑非常简单粗暴:在计算下一个 token 的概率时,模型会查看这个 token 是否在之前的文本中出现过。如果出现过,就将其概率值除以 repeat_penalty 系数(一个大于1的值,例如 1.1)。
- 惩罚值 = 1.0:无惩罚。
- 惩罚值 > 1.0(例如 1.1-1.5):惩罚重复,值越大,惩罚越重。
设计目标: 专门为了解决模型“卡住”、不断重复同一句话或同一个词的问题。常用在小型大模型场景。
优缺点:
- 优点: 对于抑制重复非常有效。能迅速打破“I'm sorry, I'm sorry, I'm sorry...”或“The cat sat on the mat. The cat sat on the mat...”这类循环。
- 缺点: 可能过于“暴力”,有时会抑制合理的、非重复性的用词,导致文本生硬或不自然。
使用建议:
- 何时使用:当你发现模型输出开始出现明显的词语或句子重复时。
- 典型值:1.1 是一个温和的起点,1.2 是常见值。对于严重的重复问题,可以尝试 1.3 或 1.5。不建议设置得过高(如 > 2.0),否则会严重损害文本的连贯性。
数学表达: 对于每个在上下文中出现过的 token_id:
- 如果 logit(token_id) > 0:logit'(token_id) = logit(token_id) / penalty
- 如果 logit(token_id) < 0:logit'(token_id) = logit(token_id) * penalty
示例:
假设 token "the" 的原始 logit = 2.0,penalty = 1.2:
惩罚后:2.0 / 1.2 = 1.67
概率从 softmax(2.0) 降低到 softmax(1.67)
核心算法逻辑:
def apply_repetition_penalty(logits, context_tokens, penalty):
"""
logits: 模型输出的原始分数向量 [vocab_size]
context_tokens: 当前上下文中的所有token ID列表
penalty: 惩罚系数 (通常 > 1.0)
"""
for token_id in set(context_tokens): # 去重处理
if logits[token_id] > 0:
# 对正logits进行除法惩罚
logits[token_id] = logits[token_id] / penalty
else:
# 对负logits进行乘法惩罚(效果相同)
logits[token_id] = logits[token_id] * penalty
return logits
频率惩罚(frequency_penalty)
工作机制:根据 token 在之前文本中出现的总次数进行惩罚。出现次数越多的 token,受到的惩罚越大。它惩罚的是“高频词”。
设计目标: 降低已出现过的词的使用,特别是高频词,同时鼓励模型使用新词汇,增加输出的多样性和新颖性,防止文本内容过于围绕几个核心词汇展开。
使用建议:
- 当你希望模型避免反复使用同一个词时使用。对于抑制轻微的重复有效。
- 典型值:
0.0 到 1.0。正值惩罚高频词,负值鼓励高频词。
数学表达:
对于每个 token_id:
- logit'(token_id) = logit(token_id) - (penalty × frequency_count)
效果示例:
假设 token "the" 出现3次,penalty = 0.5:
原始 logit = 2.0
惩罚后:2.0 - (0.5 × 3) = 0.5
概率显著降低。
核心算法逻辑:
def apply_frequency_penalty(logits, generated_tokens, penalty):
"""
generated_tokens: 本次生成过程中已生成的所有token
penalty: 频率惩罚系数
"""
# 统计每个token的出现次数
token_counts = {}
for token_id in generated_tokens:
token_counts[token_id] = token_counts.get(token_id, 0) + 1
# 应用频率惩罚
for token_id, count in token_counts.items():
logits[token_id] = logits[token_id] - (penalty * count)
return logits
存在惩罚(presence_penalty)
工作机制:要一个 token 在之前文本中出现过至少一次,就会受到固定值的惩罚,与它出现了多少次无关。它惩罚的是“已出现的词”。
设计目标: 鼓励模型使用新词汇,增加输出的多样性和新颖性,防止文本内容过于围绕几个核心词汇展开。
使用建议:
- 当你希望模型在后续输出中避免使用已经用过的词,但不太关心用了多少次时使用。
- 典型值:
0.0 到 1.0。正值鼓励新词,负值则鼓励使用已出现的词。
数学表达:
对于每个出现过的 token_id:
- logit'(token_id) = logit(token_id) - penalty
效果示例:
假设 token "the" 出现过,penalty = 0.7:
原始 logit = 2.0
惩罚后:2.0 - 0.7 = 1.3
无论出现多少次,惩罚值固定。
核心算法逻辑:
def apply_presence_penalty(logits, generated_tokens, penalty):
"""
只要出现过就施加固定惩罚
"""
appeared_tokens = set(generated_tokens)
for token_id in appeared_tokens:
logits[token_id] = logits[token_id] - penalty
return logits
三者的详细对比
| | | |
|---|
| | | |
| | | |
| | | |
| | | |
| | | |
| 通常作用于整个上下文(包括系统提示词、用户输入和已生成内容) | | |
| | | |
主流框架对于惩罚参数的支持
目前,各大推理框架中(如:HF Transformers/vLLM/Llama.cpp/SGLang)都有对于惩罚功能的实现,只是支持的程度不一。
其中,vLLM 对三者都进行了支持,具体实现代码如下。
def apply_penalties(
logits: torch.Tensor,
prompt_tokens_tensor: torch.Tensor,
output_tokens_tensor: torch.Tensor,
presence_penalties: torch.Tensor,
frequency_penalties: torch.Tensor,
repetition_penalties: torch.Tensor,
) -> torch.Tensor:
"""
Applies penalties in place to the logits tensor
logits : The input logits tensor of shape [num_seqs, vocab_size]
prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts
are padded to the maximum prompt length within the batch using
`vocab_size` as the padding value. The value `vocab_size` is used
for padding because it does not correspond to any valid token ID
in the vocabulary.
output_tokens_tensor: The output tokens tensor.
presence_penalties: The presence penalties of shape (num_seqs, )
frequency_penalties: The frequency penalties of shape (num_seqs, )
repetition_penalties: The repetition penalties of shape (num_seqs, )
"""
num_seqs, vocab_size = logits.shape
_, prompt_mask = get_token_bin_counts_and_mask(
prompt_tokens_tensor, vocab_size, num_seqs
)
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
output_tokens_tensor, vocab_size, num_seqs
)
# Apply repetition penalties as a custom op
from vllm._custom_ops import apply_repetition_penalties
apply_repetition_penalties(logits, prompt_mask, output_mask, repetition_penalties)
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
return logits
def apply_repetition_penalties_torch(
logits: torch.Tensor,
prompt_mask: torch.Tensor,
output_mask: torch.Tensor,
repetition_penalties: torch.Tensor,
) -> None:
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
1, logits.size(1)
)
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
penalties = torch.where(prompt_mask | output_mask, repetition_penalties, 1.0)
# If logits are positive, divide by penalty, otherwise multiply by penalty.
scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
logits *= scaling
def apply_repetition_penalties_cuda(
logits: torch.Tensor,
prompt_mask: torch.Tensor,
output_mask: torch.Tensor,
repetition_penalties: torch.Tensor,
) -> None:
torch.ops._C.apply_repetition_penalties_(
logits, prompt_mask, output_mask, repetition_penalties
)
def apply_repetition_penalties(
logits: torch.Tensor,
prompt_mask: torch.Tensor,
output_mask: torch.Tensor,
repetition_penalties: torch.Tensor,
) -> None:
"""Apply repetition penalties to logits in-place.
Args:
logits: The logits tensor of shape [num_seqs, vocab_size].
prompt_mask: A boolean tensor indicating which tokens appear in the prompt.
output_mask: A boolean tensor indicating which tokens appear in the output.
repetition_penalties: The repetition penalties of shape (num_seqs, ).
"""
if logits.is_cuda and logits.is_contiguous():
apply_repetition_penalties_cuda(
logits, prompt_mask, output_mask, repetition_penalties
)
else:
apply_repetition_penalties_torch(
logits, prompt_mask, output_mask, repetition_penalties
)
当前 SGLang 实现了频率惩罚与存在惩罚。 SGLang 与 OpenAI 接口保持一致移除了之前实现的重复惩罚功能,详见GitHub关于该特性的讨论。需要注意的是当前 SGLang 的 SamplingParams 对象仍然保留了repetition_penalty, 因此,设置该参数不生效。
如果需要在 SGLang 中使用 repetition_penalty,可参考https://github.com/sgl-project/sglang/pull/5703 的实现,不过该实现仅考虑了生成内容,未考虑原始输入内容。需要进行相应的修改,具体如下:
class BatchedRepetitionPenalizer(_BatchedPenalizer):
def _prepare(self):
batch_cumulated_repetition_penalties = []
for req in self.orchestrator.reqs():
cumulated_repetition_penalties_lst = [1] * self.orchestrator.vocab_size
for idx in req.origin_input_ids:
cumulated_repetition_penalties_lst[idx] = req.sampling_params.repetition_penalty
batch_cumulated_repetition_penalties.append(cumulated_repetition_penalties_lst)
self.cumulated_repetition_penalties = torch.tensor(
data=batch_cumulated_repetition_penalties,
dtype=torch.float32,
device=self.orchestrator.device,
)
self.repetition_penalties = (
torch.tensor(
data=[
req.sampling_params.repetition_penalty
for req in self.orchestrator.reqs()
],
dtype=torch.float32,
device=self.orchestrator.device,
)
).unsqueeze_(1)
HF Transformers 支持重复惩罚,具体源码和使用示例如下:
class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] that prevents the repetition of previous tokens through a penalty. This penalty is applied at
most once per token. Note that, for decoder-only models like most LLMs, the considered tokens include the prompt.
In the original [paper](https://arxiv.org/pdf/1909.05858.pdf), the authors suggest the use of a penalty of around
1.2 to achieve a good balance between truthful generation and lack of repetition. To penalize and reduce
repetition, use `penalty` values above 1.0, where a higher value penalizes more strongly. To reward and encourage
repetition, use `penalty` values between 0.0 and 1.0, where a lower value rewards more strongly.
Args:
penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. Above 1.0 penalizes previously generated
tokens. Between 0.0 and 1.0 rewards previously generated tokens.
"""
def __init__(self, penalty: float):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
if prompt_ignore_length is not None and (
not isinstance(prompt_ignore_length, int) or prompt_ignore_length < 0
):
raise ValueError(f"`prompt_ignore_length` has to be a positive integer, but is {prompt_ignore_length}")
self.penalty = penalty
self.prompt_ignore_length = prompt_ignore_length
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if self.prompt_ignore_length:
input_ids = input_ids[:, self.prompt_ignore_length :]
score = torch.gather(scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
scores_processed = scores.scatter(1, input_ids, score)
return scores_processed
使用示例:
from transformers import AutoTokenizer, AutoModelForCausalLM, RepetitionPenaltyLogitsProcessor
# Initializing the model and tokenizer for it
model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
inputs = tokenizer(["I'm not going to"], return_tensors="pt")
# This shows a normal generate without any specific parameters
summary_ids = model.generate(**inputs)
print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0])
# This generates a penalty for repeated tokens
penalized_ids = model.generate(**inputs, repetition_penalty=1.1)
print(tokenizer.batch_decode(penalized_ids, skip_special_tokens=True)[0])
# We can also exclude the input prompt by creating an instance of this class
# with a `prompt_ignore_length` and passing it as a custom logit processor
rep_pen_processor = RepetitionPenaltyLogitsProcessor(
penalty=1.1,
prompt_ignore_length=inputs["input_ids"].shape[-1]
)
penalized_ids = model.generate(**inputs, logits_processor=[rep_pen_processor])
print(tokenizer.batch_decode(penalized_ids, skip_special_tokens=True)[0])
总结
总之,repeat_penalty、frequency_penalty, presence_penalty 的设计哲学和惩罚机制有根本区别,它们解决的是不同类型的问题。 对于抑制重复和循环文本,repeat_penalty 通常是更有效、更直接的工具。而frequency_penalty 、 presence_penalty 它们提供了对词汇多样性和主题广度的更精细控制,但抑制强力重复的能力相对较弱。此外,repetition_penalty 通常作用于整个上下文(包括系统提示词、用户输入和已生成内容),而frequency_penalty/presence_penalty通常只作用于本次生成过程中已生成的部分。
在一些支持所有这三个参数的平台上,最佳实践是组合使用。比如:设置一个温和的 repeat_penalty(如 1.05 ~ 1.15)作为“安全网”,防止灾难性的循环。再设置一个轻微的 presence_penalty(如 0.1 ~ 0.3)来温和地鼓励词汇多样性。这样既能有效阻止重复,又能让文本保持自然和丰富。
需要注意的是任何惩罚参数都是“事后补救”。最好的方法是优化你的 Prompt。给出更明确的指令,如“避免重复之前的观点”、“确保内容新颖多样”,或者提供高质量的示例(Few-shot),引导模型产生你期望的输出模式。