跳到内容

添加运行管道的要求

当共享包含自定义 StepTaskPipeline 时,您可能希望添加运行它们所需的特定要求。 distilabel 将获取此需求列表,并在缺少任何需求时警告用户。

让我们通过一个示例来看看如何添加额外的要求。我们要做的第一件事是为我们的 CustomStep 添加要求。为此,我们使用 requirements 装饰器来指定步骤具有 nltk>=3.8 作为依赖项(我们可以使用 版本说明符)。此外,我们将在 Pipeline 级别指定我们需要 distilabel>=1.3.0 才能运行它。

from typing import List

from distilabel.steps import Step
from distilabel.steps.base import StepInput
from distilabel.typing import StepOutput
from distilabel.steps import LoadDataFromDicts
from distilabel.utils.requirements import requirements
from distilabel.pipeline import Pipeline


@requirements(["nltk"])
class CustomStep(Step):
    @property
    def inputs(self) -> List[str]:
        return ["instruction"]

    @property
    def outputs(self) -> List[str]:
        return ["response"]

    def process(self, inputs: StepInput) -> StepOutput:  # type: ignore
        for input in inputs:
            input["response"] = nltk.word_tokenize(input)
        yield inputs


with Pipeline(
    name="pipeline-with-requirements", requirements=["distilabel>=1.3.0"]
) as pipeline:
    loader = LoadDataFromDicts(data=[{"instruction": "sample sentence"}])
    step1 = CustomStep()
    loader >> step1

if __name__ == "__main__":
    pipeline.run()

一旦我们调用 pipeline.run(),如果 StepPipeline 级别告知的任何要求缺失,将引发 ValueError,告诉我们应该安装依赖项列表

>>> pipeline.run()
[06/27/24 11:07:33] ERROR    ['distilabel.pipeline'] Please install the following requirements to run the pipeline:                                                                                                                                     base.py:350
                             distilabel>=1.3.0
...
ValueError: Please install the following requirements to run the pipeline:
distilabel>=1.3.0