跳到内容

ImageTask

本节包含 distilabel 图像生成任务的 API 参考。

有关 ImageTask 的工作原理和一些示例的更多信息,请查看 教程 - Task - ImageTask 页面。

ImageTask

基类: _Task, Step

ImageTask 是一个类,它实现了 _Task 抽象类,并添加了 Step 接口,以便在 pipeline 中用作 step。 它与 Task 的不同之处在于,它期望与 ImageGenerationModel 而不是 LLM 一起工作。

属性

名称 类型 描述
image_generation_model ImageGenerationModel

用于生成输出的 ImageGenerationModel

llm Union[LLM, ImageGenerationModel, None]

此属性在此处是为了遵守 _Task 接口,但仅在内部使用。

group_generations bool

是否将每个输入生成的 num_generations 分组在一个列表中,或者为每个生成创建一个行。 默认为 False

num_generations RuntimeParameter[int]

每个输入要生成的数量。

源代码位于 src/distilabel/steps/tasks/base.py
class ImageTask(_Task, Step):
    """`ImageTask` is a class that implements the `_Task` abstract class and adds the `Step`
    interface to be used as a step in the pipeline. It differs from the `Task` in that it's
    expected to work with `ImageGenerationModel`s instead of `LLM`s.

    Attributes:
        image_generation_model: the `ImageGenerationModel` to be used to generate the outputs.
        llm: This attribute is here to respect the `_Task` interface, but it's used internally only.
        group_generations: whether to group the `num_generations` generated per input in
            a list or create a row per generation. Defaults to `False`.
        num_generations: The number of generations to be produced per input.
    """

    llm: Union[LLM, ImageGenerationModel, None] = None
    image_generation_model: ImageGenerationModel

    def model_post_init(self, __context: Any) -> None:
        assert self.llm is None, (
            "`ImageTask` cannot use an `LLM` attribute given by the user, pass "
            "the `image_generation_model` attribute instead."
        )
        self.llm = self.image_generation_model
        # Call the post init from the Step, as we don't want to call specific behaviour
        # from the task, that may need to deal with specific attributes from the LLM
        # not in the ImageGenerationModel
        super(Step, self).model_post_init(__context)

    @abstractmethod
    def format_input(self, input: dict[str, any]) -> str:
        """Abstract method to format the inputs of the task. It needs to receive an input
        as a Python dictionary, and generates a string to be used as the prompt for the model."""
        pass

    def _format_inputs(self, inputs: list[dict[str, any]]) -> List["FormattedInput"]:
        """Formats the inputs of the task using the `format_input` method.

        Args:
            inputs: A list of Python dictionaries with the inputs of the task.

        Returns:
            A list containing the formatted inputs, which are `ChatType`-like following
            the OpenAI formatting.
        """
        return [self.format_input(input) for input in inputs]

    def _format_outputs(
        self,
        outputs: list[Union[str, None]],
        input: Union[Dict[str, Any], None] = None,
    ) -> List[Dict[str, Any]]:
        """Formats the outputs of the task using the `format_output` method. If the output
        is `None` (i.e. the LLM failed to generate a response), then the outputs will be
        set to `None` as well.

        Args:
            outputs: The outputs (`n` generations) for the provided `input`.
            input: The input used to generate the output.

        Returns:
            A list containing a dictionary with the outputs of the task for each input.
        """
        inputs = [None] if input is None else [input]
        formatted_outputs = []

        for output, input in zip(outputs, inputs):  # type: ignore
            try:
                formatted_output = self.format_output(output, input)
                formatted_output = self._create_metadata(
                    formatted_output,
                    output,
                    input,
                    add_raw_output=self.add_raw_output,  # type: ignore
                    add_raw_input=self.add_raw_input,  # type: ignore
                    statistics=None,
                )
                formatted_outputs.append(formatted_output)
            except Exception as e:
                self._logger.warning(  # type: ignore
                    f"Task '{self.name}' failed to format output: {e}. Saving raw response."  # type: ignore
                )
                formatted_outputs.append(self._output_on_failure(output, input))
        return formatted_outputs

    @abstractmethod
    def process(self, inputs: StepInput) -> "StepOutput":  # type: ignore
        """Processes the inputs of the task and generates the outputs using the `ImageGenerationModel`.

        Args:
            inputs: A list of Python dictionaries with the inputs of the task.

        Yields:
            A list of Python dictionaries with the outputs of the task.
        """
        pass

format_input(input) abstractmethod

用于格式化 task 输入的抽象方法。 它需要接收一个 Python 字典作为输入,并生成一个字符串用作模型的 prompt。

源代码位于 src/distilabel/steps/tasks/base.py
@abstractmethod
def format_input(self, input: dict[str, any]) -> str:
    """Abstract method to format the inputs of the task. It needs to receive an input
    as a Python dictionary, and generates a string to be used as the prompt for the model."""
    pass

process(inputs) abstractmethod

处理 task 的输入并使用 ImageGenerationModel 生成输出。

参数

名称 类型 描述 默认
inputs StepInput

包含 task 输入的 Python 字典列表。

必需

产出

类型 描述
StepOutput

包含 task 输出的 Python 字典列表。

源代码位于 src/distilabel/steps/tasks/base.py
@abstractmethod
def process(self, inputs: StepInput) -> "StepOutput":  # type: ignore
    """Processes the inputs of the task and generates the outputs using the `ImageGenerationModel`.

    Args:
        inputs: A list of Python dictionaries with the inputs of the task.

    Yields:
        A list of Python dictionaries with the outputs of the task.
    """
    pass