FormatChatGenerationDPO¶
格式化 ChatGeneration
和直接偏好优化 (DPO) 的偏好任务的组合输出。
FormatChatGenerationDPO
是一个 Step
,它格式化 ChatGeneration
任务与偏好 Task
(即生成 ratings
的任务,例如 UltraFeedback
,遵循 axolotl
或 alignment-handbook
等框架的标准格式)的组合输出,以便用于对现有生成结果进行排名,并根据 ratings
提供 chosen
和 rejected
生成结果。
注意¶
messages
列应至少包含一条来自用户的消息,generations
列应包含至少两个生成结果,ratings
列应包含与生成结果相同数量的评分。
输入 & 输出列¶
graph TD
subgraph Dataset
subgraph Columns
ICOL0[messages]
ICOL1[generations]
ICOL2[generation_models]
ICOL3[ratings]
end
subgraph New columns
OCOL0[prompt]
OCOL1[prompt_id]
OCOL2[chosen]
OCOL3[chosen_model]
OCOL4[chosen_rating]
OCOL5[rejected]
OCOL6[rejected_model]
OCOL7[rejected_rating]
end
end
subgraph FormatChatGenerationDPO
StepInput[Input Columns: messages, generations, generation_models, ratings]
StepOutput[Output Columns: prompt, prompt_id, chosen, chosen_model, chosen_rating, rejected, rejected_model, rejected_rating]
end
ICOL0 --> StepInput
ICOL1 --> StepInput
ICOL2 --> StepInput
ICOL3 --> StepInput
StepOutput --> OCOL0
StepOutput --> OCOL1
StepOutput --> OCOL2
StepOutput --> OCOL3
StepOutput --> OCOL4
StepOutput --> OCOL5
StepOutput --> OCOL6
StepOutput --> OCOL7
StepInput --> StepOutput
输入¶
-
messages (
List[Dict[str, str]]
): 对话消息。 -
generations (
List[str]
): 由LLM
生成的生成结果。 -
generation_models (
List[str]
, 可选): 用于生成generations
的模型名称,仅当ChatGeneration
任务的模型名称组合成以此命名的单个列时可用,否则将被忽略。 -
ratings (
List[float]
): 每个generations
的评分,由偏好任务(如UltraFeedback
)生成。
输出¶
-
prompt (
str
): 用于使用LLM
生成generations
的用户消息。 -
prompt_id (
str
):prompt
的SHA256
哈希值。 -
chosen (
List[Dict[str, str]]
): 基于ratings
的chosen
生成结果。 -
chosen_model (
str
, 可选): 用于生成chosen
生成结果的模型名称(如果generation_models
可用)。 -
chosen_rating (
float
):chosen
生成结果的评分。 -
rejected (
List[Dict[str, str]]
): 基于ratings
的rejected
生成结果。 -
rejected_model (
str
, 可选): 用于生成rejected
生成结果的模型名称(如果generation_models
可用)。 -
rejected_rating (
float
):rejected
生成结果的评分。
示例¶
为 DPO 微调格式化数据集¶
from distilabel.steps import FormatChatGenerationDPO
format_dpo = FormatChatGenerationDPO()
format_dpo.load()
# NOTE: "generation_models" can be added optionally.
result = next(
format_dpo.process(
[
{
"messages": [{"role": "user", "content": "What's 2+2?"}],
"generations": ["4", "5", "6"],
"ratings": [1, 0, -1],
}
]
)
)
# >>> result
# [
# {
# 'messages': [{'role': 'user', 'content': "What's 2+2?"}],
# 'generations': ['4', '5', '6'],
# 'ratings': [1, 0, -1],
# 'prompt': "What's 2+2?",
# 'prompt_id': '7762ecf17ad41479767061a8f4a7bfa3b63d371672af5180872f9b82b4cd4e29',
# 'chosen': [{'role': 'user', 'content': "What's 2+2?"}, {'role': 'assistant', 'content': '4'}],
# 'chosen_rating': 1,
# 'rejected': [{'role': 'user', 'content': "What's 2+2?"}, {'role': 'assistant', 'content': '6'}],
# 'rejected_rating': -1
# }
# ]