生成输出的 GeneratorTask¶
使用 GeneratorTasks¶
GeneratorTask
是基于 GeneratorStep
的 Task
的自定义实现。与 Task
一样,它通常在 Pipeline
中使用,但也可以单独使用。
警告
此任务仍处于实验阶段,未来可能会发生更改。
from typing import Any, Dict, List, Union
from typing_extensions import override
from distilabel.steps.tasks.base import GeneratorTask
from distilabel.typing import ChatType, GeneratorOutput
class MyCustomTask(GeneratorTask):
instruction: str
@override
def process(self, offset: int = 0) -> GeneratorStepOutput:
output = self.llm.generate(
inputs=[
[
{"role": "user", "content": self.instruction},
],
],
)
output = {"model_name": self.llm.model_name}
output.update(
self.format_output(output=output, input=None)
)
yield output
@property
def outputs(self) -> List[str]:
return ["output_field", "model_name"]
def format_output(
self, output: Union[str, None], input: Dict[str, Any]
) -> Dict[str, Any]:
return {"output_field": output}
然后我们可以按如下方式使用它
task = MyCustomTask(
name="custom-generation",
instruction="Tell me a joke.",
llm=OpenAILLM(model="gpt-4"),
)
task.load()
next(task.process())
# [{'output_field": "Why did the scarecrow win an award? Because he was outstanding!", "model_name": "gpt-4"}]
注意
大多数情况下,您需要覆盖默认的 process
方法,因为它适用于标准的 Task
,而不适用于 GeneratorTask
。但是在 process
函数的上下文中,您可以自由地使用 llm
以任何方式生成数据。
注意
当作为独立组件使用时,始终需要执行 Step.load()
。在 pipeline 中,这将在 pipeline 执行期间自动完成。
定义自定义 GeneratorTasks¶
我们可以通过创建 GeneratorTask
的新子类并定义以下内容来定义自定义生成器任务
-
process
:是一种基于LLM
和类实例中提供的instruction
生成数据的方法,并返回一个字典,其中包含所需格式的输出数据,即outputs
中列的值。请注意,inputs
参数在此函数中是不允许的,因为这是一个GeneratorTask
。签名仅期望offset
参数,该参数用于跟踪生成器中的当前迭代。 -
outputs
:是一个属性,返回一个字符串列表,其中包含输出字段的名称,此属性应始终包含model_name
作为输出之一,因为它是从 LLM 自动注入的。 -
format_output
:是一种方法,接收来自LLM
的输出,并且可以选择接收输入数据(在某些情况下可能有助于构建输出),并返回一个字典,其中包含所需格式的输出数据,即outputs
中列的值。请注意,无需在输出中包含model_name
。
from typing import Any, Dict, List, Union
from distilabel.steps.tasks.base import GeneratorTask
from distilabel.typing import ChatType
class MyCustomTask(GeneratorTask):
@override
def process(self, offset: int = 0) -> GeneratorStepOutput:
output = self.llm.generate(
inputs=[
[{"role": "user", "content": "Tell me a joke."}],
],
)
output = {"model_name": self.llm.model_name}
output.update(
self.format_output(output=output, input=None)
)
yield output
@property
def outputs(self) -> List[str]:
return ["output_field", "model_name"]
def format_output(
self, output: Union[str, None], input: Dict[str, Any]
) -> Dict[str, Any]:
return {"output_field": output}