生成合成文本分类数据¶
- 目标:生成合成文本分类数据,以扩充不平衡且有限的数据集,用于训练主题分类器。此外,生成新数据以训练基于事实与基于观点的分类器,从而添加新标签。
- 库:argilla,hf-inference-endpoints,SetFit
- 组件:LoadDataFromDicts,EmbeddingTaskGenerator,GenerateTextClassificationData
开始使用¶
安装依赖项¶
要完成本教程,您需要通过 pip 安装 distilabel SDK 和一些第三方库。在本教程中,我们将使用免费但速率受限的 Hugging Face 无服务器推理 API,因此我们需要将其作为额外的 distilabel 依赖项安装。您可以通过运行以下命令来安装它们
让我们进行所需的导入
import random
from collections import Counter
from datasets import load_dataset, Dataset
from distilabel.models import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromDicts
from distilabel.steps.tasks import (
GenerateTextClassificationData,
)
from setfit import SetFitModel, Trainer, sample_dataset
您需要一个 HF_TOKEN
才能使用 HF Inference Endpoints。登录以在本笔记本中直接使用它。
数据集¶
我们将使用 Hugging Face Hub 中的 fancyzhx/ag_news
数据集作为我们的原始数据源。为了模拟真实世界中数据不平衡且有限的场景,我们将仅从该数据集加载 20 个样本。
现在,我们可以检索数据集中可用的标签并检查当前的数据分布。
正如观察到的,数据集是不平衡的,大多数样本属于 World
类别,而 Sci/Tech
类别完全缺失。此外,样本数量不足以有效地训练主题分类模型。
我们还将为新的分类任务定义标签。
定义文本分类任务¶
为了生成数据,我们将使用 GenerateTextClassificationData
任务。此任务将使用分类任务作为输入,我们可以定义生成数据的语言、难度和清晰度要求。
task = GenerateTextClassificationData(
language="English",
difficulty="college",
clarity="clear",
num_generations=1,
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
generation_kwargs={"max_new_tokens": 512, "temperature": 0.4},
),
input_batch_size=5,
)
task.load()
result = next(
task.process([{"task": "Classify the news article as fact-based or opinion-based"}])
)
print(result[0]["distilabel_metadata"]["raw_input_generate_text_classification_data_0"])
对于我们的用例,我们只需要为两个任务生成数据:主题分类任务和事实与观点分类任务。因此,我们将相应地定义任务。由于我们将使用较小的模型进行生成,因此我们将为每个主题分类任务选择 2 个随机标签,并更改事实与观点分类任务的顺序,以确保生成的数据具有更高的多样性。
task_templates = [
"Determine the news article as {}",
"Classify news article as {}",
"Identify the news article as {}",
"Categorize the news article as {}",
"Label the news article using {}",
"Annotate the news article based on {}",
"Determine the theme of a news article from {}",
"Recognize the topic of the news article as {}",
]
classification_tasks = [
{"task": action.format(" or ".join(random.sample(labels_topic, 2)))}
for action in task_templates for _ in range(4)
] + [
{"task": action.format(" or ".join(random.sample(labels_fact_opinion, 2)))}
for action in task_templates
]
运行 Pipeline¶
现在,是时候定义和运行 Pipeline 了。如前所述,我们将加载编写的任务并将它们输入到 GenerateTextClassificationData
任务中。对于我们的用例,我们将通过 InferenceEndpointsLLM
使用 Meta-Llama-3.1-8B-Instruct
,具有不同的难度和清晰度。
difficulties = ["college", "high school", "PhD"]
clarity = ["clear", "understandable with some effort", "ambiguous"]
with Pipeline("texcat-generation-pipeline") as pipeline:
tasks_generator = LoadDataFromDicts(data=classification_tasks)
generate_data = []
for difficulty in difficulties:
for clarity_level in clarity:
task = GenerateTextClassificationData(
language="English",
difficulty=difficulty,
clarity=clarity_level,
num_generations=2,
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
generation_kwargs={"max_new_tokens": 512, "temperature": 0.7},
),
input_batch_size=5,
)
generate_data.append(task)
for task in generate_data:
tasks_generator.connect(task)
现在让我们运行 Pipeline 并生成合成数据。
您可以将数据集推送到 Hub 以与社区共享,并 嵌入它以探索数据。
通过检查 distiset 分布,我们可以确认它至少包含每个标签所需的 8 个样本,以使用 SetFit 训练我们的分类模型。
我们将使用所需的标签和数据为我们的用例创建两个数据集。
def extract_rows(distiset, labels):
return [
{
"text": entry["input_text"],
"label": entry["label"],
"id": i
}
for dataset_name in distiset
for i, entry in enumerate(distiset[dataset_name]["train"])
if entry["label"] in labels
]
data_topic = extract_rows(distiset, labels_topic)
data_fact_opinion = extract_rows(distiset, labels_fact_opinion)
(可选)使用 Argilla 评估¶
Argilla 入门
如果您不熟悉 Argilla,我们建议您查看 Argilla 快速入门文档。或者,您可以使用您的 Hugging Face 帐户登录到 Argilla 演示 Space。
为了充分利用我们的数据,我们将使用 Argilla。首先,我们需要连接到 Argilla 实例。
import argilla as rg
# Replace api_url with your url if using Docker
# Replace api_key with your API key under "My Settings" in the UI
# Uncomment the last line and set your HF_TOKEN if your space is private
client = rg.Argilla(
api_url="https://[your-owner-name]-[your_space_name].hf.space",
api_key="[your-api-key]",
# headers={"Authorization": f"Bearer {HF_TOKEN}"}
)
我们将为每个任务创建一个 Dataset
,其中包含用于文本分类文本的 TextField
输入和用于确保生成的标签正确的 LabelQuestion
。
def create_texcat_dataset(dataset_name, labels):
settings = rg.Settings(
fields=[rg.TextField("text")],
questions=[
rg.LabelQuestion(
name="label",
title="Classify the texts according to the following labels",
labels=labels,
),
],
)
return rg.Dataset(name=dataset_name, settings=settings).create()
rg_dataset_topic = create_texcat_dataset("topic-classification", labels_topic)
rg_dataset_fact_opinion = create_texcat_dataset(
"fact-opinion-classification", labels_fact_opinion
)
现在,我们可以将生成的数据上传到 Argilla 并进行评估。我们将使用生成的标签作为建议。
现在,我们可以开始标注过程。只需在 Argilla UI 中打开数据集并开始标注记录。如果建议是正确的,您只需单击“提交
”。否则,您可以选择正确的标签。
注意
查看此 操作指南 以了解有关在 UI 中进行标注的更多信息。
获得标注后,让我们继续从 Argilla 检索数据,并将其格式化为包含所需数据的数据集。
训练您的模型¶
在我们的例子中,我们将使用 SetFit 进行微调。但是,您可以选择最适合您需求的模型。
格式化数据¶
下一步是将数据格式化为与 SetFit 兼容。在主题分类的情况下,我们将需要将合成数据与原始数据结合起来。
如果我们现在检查数据分布,我们可以看到我们有足够的样本用于每个标签来训练我们的模型。
现在,让我们创建我们的训练和验证数据集。训练数据集将收集每个标签 8 个样本。在这种情况下,验证数据集将包含未包含在训练数据集中的剩余样本。
def sample_and_split(dataset, label_column, num_samples):
train_dataset = sample_dataset(
dataset, label_column=label_column, num_samples=num_samples
)
eval_dataset = dataset.filter(lambda x: x["id"] not in set(train_dataset["id"]))
return train_dataset, eval_dataset
dataset_topic_full = Dataset.from_list(data_topic)
dataset_fact_opinion_full = Dataset.from_list(data_fact_opinion)
train_dataset_topic, eval_dataset_topic = sample_and_split(
dataset_topic_full, "label", 8
)
train_dataset_fact_opinion, eval_dataset_fact_opinion = sample_and_split(
dataset_fact_opinion_full, "label", 8
)
实际训练¶
让我们为每个任务训练我们的模型!我们将使用 Hugging Face Hub 中提供的 TaylorAI/bge-micro-v2。您可以查看 MTEB 排行榜,以选择最适合您用例的模型。
model_fact_opinion = train_model(
model_name="TaylorAI/bge-micro-v2",
dataset=train_dataset_fact_opinion,
eval_dataset=eval_dataset_fact_opinion,
)
model_fact_opinion.save_pretrained("fact_opinion_classification_model")
model_fact_opinion = SetFitModel.from_pretrained("fact_opinion_classification_model")
瞧!模型现在已训练完成并可以使用了。您可以开始进行预测以检查模型的性能并添加新标签。可选地,您可以继续使用 distilabel 生成更多数据,或使用 Argilla 验证预测的质量。
结论¶
在本教程中,我们展示了构建用于使用 distilabel 生成文本分类数据的 Pipeline 的详细步骤。您可以为自己的用例自定义此 Pipeline,并通过 Hugging Face Hub 与社区分享您的数据集。
我们定义了两个文本分类任务——主题分类任务和事实与观点分类任务——并使用无服务器 Hugging Face Inference API 通过各种模型生成了新数据。然后,我们使用 Argilla 管理了生成的数据。最后,我们使用 SetFit 以及原始数据和合成数据训练了模型。