跳到内容

使用 FinePersonas 创建社交网络

在本示例中,我们将探索如何使用来自 Hugging Face 的 FinePersonas-v0.1 数据集创建用于社交网络互动的专用用户角色。最终数据集将准备好用于微调具有特定特征和特性的聊天模型。

简介

我们将深入研究微调不同的 LoRA(低秩自适应)模型的过程,以使这些角色具有特定的特征和特性。

这种方法从 Michael Sayman 在 SocialAI 上的工作(访问 个人资料 查看一些示例)中汲取灵感,利用 FinePersonas-v0.1 构建可以模拟具有特定行为的机器人的模型。

通过微调这些适配器,我们有可能创建具有独特特征、沟通风格和专业领域的 AI 角色。结果呢?AI 互动感觉更自然,并且为特定情境或用户需求量身定制。对于那些对这种方法的技术方面感兴趣的人,我们推荐关于 Multi-LoRA serving 的富有洞察力的博客文章。它对这项创新技术背后的技术提供了清晰而全面的解释。

让我们跳转到演示。

创建我们的 SocialAI Task

基于新的 TextGeneration,创建自定义任务比以往任何时候都更容易。这个强大的工具为轻松精确地创建量身定制的基于文本的内容开辟了无限可能。我们将创建一个 SocialAI 任务,该任务将负责生成对用户交互的响应,同时考虑到给定的 follower_type,并使用给定 persona 的视角。

from distilabel.steps.tasks import TextGeneration

class SocialAI(TextGeneration):
    follower_type: Literal["supporter", "troll", "alarmist"] = "supporter"
    system_prompt: str = (
        "You are an AI assistant expert at simulating user interactions. "
        "You must answer as if you were a '{follower_type}', be concise answer with no more than 200 characters, nothing else."
        "Here are some traits to use for your personality:\n\n"
        "{traits}"
    )  # (1)
    template: str = "You are the folowing persona:\n\n{{ persona }}\n\nWhat would you say to the following?\n\n {{ post }}"  # (2)
    columns: str | list[str] = ["persona", "post"]  # (3)

    _follower_traits: dict[str, str] = {
        "supporter": (
            "- Encouraging and positive\n"
            "- Tends to prioritize enjoyment and relaxation\n"
            "- Focuses on the present moment and short-term pleasure\n"
            "- Often uses humor and playful language\n"
            "- Wants to help others feel good and have fun\n"
        ),
        "troll": (
            "- Provocative and confrontational\n"
            "- Enjoys stirring up controversy and conflict\n"
            "- Often uses sarcasm, irony, and mocking language\n"
            "- Tends to belittle or dismiss others' opinions and feelings\n"
            "- Seeks to get a rise out of others and create drama\n"
        ),
        "alarmist": (
            "- Anxious and warning-oriented\n"
            "- Focuses on potential risks and negative consequences\n"
            "- Often uses dramatic or sensational language\n"
            "- Tends to be serious and stern in tone\n"
            "- Seeks to alert others to potential dangers and protect them from harm (even if it's excessive or unwarranted)\n"
        ),
    }

    def load(self) -> None:
        super().load()
        self.system_prompt = self.system_prompt.format(
            follower_type=self.follower_type,
            traits=self._follower_traits[self.follower_type]
        )  # (4)
  1. 我们有一个自定义的系统提示,它将取决于我们为模型决定的 follower_type

  2. 基本模板或提示将从 persona 的角度回答我们拥有的 post

  3. 我们将需要我们的数据集具有 personapost 列来填充提示。

  4. 在加载方法中,我们将特定于我们的关注者类型的特征放在系统提示中。

数据准备

这是一个示例,所以让我们保持简短。我们将使用 3 个帖子和 3 种不同类型的角色。虽然有潜力增强此过程(也许通过实施随机角色选择或利用语义相似性),但我们将在此演示中选择一种直接的方法。

我们的目标是创建一组九个示例,每个示例都将一个帖子与一个角色配对。为了实现这一目标,我们将使用 LLM 从特定 persona 的角度回复每个帖子,有效地模拟不同的角色可能如何与内容互动。

