跳到内容

ImageGenerationModel

本节包含 distilabel 图像生成模型的 API 参考,包括 ImageGenerationModel 同步实现和 AsyncImageGenerationModel 异步实现。

有关如何使用现有 LLM 或创建自定义 LLM 的更多信息和示例,请参阅 教程 - ImageGenerationModel

base

ImageGenerationModel

基类: RuntimeParametersModelMixinBaseModel_SerializableABC

ImageGeneration 模型的基础类。

要实现 ImageGeneration 子类,您需要继承此类并实现: - load 方法以在需要时加载 ImageGeneration 模型。不要忘记调用 super().load(),以便初始化 _logger 属性。 - model_name 属性以返回用于 LLM 的模型名称。 - generate 方法以在 inputs 中为每个输入生成 num_generations

属性

名称 类型 描述
generation_kwargs 可选[RuntimeParameter[dict[str, Any]]]

要传播到每个 ImageGenerationModel 中的 generateagenerate 方法的 kwargs。

_logger Logger

用于 ImageGenerationModel 的 logger。它将在调用 load 方法时初始化。

源代码位于 src/distilabel/models/image_generation/base.py
class ImageGenerationModel(RuntimeParametersModelMixin, BaseModel, _Serializable, ABC):
    """Base class for `ImageGeneration` models.

    To implement an `ImageGeneration` subclass, you need to subclass this class and implement:
        - `load` method to load the `ImageGeneration` model if needed. Don't forget to call `super().load()`,
            so the `_logger` attribute is initialized.
        - `model_name` property to return the model name used for the LLM.
        - `generate` method to generate `num_generations` per input in `inputs`.

    Attributes:
        generation_kwargs: the kwargs to be propagated to either `generate` or `agenerate`
            methods within each `ImageGenerationModel`.
        _logger: the logger to be used for the `ImageGenerationModel`. It will be initialized
            when the `load` method is called.
    """

    model_config = ConfigDict(
        arbitrary_types_allowed=True,
        protected_namespaces=(),
        validate_default=True,
        validate_assignment=True,
        extra="forbid",
    )

    generation_kwargs: Optional[RuntimeParameter[dict[str, Any]]] = Field(
        default_factory=dict,
        description="The kwargs to be propagated to either `generate` or `agenerate`"
        " methods within each `ImageGenerationModel`.",
    )
    _logger: "Logger" = PrivateAttr(None)

    def load(self) -> None:
        """Method to be called to initialize the `ImageGenerationModel`, and its logger."""
        self._logger = logging.getLogger(
            f"distilabel.models.image_generation.{self.model_name}"
        )

    def unload(self) -> None:
        """Method to be called to unload the `ImageGenerationModel` and release any resources."""
        pass

    @property
    @abstractmethod
    def model_name(self) -> str:
        """Returns the model name used for the `ImageGenerationModel`."""
        pass

    def get_generation_kwargs(self) -> dict[str, Any]:
        """Returns the generation kwargs to be used for the generation. This method can
        be overridden to provide a more complex logic for the generation kwargs.

        Returns:
            The kwargs to be used for the generation.
        """
        return self.generation_kwargs  # type: ignore

    @abstractmethod
    def generate(
        self, inputs: list[str], num_generations: int = 1, **kwargs: Any
    ) -> list[list[dict[str, Any]]]:
        """Generates images from the provided input.

        Args:
            inputs: the prompt text to generate the image from.
            num_generations: the number of images to generate. Defaults to `1`.

        Returns:
            A list with a dictionary with the list of images generated.
        """
        pass

    def generate_outputs(
        self,
        inputs: list[str],
        num_generations: int = 1,
        **kwargs: Any,
    ) -> list[list[dict[str, Any]]]:
        """This method is defined for compatibility with the `LLMs`. It calls the `generate`
        method.
        """
        return self.generate(inputs=inputs, num_generations=num_generations, **kwargs)
model_name abstractmethod property

返回用于 ImageGenerationModel 的模型名称。

load()

要调用的方法,以初始化 ImageGenerationModel 及其 logger。

源代码位于 src/distilabel/models/image_generation/base.py
def load(self) -> None:
    """Method to be called to initialize the `ImageGenerationModel`, and its logger."""
    self._logger = logging.getLogger(
        f"distilabel.models.image_generation.{self.model_name}"
    )
unload()

要调用的方法,以卸载 ImageGenerationModel 并释放任何资源。

