跳到内容

保存步骤生成工件

一些 Step 可能需要生成辅助工件,该工件不是计算结果,但计算需要它。例如,FaissNearestNeighbour 需要创建一个 Faiss 索引来计算步骤的输出,即每个输入的 top k 个最近邻居。生成 Faiss 索引需要时间,并且它有可能在 distilabel pipeline 之外重用,因此不保存它将是一种遗憾。

因此,Step 有一个名为 save_artifact 的方法,允许保存工件,这些工件将与 pipeline 的输出一起包含在生成的 Distiset 中。当分别使用 Distiset.push_to_hubDistiset.save_to_disk 时,生成的工件将被上传和保存。让我们看一个简单的示例,了解如何使用它。

from typing import List, TYPE_CHECKING
from distilabel.steps import GlobalStep, StepInput, StepOutput
import matplotlib.pyplot as plt

if TYPE_CHECKING:
    from distilabel.steps import StepOutput


class CountTextCharacters(GlobalStep):
    @property
    def inputs(self) -> List[str]:
        return ["text"]

    @property
    def outputs(self) -> List[str]:
        return ["text_character_count"]

    def process(self, inputs: StepInput) -> "StepOutput":  # type: ignore
        character_counts = []

        for input in inputs:
            text_character_count = len(input["text"])
            input["text_character_count"] = text_character_count
            character_counts.append(text_character_count)

        # Generate plot with the distribution of text character counts
        plt.figure(figsize=(10, 6))
        plt.hist(character_counts, bins=30, edgecolor="black")
        plt.title("Distribution of Text Character Counts")
        plt.xlabel("Character Count")
        plt.ylabel("Frequency")

        # Save the plot as an artifact of the step
        self.save_artifact(
            name="text_character_count_distribution",
            write_function=lambda path: plt.savefig(path / "figure.png"),
            metadata={"type": "image", "library": "matplotlib"},
        )

        plt.close()

        yield inputs

正如上面的示例所示,我们创建了一个简单的 step,用于计算每个输入文本中的字符数,并生成一个包含字符数分布的直方图。我们使用 save_artifact 方法将直方图保存为 step 的工件。该方法接受三个参数

  • name:我们想要赋予工件的名称。
  • write_function:一个将工件写入所需路径的函数。该函数将接收一个 path 参数,该参数是一个 pathlib.Path 对象,指向应保存工件的目录。
  • metadata:包含有关工件元数据的字典。此元数据将与工件一起保存。

让我们使用一个简单的 pipeline 执行 step,并将生成的 Distiset 推送到 Hugging Face Hub

示例完整代码
from typing import TYPE_CHECKING, List

import matplotlib.pyplot as plt
from datasets import load_dataset
from distilabel.pipeline import Pipeline
from distilabel.steps import GlobalStep, StepInput, StepOutput

if TYPE_CHECKING:
    from distilabel.steps import StepOutput


class CountTextCharacters(GlobalStep):
    @property
    def inputs(self) -> List[str]:
        return ["text"]

    @property
    def outputs(self) -> List[str]:
        return ["text_character_count"]

    def process(self, inputs: StepInput) -> "StepOutput":  # type: ignore
        character_counts = []

        for input in inputs:
            text_character_count = len(input["text"])
            input["text_character_count"] = text_character_count
            character_counts.append(text_character_count)

        # Generate plot with the distribution of text character counts
        plt.figure(figsize=(10, 6))
        plt.hist(character_counts, bins=30, edgecolor="black")
        plt.title("Distribution of Text Character Counts")
        plt.xlabel("Character Count")
        plt.ylabel("Frequency")

        # Save the plot as an artifact of the step
        self.save_artifact(
            name="text_character_count_distribution",
            write_function=lambda path: plt.savefig(path / "figure.png"),
            metadata={"type": "image", "library": "matplotlib"},
        )

        plt.close()

        yield inputs


with Pipeline() as pipeline:
    count_text_characters = CountTextCharacters()

if __name__ == "__main__":
    distiset = pipeline.run(
        dataset=load_dataset(
            "HuggingFaceH4/instruction-dataset", split="test"
        ).rename_column("prompt", "text"),
    )

    distiset.push_to_hub("distilabel-internal-testing/distilabel-artifacts-example")

生成的 distilabel-internal-testing/distilabel-artifacts-example 数据集存储库在其卡片中有一个部分 描述了 pipeline 生成的工件,并且生成的图表可以在 这里 看到。