posts = [
    {
        "post": "Hmm, ok now I'm torn: should I go for healthy chicken tacos or unhealthy beef tacos for late night cravings?"
    },
    {
        "post": "I need to develop a training course for my company on communication skills. Need to decide how deliver it remotely."
    },
    {
        "post": "I'm always 10 minutes late to meetups but no one's complained. Could this be annoying to them?"
    },
]

personas = (
    load_dataset("argilla/FinePersonas-v0.1-clustering-100k", split="train")
    .shuffle()
    .select(range(3))
    .select_columns("persona")
    .to_list()
)

data = []
for post in posts:
    for persona in personas:
        data.append({"post": post["post"], "persona": persona["persona"]})

每行将具有以下格式

import json
print(json.dumps(data[0], indent=4))
{
    "post": "Hmm, ok now I'm torn: should I go for healthy chicken tacos or unhealthy beef tacos for late night cravings?",
    "persona": "A high school or college environmental science teacher or an ecology student specializing in biogeography and ecosystem dynamics."
}

这将是我们的数据集,我们可以使用 LoadDataFromDicts 摄取

loader = LoadDataFromDicts(data=data)

从不同类型的关注者进行模拟

有了我们的数据,我们就可以探索我们的 SocialAI 任务的功能了。对于此演示,我们将使用 meta-llama/Meta-Llama-3.1-70B-Instruct 虽然这个模型最近已成为一种首选,但值得注意的是,尝试各种模型可能会产生更有趣的结果

from distilabel.models import InferenceEndpointsLLM

llm = InferenceEndpointsLLM(
    model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
    generation_kwargs={
        "temperature": 0.7,
        "max_new_tokens": 256,
    },
)
follower_type = "supporter"

follower = SocialAI(
    llm=llm,
    follower_type=follower_type,
    name=f"{follower_type}_user",
)

此设置简化了流程,我们只需要输入关注者类型,系统就会处理其余部分。我们也可以更新此设置,使其默认具有随机类型的关注者,并从各种不同的人物角色进行模拟。

构建我们的 Pipeline

我们 pipeline 的基础现在已经到位。它的核心是单个强大的 LLM。这个多功能模型将被重新用于驱动三个不同的 SocialAI Tasks,每个 Task 都针对特定的 TextGeneration 任务量身定制,并且每个 Task 都将使用 FormatTextGenerationSFT 为监督式微调做好准备

with Pipeline(name="Social AI Personas") as pipeline:
    loader = LoadDataFromDicts(data=data, batch_size=1)

    llm = InferenceEndpointsLLM(
        model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
        generation_kwargs={
            "temperature": 0.7,
            "max_new_tokens": 256,
        },
    )

    for follower_type in ["supporter", "troll", "alarmist"]:
        follower = SocialAI(
            llm=llm,
            follower_type=follower_type,
            name=f"{follower_type}_user",  # (1)
            output_mappings={
                "generation": f"interaction_{follower_type}"  # (2)
            }
        )
        format_sft = FormatTextGenerationSFT(
            name=f"format_sft_{follower_type}",
            input_mappings={
                "instruction": "post",
                "generation": f"interaction_{follower_type}"  # (3)
            },
        )
        loader >> follower >> format_sft  # (4)
  1. 我们更新 step 的名称以在 pipeline 中进行跟踪。

  2. 来自每个 LLM 的 generation 列将被映射,以避免它们被覆盖,因为我们正在重复使用相同的任务。

  3. 由于我们修改了 SocialAI 的输出列,我们将每个“follower_type”响应重定向。

  4. 将加载器连接到每个关注者任务和 format_sft 以获得 3 个不同的子集。

此 pipeline 的结果将是三个专门的模型,每个模型都针对由 SocialAI 任务制作的独特 follower type 进行了微调。这些模型将生成 SFT 格式的数据集,其中每个帖子都与其特定关注者类型的相应交互数据配对。此设置可以使用您首选的框架(例如 TRL)或您选择的任何其他训练框架无缝进行微调。

脚本和最终数据集

我们脚本的所有部分都已到位,可以在此处查看完整的 pipeline

运行
python examples/finepersonas_social_ai.py
finepersonas_social_ai.py
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://apache.ac.cn/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Literal

from datasets import load_dataset

from distilabel.models import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import FormatTextGenerationSFT, LoadDataFromDicts
from distilabel.steps.tasks import TextGeneration