源代码位于 src/distilabel/models/image_generation/base.py
def unload(self) -> None:
    """Method to be called to unload the `ImageGenerationModel` and release any resources."""
    pass
get_generation_kwargs()

返回用于生成的 generation kwargs。可以重写此方法以提供更复杂的 generation kwargs 逻辑。

返回

类型 描述
dict[str, Any]

要用于生成的 kwargs。

源代码位于 src/distilabel/models/image_generation/base.py
def get_generation_kwargs(self) -> dict[str, Any]:
    """Returns the generation kwargs to be used for the generation. This method can
    be overridden to provide a more complex logic for the generation kwargs.

    Returns:
        The kwargs to be used for the generation.
    """
    return self.generation_kwargs  # type: ignore
generate(inputs, num_generations=1, **kwargs) abstractmethod

从提供的输入生成图像。

参数

名称 类型 描述 默认
inputs list[str]

从中生成图像的提示文本。

必需
num_generations int

要生成的图像数量。默认为 1

1

返回

类型 描述
list[list[dict[str, Any]]]

包含图像列表的字典的列表。

源代码位于 src/distilabel/models/image_generation/base.py
@abstractmethod
def generate(
    self, inputs: list[str], num_generations: int = 1, **kwargs: Any
) -> list[list[dict[str, Any]]]:
    """Generates images from the provided input.

    Args:
        inputs: the prompt text to generate the image from.
        num_generations: the number of images to generate. Defaults to `1`.

    Returns:
        A list with a dictionary with the list of images generated.
    """
    pass
generate_outputs(inputs, num_generations=1, **kwargs)

此方法为与 LLMs 兼容而定义。它调用 generate 方法。

源代码位于 src/distilabel/models/image_generation/base.py
def generate_outputs(
    self,
    inputs: list[str],
    num_generations: int = 1,
    **kwargs: Any,
) -> list[list[dict[str, Any]]]:
    """This method is defined for compatibility with the `LLMs`. It calls the `generate`
    method.
    """
    return self.generate(inputs=inputs, num_generations=num_generations, **kwargs)

AsyncImageGenerationModel

基类: ImageGenerationModel

异步 ImageGenerationModels 的抽象类,以受益于每个 LLM 实现的异步功能。此类旨在由每个 ImageGenerationModel 继承,并且需要实现 agenerate 方法以提供异步生成响应。

属性

名称 类型 描述
_event_loop AbstractEventLoop

用于异步生成响应的事件循环。

源代码位于 src/distilabel/models/image_generation/base.py
class AsyncImageGenerationModel(ImageGenerationModel):
    """Abstract class for asynchronous `ImageGenerationModels`, to benefit from the async capabilities
    of each LLM implementation. This class is meant to be subclassed by each `ImageGenerationModel`, and the
    method `agenerate` needs to be implemented to provide the asynchronous generation of
    responses.

    Attributes:
        _event_loop: the event loop to be used for the asynchronous generation of responses.
    """

    _num_generations_param_supported = True
    _event_loop: "asyncio.AbstractEventLoop" = PrivateAttr(default=None)
    _new_event_loop: bool = PrivateAttr(default=False)

    @property
    def generate_parameters(self) -> list[inspect.Parameter]:
        """Returns the parameters of the `agenerate` method.

        Returns:
            A list containing the parameters of the `agenerate` method.
        """
        return list(inspect.signature(self.agenerate).parameters.values())

    @cached_property
    def generate_parsed_docstring(self) -> "Docstring":
        """Returns the parsed docstring of the `agenerate` method.

        Returns:
            The parsed docstring of the `agenerate` method.
        """
        return parse_google_docstring(self.agenerate)

    @property
    def event_loop(self) -> "asyncio.AbstractEventLoop":
        if self._event_loop is None:
            try:
                self._event_loop = asyncio.get_running_loop()
                if self._event_loop.is_closed():
                    self._event_loop = asyncio.new_event_loop()  # type: ignore
                    self._new_event_loop = True
            except RuntimeError:
                self._event_loop = asyncio.new_event_loop()
                self._new_event_loop = True
        asyncio.set_event_loop(self._event_loop)
        return self._event_loop

    @abstractmethod
    async def agenerate(
        self, input: str, num_generations: int = 1, **kwargs: Any
    ) -> list[dict[str, Any]]:
        """Generates images from the provided input.

        Args:
            input: the input text to generate the image from.
            num_generations: the number of images to generate. Defaults to `1`.

        Returns:
            A list with a dictionary with the list of images generated.
        """
        pass

    async def _agenerate(
        self, inputs: list[str], num_generations: int = 1, **kwargs: Any
    ) -> list[list[dict[str, Any]]]:
        """Internal function to concurrently generate images for a list of inputs.

        Args:
            inputs: the list of inputs to generate images for.
            num_generations: the number of generations to generate per input.
            **kwargs: the additional kwargs to be used for the generation.

        Returns:
            A list containing the generations for each input.
        """
        if self._num_generations_param_supported:
            tasks = [
                asyncio.create_task(
                    self.agenerate(
                        input=input, num_generations=num_generations, **kwargs
                    )
                )
                for input in inputs
            ]
            return await asyncio.gather(*tasks)

        tasks = [
            asyncio.create_task(self.agenerate(input=input, **kwargs))
            for input in inputs
            for _ in range(num_generations)
        ]
        outputs = [outputs[0] for outputs in await asyncio.gather(*tasks)]
        return [
            list(group)
            for group in grouper(outputs, n=num_generations, incomplete="ignore")
        ]

    def generate(
        self,
        inputs: list[str],
        num_generations: int = 1,
        **kwargs: Any,
    ) -> list[list[dict[str, Any]]]:
        """Method to generate a list of images asynchronously, returning the output
        synchronously awaiting for the image of each input sent to `agenerate`.

        Args:
            inputs: the list of inputs to generate images for.
            num_generations: the number of generations to generate per input.
            **kwargs: the additional kwargs to be used for the generation.

        Returns:
            A list containing the images for each input.
        """
        return self.event_loop.run_until_complete(
            self._agenerate(inputs=inputs, num_generations=num_generations, **kwargs)
        )

    def __del__(self) -> None:
        """Closes the event loop when the object is deleted."""
        if sys.meta_path is None:
            return

        if self._new_event_loop:
            if self._event_loop.is_running():
                self._event_loop.stop()
            self._event_loop.close()
