LLM解码之惩罚参数详谈

2025-11-19 13:59:43
文章摘要
本文来详谈一下惩罚参数(重复惩罚、频率惩罚、存在惩罚)这三者的区别以及代码实现。

之前在 一文搞懂大模型生成文本的解码策略  中简单介绍过惩罚参数(重复惩罚、频率惩罚、存在惩罚),本文来详谈一下这三者的区别以及代码实现。

重复惩罚(repetition_penalty)

工作机制: 直接针对当前上下文中(包括输入+已生成的token)已经出现过的 token(词元)进行惩罚。它的逻辑非常简单粗暴:在计算下一个 token 的概率时,模型会查看这个 token 是否在之前的文本中出现过。如果出现过,就将其概率值除以 repeat_penalty 系数(一个大于1的值,例如 1.1)。

  1. 惩罚值 = 1.0:无惩罚。
  2. 惩罚值 > 1.0(例如 1.1-1.5):惩罚重复,值越大,惩罚越重。

设计目标: 专门为了解决模型“卡住”、不断重复同一句话或同一个词的问题。常用在小型大模型场景。

优缺点

  1. 优点: 对于抑制重复非常有效。能迅速打破“I'm sorry, I'm sorry, I'm sorry...”或“The cat sat on the mat. The cat sat on the mat...”这类循环。
  2. 缺点: 可能过于“暴力”,有时会抑制合理的、非重复性的用词,导致文本生硬或不自然。

使用建议

  1. 何时使用:当你发现模型输出开始出现明显的词语或句子重复时。
  2. 典型值:1.1 是一个温和的起点,1.2 是常见值。对于严重的重复问题,可以尝试 1.3 或 1.5。不建议设置得过高(如 > 2.0),否则会严重损害文本的连贯性。

数学表达: 对于每个在上下文中出现过的 token_id:

  1. 如果 logit(token_id) > 0:logit'(token_id) = logit(token_id) / penalty
  2. 如果 logit(token_id) < 0:logit'(token_id) = logit(token_id) * penalty

示例

假设 token "the" 的原始 logit = 2.0,penalty = 1.2:

惩罚后:2.0 / 1.2 = 1.67

概率从 softmax(2.0) 降低到 softmax(1.67)

核心算法逻辑

def apply_repetition_penalty(logits, context_tokens, penalty):
    """
    logits: 模型输出的原始分数向量 [vocab_size]
    context_tokens: 当前上下文中的所有token ID列表
    penalty: 惩罚系数 (通常 > 1.0)
    """
    for token_id in set(context_tokens):  # 去重处理
        if logits[token_id] > 0:
            # 对正logits进行除法惩罚
            logits[token_id] = logits[token_id] / penalty
        else:
            # 对负logits进行乘法惩罚(效果相同)
            logits[token_id] = logits[token_id] * penalty
    
    return logits

频率惩罚(frequency_penalty)

工作机制:根据 token 在之前文本中出现的总次数进行惩罚。出现次数越多的 token,受到的惩罚越大。它惩罚的是“高频词”。

设计目标: 降低已出现过的词的使用,特别是高频词,同时鼓励模型使用新词汇,增加输出的多样性和新颖性,防止文本内容过于围绕几个核心词汇展开。

使用建议:

  1. 当你希望模型避免反复使用同一个词时使用。对于抑制轻微的重复有效。
  2. 典型值:0.0 到 1.0。正值惩罚高频词,负值鼓励高频词。

数学表达

对于每个 token_id:

  1. logit'(token_id) = logit(token_id) - (penalty × frequency_count)

效果示例

假设 token "the" 出现3次,penalty = 0.5:

原始 logit = 2.0

惩罚后:2.0 - (0.5 × 3) = 0.5

概率显著降低。

核心算法逻辑

