跳到内容

生成输出的 GeneratorTask

使用 GeneratorTasks

GeneratorTask 是基于 GeneratorStepTask 的自定义实现。与 Task 一样,它通常在 Pipeline 中使用,但也可以单独使用。

警告

此任务仍处于实验阶段,未来可能会发生更改。

from typing import Any, Dict, List, Union
from typing_extensions import override

from distilabel.steps.tasks.base import GeneratorTask
from distilabel.typing import ChatType, GeneratorOutput


class MyCustomTask(GeneratorTask):
    instruction: str

    @override
    def process(self, offset: int = 0) -> GeneratorStepOutput:
        output = self.llm.generate(
            inputs=[
                [
                    {"role": "user", "content": self.instruction},
                ],
            ],
        )
        output = {"model_name": self.llm.model_name}
        output.update(
            self.format_output(output=output, input=None)
        )
        yield output

    @property
    def outputs(self) -> List[str]:
        return ["output_field", "model_name"]

    def format_output(
        self, output: Union[str, None], input: Dict[str, Any]
    ) -> Dict[str, Any]:
        return {"output_field": output}

然后我们可以按如下方式使用它

task = MyCustomTask(
    name="custom-generation",
    instruction="Tell me a joke.",
    llm=OpenAILLM(model="gpt-4"),
)
task.load()

next(task.process())
# [{'output_field": "Why did the scarecrow win an award? Because he was outstanding!", "model_name": "gpt-4"}]

注意

大多数情况下,您需要覆盖默认的 process 方法,因为它适用于标准的 Task,而不适用于 GeneratorTask。但是在 process 函数的上下文中,您可以自由地使用 llm 以任何方式生成数据。

注意

当作为独立组件使用时,始终需要执行 Step.load()。在 pipeline 中,这将在 pipeline 执行期间自动完成。

定义自定义 GeneratorTasks

我们可以通过创建 GeneratorTask 的新子类并定义以下内容来定义自定义生成器任务

  • process:是一种基于 LLM 和类实例中提供的 instruction 生成数据的方法,并返回一个字典,其中包含所需格式的输出数据,即 outputs 中列的值。请注意,inputs 参数在此函数中是不允许的,因为这是一个 GeneratorTask。签名仅期望 offset 参数,该参数用于跟踪生成器中的当前迭代。

  • outputs:是一个属性,返回一个字符串列表,其中包含输出字段的名称,此属性应始终包含 model_name 作为输出之一,因为它是从 LLM 自动注入的。

  • format_output:是一种方法,接收来自 LLM 的输出,并且可以选择接收输入数据(在某些情况下可能有助于构建输出),并返回一个字典,其中包含所需格式的输出数据,即 outputs 中列的值。请注意,无需在输出中包含 model_name

from typing import Any, Dict, List, Union

from distilabel.steps.tasks.base import GeneratorTask
from distilabel.typing import ChatType


class MyCustomTask(GeneratorTask):
    @override
    def process(self, offset: int = 0) -> GeneratorStepOutput:
        output = self.llm.generate(
            inputs=[
                [{"role": "user", "content": "Tell me a joke."}],
            ],
        )
        output = {"model_name": self.llm.model_name}
        output.update(
            self.format_output(output=output, input=None)
        )
        yield output

    @property
    def outputs(self) -> List[str]:
        return ["output_field", "model_name"]

    def format_output(
        self, output: Union[str, None], input: Dict[str, Any]
    ) -> Dict[str, Any]:
        return {"output_field": output}