跳到内容

Argilla

本节包含与 Argilla 集成的现有步骤,以便轻松将生成的数据集推送到 Argilla。

base

ArgillaBase

基类: Step, ABC

抽象步骤,提供一个类以供子类化,其中包含与 Argilla 交互所需的样板代码,以及在其之上的一些额外验证。它还定义了需要实现的抽象方法,以便添加新的数据集类型作为步骤。

注意

此类不打算直接实例化,而是通过子类实例化。

属性

名称 类型 描述
dataset_name RuntimeParameter[str]

Argilla 中数据集的名称,记录将添加到此数据集。

dataset_workspace Optional[RuntimeParameter[str]]

数据集将在 Argilla 中创建的工作区。默认为 None,表示将在默认工作区中创建。

api_url Optional[RuntimeParameter[str]]

Argilla API 的 URL。默认为 None,表示将从 ARGILLA_API_URL 环境变量中读取。

api_key Optional[RuntimeParameter[SecretStr]]

用于 Argilla 身份验证的 API 密钥。默认为 None,表示将从 ARGILLA_API_KEY 环境变量中读取。

运行时参数
  • dataset_name: Argilla 中数据集的名称,记录将添加到此数据集。
  • dataset_workspace: 数据集将在 Argilla 中创建的工作区。默认为 None,表示将在默认工作区中创建。
  • api_url: 用于 Argilla API 请求的基本 URL。
  • api_key: 用于验证对 Argilla API 请求的 API 密钥。
输入列
  • 动态,基于提供的 inputs
源代码位于 src/distilabel/steps/argilla/base.py
class ArgillaBase(Step, ABC):
    """Abstract step that provides a class to subclass from, that contains the boilerplate code
    required to interact with Argilla, as well as some extra validations on top of it. It also defines
    the abstract methods that need to be implemented in order to add a new dataset type as a step.

    Note:
        This class is not intended to be instanced directly, but via subclass.

    Attributes:
        dataset_name: The name of the dataset in Argilla where the records will be added.
        dataset_workspace: The workspace where the dataset will be created in Argilla. Defaults to
            `None`, which means it will be created in the default workspace.
        api_url: The URL of the Argilla API. Defaults to `None`, which means it will be read from
            the `ARGILLA_API_URL` environment variable.
        api_key: The API key to authenticate with Argilla. Defaults to `None`, which means it will
            be read from the `ARGILLA_API_KEY` environment variable.

    Runtime parameters:
        - `dataset_name`: The name of the dataset in Argilla where the records will be
            added.
        - `dataset_workspace`: The workspace where the dataset will be created in Argilla.
            Defaults to `None`, which means it will be created in the default workspace.
        - `api_url`: The base URL to use for the Argilla API requests.
        - `api_key`: The API key to authenticate the requests to the Argilla API.

    Input columns:
        - dynamic, based on the `inputs` value provided
    """

    dataset_name: RuntimeParameter[str] = Field(
        default=None, description="The name of the dataset in Argilla."
    )
    dataset_workspace: Optional[RuntimeParameter[str]] = Field(
        default=None,
        description="The workspace where the dataset will be created in Argilla. Defaults "
        "to `None` which means it will be created in the default workspace.",
    )

    api_url: Optional[RuntimeParameter[str]] = Field(
        default_factory=lambda: os.getenv(_ARGILLA_API_URL_ENV_VAR_NAME),
        description="The base URL to use for the Argilla API requests.",
    )
    api_key: Optional[RuntimeParameter[SecretStr]] = Field(
        default_factory=lambda: os.getenv(_ARGILLA_API_KEY_ENV_VAR_NAME),
        description="The API key to authenticate the requests to the Argilla API.",
    )

    _client: Optional["Argilla"] = PrivateAttr(...)
    _dataset: Optional["Dataset"] = PrivateAttr(...)

    def model_post_init(self, __context: Any) -> None:
        """Checks that the Argilla Python SDK is installed, and then filters the Argilla warnings."""
        super().model_post_init(__context)

        if importlib.util.find_spec("argilla") is None:
            raise ImportError(
                "Argilla is not installed. Please install it using `pip install 'distilabel[argilla]'`."
            )

    def _client_init(self) -> None:
        """Initializes the Argilla API client with the provided `api_url` and `api_key`."""
        try:
            self._client = rg.Argilla(  # type: ignore
                api_url=self.api_url,
                api_key=self.api_key.get_secret_value(),  # type: ignore
                headers={"Authorization": f"Bearer {os.environ['HF_TOKEN']}"}
                if isinstance(self.api_url, str)
                and "hf.space" in self.api_url
                and "HF_TOKEN" in os.environ
                else {},
            )
        except Exception as e:
            raise DistilabelUserError(
                f"Failed to initialize the Argilla API: {e}",
                page="sections/how_to_guides/advanced/argilla/",
            ) from e

    @property
    def _dataset_exists_in_workspace(self) -> bool:
        """Checks if the dataset already exists in Argilla in the provided workspace if any.

        Returns:
            `True` if the dataset exists, `False` otherwise.
        """
        return (
            self._client.datasets(  # type: ignore
                name=self.dataset_name,  # type: ignore
                workspace=self.dataset_workspace,
            )
            is not None
        )

    @property
    def outputs(self) -> "StepColumns":
        """The outputs of the step is an empty list, since the steps subclassing from this one, will
        always be leaf nodes and won't propagate the inputs neither generate any outputs.
        """
        return []

    def load(self) -> None:
        """Method to perform any initialization logic before the `process` method is
        called. For example, to load an LLM, stablish a connection to a database, etc.
        """
        super().load()

        if self.api_url is None or self.api_key is None:
            raise DistilabelUserError(
                "`Argilla` step requires the `api_url` and `api_key` to be provided. Please,"
                " provide those at step instantiation, via environment variables `ARGILLA_API_URL`"
                " and `ARGILLA_API_KEY`, or as `Step` runtime parameters via `pipeline.run(parameters={...})`.",
                page="sections/how_to_guides/advanced/argilla/",
            )

        self._client_init()

    @property
    @abstractmethod
    def inputs(self) -> "StepColumns": ...

    @abstractmethod
    def process(self, *inputs: StepInput) -> "StepOutput": ...