generate_parameters property

返回 agenerate 方法的参数。

返回

类型 描述
list[Parameter]

包含 agenerate 方法参数的列表。

generate_parsed_docstring cached property

返回 agenerate 方法的解析文档字符串。

返回

类型 描述
文档字符串

agenerate 方法的解析文档字符串。

agenerate(input, num_generations=1, **kwargs) abstractmethod async

从提供的输入生成图像。

参数

名称 类型 描述 默认
input str

从中生成图像的输入文本。

必需
num_generations int

要生成的图像数量。默认为 1

1

返回

类型 描述
list[dict[str, Any]]

包含图像列表的字典的列表。

源代码位于 src/distilabel/models/image_generation/base.py
@abstractmethod
async def agenerate(
    self, input: str, num_generations: int = 1, **kwargs: Any
) -> list[dict[str, Any]]:
    """Generates images from the provided input.

    Args:
        input: the input text to generate the image from.
        num_generations: the number of images to generate. Defaults to `1`.

    Returns:
        A list with a dictionary with the list of images generated.
    """
    pass
generate(inputs, num_generations=1, **kwargs)

异步生成图像列表的方法,通过等待发送到 agenerate 的每个输入的图像来同步返回输出。

参数

名称 类型 描述 默认
inputs list[str]

要为其生成图像的输入列表。

必需
num_generations int

每个输入要生成的代数。

1
**kwargs Any

要用于生成的其他 kwargs。

{}

返回

类型 描述
list[list[dict[str, Any]]]

包含每个输入的图像的列表。

源代码位于 src/distilabel/models/image_generation/base.py
def generate(
    self,
    inputs: list[str],
    num_generations: int = 1,
    **kwargs: Any,
) -> list[list[dict[str, Any]]]:
    """Method to generate a list of images asynchronously, returning the output
    synchronously awaiting for the image of each input sent to `agenerate`.

    Args:
        inputs: the list of inputs to generate images for.
        num_generations: the number of generations to generate per input.
        **kwargs: the additional kwargs to be used for the generation.

    Returns:
        A list containing the images for each input.
    """
    return self.event_loop.run_until_complete(
        self._agenerate(inputs=inputs, num_generations=num_generations, **kwargs)
    )
__del__()

在删除对象时关闭事件循环。

源代码位于 src/distilabel/models/image_generation/base.py
def __del__(self) -> None:
    """Closes the event loop when the object is deleted."""
    if sys.meta_path is None:
        return

    if self._new_event_loop:
        if self._event_loop.is_running():
            self._event_loop.stop()
        self._event_loop.close()