跳到内容

ImageGenerationModel 图库

本节包含在 distilabel 中实现的现有 ImageGenerationModel 子类。

image_generation

AsyncImageGenerationModel

基类: ImageGenerationModel

异步 ImageGenerationModels 的抽象类,以受益于每个 LLM 实现的异步功能。此类旨在由每个 ImageGenerationModel 子类化,并且需要实现方法 agenerate 以提供响应的异步生成。

属性

名称 类型 描述
_event_loop AbstractEventLoop

用于异步生成响应的事件循环。

源代码位于 src/distilabel/models/image_generation/base.py
class AsyncImageGenerationModel(ImageGenerationModel):
    """Abstract class for asynchronous `ImageGenerationModels`, to benefit from the async capabilities
    of each LLM implementation. This class is meant to be subclassed by each `ImageGenerationModel`, and the
    method `agenerate` needs to be implemented to provide the asynchronous generation of
    responses.

    Attributes:
        _event_loop: the event loop to be used for the asynchronous generation of responses.
    """

    _num_generations_param_supported = True
    _event_loop: "asyncio.AbstractEventLoop" = PrivateAttr(default=None)
    _new_event_loop: bool = PrivateAttr(default=False)

    @property
    def generate_parameters(self) -> list[inspect.Parameter]:
        """Returns the parameters of the `agenerate` method.

        Returns:
            A list containing the parameters of the `agenerate` method.
        """
        return list(inspect.signature(self.agenerate).parameters.values())

    @cached_property
    def generate_parsed_docstring(self) -> "Docstring":
        """Returns the parsed docstring of the `agenerate` method.

        Returns:
            The parsed docstring of the `agenerate` method.
        """
        return parse_google_docstring(self.agenerate)

    @property
    def event_loop(self) -> "asyncio.AbstractEventLoop":
        if self._event_loop is None:
            try:
                self._event_loop = asyncio.get_running_loop()
                if self._event_loop.is_closed():
                    self._event_loop = asyncio.new_event_loop()  # type: ignore
                    self._new_event_loop = True
            except RuntimeError:
                self._event_loop = asyncio.new_event_loop()
                self._new_event_loop = True
        asyncio.set_event_loop(self._event_loop)
        return self._event_loop

    @abstractmethod
    async def agenerate(
        self, input: str, num_generations: int = 1, **kwargs: Any
    ) -> list[dict[str, Any]]:
        """Generates images from the provided input.

        Args:
            input: the input text to generate the image from.
            num_generations: the number of images to generate. Defaults to `1`.

        Returns:
            A list with a dictionary with the list of images generated.
        """
        pass

    async def _agenerate(
        self, inputs: list[str], num_generations: int = 1, **kwargs: Any
    ) -> list[list[dict[str, Any]]]:
        """Internal function to concurrently generate images for a list of inputs.

        Args:
            inputs: the list of inputs to generate images for.
            num_generations: the number of generations to generate per input.
            **kwargs: the additional kwargs to be used for the generation.

        Returns:
            A list containing the generations for each input.
        """
        if self._num_generations_param_supported:
            tasks = [
                asyncio.create_task(
                    self.agenerate(
                        input=input, num_generations=num_generations, **kwargs
                    )
                )
                for input in inputs
            ]
            return await asyncio.gather(*tasks)

        tasks = [
            asyncio.create_task(self.agenerate(input=input, **kwargs))
            for input in inputs
            for _ in range(num_generations)
        ]
        outputs = [outputs[0] for outputs in await asyncio.gather(*tasks)]
        return [
            list(group)
            for group in grouper(outputs, n=num_generations, incomplete="ignore")
        ]

    def generate(
        self,
        inputs: list[str],
        num_generations: int = 1,
        **kwargs: Any,
    ) -> list[list[dict[str, Any]]]:
        """Method to generate a list of images asynchronously, returning the output
        synchronously awaiting for the image of each input sent to `agenerate`.

        Args:
            inputs: the list of inputs to generate images for.
            num_generations: the number of generations to generate per input.
            **kwargs: the additional kwargs to be used for the generation.

        Returns:
            A list containing the images for each input.
        """
        return self.event_loop.run_until_complete(
            self._agenerate(inputs=inputs, num_generations=num_generations, **kwargs)
        )

    def __del__(self) -> None:
        """Closes the event loop when the object is deleted."""
        if sys.meta_path is None:
            return

        if self._new_event_loop:
            if self._event_loop.is_running():
                self._event_loop.stop()
            self._event_loop.close()