outputs property

该步骤的输出是一个空列表,因为从此步骤子类化的步骤将始终是叶节点,既不会传播输入,也不会生成任何输出。

model_post_init(__context)

检查是否安装了 Argilla Python SDK,然后过滤 Argilla 警告。

源代码位于 src/distilabel/steps/argilla/base.py
def model_post_init(self, __context: Any) -> None:
    """Checks that the Argilla Python SDK is installed, and then filters the Argilla warnings."""
    super().model_post_init(__context)

    if importlib.util.find_spec("argilla") is None:
        raise ImportError(
            "Argilla is not installed. Please install it using `pip install 'distilabel[argilla]'`."
        )
load()

在调用 process 方法之前执行任何初始化逻辑的方法。例如,加载 LLM、建立与数据库的连接等。

源代码位于 src/distilabel/steps/argilla/base.py
def load(self) -> None:
    """Method to perform any initialization logic before the `process` method is
    called. For example, to load an LLM, stablish a connection to a database, etc.
    """
    super().load()

    if self.api_url is None or self.api_key is None:
        raise DistilabelUserError(
            "`Argilla` step requires the `api_url` and `api_key` to be provided. Please,"
            " provide those at step instantiation, via environment variables `ARGILLA_API_URL`"
            " and `ARGILLA_API_KEY`, or as `Step` runtime parameters via `pipeline.run(parameters={...})`.",
            page="sections/how_to_guides/advanced/argilla/",
        )

    self._client_init()

preference

PreferenceToArgilla

基类: ArgillaBase

在 Argilla 中创建偏好数据集。

在加载阶段在 Argilla 中创建数据集的步骤,然后将输入批次作为记录推送到其中。此数据集是一个偏好数据集,其中指令有一个字段,每个生成结果在同一记录中有一个额外的字段,然后每个生成字段都有一个评分问题。评分问题要求注释者为每个提供的生成结果设置 1 到 5 的评分。

注意

此步骤旨在与 UltraFeedback 步骤结合使用,或与任何其他步骤结合使用,这些步骤为给定的指令和给定指令的生成结果生成评分和响应。但或者,它也可以与任何其他仅生成 instructiongenerations 的任务或步骤一起使用,因为 ratingsrationales 是可选的。

属性

名称 类型 描述
num_generations int

要包含在数据集中的生成结果的数量。

dataset_name RuntimeParameter[str]

Argilla 中数据集的名称。

dataset_workspace Optional[RuntimeParameter[str]]

数据集将在 Argilla 中创建的工作区。默认为 None,表示将在默认工作区中创建。

api_url Optional[RuntimeParameter[str]]

Argilla API 的 URL。默认为 None,表示将从 ARGILLA_API_URL 环境变量中读取。

api_key Optional[RuntimeParameter[SecretStr]]

用于 Argilla 身份验证的 API 密钥。默认为 None,表示将从 ARGILLA_API_KEY 环境变量中读取。

