FormatChatGenerationDPO¶
格式化 ChatGeneration 和直接偏好优化 (DPO) 的偏好任务的组合输出。
FormatChatGenerationDPO 是一个 Step,它格式化 ChatGeneration 任务与偏好 Task(即生成 ratings 的任务,例如 UltraFeedback,遵循 axolotl 或 alignment-handbook 等框架的标准格式)的组合输出,以便用于对现有生成结果进行排名,并根据 ratings 提供 chosen 和 rejected 生成结果。
注意¶
messages 列应至少包含一条来自用户的消息,generations 列应包含至少两个生成结果,ratings 列应包含与生成结果相同数量的评分。
输入 & 输出列¶
graph TD
subgraph Dataset
subgraph Columns
ICOL0[messages]
ICOL1[generations]
ICOL2[generation_models]
ICOL3[ratings]
end
subgraph New columns
OCOL0[prompt]
OCOL1[prompt_id]
OCOL2[chosen]
OCOL3[chosen_model]
OCOL4[chosen_rating]
OCOL5[rejected]
OCOL6[rejected_model]
OCOL7[rejected_rating]
end
end
subgraph FormatChatGenerationDPO
StepInput[Input Columns: messages, generations, generation_models, ratings]
StepOutput[Output Columns: prompt, prompt_id, chosen, chosen_model, chosen_rating, rejected, rejected_model, rejected_rating]
end
ICOL0 --> StepInput
ICOL1 --> StepInput
ICOL2 --> StepInput
ICOL3 --> StepInput
StepOutput --> OCOL0
StepOutput --> OCOL1
StepOutput --> OCOL2
StepOutput --> OCOL3
StepOutput --> OCOL4
StepOutput --> OCOL5
StepOutput --> OCOL6
StepOutput --> OCOL7
StepInput --> StepOutput
输入¶
-
messages (
List[Dict[str, str]]): 对话消息。 -
generations (
List[str]): 由LLM生成的生成结果。 -
generation_models (
List[str], 可选): 用于生成generations的模型名称,仅当ChatGeneration任务的模型名称组合成以此命名的单个列时可用,否则将被忽略。 -
ratings (
List[float]): 每个generations的评分,由偏好任务(如UltraFeedback)生成。
输出¶
-
prompt (
str): 用于使用LLM生成generations的用户消息。 -
prompt_id (
str):prompt的SHA256哈希值。 -
chosen (
List[Dict[str, str]]): 基于ratings的chosen生成结果。 -
chosen_model (
str, 可选): 用于生成chosen生成结果的模型名称(如果generation_models可用)。 -
chosen_rating (
float):chosen生成结果的评分。 -
rejected (
List[Dict[str, str]]): 基于ratings的rejected生成结果。 -
rejected_model (
str, 可选): 用于生成rejected生成结果的模型名称(如果generation_models可用)。 -
rejected_rating (
float):rejected生成结果的评分。
示例¶
为 DPO 微调格式化数据集¶
from distilabel.steps import FormatChatGenerationDPO
format_dpo = FormatChatGenerationDPO()
format_dpo.load()
# NOTE: "generation_models" can be added optionally.
result = next(
format_dpo.process(
[
{
"messages": [{"role": "user", "content": "What's 2+2?"}],
"generations": ["4", "5", "6"],
"ratings": [1, 0, -1],
}
]
)
)
# >>> result
# [
# {
# 'messages': [{'role': 'user', 'content': "What's 2+2?"}],
# 'generations': ['4', '5', '6'],
# 'ratings': [1, 0, -1],
# 'prompt': "What's 2+2?",
# 'prompt_id': '7762ecf17ad41479767061a8f4a7bfa3b63d371672af5180872f9b82b4cd4e29',
# 'chosen': [{'role': 'user', 'content': "What's 2+2?"}, {'role': 'assistant', 'content': '4'}],
# 'chosen_rating': 1,
# 'rejected': [{'role': 'user', 'content': "What's 2+2?"}, {'role': 'assistant', 'content': '6'}],
# 'rejected_rating': -1
# }
# ]