跳到内容

FormatChatGenerationSFT

格式化 ChatGeneration 任务的输出,用于监督式微调 (SFT)。

FormatChatGenerationSFT 是一个 Step,它格式化 ChatGeneration 任务的输出,用于监督式微调 (SFT),遵循来自 axolotlalignment-handbook 等框架的标准格式。ChatGeneration 任务的输出被格式化为类似聊天的对话,其中 instruction 作为用户消息,generation 作为助手消息。可选地,如果 system_prompt 可用,则将其包含为对话中的第一条消息。

输入 & 输出列

graph TD
    subgraph Dataset
        subgraph Columns
            ICOL0[system_prompt]
            ICOL1[instruction]
            ICOL2[generation]
        end
        subgraph New columns
            OCOL0[prompt]
            OCOL1[prompt_id]
            OCOL2[messages]
        end
    end

    subgraph FormatChatGenerationSFT
        StepInput[Input Columns: system_prompt, instruction, generation]
        StepOutput[Output Columns: prompt, prompt_id, messages]
    end

    ICOL0 --> StepInput
    ICOL1 --> StepInput
    ICOL2 --> StepInput
    StepOutput --> OCOL0
    StepOutput --> OCOL1
    StepOutput --> OCOL2
    StepInput --> StepOutput

输入

  • system_prompt (str, 可选): 在 LLM 中用于生成 generation 的系统提示(如果可用)。

  • instruction (str): 用于使用 LLM 生成 generation 的指令。

  • generation (str): 由 LLM 生成的内容。

输出

  • prompt (str): 用于使用 LLM 生成 generation 的指令。

  • prompt_id (str): promptSHA256 哈希值。

  • messages (List[Dict[str, str]]): 类似聊天的对话,其中 instruction 作为用户消息,generation 作为助手消息。

示例

为 SFT 格式化您的数据集

from distilabel.steps import FormatChatGenerationSFT

format_sft = FormatChatGenerationSFT()
format_sft.load()

# NOTE: "system_prompt" can be added optionally.
result = next(
    format_sft.process(
        [
            {
                "messages": [{"role": "user", "content": "What's 2+2?"}],
                "generation": "4"
            }
        ]
    )
)
# >>> result
# [
#     {
#         'messages': [{'role': 'user', 'content': "What's 2+2?"}, {'role': 'assistant', 'content': '4'}],
#         'generation': '4',
#         'prompt': 'What's 2+2?',
#         'prompt_id': '7762ecf17ad41479767061a8f4a7bfa3b63d371672af5180872f9b82b4cd4e29',
#     }
# ]