GenerateTextClassificationData¶
使用 LLM 生成文本分类数据,以便后续训练 embedding 模型。
GenerateTextClassificationData
是一个 Task
,它使用 LLM
生成文本分类数据,以便后续训练 embedding 模型。该任务基于论文“Improving Text Embeddings with Large Language Models”,并且数据基于提供的属性生成,如果未提供,则随机抽样。
注意¶
理想情况下,此任务应与 EmbeddingTaskGenerator
和 flatten_tasks=True
以及 category="text-classification"
一起使用;以便 LLM
生成一个任务列表,这些任务被展平,从而使每行包含一个用于文本分类类别的任务。
属性¶
-
language: 要生成的数据的语言,可以是 https://aclanthology.org/2020.acl-main.747.pdf 附录 A 中 XLM-R 列表中检索到的任何语言。
-
difficulty: 要生成的查询的难度,可以是
high school
、college
或PhD
。默认为None
,表示将随机抽样。 -
clarity: 要生成的查询的清晰度,可以是
clear
、understandable with some effort
或ambiguous
。默认为None
,表示将随机抽样。 -
seed: 要设置的随机种子,以防在
format_input
方法中进行任何抽样。
输入和输出列¶
graph TD
subgraph Dataset
subgraph Columns
ICOL0[task]
end
subgraph New columns
OCOL0[input_text]
OCOL1[label]
OCOL2[misleading_label]
OCOL3[model_name]
end
end
subgraph GenerateTextClassificationData
StepInput[Input Columns: task]
StepOutput[Output Columns: input_text, label, misleading_label, model_name]
end
ICOL0 --> StepInput
StepOutput --> OCOL0
StepOutput --> OCOL1
StepOutput --> OCOL2
StepOutput --> OCOL3
StepInput --> StepOutput
输入¶
- task (
str
): 生成中使用的任务描述。
输出¶
-
input_text (
str
): 由 LLM 生成的输入文本。 -
label (
str
): 由 LLM 生成的标签。 -
misleading_label (
str
): 由 LLM 生成的误导性标签。 -
model_name (
str
): 用于生成文本分类数据的模型的名称。
示例¶
生成用于训练 embedding 模型的合成文本分类数据¶
from distilabel.pipeline import Pipeline
from distilabel.steps.tasks import EmbeddingTaskGenerator, GenerateTextClassificationData
with Pipeline("my-pipeline") as pipeline:
task = EmbeddingTaskGenerator(
category="text-classification",
flatten_tasks=True,
llm=..., # LLM instance
)
generate = GenerateTextClassificationData(
language="English",
difficulty="high school",
clarity="clear",
llm=..., # LLM instance
)
task >> generate