跳到内容

生成合成文本分类数据

开始使用

安装依赖项

要完成本教程,您需要通过 pip 安装 distilabel SDK 和一些第三方库。在本教程中,我们将使用免费但速率受限的 Hugging Face 无服务器推理 API,因此我们需要将其作为额外的 distilabel 依赖项安装。您可以通过运行以下命令来安装它们

!pip install "distilabel[hf-inference-endpoints]"
!pip install "transformers~=4.40" "torch~=2.0" "setfit~=1.0"

让我们进行所需的导入

import random
from collections import Counter

from datasets import load_dataset, Dataset
from distilabel.models import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromDicts
from distilabel.steps.tasks import (
    GenerateTextClassificationData,
)
from setfit import SetFitModel, Trainer, sample_dataset

您需要一个 HF_TOKEN 才能使用 HF Inference Endpoints。登录以在本笔记本中直接使用它。

import os
from huggingface_hub import login

login(token=os.getenv("HF_TOKEN"), add_to_git_credential=True)

(可选)部署 Argilla

您可以跳过此步骤或将其替换为任何其他数据评估工具,但是您的模型质量会因缺乏数据质量而受到影响,因此我们强烈建议您查看您的数据。如果您已经部署了 Argilla,则可以跳过此步骤。否则,您可以按照 本指南 快速部署 Argilla。

同时,您还需要安装 Argilla 作为 distilabel 的额外组件。

!pip install "distilabel[argilla, hf-inference-endpoints]"

数据集

我们将使用 Hugging Face Hub 中的 fancyzhx/ag_news 数据集作为我们的原始数据源。为了模拟真实世界中数据不平衡且有限的场景,我们将仅从该数据集加载 20 个样本。

hf_dataset = load_dataset("fancyzhx/ag_news", split="train[-20:]")

现在,我们可以检索数据集中可用的标签并检查当前的数据分布。

labels_topic = hf_dataset.features["label"].names
id2str = {i: labels_topic[i] for i in range(len(labels_topic))}
print(id2str)
print(Counter(hf_dataset["label"]))
{0: 'World', 1: 'Sports', 2: 'Business', 3: 'Sci/Tech'}
Counter({0: 12, 1: 6, 2: 2})

正如观察到的,数据集是不平衡的,大多数样本属于 World 类别,而 Sci/Tech 类别完全缺失。此外,样本数量不足以有效地训练主题分类模型。

我们还将为新的分类任务定义标签。

labels_fact_opinion = ["Fact-based", "Opinion-based"]

定义文本分类任务

为了生成数据,我们将使用 GenerateTextClassificationData 任务。此任务将使用分类任务作为输入,我们可以定义生成数据的语言、难度和清晰度要求。

task = GenerateTextClassificationData(
    language="English",
    difficulty="college",
    clarity="clear",
    num_generations=1,
    llm=InferenceEndpointsLLM(
        model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
        tokenizer_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
        generation_kwargs={"max_new_tokens": 512, "temperature": 0.4},
    ),
    input_batch_size=5,
)
task.load()
result = next(
    task.process([{"task": "Classify the news article as fact-based or opinion-based"}])
)
print(result[0]["distilabel_metadata"]["raw_input_generate_text_classification_data_0"])
[{'role': 'user', 'content': 'You have been assigned a text classification task: Classify the news article as fact-based or opinion-based\n\nYour mission is to write one text classification example for this task in JSON format. The JSON object must contain the following keys:\n - "input_text": a string, the input text specified by the classification task.\n - "label": a string, the correct label of the input text.\n - "misleading_label": a string, an incorrect label that is related to the task.\n\nPlease adhere to the following guidelines:\n - The "input_text" should be diverse in expression.\n - The "misleading_label" must be a valid label for the given task, but not as appropriate as the "label" for the "input_text".\n - The values for all fields should be in English.\n - Avoid including the values of the "label" and "misleading_label" fields in the "input_text", that would make the task too easy.\n - The "input_text" is clear and requires college level education to comprehend.\n\nYour output must always be a JSON object only, do not explain yourself or output anything else. Be creative!'}]

