跳到内容

LLM

本节包含 distilabel LLM 的 API 参考,包括 LLM 同步实现和 AsyncLLM 异步实现。

有关如何使用现有 LLM 或创建自定义 LLM 的更多信息和示例,请参阅 教程 - LLM

base

LLM

基类: RuntimeParametersModelMixin, BaseModel, _Serializable, ABC

用于 distilabel 框架的 LLM 的基类。

要实现 LLM 子类,您需要继承此类并实现: - load 方法以在需要时加载 LLM。不要忘记调用 super().load(),以便初始化 _logger 属性。 - model_name 属性以返回 LLM 使用的模型名称。 - generate 方法以针对 inputs 中的每个输入生成 num_generations 个生成结果。

属性

名称 类型 描述
generation_kwargs 可选[RuntimeParameter[Dict[str, Any]]]

要在每个 LLM 中传播到 generateagenerate 方法的 kwargs。

use_offline_batch_generation 可选[RuntimeParameter[bool]]

是否使用 offline_batch_generate 方法生成响应。

offline_batch_generation_block_until_done 可选[RuntimeParameter[int]]

如果提供,则将进行轮询,直到 ofline_batch_generate 方法能够检索结果。该值指示每次轮询之间等待的时间。

jobs_ids Union[Tuple[str, ...], None]

offline_batch_generate 方法生成的作业 ID。此属性用于存储由 offline_batch_generate 方法生成的作业 ID,以便稍后可以用于检索结果。它不打算由用户设置。

_logger Logger

用于 LLM 的 logger。它将在调用 load 方法时初始化。