运行时参数
  • api_url: 用于 Argilla API 请求的基本 URL。
  • api_key: 用于验证对 Argilla API 请求的 API 密钥。
输入列
  • instruction (str): 用于生成 completion 的指令。
  • generations (List[str]): 基于输入指令生成的 completion。
  • ratings (List[str], 可选): 生成结果的评分。如果未提供,则生成的评分不会推送到 Argilla。
  • rationales (List[str], 可选): 评分的理由。如果未提供,则生成的理由不会推送到 Argilla。

示例

将偏好数据集推送到 Argilla 实例

from distilabel.steps import PreferenceToArgilla

to_argilla = PreferenceToArgilla(
    num_generations=2,
    api_url="https://dibt-demo-argilla-space.hf.space/",
    api_key="api.key",
    dataset_name="argilla_dataset",
    dataset_workspace="my_workspace",
)
to_argilla.load()

result = next(
    to_argilla.process(
        [
            {
                "instruction": "instruction",
                "generations": ["first_generation", "second_generation"],
            }
        ],
    )
)
# >>> result
# [{'instruction': 'instruction', 'generations': ['first_generation', 'second_generation']}]

它还可以包括评分和理由

result = next(
    to_argilla.process(
        [
            {
                "instruction": "instruction",
                "generations": ["first_generation", "second_generation"],
                "ratings": ["4", "5"],
                "rationales": ["rationale for 4", "rationale for 5"],
            }
        ],
    )
)
# >>> result
# [
#     {
#         'instruction': 'instruction',
#         'generations': ['first_generation', 'second_generation'],
#         'ratings': ['4', '5'],
#         'rationales': ['rationale for 4', 'rationale for 5']
#     }
# ]
源代码位于 src/distilabel/steps/argilla/preference.py
class PreferenceToArgilla(ArgillaBase):
    """Creates a preference dataset in Argilla.

    Step that creates a dataset in Argilla during the load phase, and then pushes the input
    batches into it as records. This dataset is a preference dataset, where there's one field
    for the instruction and one extra field per each generation within the same record, and then
    a rating question per each of the generation fields. The rating question asks the annotator to
    set a rating from 1 to 5 for each of the provided generations.

    Note:
        This step is meant to be used in conjunction with the `UltraFeedback` step, or any other step
        generating both ratings and responses for a given set of instruction and generations for the
        given instruction. But alternatively, it can also be used with any other task or step generating
        only the `instruction` and `generations`, as the `ratings` and `rationales` are optional.

    Attributes:
        num_generations: The number of generations to include in the dataset.
        dataset_name: The name of the dataset in Argilla.
        dataset_workspace: The workspace where the dataset will be created in Argilla. Defaults to
            `None`, which means it will be created in the default workspace.
        api_url: The URL of the Argilla API. Defaults to `None`, which means it will be read from
            the `ARGILLA_API_URL` environment variable.
        api_key: The API key to authenticate with Argilla. Defaults to `None`, which means it will
            be read from the `ARGILLA_API_KEY` environment variable.

    Runtime parameters:
        - `api_url`: The base URL to use for the Argilla API requests.
        - `api_key`: The API key to authenticate the requests to the Argilla API.

    Input columns:
        - instruction (`str`): The instruction that was used to generate the completion.
        - generations (`List[str]`): The completion that was generated based on the input instruction.
        - ratings (`List[str]`, optional): The ratings for the generations. If not provided, the
            generated ratings won't be pushed to Argilla.
        - rationales (`List[str]`, optional): The rationales for the ratings. If not provided, the
            generated rationales won't be pushed to Argilla.

    Examples:
        Push a preference dataset to an Argilla instance:

        ```python
        from distilabel.steps import PreferenceToArgilla

        to_argilla = PreferenceToArgilla(
            num_generations=2,
            api_url="https://dibt-demo-argilla-space.hf.space/",
            api_key="api.key",
            dataset_name="argilla_dataset",
            dataset_workspace="my_workspace",
        )
        to_argilla.load()

        result = next(
            to_argilla.process(
                [
                    {
                        "instruction": "instruction",
                        "generations": ["first_generation", "second_generation"],
                    }
                ],
            )
        )
        # >>> result
        # [{'instruction': 'instruction', 'generations': ['first_generation', 'second_generation']}]
        ```

        It can also include ratings and rationales:

        ```python
        result = next(
            to_argilla.process(
                [
                    {
                        "instruction": "instruction",
                        "generations": ["first_generation", "second_generation"],
                        "ratings": ["4", "5"],
                        "rationales": ["rationale for 4", "rationale for 5"],
                    }
                ],
            )
        )
        # >>> result
        # [
        #     {
        #         'instruction': 'instruction',
        #         'generations': ['first_generation', 'second_generation'],
        #         'ratings': ['4', '5'],
        #         'rationales': ['rationale for 4', 'rationale for 5']
        #     }
        # ]
        ```
    """

    num_generations: int

    _id: str = PrivateAttr(default="id")
    _instruction: str = PrivateAttr(...)
    _generations: str = PrivateAttr(...)
    _ratings: str = PrivateAttr(...)
    _rationales: str = PrivateAttr(...)

    def load(self) -> None:
        """Sets the `_instruction` and `_generations` attributes based on the `inputs_mapping`, otherwise
        uses the default values; and then uses those values to create a `FeedbackDataset` suited for
        the text-generation scenario. And then it pushes it to Argilla.
        """
        super().load()

        # Both `instruction` and `generations` will be used as the fields of the dataset
        self._instruction = self.input_mappings.get("instruction", "instruction")
        self._generations = self.input_mappings.get("generations", "generations")
        # Both `ratings` and `rationales` will be used as suggestions to the default questions of the dataset
        self._ratings = self.input_mappings.get("ratings", "ratings")
        self._rationales = self.input_mappings.get("rationales", "rationales")

        if self._dataset_exists_in_workspace:
            _dataset = self._client.datasets(  # type: ignore
                name=self.dataset_name,  # type: ignore
                workspace=self.dataset_workspace,  # type: ignore
            )

            for field in _dataset.fields:
                if not isinstance(field, rg.TextField):
                    continue
                if (
                    field.name
                    not in [self._id, self._instruction]  # type: ignore
                    + [
                        f"{self._generations}-{idx}"
                        for idx in range(self.num_generations)
                    ]
                    and field.required
                ):
                    raise DistilabelUserError(
                        f"The dataset '{self.dataset_name}' in the workspace '{self.dataset_workspace}'"
                        f" already exists, but contains at least a required field that is"
                        f" neither `{self._id}`, `{self._instruction}`, nor `{self._generations}`"
                        f" (one per generation starting from 0 up to {self.num_generations - 1}).",
                        page="components-gallery/steps/preferencetoargilla/",
                    )

            self._dataset = _dataset
        else:
            _settings = rg.Settings(  # type: ignore
                fields=[
                    rg.TextField(name=self._id, title=self._id),  # type: ignore
                    rg.TextField(name=self._instruction, title=self._instruction),  # type: ignore
                    *self._generation_fields(),  # type: ignore
                ],
                questions=self._rating_rationale_pairs(),  # type: ignore
            )
            _dataset = rg.Dataset(  # type: ignore
                name=self.dataset_name,
                workspace=self.dataset_workspace,
                settings=_settings,
                client=self._client,
            )
            self._dataset = _dataset.create()

    def _generation_fields(self) -> List["TextField"]:
        """Method to generate the fields for each of the generations.

        Returns:
            A list containing `TextField`s for each text generation.
        """
        return [
            rg.TextField(  # type: ignore
                name=f"{self._generations}-{idx}",
                title=f"{self._generations}-{idx}",
                required=True if idx == 0 else False,
            )
            for idx in range(self.num_generations)
        ]

    def _rating_rationale_pairs(
        self,
    ) -> List[Union["RatingQuestion", "TextQuestion"]]:
        """Method to generate the rating and rationale questions for each of the generations.

        Returns:
            A list of questions containing a `RatingQuestion` and `TextQuestion` pair for
            each text generation.
        """
        questions = []
        for idx in range(self.num_generations):
            questions.extend(
                [
                    rg.RatingQuestion(  # type: ignore
                        name=f"{self._generations}-{idx}-rating",
                        title=f"Rate {self._generations}-{idx} given {self._instruction}.",
                        description=f"Ignore this question if the corresponding `{self._generations}-{idx}` field is not available."
                        if idx != 0
                        else None,
                        values=[1, 2, 3, 4, 5],
                        required=True if idx == 0 else False,
                    ),
                    rg.TextQuestion(  # type: ignore
                        name=f"{self._generations}-{idx}-rationale",
                        title=f"Specify the rationale for {self._generations}-{idx}'s rating.",
                        description=f"Ignore this question if the corresponding `{self._generations}-{idx}` field is not available."
                        if idx != 0
                        else None,
                        required=False,
                    ),
                ]
            )
        return questions

    @property
    def inputs(self) -> List[str]:
        """The inputs for the step are the `instruction` and the `generations`. Optionally, one could also
        provide the `ratings` and the `rationales` for the generations."""
        return ["instruction", "generations"]

    @property
    def optional_inputs(self) -> List[str]:
        """The optional inputs for the step are the `ratings` and the `rationales` for the generations."""
        return ["ratings", "rationales"]

    def _add_suggestions_if_any(self, input: Dict[str, Any]) -> List["Suggestion"]:
        """Method to generate the suggestions for the `rg.Record` based on the input.

        Returns:
            A list of `Suggestion`s for the rating and rationales questions.
        """
        # Since the `suggestions` i.e. answers to the `questions` are optional, will default to {}
        suggestions = []
        # If `ratings` is in `input`, then add those as suggestions
        if self._ratings in input:
            suggestions.extend(
                [
                    rg.Suggestion(  # type: ignore
                        value=rating,
                        question_name=f"{self._generations}-{idx}-rating",
                    )
                    for idx, rating in enumerate(input[self._ratings])
                    if rating is not None
                    and isinstance(rating, int)
                    and rating in [1, 2, 3, 4, 5]
                ],
            )
        # If `rationales` is in `input`, then add those as suggestions
        if self._rationales in input:
            suggestions.extend(
                [
                    rg.Suggestion(  # type: ignore
                        value=rationale,
                        question_name=f"{self._generations}-{idx}-rationale",
                    )
                    for idx, rationale in enumerate(input[self._rationales])
                    if rationale is not None and isinstance(rationale, str)
                ],
            )
        return suggestions

    @override
    def process(self, inputs: StepInput) -> "StepOutput":  # type: ignore
        """Creates and pushes the records as `rg.Record`s to the Argilla dataset.

        Args:
            inputs: A list of Python dictionaries with the inputs of the task.

        Returns:
            A list of Python dictionaries with the outputs of the task.
        """
        records = []
        for input in inputs:
            # Generate the SHA-256 hash of the instruction to use it as the metadata
            instruction_id = hashlib.sha256(
                input["instruction"].encode("utf-8")  # type: ignore
            ).hexdigest()

            generations = {
                f"{self._generations}-{idx}": generation
                for idx, generation in enumerate(input["generations"])  # type: ignore
            }

            records.append(  # type: ignore
                rg.Record(  # type: ignore
                    fields={
                        "id": instruction_id,
                        "instruction": input["instruction"],  # type: ignore
                        **generations,
                    },
                    suggestions=self._add_suggestions_if_any(input),  # type: ignore
                )
            )
        self._dataset.records.log(records)  # type: ignore
        yield inputs
