跳到内容

ImageTask 用于处理图像生成模型

使用 ImageTasks

ImageTaskTask 的自定义实现,专门用于处理图像。 这些任务的行为与任何其他 Task 完全相同,但它们不依赖于 LLM,而是使用 ImageGenerationModel

1.5.0 版本中的新功能

此任务是新功能,预计与图像生成模型一起使用。

这些任务使用 image_generation_model 属性而不是像标准 Task 那样使用 llm,但其他一切保持不变。 让我们看一个关于 ImageGeneration 的示例

from distilabel.steps.tasks import ImageGeneration
from distilabel.models.image_generation import InferenceEndpointsImageGeneration

task = ImageGeneration(
    name="image-generation",
    image_generation_model=InferenceEndpointsImageGeneration(model_id="black-forest-labs/FLUX.1-schnell"),
)
task.load()

next(task.process([{"prompt": "a white siamese cat"}]))
# [{'image": "iVBORw0KGgoAAAANSUhEUgA...", "model_name": "black-forest-labs/FLUX.1-schnell"}]

在 notebook 中可视化图像

如果您正在 notebook 中测试 ImageGeneration 任务,您可以执行以下操作来查看渲染的图像

from distilabel.models.image_generation.utils import image_from_str

result = next(task.process([{"prompt": "a white siamese cat"}]))
image_from_str(result[0]["image"])  # Returns a `PIL.Image.Image` that renders directly

在 Pipeline 中运行 ImageGeneration

如果运行 pipeline,则可以通过在最终 distiset 上调用方法 transform_columns_to_image 并传递列图像的名称(或名称列表),为整个数据集完成图像作为字符串和作为 PIL 对象之间的转换。

定义自定义 ImageTasks

我们可以通过创建 ImageTask 的新子类并定义以下内容来定义自定义生成器任务

  • process:是一种基于 ImageGenerationModel 和类实例中提供的 prompt 生成数据的方法,并返回一个字典,其中包含所需格式的输出数据,即包含 outputs 中列的值。

  • inputs:是一个属性,返回包含所需输入字段名称的字符串列表,或者一个字典,其中键是列名,值是布尔值,指示列是否为必需。

  • outputs:是一个属性,返回包含输出字段名称的字符串列表,或者一个字典,其中键是列名,值是布尔值,指示列是否为必需。 此属性应始终包含 model_name 作为输出之一,因为它是从 LLM 自动注入的。

  • format_input:是一种接收包含输入数据的字典并返回要传递给模型的prompt的方法。

  • format_output:是一种接收来自 ImageGenerationModel 的输出,并可选择接收输入数据(这在某些情况下可能有助于构建输出)的方法,并返回一个字典,其中包含所需格式的输出数据,即包含 outputs 中列的值。

from typing import TYPE_CHECKING

from distilabel.models.image_generation.utils import image_from_str, image_to_str
from distilabel.steps.base import StepInput
from distilabel.steps.tasks.base import ImageTask

if TYPE_CHECKING:
    from distilabel.typing import StepColumns, StepOutput


class MyCustomImageTask(ImageTask):
    @override
    def process(self, offset: int = 0) -> GeneratorOutput:
        formatted_inputs = self._format_inputs(inputs)

        outputs = self.llm.generate_outputs(
            inputs=formatted_inputs,
            num_generations=self.num_generations,
            **self.llm.get_generation_kwargs(),
        )

        task_outputs = []
        for input, input_outputs in zip(inputs, outputs):
            formatted_outputs = self._format_outputs(input_outputs, input)
            for formatted_output in formatted_outputs:
                task_outputs.append(
                    {**input, **formatted_output, "model_name": self.llm.model_name}
                )
        yield task_outputs

    @property
    def inputs(self) -> "StepColumns":
        return ["prompt"]

    @property
    def outputs(self) -> "StepColumns":
        return ["image", "model_name"]

    def format_input(self, input: dict[str, any]) -> str:
        return input["prompt"]

    def format_output(
        self, output: Union[str, None], input: dict[str, any]
    ) -> Dict[str, Any]:
        # Extract/generate/modify the image from the output
        return {"image": ..., "model_name": self.llm.model_name}

警告

请注意,在 process 方法中,我们不是在处理 image_generation 属性,而是在处理 llm。 这不是错误,而是有意为之,因为我们在内部将 image_generation 重命名为 llm 以重用代码。