generate_parameters property

返回 agenerate 方法的参数。

返回

类型 描述
list[Parameter]

包含 agenerate 方法参数的列表。

generate_parsed_docstring cached property

返回 agenerate 方法的解析后的文档字符串。

返回

类型 描述
文档字符串

agenerate 方法的解析后的文档字符串。

agenerate(input, num_generations=1, **kwargs) abstractmethod async

从提供的输入生成图像。

参数

名称 类型 描述 默认值
input str

从中生成图像的输入文本。

必需
num_generations int

要生成的图像数量。默认为 1

1

返回

类型 描述
list[dict[str, Any]]

包含一个字典的列表,该字典包含生成的图像列表。

源代码位于 src/distilabel/models/image_generation/base.py
@abstractmethod
async def agenerate(
    self, input: str, num_generations: int = 1, **kwargs: Any
) -> list[dict[str, Any]]:
    """Generates images from the provided input.

    Args:
        input: the input text to generate the image from.
        num_generations: the number of images to generate. Defaults to `1`.

    Returns:
        A list with a dictionary with the list of images generated.
    """
    pass
_agenerate(inputs, num_generations=1, **kwargs) async

用于并发地为输入列表生成图像的内部函数。

参数

名称 类型 描述 默认值
inputs list[str]

要为其生成图像的输入列表。

必需
num_generations int

每个输入要生成的代数。

1
**kwargs Any

用于生成的其他 kwargs。

{}

返回

类型 描述
list[list[dict[str, Any]]]

包含每个输入的生成的列表。

源代码位于 src/distilabel/models/image_generation/base.py
async def _agenerate(
    self, inputs: list[str], num_generations: int = 1, **kwargs: Any
) -> list[list[dict[str, Any]]]:
    """Internal function to concurrently generate images for a list of inputs.

    Args:
        inputs: the list of inputs to generate images for.
        num_generations: the number of generations to generate per input.
        **kwargs: the additional kwargs to be used for the generation.

    Returns:
        A list containing the generations for each input.
    """
    if self._num_generations_param_supported:
        tasks = [
            asyncio.create_task(
                self.agenerate(
                    input=input, num_generations=num_generations, **kwargs
                )
            )
            for input in inputs
        ]
        return await asyncio.gather(*tasks)

    tasks = [
        asyncio.create_task(self.agenerate(input=input, **kwargs))
        for input in inputs
        for _ in range(num_generations)
    ]
    outputs = [outputs[0] for outputs in await asyncio.gather(*tasks)]
    return [
        list(group)
        for group in grouper(outputs, n=num_generations, incomplete="ignore")
    ]
generate(inputs, num_generations=1, **kwargs)

异步生成图像列表的方法,同步返回输出,等待发送到 agenerate 的每个输入的图像。

参数

名称 类型 描述 默认值
inputs list[str]

要为其生成图像的输入列表。

必需
num_generations int

每个输入要生成的代数。

1
**kwargs Any

用于生成的其他 kwargs。

{}

返回

类型 描述
list[list[dict[str, Any]]]

包含每个输入的图像的列表。

源代码位于 src/distilabel/models/image_generation/base.py
def generate(
    self,
    inputs: list[str],
    num_generations: int = 1,
    **kwargs: Any,
) -> list[list[dict[str, Any]]]:
    """Method to generate a list of images asynchronously, returning the output
    synchronously awaiting for the image of each input sent to `agenerate`.

    Args:
        inputs: the list of inputs to generate images for.
        num_generations: the number of generations to generate per input.
        **kwargs: the additional kwargs to be used for the generation.

    Returns:
        A list containing the images for each input.
    """
    return self.event_loop.run_until_complete(
        self._agenerate(inputs=inputs, num_generations=num_generations, **kwargs)
    )
