跳到内容

基于人工智能修订的对比学习 (CLAIR)

“锚定偏好优化和对比修订:解决对齐中的欠规范问题” 介绍了基于人工智能修订的对比学习 (CLAIR),这是一种数据创建方法,可以产生更具对比性的偏好对,以及锚定偏好优化 (APO),这是一种可控且更稳定的对齐目标。虽然可以在 TRL 中找到 APO,但我们在 distilabel 中为 CLAIR 实现了任务。

CLAIR 是一种创建偏好对的方法,它通过最小程度地修订一个输出以表达偏好,从而产生更精确的学习信号,而不是使用判断者选择首选响应的传统方法。

CLAIR overview

原始论文的作者分享了来自 CLAIR 和 APO 的数据集集合,其中 ContextualAI/ultrafeedback_clair_32k 对应于 CLAIR 的实现。

复现

注意

本节名为“复现”,但在本例中,我们将展示如何使用 CLAIR 任务使用 distilabel 为您的生成创建修订。

为了展示 CLAIR,我们将使用在 distilabel 中实现的 CLAIR 任务,并且我们将重用 ContextualAI ContextualAI/ultrafeedback_clair_32k 已生成数据集的一个小样本进行测试。

安装

要重现以下代码,需要按如下方式安装 distilabel

pip install "distilabel>=1.4.0"

根据您想要使用的大型语言模型提供商,要求可能会有所不同,请查看相关依赖项。在本例中,我们使用 Hugging Face 的免费推理端点,但这不适用于更大的数据集。

构建模块

在这种情况下,我们已经有了指令及其生成结果,我们只需要加载数据和相应的 CLAIR 任务来进行修订

  • CLAIR 来生成修订。

代码

让我们看看应用于 distilabelContextualAI/ultrafeedback_clair_32k 的完整 pipeline

from typing import Any, Dict

from datasets import load_dataset

from distilabel.pipeline import Pipeline
from distilabel.steps.tasks import CLAIR
from distilabel.models import InferenceEndpointsLLM


def transform_ultrafeedback(example: Dict[str, Any]) -> Dict[str, Any]:
    return {
        "task": example["prompt"],
        "student_solution": example["rejected"][1]["content"],
    }

dataset = (
    load_dataset("ContextualAI/ultrafeedback_clair_32k", split="train")
    .select(range(10))             # We collect just 10 examples
    .map(transform_ultrafeedback)  # Apply the transformation to get just the text
)

with Pipeline(name="CLAIR UltraFeedback sample") as pipeline:
    clair = CLAIR(  # (1)
        llm=InferenceEndpointsLLM(
            model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
            generation_kwargs={
                "temperature": 0.7,
                "max_new_tokens": 4096
            }
        )
    )


if __name__ == "__main__":
    distiset = pipeline.run(dataset=dataset)  # (2)
    distiset.push_to_hub(repo_id="username/clair-test", include_script=True)  # (3)
  1. 此 Pipeline 仅使用 CLAIR,因为我们已经有了生成结果,但是可以包含第一个任务来从指令创建生成结果,然后使用 CLAIR 进行修订。

  2. 为简单起见,将数据集直接包含在 run 方法中。

  3. 使用脚本将 distiset 推送到 hub 以实现可重现性。

示例数据集可以在以下位置找到:distilabel-internal-testing/clair-test