源代码位于 src/distilabel/models/llms/base.py
class LLM(RuntimeParametersModelMixin, BaseModel, _Serializable, ABC):
    """Base class for `LLM`s to be used in `distilabel` framework.

    To implement an `LLM` subclass, you need to subclass this class and implement:
        - `load` method to load the `LLM` if needed. Don't forget to call `super().load()`,
            so the `_logger` attribute is initialized.
        - `model_name` property to return the model name used for the LLM.
        - `generate` method to generate `num_generations` per input in `inputs`.

    Attributes:
        generation_kwargs: the kwargs to be propagated to either `generate` or `agenerate`
            methods within each `LLM`.
        use_offline_batch_generation: whether to use the `offline_batch_generate` method to
            generate the responses.
        offline_batch_generation_block_until_done: if provided, then polling will be done until
            the `ofline_batch_generate` method is able to retrieve the results. The value indicate
            the time to wait between each polling.
        jobs_ids: the job ids generated by the `offline_batch_generate` method. This attribute
            is used to store the job ids generated by the `offline_batch_generate` method
            so later they can be used to retrieve the results. It is not meant to be set by
            the user.
        _logger: the logger to be used for the `LLM`. It will be initialized when the `load`
            method is called.
    """

    model_config = ConfigDict(
        arbitrary_types_allowed=True,
        protected_namespaces=(),
        validate_default=True,
        validate_assignment=True,
        extra="forbid",
    )

    generation_kwargs: Optional[RuntimeParameter[Dict[str, Any]]] = Field(
        default_factory=dict,
        description="The kwargs to be propagated to either `generate` or `agenerate`"
        " methods within each `LLM`.",
    )
    use_offline_batch_generation: Optional[RuntimeParameter[bool]] = Field(
        default=False,
        description="Whether to use the `offline_batch_generate` method to generate"
        " the responses.",
    )
    offline_batch_generation_block_until_done: Optional[RuntimeParameter[int]] = Field(
        default=None,
        description="If provided, then polling will be done until the `ofline_batch_generate`"
        " method is able to retrieve the results. The value indicate the time to wait between"
        " each polling.",
    )

    jobs_ids: Union[Tuple[str, ...], None] = Field(default=None)
    _logger: "Logger" = PrivateAttr(None)

    def load(self) -> None:
        """Method to be called to initialize the `LLM`, its logger and optionally the
        structured output generator."""
        self._logger = logging.getLogger(f"distilabel.llm.{self.model_name}")

    def unload(self) -> None:
        """Method to be called to unload the `LLM` and release any resources."""
        pass

    @property
    @abstractmethod
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        pass

    def get_generation_kwargs(self) -> Dict[str, Any]:
        """Returns the generation kwargs to be used for the generation. This method can
        be overridden to provide a more complex logic for the generation kwargs.

        Returns:
            The kwargs to be used for the generation.
        """
        return self.generation_kwargs  # type: ignore

    @abstractmethod
    def generate(
        self,
        inputs: List["FormattedInput"],
        num_generations: int = 1,
        **kwargs: Any,
    ) -> List["GenerateOutput"]:
        """Abstract method to be implemented by each LLM to generate `num_generations`
        per input in `inputs`.

        Args:
            inputs: the list of inputs to generate responses for which follows OpenAI's
                API format:

                ```python
                [
                    {"role": "system", "content": "You're a helpful assistant..."},
                    {"role": "user", "content": "Give a template email for B2B communications..."},
                    {"role": "assistant", "content": "Sure, here's a template you can use..."},
                    {"role": "user", "content": "Modify the second paragraph..."}
                ]
                ```
            num_generations: the number of generations to generate per input.
            **kwargs: the additional kwargs to be used for the generation.
        """
        pass

    def generate_outputs(
        self,
        inputs: List["FormattedInput"],
        num_generations: int = 1,
        **kwargs: Any,
    ) -> List["GenerateOutput"]:
        """Generates outputs for the given inputs using either `generate` method or the
        `offine_batch_generate` method if `use_offline_
        """
        if self.use_offline_batch_generation:
            if self.offline_batch_generation_block_until_done is not None:
                return self._offline_batch_generate_polling(
                    inputs=inputs,
                    num_generations=num_generations,
                    **kwargs,
                )

            # This will raise `DistilabelOfflineBatchGenerationNotFinishedException` right away
            # if the batch generation is not finished.
            return self.offline_batch_generate(
                inputs=inputs,
                num_generations=num_generations,
                **kwargs,
            )

        return self.generate(inputs=inputs, num_generations=num_generations, **kwargs)

    def _offline_batch_generate_polling(
        self,
        inputs: List["FormattedInput"],
        num_generations: int = 1,
        **kwargs: Any,
    ) -> List["GenerateOutput"]:
        """Method to poll the `offline_batch_generate` method until the batch generation
        is finished.

        Args:
            inputs: the list of inputs to generate responses for.
            num_generations: the number of generations to generate per input.
            **kwargs: the additional kwargs to be used for the generation.

        Returns:
            A list containing the generations for each input.
        """
        while True:
            try:
                return self.offline_batch_generate(
                    inputs=inputs,
                    num_generations=num_generations,
                    **kwargs,
                )
            except DistilabelOfflineBatchGenerationNotFinishedException as e:
                self._logger.info(
                    f"Waiting for the offline batch generation to finish: {e}. Sleeping"
                    f" for {self.offline_batch_generation_block_until_done} seconds before"
                    " trying to get the results again."
                )
                # When running a `Step` in a child process, SIGINT is overridden so the child
                # process doesn't stop when the parent process receives a SIGINT signal.
                # The new handler sets an environment variable that is checked here to stop
                # the polling.
                if os.getenv(SIGINT_HANDLER_CALLED_ENV_NAME) is not None:
                    self._logger.info(
                        "Received a KeyboardInterrupt. Stopping polling for checking if the"
                        " offline batch generation is finished..."
                    )
                    raise e
                time.sleep(self.offline_batch_generation_block_until_done)  # type: ignore
            except KeyboardInterrupt as e:
                # This is for the case the `LLM` is being executed outside a pipeline
                self._logger.info(
                    "Received a KeyboardInterrupt. Stopping polling for checking if the"
                    " offline batch generation is finished..."
                )
                raise DistilabelOfflineBatchGenerationNotFinishedException(
                    jobs_ids=self.jobs_ids  # type: ignore
                ) from e

    def get_last_hidden_states(
        self, inputs: List["StandardInput"]
    ) -> List["HiddenState"]:
        """Method to get the last hidden states of the model for a list of inputs.

        Args:
            inputs: the list of inputs to get the last hidden states from.

        Returns:
            A list containing the last hidden state for each sequence using a NumPy array
                with shape [num_tokens, hidden_size].
        """
        # TODO: update to use `DistilabelNotImplementedError`
        raise NotImplementedError(
            f"Method `get_last_hidden_states` is not implemented for `{self.__class__.__name__}`"
        )

    def _prepare_structured_output(
        self, structured_output: "StructuredOutputType"
    ) -> Union[Any, None]:
        """Method in charge of preparing the structured output generator.

        By default will raise a `NotImplementedError`, subclasses that allow it must override this
        method with the implementation.

        Args:
            structured_output: the config to prepare the guided generation.

        Returns:
            The structure to be used for the guided generation.
        """
        # TODO: update to use `DistilabelNotImplementedError`
        raise NotImplementedError(
            f"Guided generation is not implemented for `{type(self).__name__}`"
        )

    def offline_batch_generate(
        self,
        inputs: Union[List["FormattedInput"], None] = None,
        num_generations: int = 1,
        **kwargs: Any,
    ) -> List["GenerateOutput"]:
        """Method to generate a list of outputs for the given inputs using an offline batch
        generation method to be implemented by each `LLM`.

        This method should create jobs the first time is called and store the job ids, so
        the second and subsequent calls can retrieve the results of the batch generation.
        If subsequent calls are made before the batch generation is finished, then the method
        should raise a `DistilabelOfflineBatchGenerationNotFinishedException`. This exception
        will be handled automatically by the `Pipeline` which will store all the required
        information for recovering the pipeline execution when the batch generation is finished.

        Args:
            inputs: the list of inputs to generate responses for.
            num_generations: the number of generations to generate per input.
            **kwargs: the additional kwargs to be used for the generation.

        Returns:
            A list containing the generations for each input.
        """
        raise DistilabelNotImplementedError(
            f"`offline_batch_generate` is not implemented for `{self.__class__.__name__}`",
            page="sections/how_to_guides/advanced/offline-batch-generation/",
        )