__del__()

当对象被删除时关闭事件循环。

源代码位于 src/distilabel/models/image_generation/base.py
def __del__(self) -> None:
    """Closes the event loop when the object is deleted."""
    if sys.meta_path is None:
        return

    if self._new_event_loop:
        if self._event_loop.is_running():
            self._event_loop.stop()
        self._event_loop.close()

InferenceEndpointsImageGeneration

基类: InferenceEndpointsBaseClient, AsyncImageGenerationModel

Inference Endpoint 图像生成实现,运行异步 API 客户端。

属性

名称 类型 描述
model_id Optional[str]

用于 ImageGenerationModel 的模型 ID,可在 Hugging Face Hub 中找到,它将用于解析 Serverless Inference Endpoints API 请求的基本 URL。默认为 None

endpoint_name Optional[RuntimeParameter[str]]

用于 LLM 的 Inference Endpoint 的名称。默认为 None

endpoint_namespace Optional[RuntimeParameter[str]]

用于 LLM 的 Inference Endpoint 的命名空间。默认为 None

base_url Optional[RuntimeParameter[str]]

用于 Inference Endpoints API 请求的基本 URL。

api_key Optional[RuntimeParameter[SecretStr]]

用于验证对 Inference Endpoints API 请求的 API 密钥。

图标

:hugging

示例

从文本提示生成图像

from distilabel.models.image_generation import InferenceEndpointsImageGeneration

igm = InferenceEndpointsImageGeneration(model_id="black-forest-labs/FLUX.1-schnell", api_key="api.key")
igm.load()

output = igm.generate_outputs(
    inputs=["a white siamese cat"],
)
# [{"images": ["iVBORw0KGgoAAAANSUhEUgA..."]}]
源代码位于 src/distilabel/models/image_generation/huggingface/inference_endpoints.py
class InferenceEndpointsImageGeneration(  # type: ignore
    InferenceEndpointsBaseClient, AsyncImageGenerationModel
):
    """Inference Endpoint image generation implementation running the async API client.

    Attributes:
        model_id: the model ID to use for the ImageGenerationModel as available in the Hugging Face Hub, which
            will be used to resolve the base URL for the serverless Inference Endpoints API requests.
            Defaults to `None`.
        endpoint_name: the name of the Inference Endpoint to use for the LLM. Defaults to `None`.
        endpoint_namespace: the namespace of the Inference Endpoint to use for the LLM. Defaults to `None`.
        base_url: the base URL to use for the Inference Endpoints API requests.
        api_key: the API key to authenticate the requests to the Inference Endpoints API.

    Icon:
        `:hugging:`

    Examples:
        Generate images from text prompts:

        ```python
        from distilabel.models.image_generation import InferenceEndpointsImageGeneration

        igm = InferenceEndpointsImageGeneration(model_id="black-forest-labs/FLUX.1-schnell", api_key="api.key")
        igm.load()

        output = igm.generate_outputs(
            inputs=["a white siamese cat"],
        )
        # [{"images": ["iVBORw0KGgoAAAANSUhEUgA..."]}]
        ```
    """

    def load(self) -> None:
        from distilabel.models.image_generation.utils import image_to_str

        # Sets the logger and calls the load method of the BaseClient
        AsyncImageGenerationModel.load(self)
        InferenceEndpointsBaseClient.load(self)

        self._image_to_str = image_to_str

    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: str,
        negative_prompt: Optional[str] = None,
        height: Optional[float] = None,
        width: Optional[float] = None,
        num_inference_steps: Optional[float] = None,
        guidance_scale: Optional[float] = None,
        num_generations: int = 1,
    ) -> list[dict[str, Any]]:
        """Generates images from text prompts using `huggingface_hub.AsyncInferenceClient.text_to_image`.

        Args:
            input: Prompt to generate an image from.
            negative_prompt: An optional negative prompt for the image generation. Defaults to None.
            height: The height in pixels of the image to generate.
            width: The width in pixels of the image to generate.
            num_inference_steps: The number of denoising steps. More denoising steps usually lead
                to a higher quality image at the expense of slower inference.
            guidance_scale: Higher guidance scale encourages to generate images that are closely
                linked to the text `prompt`, usually at the expense of lower image quality.
            num_generations: The number of images to generate. Defaults to `1`.
                It's here to ensure the validation succeeds, but it won't have effect.

        Returns:
            A list with a dictionary containing a list with the image as a base64 string.
        """

        image: "Image" = await self._aclient.text_to_image(  # type: ignore
            input,
            negative_prompt=negative_prompt,
            height=height,
            width=width,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
        )
        img_str = self._image_to_str(image, image_format="JPEG")

        return [{"images": [img_str]}]