class SocialAI(TextGeneration):
    follower_type: Literal["supporter", "troll", "alarmist"] = "supporter"
    system_prompt: str = (
        "You are an AI assistant expert at simulating user interactions. "
        "You must answer as if you were a '{follower_type}', be concise answer with no more than 200 characters, nothing else."
        "Here are some traits to use for your personality:\n\n"
        "{traits}"
    )
    template: str = "You are the folowing persona:\n\n{{ persona }}\n\nWhat would you say to the following?\n\n {{ post }}"
    columns: str | list[str] = ["persona", "post"]

    _follower_traits: dict[str, str] = {
        "supporter": (
            "- Encouraging and positive\n"
            "- Tends to prioritize enjoyment and relaxation\n"
            "- Focuses on the present moment and short-term pleasure\n"
            "- Often uses humor and playful language\n"
            "- Wants to help others feel good and have fun\n"
        ),
        "troll": (
            "- Provocative and confrontational\n"
            "- Enjoys stirring up controversy and conflict\n"
            "- Often uses sarcasm, irony, and mocking language\n"
            "- Tends to belittle or dismiss others' opinions and feelings\n"
            "- Seeks to get a rise out of others and create drama\n"
        ),
        "alarmist": (
            "- Anxious and warning-oriented\n"
            "- Focuses on potential risks and negative consequences\n"
            "- Often uses dramatic or sensational language\n"
            "- Tends to be serious and stern in tone\n"
            "- Seeks to alert others to potential dangers and protect them from harm (even if it's excessive or unwarranted)\n"
        ),
    }

    def load(self) -> None:
        super().load()
        self.system_prompt = self.system_prompt.format(
            follower_type=self.follower_type,
            traits=self._follower_traits[self.follower_type],
        )


posts = [
    {
        "post": "Hmm, ok now I'm torn: should I go for healthy chicken tacos or unhealthy beef tacos for late night cravings?"
    },
    {
        "post": "I need to develop a training course for my company on communication skills. Need to decide how deliver it remotely."
    },
    {
        "post": "I'm always 10 minutes late to meetups but no one's complained. Could this be annoying to them?"
    },
]

personas = (
    load_dataset("argilla/FinePersonas-v0.1-clustering-100k", split="train")
    .shuffle()
    .select(range(3))
    .select_columns("persona")
    .to_list()
)

data = []
for post in posts:
    for persona in personas:
        data.append({"post": post["post"], "persona": persona["persona"]})


with Pipeline(name="Social AI Personas") as pipeline:
    loader = LoadDataFromDicts(data=data, batch_size=1)

    llm = InferenceEndpointsLLM(
        model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
        generation_kwargs={
            "temperature": 0.7,
            "max_new_tokens": 256,
        },
    )

    for follower_type in ["supporter", "troll", "alarmist"]:
        follower = SocialAI(
            llm=llm,
            follower_type=follower_type,
            name=f"{follower_type}_user",
            output_mappings={"generation": f"interaction_{follower_type}"},
        )
        format_sft = FormatTextGenerationSFT(
            name=f"format_sft_{follower_type}",
            input_mappings={
                "instruction": "post",
                "generation": f"interaction_{follower_type}",
            },
        )
        loader >> follower >> format_sft


if __name__ == "__main__":
    distiset = pipeline.run(use_cache=False)
    distiset.push_to_hub("plaguss/FinePersonas-SocialAI-test", include_script=True)

这是我们获得的最终玩具数据集:FinePersonas-SocialAI-test

您可以看到如何加载它们的每个子集以微调模型的示例

from datasets import load_dataset

ds = load_dataset("plaguss/FinePersonas-SocialAI-test", "format_sft_troll")

以及生成的字段的示例,其中包含相应的 postpersona

{
    "post": "Hmm, ok now I\u0027m torn: should I go for healthy chicken tacos or unhealthy beef tacos for late night cravings?",
    "persona": "A high school or undergraduate physics or chemistry teacher, likely with a focus on experimental instruction.",
    "interaction_troll": "\"Late night cravings? More like late night brain drain. Either way, it\u0027s just a collision of molecules in your stomach. Choose the one with more calories, at least that\u0027s some decent kinetic energy.\"",
}

还有很大的改进空间,但这是一个非常有希望的开始。