def apply_frequency_penalty(logits, generated_tokens, penalty):
    """
    generated_tokens: 本次生成过程中已生成的所有token
    penalty: 频率惩罚系数
    """
    # 统计每个token的出现次数
    token_counts = {}
    for token_id in generated_tokens:
        token_counts[token_id] = token_counts.get(token_id, 0) + 1
    
    # 应用频率惩罚
    for token_id, count in token_counts.items():
        logits[token_id] = logits[token_id] - (penalty * count)
    
    return logits

存在惩罚(presence_penalty)

工作机制:要一个 token 在之前文本中出现过至少一次,就会受到固定值的惩罚,与它出现了多少次无关。它惩罚的是“已出现的词”。

设计目标: 鼓励模型使用新词汇,增加输出的多样性和新颖性,防止文本内容过于围绕几个核心词汇展开。

使用建议:

  1. 当你希望模型在后续输出中避免使用已经用过的词,但不太关心用了多少次时使用。
  2. 典型值:0.0 到 1.0。正值鼓励新词,负值则鼓励使用已出现的词。

数学表达

对于每个出现过的 token_id:

  1. logit'(token_id) = logit(token_id) - penalty

效果示例

假设 token "the" 出现过,penalty = 0.7:

原始 logit = 2.0

惩罚后:2.0 - 0.7 = 1.3

无论出现多少次,惩罚值固定。

核心算法逻辑

def apply_presence_penalty(logits, generated_tokens, penalty):
    """
    只要出现过就施加固定惩罚
    """
    appeared_tokens = set(generated_tokens)
    
    for token_id in appeared_tokens:
        logits[token_id] = logits[token_id] - penalty
    
    return logits

三者的详细对比

特性

repeat_penalty

frequency_penalty

presence_penalty

惩罚对象

当前上下文中的所有已出现 token

在上下文中出现次数多的 token

在上下文中出现过的所有 token

惩罚依据

是否出现过(布尔判断)

出现的频率(次数)

是否出现过(布尔判断)

主要目标

防止循环和逐字重复

降低高频词的权重

鼓励使用全新词汇

抑制重复效果

非常强效、直接

中等,对轻微重复有效

较弱,主要鼓励词汇多样性

粗暴程度

较高,可能影响连贯性

中等

较低,更精细

作用范围

通常作用于整个上下文(包括系统提示词、用户输入和已生成内容)

通常只作用于本次生成过程中已生成的部分

通常只作用于本次生成过程中已生成的部分

类比

“不准说已经说过的话!”

“别老提那几个词!”

“换个新词说说!”

主流框架对于惩罚参数的支持

目前,各大推理框架中(如:HF Transformers/vLLM/Llama.cpp/SGLang)都有对于惩罚功能的实现,只是支持的程度不一。

特性

重复惩罚

频率惩罚

存在惩罚

HF

×

×

SGLang

x

vLLM

其中,vLLM 对三者都进行了支持,具体实现代码如下。

def apply_penalties(
    logits: torch.Tensor,
    prompt_tokens_tensor: torch.Tensor,
    output_tokens_tensor: torch.Tensor,
    presence_penalties: torch.Tensor,
    frequency_penalties: torch.Tensor,
    repetition_penalties: torch.Tensor,
) -> torch.Tensor:
    """
    Applies penalties in place to the logits tensor
    logits : The input logits tensor of shape [num_seqs, vocab_size]
    prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts
        are padded to the maximum prompt length within the batch using
        `vocab_size` as the padding value. The value `vocab_size` is used
        for padding because it does not correspond to any valid token ID
        in the vocabulary.
    output_tokens_tensor: The output tokens tensor.
    presence_penalties: The presence penalties of shape (num_seqs, )
    frequency_penalties: The frequency penalties of shape (num_seqs, )
    repetition_penalties: The repetition penalties of shape (num_seqs, )
    """
    num_seqs, vocab_size = logits.shape
    _, prompt_mask = get_token_bin_counts_and_mask(
        prompt_tokens_tensor, vocab_size, num_seqs
    )
    output_bin_counts, output_mask = get_token_bin_counts_and_mask(
        output_tokens_tensor, vocab_size, num_seqs
    )

    # Apply repetition penalties as a custom op
    from vllm._custom_ops import apply_repetition_penalties

    apply_repetition_penalties(logits, prompt_mask, output_mask, repetition_penalties)

    # We follow the definition in OpenAI API.
    # Refer to https://platform.openai.com/docs/api-reference/parameter-details
    logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
    logits -= presence_penalties.unsqueeze(dim=1) * output_mask
    return logits

