跳到内容

用于微调自定义检索和重排序模型的合成数据生成

GenerateSentencePair pipeline overview

注意

有关优化 RAG pipeline 中检索性能的全面概述,请查看与 ZenML 合作的指南,ZenML 是一个开源 MLOps 框架,旨在构建可移植且可用于生产环境的机器学习 pipelines。

开始使用

安装依赖项

要完成本教程,您需要通过 pip 安装 distilabel SDK 和一些第三方库。在本教程中,我们将使用免费但速率受限的 Hugging Face serverless Inference API,因此我们需要将其作为额外的 distilabel 依赖项安装。您可以通过运行以下命令来安装它们

!pip install "distilabel[hf-inference-endpoints]"
!pip install "sentence-transformers~=3.0"

让我们进行所需的导入

from distilabel.models import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps.tasks import GenerateSentencePair
from distilabel.steps import LoadDataFromHub

from sentence_transformers import SentenceTransformer, CrossEncoder
import torch

您需要一个 HF_TOKEN 才能使用 HF Inference Endpoints。登录以在本 notebook 中直接使用它。

import os
from huggingface_hub import login

login(token=os.getenv("HF_TOKEN"), add_to_git_credential=True)

(可选)部署 Argilla

您可以跳过此步骤或将其替换为任何其他数据评估工具,但是您的模型质量会因缺乏数据质量而受到影响,因此我们建议您查看您的数据。如果您已经部署了 Argilla,则可以跳过此步骤。否则,您可以按照本指南快速部署 Argilla。

与此同时,您将需要安装 Argilla 作为 distilabel 的额外项。

!pip install "distilabel[argilla, hf-inference-endpoints]"

让我们进行额外的所需导入

import argilla as rg

数据集

在开始任何项目之前,始终重要的是查看您的数据。我们的数据在 Hugging Face Hub 上公开可用,因此我们可以通过嵌入式 iFrame 中的数据集查看器快速浏览一下。

正如我们所见,我们的数据集包含一个名为 chunks 的列,该列是从 Argilla 文档中获得的。通常,您需要下载数据并进行分块,但本教程中我们不会介绍这一点。要阅读有关如何生成此数据集的完整说明,请参考我们如何利用 distilabel 创建 Argilla 2.0 聊天机器人

或者,我们可以使用 datasets.load_dataset 将整个数据集加载到磁盘。

合成数据生成

distilabel 中的 GenerateSentencePair 组件可用于为 embeddings 模型生成训练数据集。

这是一个预定义的 Task,给定一个 anchor 句子,为特定的 action 生成数据。支持的操作包括:"paraphrase""semantically-similar""query""answer"。在我们的例子中,chunks 列对应于 anchor。这意味着我们将使用 query 为微调检索模型生成潜在的查询,并且我们将使用 semantically-similar 生成与初始 anchor 相似的文本,以微调重排序模型。

我们将 triplet=True 以生成正面和负面示例,这应该有助于模型在微调期间更好地泛化,并且我们将设置 hard_negative=True 以生成更具挑战性的示例,这些示例更接近 anchor 和讨论的主题。

最后,我们可以使用 context 为 LLM 提供种子,以生成更相关的示例。

context = (
"""
The text is a chunk from technical Python SDK documentation of Argilla.
Argilla is a collaboration tool for AI engineers and domain experts to build high-quality datasets.
Along with prose explanations, the text chunk may include code snippets and Python references.
"""
)

检索

对于检索,我们将生成与 chunks 列相似的查询。我们将使用 query 操作为微调检索模型生成潜在的查询。

generate_sentence_pair = GenerateSentencePair(
    triplet=True,  
    hard_negative=True,
    action="query",
    llm=llm,
    input_batch_size=10,
    context=context,
)

重排序

对于重排序,我们将生成与初始 anchor 相似的文本。我们将使用 semantically-similar 操作生成与初始 anchor 相似的文本,以微调重排序模型。在这种情况下,我们将 hard_negative=False 设置为生成更多样化且可能错误的示例,这些示例可以用作相似性微调的负面示例,因为 rerankers 无法使用 triplets 进行微调

generate_sentence_pair = GenerateSentencePair(
    triplet=True,
    hard_negative=False,
    action="semantically-similar",
    llm=llm,
    input_batch_size=10,
    context=context,
)

组合 pipeline

现在,我们将使用 GenerateSentencePair task 在单个 pipeline 中为检索和重排序模型生成合成数据。请注意,我们将 chunks 列映射到 anchor 参数。

llm = InferenceEndpointsLLM(
    model_id="mistralai/Mistral-7B-Instruct-v0.2",
    tokenizer_id="mistralai/Mistral-7B-Instruct-v0.2",
)