inputs property

该步骤的输入是 instructiongenerations。可选地,也可以提供生成结果的 ratingsrationales

optional_inputs property

该步骤的可选输入是生成结果的 ratingsrationales

load()

根据 inputs_mapping 设置 _instruction_generations 属性,否则使用默认值;然后使用这些值创建一个适合文本生成场景的 FeedbackDataset。然后将其推送到 Argilla。

源代码位于 src/distilabel/steps/argilla/preference.py
def load(self) -> None:
    """Sets the `_instruction` and `_generations` attributes based on the `inputs_mapping`, otherwise
    uses the default values; and then uses those values to create a `FeedbackDataset` suited for
    the text-generation scenario. And then it pushes it to Argilla.
    """
    super().load()

    # Both `instruction` and `generations` will be used as the fields of the dataset
    self._instruction = self.input_mappings.get("instruction", "instruction")
    self._generations = self.input_mappings.get("generations", "generations")
    # Both `ratings` and `rationales` will be used as suggestions to the default questions of the dataset
    self._ratings = self.input_mappings.get("ratings", "ratings")
    self._rationales = self.input_mappings.get("rationales", "rationales")

    if self._dataset_exists_in_workspace:
        _dataset = self._client.datasets(  # type: ignore
            name=self.dataset_name,  # type: ignore
            workspace=self.dataset_workspace,  # type: ignore
        )

        for field in _dataset.fields:
            if not isinstance(field, rg.TextField):
                continue
            if (
                field.name
                not in [self._id, self._instruction]  # type: ignore
                + [
                    f"{self._generations}-{idx}"
                    for idx in range(self.num_generations)
                ]
                and field.required
            ):
                raise DistilabelUserError(
                    f"The dataset '{self.dataset_name}' in the workspace '{self.dataset_workspace}'"
                    f" already exists, but contains at least a required field that is"
                    f" neither `{self._id}`, `{self._instruction}`, nor `{self._generations}`"
                    f" (one per generation starting from 0 up to {self.num_generations - 1}).",
                    page="components-gallery/steps/preferencetoargilla/",
                )

        self._dataset = _dataset
    else:
        _settings = rg.Settings(  # type: ignore
            fields=[
                rg.TextField(name=self._id, title=self._id),  # type: ignore
                rg.TextField(name=self._instruction, title=self._instruction),  # type: ignore
                *self._generation_fields(),  # type: ignore
            ],
            questions=self._rating_rationale_pairs(),  # type: ignore
        )
        _dataset = rg.Dataset(  # type: ignore
            name=self.dataset_name,
            workspace=self.dataset_workspace,
            settings=_settings,
            client=self._client,
        )
        self._dataset = _dataset.create()