def apply_repetition_penalties_torch(
    logits: torch.Tensor,
    prompt_mask: torch.Tensor,
    output_mask: torch.Tensor,
    repetition_penalties: torch.Tensor,
) -> None:
    repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
        1, logits.size(1)
    )
    # If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
    penalties = torch.where(prompt_mask | output_mask, repetition_penalties, 1.0)
    # If logits are positive, divide by penalty, otherwise multiply by penalty.
    scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
    logits *= scaling


def apply_repetition_penalties_cuda(
    logits: torch.Tensor,
    prompt_mask: torch.Tensor,
    output_mask: torch.Tensor,
    repetition_penalties: torch.Tensor,
) -> None:
    torch.ops._C.apply_repetition_penalties_(
        logits, prompt_mask, output_mask, repetition_penalties
    )


def apply_repetition_penalties(
    logits: torch.Tensor,
    prompt_mask: torch.Tensor,
    output_mask: torch.Tensor,
    repetition_penalties: torch.Tensor,
) -> None:
    """Apply repetition penalties to logits in-place.

    Args:
        logits: The logits tensor of shape [num_seqs, vocab_size].
        prompt_mask: A boolean tensor indicating which tokens appear in the prompt.
        output_mask: A boolean tensor indicating which tokens appear in the output.
        repetition_penalties: The repetition penalties of shape (num_seqs, ).
    """
    if logits.is_cuda and logits.is_contiguous():
        apply_repetition_penalties_cuda(
            logits, prompt_mask, output_mask, repetition_penalties
        )
    else:
        apply_repetition_penalties_torch(
            logits, prompt_mask, output_mask, repetition_penalties
        )

当前 SGLang 实现了频率惩罚与存在惩罚。 SGLang 与 OpenAI 接口保持一致移除了之前实现的重复惩罚功能,详见GitHub关于该特性的讨论。需要注意的是当前 SGLang 的 SamplingParams 对象仍然保留了repetition_penalty, 因此,设置该参数不生效。

如果需要在 SGLang 中使用 repetition_penalty,可参考https://github.com/sgl-project/sglang/pull/5703 的实现,不过该实现仅考虑了生成内容,未考虑原始输入内容。需要进行相应的修改,具体如下:

class BatchedRepetitionPenalizer(_BatchedPenalizer):

    def _prepare(self):
        batch_cumulated_repetition_penalties = []
        for req in self.orchestrator.reqs():
            cumulated_repetition_penalties_lst = [1] * self.orchestrator.vocab_size
            for idx in req.origin_input_ids:
                cumulated_repetition_penalties_lst[idx] = req.sampling_params.repetition_penalty
            batch_cumulated_repetition_penalties.append(cumulated_repetition_penalties_lst)

        self.cumulated_repetition_penalties = torch.tensor(
            data=batch_cumulated_repetition_penalties,
                dtype=torch.float32,
                device=self.orchestrator.device,
        )

        self.repetition_penalties = (
            torch.tensor(
                data=[
                    req.sampling_params.repetition_penalty
                    for req in self.orchestrator.reqs()
                ],
                dtype=torch.float32,
                device=self.orchestrator.device,
            )
        ).unsqueeze_(1)

HF Transformers 支持重复惩罚,具体源码和使用示例如下:

class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
    r"""
    [`LogitsProcessor`] that prevents the repetition of previous tokens through a penalty. This penalty is applied at
    most once per token. Note that, for decoder-only models like most LLMs, the considered tokens include the prompt.

    In the original [paper](https://arxiv.org/pdf/1909.05858.pdf), the authors suggest the use of a penalty of around
    1.2 to achieve a good balance between truthful generation and lack of repetition. To penalize and reduce
    repetition, use `penalty` values above 1.0, where a higher value penalizes more strongly. To reward and encourage
    repetition, use `penalty` values between 0.0 and 1.0, where a lower value rewards more strongly.

    Args:
        penalty (`float`):
            The parameter for repetition penalty. 1.0 means no penalty. Above 1.0 penalizes previously generated
            tokens. Between 0.0 and 1.0 rewards previously generated tokens.
    """

    def __init__(self, penalty: float):
        if not isinstance(penalty, float) or not (penalty > 0):
            raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")

        if prompt_ignore_length is not None and (
            not isinstance(prompt_ignore_length, int) or prompt_ignore_length < 0
        ):
            raise ValueError(f"`prompt_ignore_length` has to be a positive integer, but is {prompt_ignore_length}")

        self.penalty = penalty
        self.prompt_ignore_length = prompt_ignore_length

    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
    
        if self.prompt_ignore_length:
            input_ids = input_ids[:, self.prompt_ignore_length :]
            
        score = torch.gather(scores, 1, input_ids)

        # if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
        score = torch.where(score < 0, score * self.penalty, score / self.penalty)

        scores_processed = scores.scatter(1, input_ids, score)
        return scores_processed

使用示例:

from transformers import AutoTokenizer, AutoModelForCausalLM, RepetitionPenaltyLogitsProcessor

# Initializing the model and tokenizer for it
model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
inputs = tokenizer(["I'm not going to"], return_tensors="pt")

# This shows a normal generate without any specific parameters
summary_ids = model.generate(**inputs)
print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0])

# This generates a penalty for repeated tokens
penalized_ids = model.generate(**inputs, repetition_penalty=1.1)
print(tokenizer.batch_decode(penalized_ids, skip_special_tokens=True)[0])

# We can also exclude the input prompt by creating an instance of this class
# with a `prompt_ignore_length` and passing it as a custom logit processor
rep_pen_processor = RepetitionPenaltyLogitsProcessor(
    penalty=1.1,
    prompt_ignore_length=inputs["input_ids"].shape[-1]
)
penalized_ids = model.generate(**inputs, logits_processor=[rep_pen_processor])
print(tokenizer.batch_decode(penalized_ids, skip_special_tokens=True)[0])

总结

总之,repeat_penalty、frequency_penalty, presence_penalty 的设计哲学和惩罚机制有根本区别,它们解决的是不同类型的问题。 对于抑制重复和循环文本,repeat_penalty 通常是更有效、更直接的工具。而frequency_penalty 、 presence_penalty 它们提供了对词汇多样性和主题广度的更精细控制,但抑制强力重复的能力相对较弱。此外,repetition_penalty 通常作用于整个上下文(包括系统提示词、用户输入和已生成内容),而frequency_penalty/presence_penalty通常只作用于本次生成过程中已生成的部分。

在一些支持所有这三个参数的平台上,最佳实践是组合使用。比如:设置一个温和的 repeat_penalty(如 1.05 ~ 1.15)作为“安全网”,防止灾难性的循环。再设置一个轻微的 presence_penalty(如 0.1 ~ 0.3)来温和地鼓励词汇多样性。这样既能有效阻止重复,又能让文本保持自然和丰富。

需要注意的是任何惩罚参数都是“事后补救”。最好的方法是优化你的 Prompt。给出更明确的指令,如“避免重复之前的观点”、“确保内容新颖多样”,或者提供高质量的示例(Few-shot),引导模型产生你期望的输出模式。


声明:该内容由作者自行发布,观点内容仅供参考,不代表平台立场;如有侵权,请联系平台删除。
标签:
大模型