agenerate(input, negative_prompt=None, height=None, width=None, num_inference_steps=None, guidance_scale=None, num_generations=1) async

使用 huggingface_hub.AsyncInferenceClient.text_to_image 从文本提示生成图像。

参数

名称 类型 描述 默认值
input str

从中生成图像的提示。

必需
negative_prompt Optional[str]

图像生成的可选负面提示。默认为 None。

None
height Optional[float]

要生成的图像的高度(以像素为单位)。

None
width Optional[float]

要生成的图像的宽度(以像素为单位)。

None
num_inference_steps Optional[float]

去噪步骤的数量。更多的去噪步骤通常会以较慢的推理速度为代价带来更高质量的图像。

None
guidance_scale Optional[float]

更高的 guidance scale 鼓励生成与文本 prompt 紧密相关的图像,通常以较低的图像质量为代价。

None
num_generations int

要生成的图像数量。默认为 1。这里是为了确保验证成功,但它不会有效果。

1

返回

类型 描述
list[dict[str, Any]]

包含一个字典的列表,该字典包含一个列表,其中包含作为 base64 字符串的图像。

源代码位于 src/distilabel/models/image_generation/huggingface/inference_endpoints.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: str,
    negative_prompt: Optional[str] = None,
    height: Optional[float] = None,
    width: Optional[float] = None,
    num_inference_steps: Optional[float] = None,
    guidance_scale: Optional[float] = None,
    num_generations: int = 1,
) -> list[dict[str, Any]]:
    """Generates images from text prompts using `huggingface_hub.AsyncInferenceClient.text_to_image`.

    Args:
        input: Prompt to generate an image from.
        negative_prompt: An optional negative prompt for the image generation. Defaults to None.
        height: The height in pixels of the image to generate.
        width: The width in pixels of the image to generate.
        num_inference_steps: The number of denoising steps. More denoising steps usually lead
            to a higher quality image at the expense of slower inference.
        guidance_scale: Higher guidance scale encourages to generate images that are closely
            linked to the text `prompt`, usually at the expense of lower image quality.
        num_generations: The number of images to generate. Defaults to `1`.
            It's here to ensure the validation succeeds, but it won't have effect.

    Returns:
        A list with a dictionary containing a list with the image as a base64 string.
    """

    image: "Image" = await self._aclient.text_to_image(  # type: ignore
        input,
        negative_prompt=negative_prompt,
        height=height,
        width=width,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
    )
    img_str = self._image_to_str(image, image_format="JPEG")

    return [{"images": [img_str]}]

OpenAIImageGeneration

基类: OpenAIBaseClient, AsyncImageGenerationModel

OpenAI 图像生成实现,运行异步 API 客户端。

属性

名称 类型 描述
model str

用于 ImageGenerationModel 的模型名称,例如 "dall-e-3" 等。支持的模型可以在这里找到。

base_url Optional[RuntimeParameter[str]]

用于 OpenAI API 请求的基本 URL。默认为 None,这意味着将使用为环境变量 OPENAI_BASE_URL 设置的值,如果未设置,则使用 "https://api.openai.com/v1"。

api_key Optional[RuntimeParameter[SecretStr]]

用于验证对 OpenAI API 请求的 API 密钥。默认为 None,这意味着将使用为环境变量 OPENAI_API_KEY 设置的值,如果未设置,则为 None

max_retries RuntimeParameter[int]