process(inputs)

创建记录并将其作为 rg.Record 推送到 Argilla 数据集。

参数

名称 类型 描述 默认
inputs StepInput

包含任务输入的 Python 字典列表。

必需

返回值

类型 描述
StepOutput

包含任务输出的 Python 字典列表。

源代码位于 src/distilabel/steps/argilla/preference.py
@override
def process(self, inputs: StepInput) -> "StepOutput":  # type: ignore
    """Creates and pushes the records as `rg.Record`s to the Argilla dataset.

    Args:
        inputs: A list of Python dictionaries with the inputs of the task.

    Returns:
        A list of Python dictionaries with the outputs of the task.
    """
    records = []
    for input in inputs:
        # Generate the SHA-256 hash of the instruction to use it as the metadata
        instruction_id = hashlib.sha256(
            input["instruction"].encode("utf-8")  # type: ignore
        ).hexdigest()

        generations = {
            f"{self._generations}-{idx}": generation
            for idx, generation in enumerate(input["generations"])  # type: ignore
        }

        records.append(  # type: ignore
            rg.Record(  # type: ignore
                fields={
                    "id": instruction_id,
                    "instruction": input["instruction"],  # type: ignore
                    **generations,
                },
                suggestions=self._add_suggestions_if_any(input),  # type: ignore
            )
        )
    self._dataset.records.log(records)  # type: ignore
    yield inputs