model_name abstractmethod property

返回 LLM 使用的模型名称。

load()

要调用的方法,用于初始化 LLM、其 logger 以及可选的结构化输出生成器。

源代码位于 src/distilabel/models/llms/base.py
def load(self) -> None:
    """Method to be called to initialize the `LLM`, its logger and optionally the
    structured output generator."""
    self._logger = logging.getLogger(f"distilabel.llm.{self.model_name}")
unload()

要调用的方法,用于卸载 LLM 并释放任何资源。

源代码位于 src/distilabel/models/llms/base.py
def unload(self) -> None:
    """Method to be called to unload the `LLM` and release any resources."""
    pass
get_generation_kwargs()

返回要用于生成的 generation kwargs。可以覆盖此方法以提供更复杂的 generation kwargs 逻辑。

返回

类型 描述
Dict[str, Any]

要用于生成的 kwargs。

源代码位于 src/distilabel/models/llms/base.py
def get_generation_kwargs(self) -> Dict[str, Any]:
    """Returns the generation kwargs to be used for the generation. This method can
    be overridden to provide a more complex logic for the generation kwargs.

    Returns:
        The kwargs to be used for the generation.
    """
    return self.generation_kwargs  # type: ignore
generate(inputs, num_generations=1, **kwargs) abstractmethod

每个 LLM 要实现的抽象方法,用于为 inputs 中的每个输入生成 num_generations 个生成结果。

参数

名称 类型 描述 默认
inputs List[FormattedInput]

要为其生成响应的输入列表,遵循 OpenAI 的 API 格式

[
    {"role": "system", "content": "You're a helpful assistant..."},
    {"role": "user", "content": "Give a template email for B2B communications..."},
    {"role": "assistant", "content": "Sure, here's a template you can use..."},
    {"role": "user", "content": "Modify the second paragraph..."}
]
必需
num_generations int

每个输入要生成的生成结果数。