对于我们的用例,我们只需要为两个任务生成数据:主题分类任务和事实与观点分类任务。因此,我们将相应地定义任务。由于我们将使用较小的模型进行生成,因此我们将为每个主题分类任务选择 2 个随机标签,并更改事实与观点分类任务的顺序,以确保生成的数据具有更高的多样性。

task_templates = [
    "Determine the news article as {}",
    "Classify news article as {}",
    "Identify the news article as {}",
    "Categorize the news article as {}",
    "Label the news article using {}",
    "Annotate the news article based on {}",
    "Determine the theme of a news article from {}",
    "Recognize the topic of the news article as {}",
]

classification_tasks = [
    {"task": action.format(" or ".join(random.sample(labels_topic, 2)))}
    for action in task_templates for _ in range(4)
] + [
    {"task": action.format(" or ".join(random.sample(labels_fact_opinion, 2)))}
    for action in task_templates
]

运行 Pipeline

现在,是时候定义和运行 Pipeline 了。如前所述,我们将加载编写的任务并将它们输入到 GenerateTextClassificationData 任务中。对于我们的用例,我们将通过 InferenceEndpointsLLM 使用 Meta-Llama-3.1-8B-Instruct,具有不同的难度和清晰度。

difficulties = ["college", "high school", "PhD"]
clarity = ["clear", "understandable with some effort", "ambiguous"]

with Pipeline("texcat-generation-pipeline") as pipeline:

    tasks_generator = LoadDataFromDicts(data=classification_tasks)

    generate_data = []
    for difficulty in difficulties:
        for clarity_level in clarity:
            task = GenerateTextClassificationData(
                language="English",
                difficulty=difficulty,
                clarity=clarity_level,
                num_generations=2,
                llm=InferenceEndpointsLLM(
                    model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
                    tokenizer_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
                    generation_kwargs={"max_new_tokens": 512, "temperature": 0.7},
                ),
                input_batch_size=5,
            )
            generate_data.append(task)

    for task in generate_data:
        tasks_generator.connect(task)

现在让我们运行 Pipeline 并生成合成数据。

distiset = pipeline.run()
distiset["generate_text_classification_data_0"]["train"][0]
{'task': 'Determine the news article as Business or World',
 'input_text': "The recent decision by the European Central Bank to raise interest rates will likely have a significant impact on the eurozone's economic growth, with some analysts predicting a 0.5% contraction in GDP due to the increased borrowing costs. The move is seen as a measure to combat inflation, which has been rising steadily over the past year.",
 'label': 'Business',
 'misleading_label': 'World',
 'distilabel_metadata': {'raw_output_generate_text_classification_data_0': '{\n  "input_text": "The recent decision by the European Central Bank to raise interest rates will likely have a significant impact on the eurozone\'s economic growth, with some analysts predicting a 0.5% contraction in GDP due to the increased borrowing costs. The move is seen as a measure to combat inflation, which has been rising steadily over the past year.",\n  "label": "Business",\n  "misleading_label": "World"\n}'},
 'model_name': 'meta-llama/Meta-Llama-3.1-8B-Instruct'}

您可以将数据集推送到 Hub 以与社区共享,并 嵌入它以探索数据

distiset.push_to_hub("[your-owner-name]/example-texcat-generation-dataset")

通过检查 distiset 分布,我们可以确认它至少包含每个标签所需的 8 个样本,以使用 SetFit 训练我们的分类模型。

all_labels = [
    entry["label"]
    for dataset_name in distiset
    for entry in distiset[dataset_name]["train"]
]

Counter(all_labels)
Counter({'Sci/Tech': 275,
         'Business': 130,
         'World': 86,
         'Fact-based': 86,
         'Sports': 64,
         'Opinion-based': 54,
         None: 20,
         'Opinion Based': 1,
         'News/Opinion': 1,
         'Science': 1,
         'Environment': 1,
         'Opinion': 1})

我们将使用所需的标签和数据为我们的用例创建两个数据集。