text_generation

TextGenerationToArgilla

基类: ArgillaBase

在 Argilla 中创建文本生成数据集。

Step,在加载阶段在 Argilla 中创建数据集,然后将输入批次作为记录推送到其中。此数据集是一个文本生成数据集,其中每个输入有一个字段,然后是一个标签问题,用于将 completion 的质量评为差 (用 👎 表示) 或好 (用 👍 表示)。

注意

此步骤旨在与 TextGeneration 步骤结合使用,并且不需要列映射,因为它将使用 instructiongeneration 列的默认值。

属性

名称 类型 描述
dataset_name RuntimeParameter[str]

Argilla 中数据集的名称。

dataset_workspace Optional[RuntimeParameter[str]]

数据集将在 Argilla 中创建的工作区。默认为 None,表示将在默认工作区中创建。

api_url Optional[RuntimeParameter[str]]

Argilla API 的 URL。默认为 None,表示将从 ARGILLA_API_URL 环境变量中读取。

api_key Optional[RuntimeParameter[SecretStr]]

用于 Argilla 身份验证的 API 密钥。默认为 None,表示将从 ARGILLA_API_KEY 环境变量中读取。

运行时参数
  • api_url: 用于 Argilla API 请求的基本 URL。
  • api_key: 用于验证对 Argilla API 请求的 API 密钥。
输入列
  • instruction (str): 用于生成 completion 的指令。
  • generation (strList[str]): 基于输入指令生成的 completion。

示例

将文本生成数据集推送到 Argilla 实例

from distilabel.steps import PreferenceToArgilla

to_argilla = TextGenerationToArgilla(
    num_generations=2,
    api_url="https://dibt-demo-argilla-space.hf.space/",
    api_key="api.key",
    dataset_name="argilla_dataset",
    dataset_workspace="my_workspace",
)
to_argilla.load()

