保存步骤生成工件¶
一些 Step
可能需要生成辅助工件,该工件不是计算结果,但计算需要它。例如,FaissNearestNeighbour
需要创建一个 Faiss 索引来计算步骤的输出,即每个输入的 top k
个最近邻居。生成 Faiss 索引需要时间,并且它有可能在 distilabel
pipeline 之外重用,因此不保存它将是一种遗憾。
因此,Step
有一个名为 save_artifact
的方法,允许保存工件,这些工件将与 pipeline 的输出一起包含在生成的 Distiset
中。当分别使用 Distiset.push_to_hub
或 Distiset.save_to_disk
时,生成的工件将被上传和保存。让我们看一个简单的示例,了解如何使用它。
from typing import List, TYPE_CHECKING
from distilabel.steps import GlobalStep, StepInput, StepOutput
import matplotlib.pyplot as plt
if TYPE_CHECKING:
from distilabel.steps import StepOutput
class CountTextCharacters(GlobalStep):
@property
def inputs(self) -> List[str]:
return ["text"]
@property
def outputs(self) -> List[str]:
return ["text_character_count"]
def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
character_counts = []
for input in inputs:
text_character_count = len(input["text"])
input["text_character_count"] = text_character_count
character_counts.append(text_character_count)
# Generate plot with the distribution of text character counts
plt.figure(figsize=(10, 6))
plt.hist(character_counts, bins=30, edgecolor="black")
plt.title("Distribution of Text Character Counts")
plt.xlabel("Character Count")
plt.ylabel("Frequency")
# Save the plot as an artifact of the step
self.save_artifact(
name="text_character_count_distribution",
write_function=lambda path: plt.savefig(path / "figure.png"),
metadata={"type": "image", "library": "matplotlib"},
)
plt.close()
yield inputs
正如上面的示例所示,我们创建了一个简单的 step,用于计算每个输入文本中的字符数,并生成一个包含字符数分布的直方图。我们使用 save_artifact
方法将直方图保存为 step 的工件。该方法接受三个参数
name
:我们想要赋予工件的名称。write_function
:一个将工件写入所需路径的函数。该函数将接收一个path
参数,该参数是一个pathlib.Path
对象,指向应保存工件的目录。metadata
:包含有关工件元数据的字典。此元数据将与工件一起保存。
让我们使用一个简单的 pipeline 执行 step,并将生成的 Distiset
推送到 Hugging Face Hub
示例完整代码
from typing import TYPE_CHECKING, List
import matplotlib.pyplot as plt
from datasets import load_dataset
from distilabel.pipeline import Pipeline
from distilabel.steps import GlobalStep, StepInput, StepOutput
if TYPE_CHECKING:
from distilabel.steps import StepOutput
class CountTextCharacters(GlobalStep):
@property
def inputs(self) -> List[str]:
return ["text"]
@property
def outputs(self) -> List[str]:
return ["text_character_count"]
def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
character_counts = []
for input in inputs:
text_character_count = len(input["text"])
input["text_character_count"] = text_character_count
character_counts.append(text_character_count)
# Generate plot with the distribution of text character counts
plt.figure(figsize=(10, 6))
plt.hist(character_counts, bins=30, edgecolor="black")
plt.title("Distribution of Text Character Counts")
plt.xlabel("Character Count")
plt.ylabel("Frequency")
# Save the plot as an artifact of the step
self.save_artifact(
name="text_character_count_distribution",
write_function=lambda path: plt.savefig(path / "figure.png"),
metadata={"type": "image", "library": "matplotlib"},
)
plt.close()
yield inputs
with Pipeline() as pipeline:
count_text_characters = CountTextCharacters()
if __name__ == "__main__":
distiset = pipeline.run(
dataset=load_dataset(
"HuggingFaceH4/instruction-dataset", split="test"
).rename_column("prompt", "text"),
)
distiset.push_to_hub("distilabel-internal-testing/distilabel-artifacts-example")
生成的 distilabel-internal-testing/distilabel-artifacts-example 数据集存储库在其卡片中有一个部分 描述了 pipeline 生成的工件,并且生成的图表可以在 这里 看到。