跳到内容

FormatChatGenerationDPO

格式化 ChatGeneration 和直接偏好优化 (DPO) 的偏好任务的组合输出。

FormatChatGenerationDPO 是一个 Step,它格式化 ChatGeneration 任务与偏好 Task(即生成 ratings 的任务,例如 UltraFeedback,遵循 axolotlalignment-handbook 等框架的标准格式)的组合输出,以便用于对现有生成结果进行排名,并根据 ratings 提供 chosenrejected 生成结果。

注意

messages 列应至少包含一条来自用户的消息,generations 列应包含至少两个生成结果,ratings 列应包含与生成结果相同数量的评分。

输入 & 输出列

graph TD
    subgraph Dataset
        subgraph Columns
            ICOL0[messages]
            ICOL1[generations]
            ICOL2[generation_models]
            ICOL3[ratings]
        end
        subgraph New columns
            OCOL0[prompt]
            OCOL1[prompt_id]
            OCOL2[chosen]
            OCOL3[chosen_model]
            OCOL4[chosen_rating]
            OCOL5[rejected]
            OCOL6[rejected_model]
            OCOL7[rejected_rating]
        end
    end

    subgraph FormatChatGenerationDPO
        StepInput[Input Columns: messages, generations, generation_models, ratings]
        StepOutput[Output Columns: prompt, prompt_id, chosen, chosen_model, chosen_rating, rejected, rejected_model, rejected_rating]
    end

    ICOL0 --> StepInput
    ICOL1 --> StepInput
    ICOL2 --> StepInput
    ICOL3 --> StepInput
    StepOutput --> OCOL0
    StepOutput --> OCOL1
    StepOutput --> OCOL2
    StepOutput --> OCOL3
    StepOutput --> OCOL4
    StepOutput --> OCOL5
    StepOutput --> OCOL6
    StepOutput --> OCOL7
    StepInput --> StepOutput

输入

  • messages (List[Dict[str, str]]): 对话消息。

  • generations (List[str]): 由 LLM 生成的生成结果。

  • generation_models (List[str], 可选): 用于生成 generations 的模型名称,仅当 ChatGeneration 任务的模型名称组合成以此命名的单个列时可用,否则将被忽略。

  • ratings (List[float]): 每个 generations 的评分,由偏好任务(如 UltraFeedback)生成。

输出

  • prompt (str): 用于使用 LLM 生成 generations 的用户消息。

  • prompt_id (str): promptSHA256 哈希值。

  • chosen (List[Dict[str, str]]): 基于 ratingschosen 生成结果。

  • chosen_model (str, 可选): 用于生成 chosen 生成结果的模型名称(如果 generation_models 可用)。

  • chosen_rating (float): chosen 生成结果的评分。

  • rejected (List[Dict[str, str]]): 基于 ratingsrejected 生成结果。

  • rejected_model (str, 可选): 用于生成 rejected 生成结果的模型名称(如果 generation_models 可用)。

  • rejected_rating (float): rejected 生成结果的评分。

示例

为 DPO 微调格式化数据集

from distilabel.steps import FormatChatGenerationDPO

format_dpo = FormatChatGenerationDPO()
format_dpo.load()

# NOTE: "generation_models" can be added optionally.
result = next(
    format_dpo.process(
        [
            {
                "messages": [{"role": "user", "content": "What's 2+2?"}],
                "generations": ["4", "5", "6"],
                "ratings": [1, 0, -1],
            }
        ]
    )
)
# >>> result
# [
#     {
#         'messages': [{'role': 'user', 'content': "What's 2+2?"}],
#         'generations': ['4', '5', '6'],
#         'ratings': [1, 0, -1],
#         'prompt': "What's 2+2?",
#         'prompt_id': '7762ecf17ad41479767061a8f4a7bfa3b63d371672af5180872f9b82b4cd4e29',
#         'chosen': [{'role': 'user', 'content': "What's 2+2?"}, {'role': 'assistant', 'content': '4'}],
#         'chosen_rating': 1,
#         'rejected': [{'role': 'user', 'content': "What's 2+2?"}, {'role': 'assistant', 'content': '6'}],
#         'rejected_rating': -1
#     }
# ]