def extract_rows(distiset, labels):
    return [
        {
            "text": entry["input_text"],
            "label": entry["label"],
            "id": i
        }
        for dataset_name in distiset
        for i, entry in enumerate(distiset[dataset_name]["train"])
        if entry["label"] in labels
    ]

data_topic = extract_rows(distiset, labels_topic)
data_fact_opinion = extract_rows(distiset, labels_fact_opinion)

(可选)使用 Argilla 评估

Argilla 入门

如果您不熟悉 Argilla,我们建议您查看 Argilla 快速入门文档。或者,您可以使用您的 Hugging Face 帐户登录到 Argilla 演示 Space

为了充分利用我们的数据,我们将使用 Argilla。首先,我们需要连接到 Argilla 实例。

import argilla as rg

# Replace api_url with your url if using Docker
# Replace api_key with your API key under "My Settings" in the UI
# Uncomment the last line and set your HF_TOKEN if your space is private
client = rg.Argilla(
    api_url="https://[your-owner-name]-[your_space_name].hf.space",
    api_key="[your-api-key]",
    # headers={"Authorization": f"Bearer {HF_TOKEN}"}
)

我们将为每个任务创建一个 Dataset,其中包含用于文本分类文本的 TextField 输入和用于确保生成的标签正确的 LabelQuestion

def create_texcat_dataset(dataset_name, labels):
    settings = rg.Settings(
        fields=[rg.TextField("text")],
        questions=[
            rg.LabelQuestion(
                name="label",
                title="Classify the texts according to the following labels",
                labels=labels,
            ),
        ],
    )
    return rg.Dataset(name=dataset_name, settings=settings).create()


rg_dataset_topic = create_texcat_dataset("topic-classification", labels_topic)
rg_dataset_fact_opinion = create_texcat_dataset(
    "fact-opinion-classification", labels_fact_opinion
)

现在,我们可以将生成的数据上传到 Argilla 并进行评估。我们将使用生成的标签作为建议。

rg_dataset_topic.records.log(data_topic)
rg_dataset_fact_opinion.records.log(data_fact_opinion)

现在,我们可以开始标注过程。只需在 Argilla UI 中打开数据集并开始标注记录。如果建议是正确的,您只需单击“提交”。否则,您可以选择正确的标签。

注意

查看此 操作指南 以了解有关在 UI 中进行标注的更多信息。

获得标注后,让我们继续从 Argilla 检索数据,并将其格式化为包含所需数据的数据集。

rg_dataset_topic = client.datasets("topic-classification")
rg_dataset_fact_opinion = client.datasets("fact-opinion-classification")
status_filter = rg.Query(filter=rg.Filter(("response.status", "==", "submitted")))

submitted_topic = rg_dataset_topic.records(status_filter).to_list(flatten=True)
submitted_fact_opinion = rg_dataset_fact_opinion.records(status_filter).to_list(
    flatten=True
)
def format_submitted(submitted):
    return [
        {
            "text": r["text"],
            "label": r["label.responses"][0],
            "id": i,
        }
        for i, r in enumerate(submitted)
    ]

data_topic = format_submitted(submitted_topic)
data_fact_opinion = format_submitted(submitted_fact_opinion)

训练您的模型

在我们的例子中,我们将使用 SetFit 进行微调。但是,您可以选择最适合您需求的模型。

格式化数据

下一步是将数据格式化为与 SetFit 兼容。在主题分类的情况下,我们将需要将合成数据与原始数据结合起来。

hf_topic = hf_dataset.to_list()
num = len(data_topic)

data_topic.extend(
    [
        {
            "text": r["text"],
            "label": id2str[r["label"]],
            "id": num + i,
        }
        for i, r in enumerate(hf_topic)
    ]
)

如果我们现在检查数据分布,我们可以看到我们有足够的样本用于每个标签来训练我们的模型。

labels = [record["label"] for record in data_topic]
Counter(labels)
Counter({'Sci/Tech': 275, 'Business': 132, 'World': 98, 'Sports': 70})
labels = [record["label"] for record in data_fact_opinion]
Counter(labels)
Counter({'Fact-based': 86, 'Opinion-based': 54})