在失败之前,重试 API 请求的最大次数。默认为 6

timeout RuntimeParameter[int]

等待 API 响应的最大时间(秒)。默认为 120

图标

:simple-openai

示例

从文本提示生成图像

from distilabel.models.image_generation import OpenAIImageGeneration

igm = OpenAIImageGeneration(model="dall-e-3", api_key="api.key")

igm.load()

output = igm.generate_outputs(
    inputs=["a white siamese cat"],
    size="1024x1024",
    quality="standard",
    style="natural",
)
# [{"images": ["iVBORw0KGgoAAAANSUhEUgA..."]}]
源代码位于 src/distilabel/models/image_generation/openai.py
class OpenAIImageGeneration(OpenAIBaseClient, AsyncImageGenerationModel):
    """OpenAI image generation implementation running the async API client.

    Attributes:
        model: the model name to use for the ImageGenerationModel e.g. "dall-e-3", etc.
            Supported models can be found [here](https://platform.openai.com/docs/guides/images).
        base_url: the base URL to use for the OpenAI API requests. Defaults to `None`, which
            means that the value set for the environment variable `OPENAI_BASE_URL` will
            be used, or "https://api.openai.com/v1" if not set.
        api_key: the API key to authenticate the requests to the OpenAI API. Defaults to
            `None` which means that the value set for the environment variable `OPENAI_API_KEY`
            will be used, or `None` if not set.
        max_retries: the maximum number of times to retry the request to the API before
            failing. Defaults to `6`.
        timeout: the maximum time in seconds to wait for a response from the API. Defaults
            to `120`.

    Icon:
        `:simple-openai:`

    Examples:
        Generate images from text prompts:

        ```python
        from distilabel.models.image_generation import OpenAIImageGeneration

        igm = OpenAIImageGeneration(model="dall-e-3", api_key="api.key")

        igm.load()

        output = igm.generate_outputs(
            inputs=["a white siamese cat"],
            size="1024x1024",
            quality="standard",
            style="natural",
        )
        # [{"images": ["iVBORw0KGgoAAAANSUhEUgA..."]}]
        ```
    """

    def load(self) -> None:
        # Sets the logger and calls the load method of the BaseClient
        AsyncImageGenerationModel.load(self)
        OpenAIBaseClient.load(self)

    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: str,
        num_generations: int = 1,
        quality: Optional[Literal["standard", "hd"]] = "standard",
        response_format: Optional[Literal["url", "b64_json"]] = "url",
        size: Optional[
            Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"]
        ] = None,
        style: Optional[Literal["vivid", "natural"]] = None,
    ) -> list[dict[str, Any]]:
        """Generates `num_generations` images for the given input using the OpenAI async
        client. The images are base64 string representations.

        Args:
            input: A text description of the desired image(s). The maximum length is 1000
                characters for `dall-e-2` and 4000 characters for `dall-e-3`.
            num_generations: The number of images to generate. Must be between 1 and 10. For `dall-e-3`, only
                `n=1` is supported.
            quality: The quality of the image that will be generated. `hd` creates images with finer
                details and greater consistency across the image. This param is only supported
                for `dall-e-3`.
            response_format: The format in which the generated images are returned. Must be one of `url` or
                `b64_json`. URLs are only valid for 60 minutes after the image has been
                generated.
            size: The size of the generated images. Must be one of `256x256`, `512x512`, or
                `1024x1024` for `dall-e-2`. Must be one of `1024x1024`, `1792x1024`, or
                `1024x1792` for `dall-e-3` models.
            style: The style of the generated images. Must be one of `vivid` or `natural`. Vivid
                causes the model to lean towards generating hyper-real and dramatic images.
                Natural causes the model to produce more natural, less hyper-real looking
                images. This param is only supported for `dall-e-3`.

        Returns:
            A list with a dictionary with the list of images generated.
        """
        images_response: "ImagesResponse" = await self._aclient.images.generate(
            model=self.model_name,
            prompt=input,
            n=num_generations,
            quality=quality,
            response_format=response_format,
            size=size,
            style=style,
        )
        images = []
        for image in images_response.data:
            if response_format == "url":
                image_data = requests.get(
                    image.url
                ).content  # TODO: Keep a requests/httpx session instead
                image_str = base64.b64encode(image_data).decode()
                images.append(image_str)
            elif response_format == "b64_json":
                images.append(image.b64_json)
        return [{"images": images}]
