跳到内容

GenerateTextRetrievalData

使用 LLM 生成文本检索数据,以便稍后训练 embedding 模型。

GenerateTextRetrievalData 是一个 Task,它使用 LLM 生成文本检索数据,以便稍后训练 embedding 模型。此任务基于论文“Improving Text Embeddings with Large Language Models”,数据基于提供的属性生成,如果未提供属性,则随机采样。

注意

理想情况下,此任务应与 EmbeddingTaskGenerator 结合使用,并将 flatten_tasks=True 设置为 category="text-retrieval";以便 LLM 生成扁平化的任务列表,从而使每行包含一个用于文本检索类别的任务。

属性

  • language: 要生成的数据的语言,可以是 https://aclanthology.org/2020.acl-main.747.pdf 附录 A 中 XLM-R 列表中检索到的任何语言。

  • query_type: 要生成的查询类型,可以是 extremely long-taillong-tailcommon。默认为 None,表示将随机采样。

  • query_length: 要生成的查询长度,可以是 less than 5 words5 to 15 wordsat least 10 words。默认为 None,表示将随机采样。

  • difficulty: 要生成的查询难度,可以是 high schoolcollegePhD。默认为 None,表示将随机采样。

  • clarity: 要生成的查询清晰度,可以是 clearunderstandable with some effortambiguous。默认为 None,表示将随机采样。

  • num_words: 要生成的查询中的字数,可以是 50100200300400500。默认为 None,表示将随机采样。

  • seed: 要设置的随机种子,以防 format_input 方法中存在任何采样。

输入和输出列

graph TD
    subgraph Dataset
        subgraph Columns
            ICOL0[task]
        end
        subgraph New columns
            OCOL0[user_query]
            OCOL1[positive_document]
            OCOL2[hard_negative_document]
            OCOL3[model_name]
        end
    end

    subgraph GenerateTextRetrievalData
        StepInput[Input Columns: task]
        StepOutput[Output Columns: user_query, positive_document, hard_negative_document, model_name]
    end

    ICOL0 --> StepInput
    StepOutput --> OCOL0
    StepOutput --> OCOL1
    StepOutput --> OCOL2
    StepOutput --> OCOL3
    StepInput --> StepOutput

输入

  • task (str): 生成中要使用的任务描述。

输出

  • user_query (str): 由 LLM 生成的用户查询。

  • positive_document (str): 由 LLM 生成的正向文档。

  • hard_negative_document (str): 由 LLM 生成的难负例文档。

  • model_name (str): 用于生成文本检索数据的模型名称。

示例

生成用于训练 embedding 模型的合成文本检索数据

from distilabel.pipeline import Pipeline
from distilabel.steps.tasks import EmbeddingTaskGenerator, GenerateTextRetrievalData

with Pipeline("my-pipeline") as pipeline:
    task = EmbeddingTaskGenerator(
        category="text-retrieval",
        flatten_tasks=True,
        llm=...,  # LLM instance
    )

    generate = GenerateTextRetrievalData(
        language="English",
        query_type="common",
        query_length="5 to 15 words",
        difficulty="high school",
        clarity="clear",
        num_words=100,
        llm=...,  # LLM instance
    )

    task >> generate

参考