跳到内容

FormatTextGenerationDPO

为直接偏好优化 (DPO) 格式化您的 LLM 的输出。

FormatTextGenerationDPO 是一个 Step,它格式化 TextGeneration 任务与偏好 Task(即生成 ratings 的任务)的组合输出,以便使用这些输出对现有生成进行排名,并根据 ratings 提供 chosenrejected 生成。使用此步骤转换 TextGeneration + 偏好任务(如 UltraFeedback)组合的输出,遵循来自 axolotlalignment-handbook 等框架的标准格式。

注意

generations 列应至少包含两个生成,ratings 列应包含与 generations 相同数量的评分。

输入 & 输出列

graph TD
    subgraph Dataset
        subgraph Columns
            ICOL0[system_prompt]
            ICOL1[instruction]
            ICOL2[generations]
            ICOL3[generation_models]
            ICOL4[ratings]
        end
        subgraph New columns
            OCOL0[prompt]
            OCOL1[prompt_id]
            OCOL2[chosen]
            OCOL3[chosen_model]
            OCOL4[chosen_rating]
            OCOL5[rejected]
            OCOL6[rejected_model]
            OCOL7[rejected_rating]
        end
    end

    subgraph FormatTextGenerationDPO
        StepInput[Input Columns: system_prompt, instruction, generations, generation_models, ratings]
        StepOutput[Output Columns: prompt, prompt_id, chosen, chosen_model, chosen_rating, rejected, rejected_model, rejected_rating]
    end

    ICOL0 --> StepInput
    ICOL1 --> StepInput
    ICOL2 --> StepInput
    ICOL3 --> StepInput
    ICOL4 --> StepInput
    StepOutput --> OCOL0
    StepOutput --> OCOL1
    StepOutput --> OCOL2
    StepOutput --> OCOL3
    StepOutput --> OCOL4
    StepOutput --> OCOL5
    StepOutput --> OCOL6
    StepOutput --> OCOL7
    StepInput --> StepOutput

输入

  • system_prompt (str, 可选): 在 LLM 中用于生成 generations 的系统提示(如果可用)。

  • instruction (str): 用于使用 LLM 生成 generations 的指令。

  • generations (List[str]): 由 LLM 生成的生成结果。

  • generation_models (List[str], 可选): 用于生成 generations 的模型名称,仅当 TextGeneration 任务的模型名称组合成名为此名称的单个列时才可用,否则将被忽略。

  • ratings (List[float]): 每个 generations 的评分,由偏好任务(如 UltraFeedback)生成。

输出

  • prompt (str): 用于使用 LLM 生成 generations 的指令。

  • prompt_id (str): promptSHA256 哈希值。

  • chosen (List[Dict[str, str]]): 基于 ratingschosen (选定的) 生成结果。

  • chosen_model (str, 可选): 用于生成 chosen (选定的) 生成结果的模型名称(如果 generation_models 可用)。

  • chosen_rating (float): chosen (选定的) 生成结果的评分。

  • rejected (List[Dict[str, str]]): 基于 ratingsrejected (拒绝的) 生成结果。

  • rejected_model (str, 可选): 用于生成 rejected (拒绝的) 生成结果的模型名称(如果 generation_models 可用)。

  • rejected_rating (float): rejected (拒绝的) 生成结果的评分。

示例

为 DPO 微调格式化您的数据集

from distilabel.steps import FormatTextGenerationDPO

format_dpo = FormatTextGenerationDPO()
format_dpo.load()

# NOTE: Both "system_prompt" and "generation_models" can be added optionally.
result = next(
    format_dpo.process(
        [
            {
                "instruction": "What's 2+2?",
                "generations": ["4", "5", "6"],
                "ratings": [1, 0, -1],
            }
        ]
    )
)
# >>> result
# [
#    {   'instruction': "What's 2+2?",
#        'generations': ['4', '5', '6'],
#        'ratings': [1, 0, -1],
#        'prompt': "What's 2+2?",
#        'prompt_id': '7762ecf17ad41479767061a8f4a7bfa3b63d371672af5180872f9b82b4cd4e29',
#        'chosen': [{'role': 'user', 'content': "What's 2+2?"}, {'role': 'assistant', 'content': '4'}],
#        'chosen_rating': 1,
#        'rejected': [{'role': 'user', 'content': "What's 2+2?"}, {'role': 'assistant', 'content': '6'}],
#        'rejected_rating': -1
#    }
# ]