with Pipeline(name="generate") as pipeline:
    load_dataset = LoadDataFromHub(
        num_examples=15,
        output_mappings={"chunks": "anchor"},
    )
    generate_retrieval_pairs = GenerateSentencePair(
        name="generate_retrieval_pairs",
        triplet=True,
        hard_negative=True,
        action="query",
        llm=llm,
        input_batch_size=10,
        context=context,
    )
    generate_reranking_pairs = GenerateSentencePair(
        name="generate_reranking_pairs",
        triplet=True,
        hard_negative=False,  # to potentially generate non-relevant pairs
        action="semantically-similar",
        llm=llm,
        input_batch_size=10,
        context=context,
    )

    load_dataset.connect(generate_retrieval_pairs, generate_reranking_pairs)

接下来,我们可以使用 pipeline.run 执行此操作。我们将为 pipeline 中的特定组件提供一些 parameters

generation_kwargs = {
    "llm": {
        "generation_kwargs": {
            "temperature": 0.7,
            "max_new_tokens": 512,
        }
    }
}

distiset = pipeline.run(  
    parameters={
        load_dataset.name: {
            "repo_id": "plaguss/argilla_sdk_docs_raw_unstructured",
            "split": "train",
        },
        generate_retrieval_pairs.name: generation_kwargs,
        generate_reranking_pairs.name: generation_kwargs,
    },
    use_cache=False,  # False for demo
)

数据生成可能很昂贵,因此建议将数据存储在某处。目前,我们将使用我们的 push_to_hub 方法将其存储在 Hugging Face Hub 上。

distiset.push_to_hub("[your-owner-name]/example-retrieval-reranking-dataset")

我们有 2 个不同的 leaf/end 节点,因此我们有可以访问的 distil 配置,一个用于检索数据,另一个用于重排序数据。

查看这些初始示例,我们可以看到它们很好地捕捉了 chunks 列的本质,但在我们可以将其用于微调之前,我们需要进一步评估数据质量。

数据质量评估

数据永远不可能像它可能的那样干净,合成生成的数据也是如此,因此,花一些时间查看您的数据始终是好的。

特征工程

为了评估我们数据的质量,我们将使用我们打算微调的模型的特征作为数据质量的代理。然后,我们可以使用这些特征来过滤掉最佳示例。

为了选择一个好的默认模型,我们将使用 Massive Text Embedding Benchmark (MTEB) Leaderboard。我们希望优化大小和速度,因此我们将模型大小设置为 <100M,然后根据最高平均分过滤 RetrievalReranking,最终得到 Snowflake/snowflake-arctic-embed-ssentence-transformers/all-MiniLM-L12-v2

检索

对于检索,我们将计算 anchor-positivepositive-negativeanchor-negative 对的当前 embeddings 的相似度。我们假设这些相似度的重叠将导致模型难以泛化,因此我们可以使用这些特征来评估我们数据的质量。

model_id = "Snowflake/snowflake-arctic-embed-m"  # Hugging Face model ID

model_retrieval = SentenceTransformer(
    model_id, device="cuda" if torch.cuda.is_available() else "cpu"
)

接下来,我们将编码生成的文本对并计算相似度。

from sklearn.metrics.pairwise import cosine_similarity

def get_embeddings(texts):
    vectors = model_retrieval.encode(texts)
    return [vector.tolist() for vector in vectors]


def get_similarities(vector_batch_a, vector_batch_b):
    similarities = []
    for vector_a, vector_b in zip(vector_batch_a, vector_batch_b):
        similarity = cosine_similarity([vector_a], [vector_b])[0][0]
        similarities.append(similarity)
    return similarities

def format_data_retriever(batch):# -&gt; Any:
    batch["anchor-vector"] = get_embeddings(batch["anchor"])
    batch["positive-vector"] = get_embeddings(batch["positive"])
    batch["negative-vector"] = get_embeddings(batch["negative"])    
    batch["similarity-positive-negative"] = get_similarities(batch["positive-vector"], batch["negative-vector"])
    batch["similarity-anchor-positive"] = get_similarities(batch["anchor-vector"], batch["positive-vector"])
    batch["similarity-anchor-negative"] = get_similarities(batch["anchor-vector"], batch["negative-vector"])
    return batch

dataset_generate_retrieval_pairs = distiset["generate_retrieval_pairs"]["train"].map(format_data_retriever, batched=True, batch_size=250)

重排序

对于重排序,我们将计算来自现有重排序模型的 anchor-positivepositive-negativeanchor-negative 对的相关性得分,并做出与检索模型类似的假设。

model_id = "sentence-transformers/all-MiniLM-L12-v2"

model = CrossEncoder(model_id)

接下来,我们将使用重排序器计算生成的文本对的相似度。最重要的是,我们将计算一个 anchor-vector 以允许进行语义搜索。

def format_data_retriever(batch):# -&gt; Any:
    batch["anchor-vector"] = get_embeddings(batch["anchor"])
    batch["similarity-positive-negative"] = model.predict(zip(batch["positive-vector"], batch["negative-vector"]))
    batch["similarity-anchor-positive"] = model.predict(zip(batch["anchor-vector"], batch["positive-vector"]))
    batch["similarity-anchor-negative"] = model.predict(zip(batch["anchor-vector"], batch["negative-vector"]))
    return batch