现在,让我们创建我们的训练和验证数据集。训练数据集将收集每个标签 8 个样本。在这种情况下,验证数据集将包含未包含在训练数据集中的剩余样本。

def sample_and_split(dataset, label_column, num_samples):
    train_dataset = sample_dataset(
        dataset, label_column=label_column, num_samples=num_samples
    )
    eval_dataset = dataset.filter(lambda x: x["id"] not in set(train_dataset["id"]))
    return train_dataset, eval_dataset


dataset_topic_full = Dataset.from_list(data_topic)
dataset_fact_opinion_full = Dataset.from_list(data_fact_opinion)

train_dataset_topic, eval_dataset_topic = sample_and_split(
    dataset_topic_full, "label", 8
)
train_dataset_fact_opinion, eval_dataset_fact_opinion = sample_and_split(
    dataset_fact_opinion_full, "label", 8
)

实际训练

让我们为每个任务训练我们的模型!我们将使用 Hugging Face Hub 中提供的 TaylorAI/bge-micro-v2。您可以查看 MTEB 排行榜,以选择最适合您用例的模型。

def train_model(model_name, dataset, eval_dataset):
    model = SetFitModel.from_pretrained(model_name)

    trainer = Trainer(
        model=model,
        train_dataset=dataset,
    )
    trainer.train()
    metrics = trainer.evaluate(eval_dataset)
    print(metrics)

    return model
model_topic = train_model(
    model_name="TaylorAI/bge-micro-v2",
    dataset=train_dataset_topic,
    eval_dataset=eval_dataset_topic,
)
model_topic.save_pretrained("topic_classification_model")
model_topic = SetFitModel.from_pretrained("topic_classification_model")
***** Running training *****
  Num unique pairs = 768
  Batch size = 16
  Num epochs = 1
  Total optimization steps = 48

{'embedding_loss': 0.1873, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.02}

***** Running evaluation *****

{'train_runtime': 4.9767, 'train_samples_per_second': 154.318, 'train_steps_per_second': 9.645, 'epoch': 1.0}
{'accuracy': 0.8333333333333334}

model_fact_opinion = train_model(
    model_name="TaylorAI/bge-micro-v2",
    dataset=train_dataset_fact_opinion,
    eval_dataset=eval_dataset_fact_opinion,
)
model_fact_opinion.save_pretrained("fact_opinion_classification_model")
model_fact_opinion = SetFitModel.from_pretrained("fact_opinion_classification_model")
***** Running training *****
  Num unique pairs = 144
  Batch size = 16
  Num epochs = 1
  Total optimization steps = 9

{'embedding_loss': 0.2985, 'learning_rate': 2e-05, 'epoch': 0.11}

***** Running evaluation *****

{'train_runtime': 0.8327, 'train_samples_per_second': 172.931, 'train_steps_per_second': 10.808, 'epoch': 1.0}
{'accuracy': 0.9090909090909091}

瞧!模型现在已训练完成并可以使用了。您可以开始进行预测以检查模型的性能并添加新标签。可选地,您可以继续使用 distilabel 生成更多数据,或使用 Argilla 验证预测的质量。

def predict(model, input, labels):
    model.labels = labels
    prediction = model.predict([input])
    return prediction[0]
predict(
    model_topic, "The new iPhone is expected to be released next month.", labels_topic
)
'Sci/Tech'
predict(
    model_fact_opinion,
    "The new iPhone is expected to be released next month.",
    labels_fact_opinion,
)
'Opinion-based'

结论

在本教程中,我们展示了构建用于使用 distilabel 生成文本分类数据的 Pipeline 的详细步骤。您可以为自己的用例自定义此 Pipeline,并通过 Hugging Face Hub 与社区分享您的数据集。

我们定义了两个文本分类任务——主题分类任务和事实与观点分类任务——并使用无服务器 Hugging Face Inference API 通过各种模型生成了新数据。然后,我们使用 Argilla 管理了生成的数据。最后,我们使用 SetFit 以及原始数据和合成数据训练了模型。