result = next(
    to_argilla.process(
        [
            {
                "instruction": "instruction",
                "generation": "generation",
            }
        ],
    )
)
# >>> result
# [{'instruction': 'instruction', 'generation': 'generation'}]
源代码位于 src/distilabel/steps/argilla/text_generation.py
class TextGenerationToArgilla(ArgillaBase):
    """Creates a text generation dataset in Argilla.

    `Step` that creates a dataset in Argilla during the load phase, and then pushes the input
    batches into it as records. This dataset is a text-generation dataset, where there's one field
    per each input, and then a label question to rate the quality of the completion in either bad
    (represented with 👎) or good (represented with 👍).

    Note:
        This step is meant to be used in conjunction with a `TextGeneration` step and no column mapping
        is needed, as it will use the default values for the `instruction` and `generation` columns.

    Attributes:
        dataset_name: The name of the dataset in Argilla.
        dataset_workspace: The workspace where the dataset will be created in Argilla. Defaults to
            `None`, which means it will be created in the default workspace.
        api_url: The URL of the Argilla API. Defaults to `None`, which means it will be read from
            the `ARGILLA_API_URL` environment variable.
        api_key: The API key to authenticate with Argilla. Defaults to `None`, which means it will
            be read from the `ARGILLA_API_KEY` environment variable.

    Runtime parameters:
        - `api_url`: The base URL to use for the Argilla API requests.
        - `api_key`: The API key to authenticate the requests to the Argilla API.

    Input columns:
        - instruction (`str`): The instruction that was used to generate the completion.
        - generation (`str` or `List[str]`): The completions that were generated based on the input instruction.

    Examples:
        Push a text generation dataset to an Argilla instance:

        ```python
        from distilabel.steps import PreferenceToArgilla

        to_argilla = TextGenerationToArgilla(
            num_generations=2,
            api_url="https://dibt-demo-argilla-space.hf.space/",
            api_key="api.key",
            dataset_name="argilla_dataset",
            dataset_workspace="my_workspace",
        )
        to_argilla.load()

        result = next(
            to_argilla.process(
                [
                    {
                        "instruction": "instruction",
                        "generation": "generation",
                    }
                ],
            )
        )
        # >>> result
        # [{'instruction': 'instruction', 'generation': 'generation'}]
        ```
    """

    _id: str = PrivateAttr(default="id")
    _instruction: str = PrivateAttr(...)
    _generation: str = PrivateAttr(...)

    def load(self) -> None:
        """Sets the `_instruction` and `_generation` attributes based on the `inputs_mapping`, otherwise
        uses the default values; and then uses those values to create a `FeedbackDataset` suited for
        the text-generation scenario. And then it pushes it to Argilla.
        """
        super().load()

        self._instruction = self.input_mappings.get("instruction", "instruction")
        self._generation = self.input_mappings.get("generation", "generation")

        if self._dataset_exists_in_workspace:
            _dataset = self._client.datasets(  # type: ignore
                name=self.dataset_name,  # type: ignore
                workspace=self.dataset_workspace,  # type: ignore
            )

            for field in _dataset.fields:
                if not isinstance(field, rg.TextField):  # type: ignore
                    continue
                if (
                    field.name not in [self._id, self._instruction, self._generation]
                    and field.required
                ):
                    raise DistilabelUserError(
                        f"The dataset '{self.dataset_name}' in the workspace '{self.dataset_workspace}'"
                        f" already exists, but contains at least a required field that is"
                        f" neither `{self._id}`, `{self._instruction}`, nor `{self._generation}`,"
                        " so it cannot be reused for this dataset.",
                        page="components-gallery/steps/textgenerationtoargilla/",
                    )

            self._dataset = _dataset
        else:
            _settings = rg.Settings(  # type: ignore
                fields=[
                    rg.TextField(name=self._id, title=self._id),  # type: ignore
                    rg.TextField(name=self._instruction, title=self._instruction),  # type: ignore
                    rg.TextField(name=self._generation, title=self._generation),  # type: ignore
                ],
                questions=[
                    rg.LabelQuestion(  # type: ignore
                        name="quality",
                        title=f"What's the quality of the {self._generation} for the given {self._instruction}?",
                        labels={"bad": "👎", "good": "👍"},  # type: ignore
                    )
                ],
            )
            _dataset = rg.Dataset(  # type: ignore
                name=self.dataset_name,
                workspace=self.dataset_workspace,
                settings=_settings,
                client=self._client,
            )
            self._dataset = _dataset.create()

    @property
    def inputs(self) -> List[str]:
        """The inputs for the step are the `instruction` and the `generation`."""
        return ["instruction", "generation"]

    @override
    def process(self, inputs: StepInput) -> "StepOutput":  # type: ignore
        """Creates and pushes the records as FeedbackRecords to the Argilla dataset.

        Args:
            inputs: A list of Python dictionaries with the inputs of the task.

        Returns:
            A list of Python dictionaries with the outputs of the task.
        """
        records = []
        for input in inputs:
            # Generate the SHA-256 hash of the instruction to use it as the metadata
            instruction_id = hashlib.sha256(
                input["instruction"].encode("utf-8")
            ).hexdigest()

            generations = input["generation"]

            # If the `generation` is not a list, then convert it into a list
            if not isinstance(generations, list):
                generations = [generations]

            # Create a `generations_set` to avoid adding duplicates
            generations_set = set()

            for generation in generations:
                # If the generation is already in the set, then skip it
                if generation in generations_set:
                    continue
                # Otherwise, add it to the set
                generations_set.add(generation)

                records.append(
                    rg.Record(  # type: ignore
                        fields={
                            self._id: instruction_id,
                            self._instruction: input["instruction"],
                            self._generation: generation,
                        },
                    ),
                )
        self._dataset.records.log(records)  # type: ignore
        yield inputs
