跳转到内容

ImageGeneration

使用图像到文本模型根据提示生成图像。

ImageGeneration 是一个预定义的任务,允许从提示生成图像。它适用于 `distilabel.models.image_generation` 下定义的任何 `image_generation`,这些模型是实现图像生成的模型。默认情况下,图像以 base64 字符串格式生成,在数据集生成后,可以使用 `Distiset.transform_columns_to_image` 将图像自动转换为 `PIL.Image.Image`。有关更多信息,请查看文档中的 `Image Generation with distilabel` 示例。使用 `save_artifacts` 属性,图像可以保存在 hugging face hub 仓库的工件文件夹中。

属性

  • save_artifacts:布尔值,用于将图像工件保存在其文件夹中。否则,图像的 base64 表示将保存为字符串。默认为 False。

  • image_format:PIL 支持的任何格式。默认为 `JPEG`。

输入和输出列

graph TD
    subgraph Dataset
        subgraph Columns
            ICOL0[prompt]
        end
        subgraph New columns
            OCOL0[image]
            OCOL1[image_path]
            OCOL2[model_name]
        end
    end

    subgraph ImageGeneration
        StepInput[Input Columns: prompt]
        StepOutput[Output Columns: image, image_path, model_name]
    end

    ICOL0 --> StepInput
    StepOutput --> OCOL0
    StepOutput --> OCOL1
    StepOutput --> OCOL2
    StepInput --> StepOutput

输入

  • prompt (str):一个名为 prompt 的列,包含生成图像的提示。

输出

  • image (str):生成的图像。最初是一个 base64 字符串,为了在 pipeline 运行期间简单起见,但这可以在 pipeline 结束时返回 distiset 后通过调用 `distiset.transform_columns_to_image(<IMAGE_COLUMN>)` 转换为 Image 对象。

  • image_path (str):图像保存的路径。仅当 `save_artifacts` 为 True 时可用。

  • model_name (str):用于生成图像的模型的名称。

示例

从提示生成图像

from distilabel.steps.tasks import ImageGeneration
from distilabel.models.image_generation import InferenceEndpointsImageGeneration

igm = InferenceEndpointsImageGeneration(
    model_id="black-forest-labs/FLUX.1-schnell"
)

# save_artifacts=True by default in JPEG format, if set to False, the image will be saved as a string.
image_gen = ImageGeneration(image_generation_model=igm)

image_gen.load()

result = next(
    image_gen.process(
        [{"prompt": "a white siamese cat"}]
    )
)

生成图像并将其保存为 Hugging Face Hub 仓库中的工件

from distilabel.steps.tasks import ImageGeneration
# Select the Image Generation model to use
from distilabel.models.image_generation import OpenAIImageGeneration

igm = OpenAIImageGeneration(
    model="dall-e-3",
    api_key="api.key",
    generation_kwargs={
        "size": "1024x1024",
        "quality": "standard",
        "style": "natural"
    }
)

# save_artifacts=True by default in JPEG format, if set to False, the image will be saved as a string.
image_gen = ImageGeneration(
    image_generation_model=igm,
    save_artifacts=True,
    image_format="JPEG"  # By default will use JPEG, the options available can be seen in PIL documentation.
)

image_gen.load()

result = next(
    image_gen.process(
        [{"prompt": "a white siamese cat"}]
    )
)