跳到内容

GenerateTextClassificationData

使用 LLM 生成文本分类数据,以便后续训练 embedding 模型。

GenerateTextClassificationData 是一个 Task,它使用 LLM 生成文本分类数据,以便后续训练 embedding 模型。该任务基于论文“Improving Text Embeddings with Large Language Models”,并且数据基于提供的属性生成,如果未提供,则随机抽样。

注意

理想情况下,此任务应与 EmbeddingTaskGeneratorflatten_tasks=True 以及 category="text-classification" 一起使用;以便 LLM 生成一个任务列表,这些任务被展平,从而使每行包含一个用于文本分类类别的任务。

属性

  • language: 要生成的数据的语言,可以是 https://aclanthology.org/2020.acl-main.747.pdf 附录 A 中 XLM-R 列表中检索到的任何语言。

  • difficulty: 要生成的查询的难度,可以是 high schoolcollegePhD。默认为 None,表示将随机抽样。

  • clarity: 要生成的查询的清晰度,可以是 clearunderstandable with some effortambiguous。默认为 None,表示将随机抽样。

  • seed: 要设置的随机种子,以防在 format_input 方法中进行任何抽样。

输入和输出列

graph TD
    subgraph Dataset
        subgraph Columns
            ICOL0[task]
        end
        subgraph New columns
            OCOL0[input_text]
            OCOL1[label]
            OCOL2[misleading_label]
            OCOL3[model_name]
        end
    end

    subgraph GenerateTextClassificationData
        StepInput[Input Columns: task]
        StepOutput[Output Columns: input_text, label, misleading_label, model_name]
    end

    ICOL0 --> StepInput
    StepOutput --> OCOL0
    StepOutput --> OCOL1
    StepOutput --> OCOL2
    StepOutput --> OCOL3
    StepInput --> StepOutput

输入

  • task (str): 生成中使用的任务描述。

输出

  • input_text (str): 由 LLM 生成的输入文本。

  • label (str): 由 LLM 生成的标签。

  • misleading_label (str): 由 LLM 生成的误导性标签。

  • model_name (str): 用于生成文本分类数据的模型的名称。

示例

生成用于训练 embedding 模型的合成文本分类数据

from distilabel.pipeline import Pipeline
from distilabel.steps.tasks import EmbeddingTaskGenerator, GenerateTextClassificationData

with Pipeline("my-pipeline") as pipeline:
    task = EmbeddingTaskGenerator(
        category="text-classification",
        flatten_tasks=True,
        llm=...,  # LLM instance
    )

    generate = GenerateTextClassificationData(
        language="English",
        difficulty="high school",
        clarity="clear",
        llm=...,  # LLM instance
    )

    task >> generate

参考