生成输出的 GeneratorTask¶
使用 GeneratorTasks¶
GeneratorTask 是基于 GeneratorStep 的 Task 的自定义实现。与 Task 一样,它通常在 Pipeline 中使用,但也可以单独使用。
警告
此任务仍处于实验阶段,未来可能会发生更改。
from typing import Any, Dict, List, Union
from typing_extensions import override
from distilabel.steps.tasks.base import GeneratorTask
from distilabel.typing import ChatType, GeneratorOutput
class MyCustomTask(GeneratorTask):
instruction: str
@override
def process(self, offset: int = 0) -> GeneratorStepOutput:
output = self.llm.generate(
inputs=[
[
{"role": "user", "content": self.instruction},
],
],
)
output = {"model_name": self.llm.model_name}
output.update(
self.format_output(output=output, input=None)
)
yield output
@property
def outputs(self) -> List[str]:
return ["output_field", "model_name"]
def format_output(
self, output: Union[str, None], input: Dict[str, Any]
) -> Dict[str, Any]:
return {"output_field": output}
然后我们可以按如下方式使用它
task = MyCustomTask(
name="custom-generation",
instruction="Tell me a joke.",
llm=OpenAILLM(model="gpt-4"),
)
task.load()
next(task.process())
# [{'output_field": "Why did the scarecrow win an award? Because he was outstanding!", "model_name": "gpt-4"}]
注意
大多数情况下,您需要覆盖默认的 process 方法,因为它适用于标准的 Task,而不适用于 GeneratorTask。但是在 process 函数的上下文中,您可以自由地使用 llm 以任何方式生成数据。
注意
当作为独立组件使用时,始终需要执行 Step.load()。在 pipeline 中,这将在 pipeline 执行期间自动完成。
定义自定义 GeneratorTasks¶
我们可以通过创建 GeneratorTask 的新子类并定义以下内容来定义自定义生成器任务
-
process:是一种基于LLM和类实例中提供的instruction生成数据的方法,并返回一个字典,其中包含所需格式的输出数据,即outputs中列的值。请注意,inputs参数在此函数中是不允许的,因为这是一个GeneratorTask。签名仅期望offset参数,该参数用于跟踪生成器中的当前迭代。 -
outputs:是一个属性,返回一个字符串列表,其中包含输出字段的名称,此属性应始终包含model_name作为输出之一,因为它是从 LLM 自动注入的。 -
format_output:是一种方法,接收来自LLM的输出,并且可以选择接收输入数据(在某些情况下可能有助于构建输出),并返回一个字典,其中包含所需格式的输出数据,即outputs中列的值。请注意,无需在输出中包含model_name。
from typing import Any, Dict, List, Union
from distilabel.steps.tasks.base import GeneratorTask
from distilabel.typing import ChatType
class MyCustomTask(GeneratorTask):
@override
def process(self, offset: int = 0) -> GeneratorStepOutput:
output = self.llm.generate(
inputs=[
[{"role": "user", "content": "Tell me a joke."}],
],
)
output = {"model_name": self.llm.model_name}
output.update(
self.format_output(output=output, input=None)
)
yield output
@property
def outputs(self) -> List[str]:
return ["output_field", "model_name"]
def format_output(
self, output: Union[str, None], input: Dict[str, Any]
) -> Dict[str, Any]:
return {"output_field": output}