agenerate(input, num_generations=1, quality='standard', response_format='url', size=None, style=None) async

使用 OpenAI 异步客户端为给定输入生成 num_generations 图像。图像是 base64 字符串表示形式。

参数

名称 类型 描述 默认值
input str

所需图像的文本描述。dall-e-2 的最大长度为 1000 个字符,dall-e-3 的最大长度为 4000 个字符。

必需
num_generations int

要生成的图像数量。必须介于 1 和 10 之间。对于 dall-e-3,仅支持 n=1

1
quality Optional[Literal['standard', 'hd']]

将生成的图像的质量。hd 创建具有更精细细节和图像之间更高一致性的图像。此参数仅支持 dall-e-3

'standard'
response_format Optional[Literal['url', 'b64_json']]

返回生成的图像的格式。必须是 urlb64_json 之一。URL 仅在图像生成后 60 分钟内有效。

'url'
size Optional[Literal['256x256', '512x512', '1024x1024', '1792x1024', '1024x1792']]

生成的图像的大小。对于 dall-e-2,必须是 256x256512x5121024x1024 之一。对于 dall-e-3 模型,必须是 1024x10241792x10241024x1792 之一。

None
style Optional[Literal['vivid', 'natural']]

生成的图像的风格。必须是 vividnatural 之一。Vivid 使模型倾向于生成超真实和戏剧性的图像。Natural 使模型产生更自然、不太超真实的图像。此参数仅支持 dall-e-3

None

返回

类型 描述
list[dict[str, Any]]

包含一个字典的列表,该字典包含生成的图像列表。

源代码位于 src/distilabel/models/image_generation/openai.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: str,
    num_generations: int = 1,
    quality: Optional[Literal["standard", "hd"]] = "standard",
    response_format: Optional[Literal["url", "b64_json"]] = "url",
    size: Optional[
        Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"]
    ] = None,
    style: Optional[Literal["vivid", "natural"]] = None,
) -> list[dict[str, Any]]:
    """Generates `num_generations` images for the given input using the OpenAI async
    client. The images are base64 string representations.

    Args:
        input: A text description of the desired image(s). The maximum length is 1000
            characters for `dall-e-2` and 4000 characters for `dall-e-3`.
        num_generations: The number of images to generate. Must be between 1 and 10. For `dall-e-3`, only
            `n=1` is supported.
        quality: The quality of the image that will be generated. `hd` creates images with finer
            details and greater consistency across the image. This param is only supported
            for `dall-e-3`.
        response_format: The format in which the generated images are returned. Must be one of `url` or
            `b64_json`. URLs are only valid for 60 minutes after the image has been
            generated.
        size: The size of the generated images. Must be one of `256x256`, `512x512`, or
            `1024x1024` for `dall-e-2`. Must be one of `1024x1024`, `1792x1024`, or
            `1024x1792` for `dall-e-3` models.
        style: The style of the generated images. Must be one of `vivid` or `natural`. Vivid
            causes the model to lean towards generating hyper-real and dramatic images.
            Natural causes the model to produce more natural, less hyper-real looking
            images. This param is only supported for `dall-e-3`.

    Returns:
        A list with a dictionary with the list of images generated.
    """
    images_response: "ImagesResponse" = await self._aclient.images.generate(
        model=self.model_name,
        prompt=input,
        n=num_generations,
        quality=quality,
        response_format=response_format,
        size=size,
        style=style,
    )
    images = []
    for image in images_response.data:
        if response_format == "url":
            image_data = requests.get(
                image.url
            ).content  # TODO: Keep a requests/httpx session instead
            image_str = base64.b64encode(image_data).decode()
            images.append(image_str)
        elif response_format == "b64_json":
            images.append(image.b64_json)
    return [{"images": images}]