inputs property

该步骤的输入是 instructiongeneration

load()

根据 inputs_mapping 设置 _instruction_generation 属性,否则使用默认值;然后使用这些值创建一个适合文本生成场景的 FeedbackDataset。然后将其推送到 Argilla。

源代码位于 src/distilabel/steps/argilla/text_generation.py
def load(self) -> None:
    """Sets the `_instruction` and `_generation` attributes based on the `inputs_mapping`, otherwise
    uses the default values; and then uses those values to create a `FeedbackDataset` suited for
    the text-generation scenario. And then it pushes it to Argilla.
    """
    super().load()

    self._instruction = self.input_mappings.get("instruction", "instruction")
    self._generation = self.input_mappings.get("generation", "generation")

    if self._dataset_exists_in_workspace:
        _dataset = self._client.datasets(  # type: ignore
            name=self.dataset_name,  # type: ignore
            workspace=self.dataset_workspace,  # type: ignore
        )

        for field in _dataset.fields:
            if not isinstance(field, rg.TextField):  # type: ignore
                continue
            if (
                field.name not in [self._id, self._instruction, self._generation]
                and field.required
            ):
                raise DistilabelUserError(
                    f"The dataset '{self.dataset_name}' in the workspace '{self.dataset_workspace}'"
                    f" already exists, but contains at least a required field that is"
                    f" neither `{self._id}`, `{self._instruction}`, nor `{self._generation}`,"
                    " so it cannot be reused for this dataset.",
                    page="components-gallery/steps/textgenerationtoargilla/",
                )

        self._dataset = _dataset
    else:
        _settings = rg.Settings(  # type: ignore
            fields=[
                rg.TextField(name=self._id, title=self._id),  # type: ignore
                rg.TextField(name=self._instruction, title=self._instruction),  # type: ignore
                rg.TextField(name=self._generation, title=self._generation),  # type: ignore
            ],
            questions=[
                rg.LabelQuestion(  # type: ignore
                    name="quality",
                    title=f"What's the quality of the {self._generation} for the given {self._instruction}?",
                    labels={"bad": "👎", "good": "👍"},  # type: ignore
                )
            ],
        )
        _dataset = rg.Dataset(  # type: ignore
            name=self.dataset_name,
            workspace=self.dataset_workspace,
            settings=_settings,
            client=self._client,
        )
        self._dataset = _dataset.create()
process(inputs)

创建记录并将其作为 FeedbackRecords 推送到 Argilla 数据集。

参数

名称 类型 描述 默认
inputs StepInput

包含任务输入的 Python 字典列表。

必需

返回值

类型 描述
StepOutput

包含任务输出的 Python 字典列表。

源代码位于 src/distilabel/steps/argilla/text_generation.py
@override
def process(self, inputs: StepInput) -> "StepOutput":  # type: ignore
    """Creates and pushes the records as FeedbackRecords to the Argilla dataset.

    Args:
        inputs: A list of Python dictionaries with the inputs of the task.

    Returns:
        A list of Python dictionaries with the outputs of the task.
    """
    records = []
    for input in inputs:
        # Generate the SHA-256 hash of the instruction to use it as the metadata
        instruction_id = hashlib.sha256(
            input["instruction"].encode("utf-8")
        ).hexdigest()

        generations = input["generation"]

        # If the `generation` is not a list, then convert it into a list
        if not isinstance(generations, list):
            generations = [generations]

        # Create a `generations_set` to avoid adding duplicates
        generations_set = set()

        for generation in generations:
            # If the generation is already in the set, then skip it
            if generation in generations_set:
                continue
            # Otherwise, add it to the set
            generations_set.add(generation)

            records.append(
                rg.Record(  # type: ignore
                    fields={
                        self._id: instruction_id,
                        self._instruction: input["instruction"],
                        self._generation: generation,
                    },
                ),
            )
    self._dataset.records.log(records)  # type: ignore
    yield inputs