1
**kwargs Any

要用于生成的其他 kwargs。

{}
源代码位于 src/distilabel/models/llms/base.py
@abstractmethod
def generate(
    self,
    inputs: List["FormattedInput"],
    num_generations: int = 1,
    **kwargs: Any,
) -> List["GenerateOutput"]:
    """Abstract method to be implemented by each LLM to generate `num_generations`
    per input in `inputs`.

    Args:
        inputs: the list of inputs to generate responses for which follows OpenAI's
            API format:

            ```python
            [
                {"role": "system", "content": "You're a helpful assistant..."},
                {"role": "user", "content": "Give a template email for B2B communications..."},
                {"role": "assistant", "content": "Sure, here's a template you can use..."},
                {"role": "user", "content": "Modify the second paragraph..."}
            ]
            ```
        num_generations: the number of generations to generate per input.
        **kwargs: the additional kwargs to be used for the generation.
    """
    pass
generate_outputs(inputs, num_generations=1, **kwargs)

使用 generate 方法或 offine_batch_generate 方法(如果 `use_offline_

源代码位于 src/distilabel/models/llms/base.py
def generate_outputs(
    self,
    inputs: List["FormattedInput"],
    num_generations: int = 1,
    **kwargs: Any,
) -> List["GenerateOutput"]:
    """Generates outputs for the given inputs using either `generate` method or the
    `offine_batch_generate` method if `use_offline_
    """
    if self.use_offline_batch_generation:
        if self.offline_batch_generation_block_until_done is not None:
            return self._offline_batch_generate_polling(
                inputs=inputs,
                num_generations=num_generations,
                **kwargs,
            )

        # This will raise `DistilabelOfflineBatchGenerationNotFinishedException` right away
        # if the batch generation is not finished.
        return self.offline_batch_generate(
            inputs=inputs,
            num_generations=num_generations,
            **kwargs,
        )

    return self.generate(inputs=inputs, num_generations=num_generations, **kwargs)
get_last_hidden_states(inputs)

获取模型针对输入列表的最后一个隐藏状态的方法。

参数

名称 类型 描述 默认
inputs List[StandardInput]

要从中获取最后一个隐藏状态的输入列表。

必需

返回

类型 描述
List[HiddenState]

一个列表,其中包含每个序列的最后一个隐藏状态,使用形状为 [num_tokens, hidden_size] 的 NumPy 数组。

源代码位于 src/distilabel/models/llms/base.py
def get_last_hidden_states(
    self, inputs: List["StandardInput"]
) -> List["HiddenState"]:
    """Method to get the last hidden states of the model for a list of inputs.

    Args:
        inputs: the list of inputs to get the last hidden states from.

    Returns:
        A list containing the last hidden state for each sequence using a NumPy array
            with shape [num_tokens, hidden_size].
    """
    # TODO: update to use `DistilabelNotImplementedError`
    raise NotImplementedError(
        f"Method `get_last_hidden_states` is not implemented for `{self.__class__.__name__}`"
    )
offline_batch_generate(inputs=None, num_generations=1, **kwargs)

使用每个 LLM 要实现的离线批量生成方法,为给定输入列表生成输出列表的方法。

此方法应在首次调用时创建作业并存储作业 ID,以便第二次和后续调用可以检索批量生成的结果。如果在批量生成完成之前进行后续调用,则该方法应引发 DistilabelOfflineBatchGenerationNotFinishedException 异常。此异常将由 Pipeline 自动处理,Pipeline 将存储所有必需的信息,以便在批量生成完成时恢复 pipeline 执行。

参数

名称 类型 描述 默认
inputs Union[List[FormattedInput], None]

要为其生成响应的输入列表。

None
num_generations int

每个输入要生成的生成结果数。

1
**kwargs Any

要用于生成的其他 kwargs。

{}

返回

类型 描述
List[GenerateOutput]

一个列表,其中包含每个输入的生成结果。

源代码位于 src/distilabel/models/llms/base.py
def offline_batch_generate(
    self,
    inputs: Union[List["FormattedInput"], None] = None,
    num_generations: int = 1,
    **kwargs: Any,
) -> List["GenerateOutput"]:
    """Method to generate a list of outputs for the given inputs using an offline batch
    generation method to be implemented by each `LLM`.

    This method should create jobs the first time is called and store the job ids, so
    the second and subsequent calls can retrieve the results of the batch generation.
    If subsequent calls are made before the batch generation is finished, then the method
    should raise a `DistilabelOfflineBatchGenerationNotFinishedException`. This exception
    will be handled automatically by the `Pipeline` which will store all the required
    information for recovering the pipeline execution when the batch generation is finished.

    Args:
        inputs: the list of inputs to generate responses for.
        num_generations: the number of generations to generate per input.
        **kwargs: the additional kwargs to be used for the generation.

    Returns:
        A list containing the generations for each input.
    """
    raise DistilabelNotImplementedError(
        f"`offline_batch_generate` is not implemented for `{self.__class__.__name__}`",
        page="sections/how_to_guides/advanced/offline-batch-generation/",
    )

AsyncLLM

基类: LLM

异步 LLM 的抽象类,以便从每个 LLM 实现的异步功能中受益。此类旨在由每个 LLM 继承,并且需要实现方法 agenerate 以提供响应的异步生成。

属性

名称 类型 描述
_event_loop AbstractEventLoop

用于异步生成响应的事件循环。

源代码位于 src/distilabel/models/llms/base.py
class AsyncLLM(LLM):
    """Abstract class for asynchronous LLMs, so as to benefit from the async capabilities
    of each LLM implementation. This class is meant to be subclassed by each LLM, and the
    method `agenerate` needs to be implemented to provide the asynchronous generation of
    responses.

    Attributes:
        _event_loop: the event loop to be used for the asynchronous generation of responses.
    """

    _num_generations_param_supported = True
    _event_loop: "asyncio.AbstractEventLoop" = PrivateAttr(default=None)
    _new_event_loop: bool = PrivateAttr(default=False)

    @property
    def generate_parameters(self) -> List[inspect.Parameter]:
        """Returns the parameters of the `agenerate` method.

        Returns:
            A list containing the parameters of the `agenerate` method.
        """
        return list(inspect.signature(self.agenerate).parameters.values())

    @cached_property
    def generate_parsed_docstring(self) -> "Docstring":
        """Returns the parsed docstring of the `agenerate` method.

        Returns:
            The parsed docstring of the `agenerate` method.
        """
        return parse_google_docstring(self.agenerate)

    @property
    def event_loop(self) -> "asyncio.AbstractEventLoop":
        if self._event_loop is None:
            try:
                self._event_loop = asyncio.get_running_loop()
                if self._event_loop.is_closed():
                    self._event_loop = asyncio.new_event_loop()  # type: ignore
                    self._new_event_loop = True
            except RuntimeError:
                self._event_loop = asyncio.new_event_loop()
                self._new_event_loop = True
        asyncio.set_event_loop(self._event_loop)
        return self._event_loop

    @abstractmethod
    async def agenerate(
        self, input: "FormattedInput", num_generations: int = 1, **kwargs: Any
    ) -> "GenerateOutput":
        """Method to generate a `num_generations` responses for a given input asynchronously,
        and executed concurrently in `generate` method.
        """
        pass

    async def _agenerate(
        self, inputs: List["FormattedInput"], num_generations: int = 1, **kwargs: Any
    ) -> List["GenerateOutput"]:
        """Internal function to concurrently generate responses for a list of inputs.

        Args:
            inputs: the list of inputs to generate responses for.
            num_generations: the number of generations to generate per input.
            **kwargs: the additional kwargs to be used for the generation.

        Returns:
            A list containing the generations for each input.
        """
        if self._num_generations_param_supported:
            tasks = [
                asyncio.create_task(
                    self.agenerate(
                        input=input, num_generations=num_generations, **kwargs
                    )
                )
                for input in inputs
            ]
            result = await asyncio.gather(*tasks)
            return result

        tasks = [
            asyncio.create_task(self.agenerate(input=input, **kwargs))
            for input in inputs
            for _ in range(num_generations)
        ]
        outputs = await asyncio.gather(*tasks)
        return merge_responses(outputs, n=num_generations)

    def generate(
        self,
        inputs: List["FormattedInput"],
        num_generations: int = 1,
        **kwargs: Any,
    ) -> List["GenerateOutput"]:
        """Method to generate a list of responses asynchronously, returning the output
        synchronously awaiting for the response of each input sent to `agenerate`.

        Args:
            inputs: the list of inputs to generate responses for.
            num_generations: the number of generations to generate per input.
            **kwargs: the additional kwargs to be used for the generation.

        Returns:
            A list containing the generations for each input.
        """
        return self.event_loop.run_until_complete(
            self._agenerate(inputs=inputs, num_generations=num_generations, **kwargs)
        )

    def __del__(self) -> None:
        """Closes the event loop when the object is deleted."""
        if sys.meta_path is None:
            return

        if self._new_event_loop:
            if self._event_loop.is_running():
                self._event_loop.stop()
            self._event_loop.close()

    @staticmethod
    def _prepare_structured_output(  # type: ignore
        structured_output: "InstructorStructuredOutputType",
        client: Any = None,
        framework: Optional[str] = None,
    ) -> Dict[str, Union[str, Any]]:
        """Wraps the client and updates the schema to work store it internally as a json schema.

        Args:
            structured_output: The configuration dict to prepare the structured output.
            client: The client to wrap to generate structured output. Implemented to work
                with `instructor`.
            framework: The name of the framework.

        Returns:
            A dictionary containing the wrapped client and the schema to update the structured_output
            variable in case it is a pydantic model.
        """
        from distilabel.steps.tasks.structured_outputs.instructor import (
            prepare_instructor,
        )

        result = {}
        client = prepare_instructor(
            client,
            mode=structured_output.get("mode"),
            framework=framework,  # type: ignore
        )
        result["client"] = client

        schema = structured_output.get("schema")
        if not schema:
            raise DistilabelUserError(
                f"The `structured_output` argument must contain a schema: {structured_output}",
                page="sections/how_to_guides/advanced/structured_generation/#instructor",
            )
        if inspect.isclass(schema) and issubclass(schema, BaseModel):
            # We want a json schema for the serialization, but instructor wants a pydantic BaseModel.
            structured_output["schema"] = schema.model_json_schema()  # type: ignore
            result["structured_output"] = structured_output

        return result

    @staticmethod
    def _prepare_kwargs(
        arguments: Dict[str, Any], structured_output: Dict[str, Any]
    ) -> Dict[str, Any]:
        """Helper method to update the kwargs with the structured output configuration,
        used in case they are defined.

        Args:
            arguments: The arguments that would be passed to the LLM as **kwargs.
                to update with the structured output configuration.
            structured_outputs: The structured output configuration to update the arguments.

        Returns:
            kwargs updated with the special arguments used by `instructor`.
        """
        # We can deal with json schema or BaseModel, but we need to convert it to a BaseModel
        # for the Instructor client.
        schema = structured_output.get("schema", {})

        # If there's already a pydantic model, we don't need to do anything,
        # otherwise, try to obtain one.
        if not (inspect.isclass(schema) and issubclass(schema, BaseModel)):
            from distilabel.steps.tasks.structured_outputs.utils import (
                json_schema_to_model,
            )

            if isinstance(schema, str):
                # In case it was saved in the dataset as a string.
                schema = json.loads(schema)

            try:
                schema = json_schema_to_model(schema)
            except Exception as e:
                raise ValueError(
                    f"Failed to convert the schema to a pydantic model, the model is too complex currently: {e}"
                ) from e

        arguments.update(
            **{
                "response_model": schema,
                "max_retries": structured_output.get("max_retries", 1),
            },
        )
        return arguments
generate_parameters property

返回 agenerate 方法的参数。

返回

类型 描述
List[Parameter]

一个列表,其中包含 agenerate 方法的参数。

generate_parsed_docstring cached property

返回 agenerate 方法的已解析文档字符串。

返回

类型 描述
Docstring

agenerate 方法的已解析文档字符串。

agenerate(input, num_generations=1, **kwargs) abstractmethod async

异步生成给定输入的 num_generations 个响应的方法,并在 generate 方法中并发执行。

源代码位于 src/distilabel/models/llms/base.py
@abstractmethod
async def agenerate(
    self, input: "FormattedInput", num_generations: int = 1, **kwargs: Any
) -> "GenerateOutput":
    """Method to generate a `num_generations` responses for a given input asynchronously,
    and executed concurrently in `generate` method.
    """
    pass
generate(inputs, num_generations=1, **kwargs)

异步生成响应列表的方法,同步返回输出,等待发送到 agenerate 的每个输入的响应。

参数

名称 类型 描述 默认
inputs List[FormattedInput]

要为其生成响应的输入列表。

必需
num_generations int

每个输入要生成的生成结果数。

1
**kwargs Any

要用于生成的其他 kwargs。

{}

返回

类型 描述
List[GenerateOutput]

一个列表,其中包含每个输入的生成结果。

源代码位于 src/distilabel/models/llms/base.py
def generate(
    self,
    inputs: List["FormattedInput"],
    num_generations: int = 1,
    **kwargs: Any,
) -> List["GenerateOutput"]:
    """Method to generate a list of responses asynchronously, returning the output
    synchronously awaiting for the response of each input sent to `agenerate`.

    Args:
        inputs: the list of inputs to generate responses for.
        num_generations: the number of generations to generate per input.
        **kwargs: the additional kwargs to be used for the generation.

    Returns:
        A list containing the generations for each input.
    """
    return self.event_loop.run_until_complete(
        self._agenerate(inputs=inputs, num_generations=num_generations, **kwargs)
    )
__del__()

在删除对象时关闭事件循环。

源代码位于 src/distilabel/models/llms/base.py
def __del__(self) -> None:
    """Closes the event loop when the object is deleted."""
    if sys.meta_path is None:
        return

    if self._new_event_loop:
        if self._event_loop.is_running():
            self._event_loop.stop()
        self._event_loop.close()

merge_responses(responses, n=1)

辅助函数,用于根据请求的生成数对 LLM.agenerate 方法的响应进行分组。

参数

名称 类型 描述 默认
responses List[GenerateOutput]

来自 LLM.agenerate 方法的响应。

必需
n int

要分组在一起的响应数。默认为 1。

1

返回

类型 描述
List[GenerateOutput]

合并响应的列表,其中每个合并响应包含 n 个生成结果

List[GenerateOutput]

及其相应的统计信息。

源代码位于 src/distilabel/models/llms/base.py
def merge_responses(
    responses: List["GenerateOutput"], n: int = 1
) -> List["GenerateOutput"]:
    """Helper function to group the responses from `LLM.agenerate` method according
    to the number of generations requested.

    Args:
        responses: the responses from the `LLM.agenerate` method.
        n: number of responses to group together. Defaults to 1.

    Returns:
        List of merged responses, where each merged response contains n generations
        and their corresponding statistics.
    """
    if not responses:
        return []

    def chunks(lst, n):
        """Yield successive n-sized chunks from lst."""
        for i in range(0, len(lst), n):
            yield list(islice(lst, i, i + n))

    extra_keys = [
        key for key in responses[0].keys() if key not in ("generations", "statistics")
    ]

    result = []
    for group in chunks(responses, n):
        merged = {
            "generations": [],
            "statistics": {"input_tokens": [], "output_tokens": []},
        }
        for response in group:
            merged["generations"].append(response["generations"][0])
            # Merge statistics
            for key in response["statistics"]:
                if key not in merged["statistics"]:
                    merged["statistics"][key] = []
                merged["statistics"][key].append(response["statistics"][key][0])
            # Merge extra keys returned by the `LLM`
            for extra_key in extra_keys:
                if extra_key not in merged:
                    merged[extra_key] = []
                merged[extra_key].append(response[extra_key][0])
        result.append(merged)
    return result