ImageTask 用于处理图像生成模型¶
使用 ImageTasks¶
ImageTask
是 Task
的自定义实现,专门用于处理图像。 这些任务的行为与任何其他 Task
完全相同,但它们不依赖于 LLM
,而是使用 ImageGenerationModel
。
1.5.0 版本中的新功能
此任务是新功能,预计与图像生成模型一起使用。
这些任务使用 image_generation_model
属性而不是像标准 Task
那样使用 llm
,但其他一切保持不变。 让我们看一个关于 ImageGeneration
的示例
from distilabel.steps.tasks import ImageGeneration
from distilabel.models.image_generation import InferenceEndpointsImageGeneration
task = ImageGeneration(
name="image-generation",
image_generation_model=InferenceEndpointsImageGeneration(model_id="black-forest-labs/FLUX.1-schnell"),
)
task.load()
next(task.process([{"prompt": "a white siamese cat"}]))
# [{'image": "iVBORw0KGgoAAAANSUhEUgA...", "model_name": "black-forest-labs/FLUX.1-schnell"}]
在 notebook 中可视化图像
如果您正在 notebook 中测试 ImageGeneration
任务,您可以执行以下操作来查看渲染的图像
在 Pipeline 中运行 ImageGeneration
如果运行 pipeline,则可以通过在最终 distiset 上调用方法 transform_columns_to_image
并传递列图像的名称(或名称列表),为整个数据集完成图像作为字符串和作为 PIL 对象之间的转换。
定义自定义 ImageTasks¶
我们可以通过创建 ImageTask
的新子类并定义以下内容来定义自定义生成器任务
-
process
:是一种基于ImageGenerationModel
和类实例中提供的prompt
生成数据的方法,并返回一个字典,其中包含所需格式的输出数据,即包含outputs
中列的值。 -
inputs
:是一个属性,返回包含所需输入字段名称的字符串列表,或者一个字典,其中键是列名,值是布尔值,指示列是否为必需。 -
outputs
:是一个属性,返回包含输出字段名称的字符串列表,或者一个字典,其中键是列名,值是布尔值,指示列是否为必需。 此属性应始终包含model_name
作为输出之一,因为它是从 LLM 自动注入的。 -
format_input
:是一种接收包含输入数据的字典并返回要传递给模型的prompt的方法。 -
format_output
:是一种接收来自ImageGenerationModel
的输出,并可选择接收输入数据(这在某些情况下可能有助于构建输出)的方法,并返回一个字典,其中包含所需格式的输出数据,即包含outputs
中列的值。
from typing import TYPE_CHECKING
from distilabel.models.image_generation.utils import image_from_str, image_to_str
from distilabel.steps.base import StepInput
from distilabel.steps.tasks.base import ImageTask
if TYPE_CHECKING:
from distilabel.typing import StepColumns, StepOutput
class MyCustomImageTask(ImageTask):
@override
def process(self, offset: int = 0) -> GeneratorOutput:
formatted_inputs = self._format_inputs(inputs)
outputs = self.llm.generate_outputs(
inputs=formatted_inputs,
num_generations=self.num_generations,
**self.llm.get_generation_kwargs(),
)
task_outputs = []
for input, input_outputs in zip(inputs, outputs):
formatted_outputs = self._format_outputs(input_outputs, input)
for formatted_output in formatted_outputs:
task_outputs.append(
{**input, **formatted_output, "model_name": self.llm.model_name}
)
yield task_outputs
@property
def inputs(self) -> "StepColumns":
return ["prompt"]
@property
def outputs(self) -> "StepColumns":
return ["image", "model_name"]
def format_input(self, input: dict[str, any]) -> str:
return input["prompt"]
def format_output(
self, output: Union[str, None], input: dict[str, any]
) -> Dict[str, Any]:
# Extract/generate/modify the image from the output
return {"image": ..., "model_name": self.llm.model_name}
警告
请注意,在 process
方法中,我们不是在处理 image_generation
属性,而是在处理 llm
。 这不是错误,而是有意为之,因为我们在内部将 image_generation
重命名为 llm
以重用代码。