dataset_generate_reranking_pairs = distiset["generate_reranking_pairs"]["train"].map(format_data_retriever, batched=True, batch_size=250)

瞧,我们有了用于质量评估的代理,我们可以使用它们来过滤掉最好和最差的示例。

(可选)Argilla

为了充分利用您的数据并实际查看我们的数据,我们将使用 Argilla。如果您不熟悉 Argilla,我们建议您查看 Argilla 快速入门文档。或者,您可以使用您的 Hugging Face 帐户登录到 Argilla 演示 Space

要开始探索数据,我们首先需要定义一个 argilla.Dataset。我们将创建一个基本数据集,其中包含一些用于 anchor 的输入 TextFields 和用于 positivenegative 对的输出 TextQuestions。此外,我们将使用 file_name 作为 MetaDataProperty。最后,我们将重新使用从上一步获得的向量以允许进行语义搜索,并且我们将添加相似度分数以进行一些基本过滤和排序。

首先,我们需要定义 Argilla 数据集的设置。我们将创建两个不同的数据集,一个用于检索数据,另一个用于重排序数据,以确保我们的标注者可以专注于手头的任务。

import argilla as rg
from argilla._exceptions import ConflictError

api_key = "ohh so secret"
api_url = "https://[your-owner-name]-[your-space-name].hf.space"

client = rg.Argilla(api_url=api_url, api_key=api_key)

settings = rg.Settings(
    fields=[
        rg.TextField("anchor")
    ],
    questions=[
        rg.TextQuestion("positive"),
        rg.TextQuestion("negative"),
        rg.LabelQuestion(
            name="is_positive_relevant",
            title="Is the positive query relevant?",
            labels=["yes", "no"],
        ),
        rg.LabelQuestion(
            name="is_negative_irrelevant",
            title="Is the negative query irrelevant?",
            labels=["yes", "no"],
        )
    ],
    metadata=[
        rg.TermsMetadataProperty("filename"),
        rg.FloatMetadataProperty("similarity-positive-negative"),
        rg.FloatMetadataProperty("similarity-anchor-positive"),
        rg.FloatMetadataProperty("similarity-anchor-negative"),
    ],
    vectors=[
        rg.VectorField("anchor-vector", dimensions=model.get_sentence_embedding_dimension())
    ]
)
rg_datasets = []
for dataset_name in ["generate_retrieval_pairs", "generate_reranking_pairs"]:
    ds = rg.Dataset(
        name=dataset_name,
        settings=settings
    )
    try:
        ds.create()
    except ConflictError:
        ds = client.datasets(dataset_name)
    rg_datasets.append(ds)

现在,我们在 Argilla 中设置了数据集定义,我们可以将我们的数据上传到 Argilla。

ds_datasets = [dataset_generate_retrieval_pairs, dataset_generate_reranking_pairs]

records = []

for rg_dataset, ds_dataset in zip(rg_datasets, ds_datasets):
    for idx, entry in enumerate(ds_dataset):
        records.append(
            rg.Record(
                id=idx,
                fields={"anchor": entry["anchor"]},
                suggestions=[
                    rg.Suggestion("positive", value=entry["positive"], agent="gpt-4o", type="model"),
                    rg.Suggestion("negative", value=entry["negative"], agent="gpt-4o", type="model"),
                ],
                metadata={
                    "filename": entry["filename"],
                    "similarity-positive-negative": entry["similarity-positive-negative"],
                    "similarity-anchor-positive": entry["similarity-anchor-positive"],
                    "similarity-anchor-negative": entry["similarity-anchor-negative"]
                },
                vectors={"anchor-vector": entry["anchor-vector"]}
            )
        )
    rg_dataset.records.log(records)

现在,我们可以探索 UI 并添加最后的人工润色,以充分利用我们的数据集。

微调

最后,我们可以微调我们的模型。我们将使用 sentence-transformers 库来微调我们的模型。

检索

对于检索,我们创建了一个脚本,用于基于生成的数据微调模型,生成的数据基于 https://github.com/argilla-io/argilla-sdk-chatbot/blob/main/train_embedding.ipynb。您也可以直接在 Google Colab 中打开它

重排序

对于重排序,sentence-transformers 提供了一个脚本,展示了如何微调 CrossEncoder 模型。目前,对于使用 triplets 微调 CrossEncoder 模型存在一些不确定性,但您仍然可以使用 positiveanchor

结论

在本教程中,我们展示了一个用于 RAG 的检索器和重排序器微调的端到端示例。这为您优化和维护您的数据和模型提供了一个良好的起点,但需要根据您的特定用例进行调整。

我们从 Argilla 文档中的一些种子数据开始,为检索和重排序模型生成合成数据,评估了数据质量,并展示了如何微调模型。我们还使用了 Argilla 来获得对数据的人工润色。