跳到内容

LLM Gallery

本节包含 distilabel 中实现的现有 LLM 子类。

llms

AnthropicLLM

基类: AsyncLLM

Anthropic LLM 实现,运行 Async API 客户端。

属性

名称 类型 描述
model str

用于 LLM 的模型名称,例如 "claude-3-opus-20240229"、"claude-3-sonnet-20240229" 等。可用模型可在此处查看:Anthropic: 模型概览

api_key 可选[RuntimeParameter[SecretStr]]

用于验证对 Anthropic API 请求的 API 密钥。如果未提供,将从 ANTHROPIC_API_KEY 环境变量中读取。

base_url 可选[RuntimeParameter[str]]

用于 Anthropic API 的基本 URL。默认为 None,表示内部将使用 https://api.anthropic.com

timeout RuntimeParameter[float]

等待响应的最长秒数。默认为 600.0

max_retries RuntimeParameter[int]

在失败之前重试请求的最大次数。默认为 6

http_client 可选[AsyncClient]

如果提供,则使用备用 HTTP 客户端来调用 Anthropic API。默认为 None

structured_output 可选[RuntimeParameter[InstructorStructuredOutputType]]

一个字典,包含使用 instructor 的结构化输出配置。您可以在 distilabel.steps.tasks.structured_outputs.instructorInstructorStructuredOutputType 中查看字典结构。

_api_key_env_var str

用于 API 密钥的环境变量名称。它旨在内部使用。

_aclient 可选[AsyncAnthropic]

用于 Anthropic API 的 AsyncAnthropic 客户端。它旨在内部使用。在 load 方法中设置。

运行时参数
  • api_key:用于验证对 Anthropic API 请求的 API 密钥。如果未提供,将从 ANTHROPIC_API_KEY 环境变量中读取。
  • base_url:用于 Anthropic API 的基本 URL。默认为 "https://api.anthropic.com"
  • timeout:等待响应的最长秒数。默认为 600.0
  • max_retries:在失败之前重试请求的最大次数。默认为 6

示例

生成文本

from distilabel.models.llms import AnthropicLLM

llm = AnthropicLLM(model="claude-3-opus-20240229", api_key="api.key")

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

生成结构化数据

from pydantic import BaseModel
from distilabel.models.llms import AnthropicLLM

class User(BaseModel):
    name: str
    last_name: str
    id: int

llm = AnthropicLLM(
    model="claude-3-opus-20240229",
    api_key="api.key",
    structured_output={"schema": User}
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
源代码位于 src/distilabel/models/llms/anthropic.py
class AnthropicLLM(AsyncLLM):
    """Anthropic LLM implementation running the Async API client.

    Attributes:
        model: the name of the model to use for the LLM e.g. "claude-3-opus-20240229",
            "claude-3-sonnet-20240229", etc. Available models can be checked here:
            [Anthropic: Models overview](https://docs.anthropic.com/claude/docs/models-overview).
        api_key: the API key to authenticate the requests to the Anthropic API. If not provided,
            it will be read from `ANTHROPIC_API_KEY` environment variable.
        base_url: the base URL to use for the Anthropic API. Defaults to `None` which means
            that `https://api.anthropic.com` will be used internally.
        timeout: the maximum time in seconds to wait for a response. Defaults to `600.0`.
        max_retries: The maximum number of times to retry the request before failing. Defaults
            to `6`.
        http_client: if provided, an alternative HTTP client to use for calling Anthropic
            API. Defaults to `None`.
        structured_output: a dictionary containing the structured output configuration configuration
            using `instructor`. You can take a look at the dictionary structure in
            `InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`.
        _api_key_env_var: the name of the environment variable to use for the API key. It
            is meant to be used internally.
        _aclient: the `AsyncAnthropic` client to use for the Anthropic API. It is meant
            to be used internally. Set in the `load` method.

    Runtime parameters:
        - `api_key`: the API key to authenticate the requests to the Anthropic API. If not
            provided, it will be read from `ANTHROPIC_API_KEY` environment variable.
        - `base_url`: the base URL to use for the Anthropic API. Defaults to `"https://api.anthropic.com"`.
        - `timeout`: the maximum time in seconds to wait for a response. Defaults to `600.0`.
        - `max_retries`: the maximum number of times to retry the request before failing.
            Defaults to `6`.

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import AnthropicLLM

        llm = AnthropicLLM(model="claude-3-opus-20240229", api_key="api.key")

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```

        Generate structured data:

        ```python
        from pydantic import BaseModel
        from distilabel.models.llms import AnthropicLLM

        class User(BaseModel):
            name: str
            last_name: str
            id: int

        llm = AnthropicLLM(
            model="claude-3-opus-20240229",
            api_key="api.key",
            structured_output={"schema": User}
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
        ```
    """

    model: str
    base_url: Optional[RuntimeParameter[str]] = Field(
        default_factory=lambda: os.getenv(
            "ANTHROPIC_BASE_URL", "https://api.anthropic.com"
        ),
        description="The base URL to use for the Anthropic API.",
    )
    api_key: Optional[RuntimeParameter[SecretStr]] = Field(
        default_factory=lambda: os.getenv(_ANTHROPIC_API_KEY_ENV_VAR_NAME),
        description="The API key to authenticate the requests to the Anthropic API.",
    )
    timeout: RuntimeParameter[float] = Field(
        default=600.0,
        description="The maximum time in seconds to wait for a response from the API.",
    )
    max_retries: RuntimeParameter[int] = Field(
        default=6,
        description="The maximum number of times to retry the request to the API before"
        " failing.",
    )
    http_client: Optional[AsyncClient] = Field(default=None, exclude=True)
    structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = (
        Field(
            default=None,
            description="The structured output format to use across all the generations.",
        )
    )

    _num_generations_param_supported = False

    _api_key_env_var: str = PrivateAttr(default=_ANTHROPIC_API_KEY_ENV_VAR_NAME)
    _aclient: Optional["AsyncAnthropic"] = PrivateAttr(...)

    def _check_model_exists(self) -> None:
        """Checks if the specified model exists in the available models."""
        from anthropic import AsyncAnthropic

        annotation = get_type_hints(AsyncAnthropic().messages.create).get("model", None)
        models = [
            value
            for type_ in get_args(annotation)
            if get_origin(type_) is Literal
            for value in get_args(type_)
        ]

        if self.model not in models:
            raise ValueError(
                f"Model {self.model} does not exist among available models. "
                f"The available models are {', '.join(models)}"
            )

    def load(self) -> None:
        """Loads the `AsyncAnthropic` client to use the Anthropic async API."""
        super().load()

        try:
            from anthropic import AsyncAnthropic
        except ImportError as ie:
            raise ImportError(
                "Anthropic Python client is not installed. Please install it using"
                " `pip install 'distilabel[anthropic]'`."
            ) from ie

        if self.api_key is None:
            raise ValueError(
                f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
                f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
            )

        self._check_model_exists()

        self._aclient = AsyncAnthropic(
            api_key=self.api_key.get_secret_value(),
            base_url=self.base_url,
            timeout=self.timeout,
            http_client=self.http_client,
            max_retries=self.max_retries,
        )
        if self.structured_output:
            result = self._prepare_structured_output(
                structured_output=self.structured_output,
                client=self._aclient,
                framework="anthropic",
            )
            self._aclient = result.get("client")
            if structured_output := result.get("structured_output"):
                self.structured_output = structured_output

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: FormattedInput,
        max_tokens: int = 128,
        stop_sequences: Union[List[str], None] = None,
        temperature: float = 1.0,
        top_p: Union[float, None] = None,
        top_k: Union[int, None] = None,
    ) -> GenerateOutput:
        """Generates a response asynchronously, using the [Anthropic Async API definition](https://github.com/anthropics/anthropic-sdk-python).

        Args:
            input: a single input in chat format to generate responses for.
            max_tokens: the maximum number of new tokens that the model will generate. Defaults to `128`.
            stop_sequences: custom text sequences that will cause the model to stop generating. Defaults to `NOT_GIVEN`.
            temperature: the temperature to use for the generation. Set only if top_p is None. Defaults to `1.0`.
            top_p: the top-p value to use for the generation. Defaults to `NOT_GIVEN`.
            top_k: the top-k value to use for the generation. Defaults to `NOT_GIVEN`.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """
        from anthropic._types import NOT_GIVEN

        structured_output = None
        if isinstance(input, tuple):
            input, structured_output = input
            result = self._prepare_structured_output(
                structured_output=structured_output,
                client=self._aclient,
                framework="anthropic",
            )
            self._aclient = result.get("client")

        if structured_output is None and self.structured_output is not None:
            structured_output = self.structured_output

        kwargs = {
            "messages": input,  # type: ignore
            "model": self.model,
            "system": (
                input.pop(0)["content"]
                if input and input[0]["role"] == "system"
                else NOT_GIVEN
            ),
            "max_tokens": max_tokens,
            "stream": False,
            "stop_sequences": NOT_GIVEN if stop_sequences is None else stop_sequences,
            "temperature": temperature,
            "top_p": NOT_GIVEN if top_p is None else top_p,
            "top_k": NOT_GIVEN if top_k is None else top_k,
        }

        if structured_output:
            kwargs = self._prepare_kwargs(kwargs, structured_output)

        completion: Union["Message", "BaseModel"] = await self._aclient.messages.create(
            **kwargs
        )  # type: ignore
        if structured_output:
            # raw_response = completion._raw_response
            return prepare_output(
                [completion.model_dump_json()],
                **self._get_llm_statistics(completion._raw_response),
            )

        if (content := completion.content[0].text) is None:
            self._logger.warning(
                f"Received no response using Anthropic client (model: '{self.model}')."
                f" Finish reason was: {completion.stop_reason}"
            )
        return prepare_output([content], **self._get_llm_statistics(completion))

    @staticmethod
    def _get_llm_statistics(completion: "Message") -> "LLMStatistics":
        return {
            "input_tokens": [completion.usage.input_tokens],
            "output_tokens": [completion.usage.output_tokens],
        }
model_name property

返回用于 LLM 的模型名称。

_check_model_exists()

检查指定的模型是否存在于可用模型中。

源代码位于 src/distilabel/models/llms/anthropic.py
def _check_model_exists(self) -> None:
    """Checks if the specified model exists in the available models."""
    from anthropic import AsyncAnthropic

    annotation = get_type_hints(AsyncAnthropic().messages.create).get("model", None)
    models = [
        value
        for type_ in get_args(annotation)
        if get_origin(type_) is Literal
        for value in get_args(type_)
    ]

    if self.model not in models:
        raise ValueError(
            f"Model {self.model} does not exist among available models. "
            f"The available models are {', '.join(models)}"
        )
load()

加载 AsyncAnthropic 客户端以使用 Anthropic 异步 API。

源代码位于 src/distilabel/models/llms/anthropic.py
def load(self) -> None:
    """Loads the `AsyncAnthropic` client to use the Anthropic async API."""
    super().load()

    try:
        from anthropic import AsyncAnthropic
    except ImportError as ie:
        raise ImportError(
            "Anthropic Python client is not installed. Please install it using"
            " `pip install 'distilabel[anthropic]'`."
        ) from ie

    if self.api_key is None:
        raise ValueError(
            f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
            f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
        )

    self._check_model_exists()

    self._aclient = AsyncAnthropic(
        api_key=self.api_key.get_secret_value(),
        base_url=self.base_url,
        timeout=self.timeout,
        http_client=self.http_client,
        max_retries=self.max_retries,
    )
    if self.structured_output:
        result = self._prepare_structured_output(
            structured_output=self.structured_output,
            client=self._aclient,
            framework="anthropic",
        )
        self._aclient = result.get("client")
        if structured_output := result.get("structured_output"):
            self.structured_output = structured_output
agenerate(input, max_tokens=128, stop_sequences=None, temperature=1.0, top_p=None, top_k=None) async

使用 Anthropic Async API 定义异步生成响应。

参数

名称 类型 描述 默认值
input FormattedInput

以聊天格式的单个输入,用于生成响应。

必需
max_tokens int

模型将生成的最大新 token 数。默认为 128

128
stop_sequences Union[List[str], None]

将导致模型停止生成的自定义文本序列。默认为 NOT_GIVEN

None
temperature float

用于生成的温度。仅当 top_p 为 None 时设置。默认为 1.0

1.0
top_p Union[float, None]

用于生成的 top-p 值。默认为 NOT_GIVEN

None
top_k Union[int, None]

用于生成的 top-k 值。默认为 NOT_GIVEN

None

返回

类型 描述
GenerateOutput

包含每个输入的生成响应的字符串列表的列表。

源代码位于 src/distilabel/models/llms/anthropic.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: FormattedInput,
    max_tokens: int = 128,
    stop_sequences: Union[List[str], None] = None,
    temperature: float = 1.0,
    top_p: Union[float, None] = None,
    top_k: Union[int, None] = None,
) -> GenerateOutput:
    """Generates a response asynchronously, using the [Anthropic Async API definition](https://github.com/anthropics/anthropic-sdk-python).

    Args:
        input: a single input in chat format to generate responses for.
        max_tokens: the maximum number of new tokens that the model will generate. Defaults to `128`.
        stop_sequences: custom text sequences that will cause the model to stop generating. Defaults to `NOT_GIVEN`.
        temperature: the temperature to use for the generation. Set only if top_p is None. Defaults to `1.0`.
        top_p: the top-p value to use for the generation. Defaults to `NOT_GIVEN`.
        top_k: the top-k value to use for the generation. Defaults to `NOT_GIVEN`.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """
    from anthropic._types import NOT_GIVEN

    structured_output = None
    if isinstance(input, tuple):
        input, structured_output = input
        result = self._prepare_structured_output(
            structured_output=structured_output,
            client=self._aclient,
            framework="anthropic",
        )
        self._aclient = result.get("client")

    if structured_output is None and self.structured_output is not None:
        structured_output = self.structured_output

    kwargs = {
        "messages": input,  # type: ignore
        "model": self.model,
        "system": (
            input.pop(0)["content"]
            if input and input[0]["role"] == "system"
            else NOT_GIVEN
        ),
        "max_tokens": max_tokens,
        "stream": False,
        "stop_sequences": NOT_GIVEN if stop_sequences is None else stop_sequences,
        "temperature": temperature,
        "top_p": NOT_GIVEN if top_p is None else top_p,
        "top_k": NOT_GIVEN if top_k is None else top_k,
    }

    if structured_output:
        kwargs = self._prepare_kwargs(kwargs, structured_output)

    completion: Union["Message", "BaseModel"] = await self._aclient.messages.create(
        **kwargs
    )  # type: ignore
    if structured_output:
        # raw_response = completion._raw_response
        return prepare_output(
            [completion.model_dump_json()],
            **self._get_llm_statistics(completion._raw_response),
        )

    if (content := completion.content[0].text) is None:
        self._logger.warning(
            f"Received no response using Anthropic client (model: '{self.model}')."
            f" Finish reason was: {completion.stop_reason}"
        )
    return prepare_output([content], **self._get_llm_statistics(completion))

AnyscaleLLM

基类: OpenAILLM

Anyscale LLM 实现,运行 OpenAI 的异步 API 客户端。

属性

名称 类型 描述
model str

用于 LLM 的模型名称,例如 google/gemma-7b-it。请参阅此处“文本生成 -> 支持的模型”部分下支持的模型。

base_url 可选[RuntimeParameter[str]]

用于 Anyscale API 请求的基本 URL。默认为 None,表示将使用为环境变量 ANYSCALE_BASE_URL 设置的值,如果未设置,则使用 "https://api.endpoints.anyscale.com/v1"。

api_key 可选[RuntimeParameter[SecretStr]]

用于验证对 Anyscale API 请求的 API 密钥。默认为 None,表示将使用为环境变量 ANYSCALE_API_KEY 设置的值,如果未设置,则为 None

_api_key_env_var str

用于 API 密钥的环境变量名称。它旨在内部使用。

示例

生成文本

from distilabel.models.llms import AnyscaleLLM

llm = AnyscaleLLM(model="google/gemma-7b-it", api_key="api.key")

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
源代码位于 src/distilabel/models/llms/anyscale.py
class AnyscaleLLM(OpenAILLM):
    """Anyscale LLM implementation running the async API client of OpenAI.

    Attributes:
        model: the model name to use for the LLM, e.g., `google/gemma-7b-it`. See the
            supported models under the "Text Generation -> Supported Models" section
            [here](https://docs.endpoints.anyscale.com/).
        base_url: the base URL to use for the Anyscale API requests. Defaults to `None`, which
            means that the value set for the environment variable `ANYSCALE_BASE_URL` will be used, or
            "https://api.endpoints.anyscale.com/v1" if not set.
        api_key: the API key to authenticate the requests to the Anyscale API. Defaults to `None` which
            means that the value set for the environment variable `ANYSCALE_API_KEY` will be used, or
            `None` if not set.
        _api_key_env_var: the name of the environment variable to use for the API key.
            It is meant to be used internally.

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import AnyscaleLLM

        llm = AnyscaleLLM(model="google/gemma-7b-it", api_key="api.key")

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```
    """

    base_url: Optional[RuntimeParameter[str]] = Field(
        default_factory=lambda: os.getenv(
            "ANYSCALE_BASE_URL", "https://api.endpoints.anyscale.com/v1"
        ),
        description="The base URL to use for the Anyscale API requests.",
    )
    api_key: Optional[RuntimeParameter[SecretStr]] = Field(
        default_factory=lambda: os.getenv(_ANYSCALE_API_KEY_ENV_VAR_NAME),
        description="The API key to authenticate the requests to the Anyscale API.",
    )

    _api_key_env_var: str = PrivateAttr(_ANYSCALE_API_KEY_ENV_VAR_NAME)

AzureOpenAILLM

基类: OpenAILLM

Azure OpenAI LLM 实现,运行异步 API 客户端。

属性

名称 类型 描述
model str

用于 LLM 的模型名称,即 Azure 部署的名称。

base_url 可选[RuntimeParameter[str]]

用于 Azure OpenAI API 的基本 URL 可以使用 AZURE_OPENAI_ENDPOINT 设置。默认为 None,表示将使用为环境变量 AZURE_OPENAI_ENDPOINT 设置的值,如果未设置,则为 None

api_key 可选[RuntimeParameter[SecretStr]]

用于验证对 Azure OpenAI API 请求的 API 密钥。默认为 None,表示将使用为环境变量 AZURE_OPENAI_API_KEY 设置的值,如果未设置,则为 None

api_version 可选[RuntimeParameter[str]]

用于 Azure OpenAI API 的 API 版本。默认为 None,表示将使用为环境变量 OPENAI_API_VERSION 设置的值,如果未设置,则为 None

图标

:material-microsoft-azure

示例

生成文本

from distilabel.models.llms import AzureOpenAILLM

llm = AzureOpenAILLM(model="gpt-4-turbo", api_key="api.key")

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

从遵循 OpenAI API 的自定义端点生成文本

from distilabel.models.llms import AzureOpenAILLM

llm = AzureOpenAILLM(
    model="prometheus-eval/prometheus-7b-v2.0",
    base_url=r"http://localhost:8080/v1"
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

生成结构化数据

from pydantic import BaseModel
from distilabel.models.llms import AzureOpenAILLM

class User(BaseModel):
    name: str
    last_name: str
    id: int

llm = AzureOpenAILLM(
    model="gpt-4-turbo",
    api_key="api.key",
    structured_output={"schema": User}
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
源代码位于 src/distilabel/models/llms/azure.py
class AzureOpenAILLM(OpenAILLM):
    """Azure OpenAI LLM implementation running the async API client.

    Attributes:
        model: the model name to use for the LLM i.e. the name of the Azure deployment.
        base_url: the base URL to use for the Azure OpenAI API can be set with `AZURE_OPENAI_ENDPOINT`.
            Defaults to `None` which means that the value set for the environment variable
            `AZURE_OPENAI_ENDPOINT` will be used, or `None` if not set.
        api_key: the API key to authenticate the requests to the Azure OpenAI API. Defaults to `None`
            which means that the value set for the environment variable `AZURE_OPENAI_API_KEY` will be
            used, or `None` if not set.
        api_version: the API version to use for the Azure OpenAI API. Defaults to `None` which means
            that the value set for the environment variable `OPENAI_API_VERSION` will be used, or
            `None` if not set.

    Icon:
        `:material-microsoft-azure:`

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import AzureOpenAILLM

        llm = AzureOpenAILLM(model="gpt-4-turbo", api_key="api.key")

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```

        Generate text from a custom endpoint following the OpenAI API:

        ```python
        from distilabel.models.llms import AzureOpenAILLM

        llm = AzureOpenAILLM(
            model="prometheus-eval/prometheus-7b-v2.0",
            base_url=r"http://localhost:8080/v1"
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```

        Generate structured data:

        ```python
        from pydantic import BaseModel
        from distilabel.models.llms import AzureOpenAILLM

        class User(BaseModel):
            name: str
            last_name: str
            id: int

        llm = AzureOpenAILLM(
            model="gpt-4-turbo",
            api_key="api.key",
            structured_output={"schema": User}
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
        ```
    """

    base_url: Optional[RuntimeParameter[str]] = Field(
        default_factory=lambda: os.getenv(_AZURE_OPENAI_ENDPOINT_ENV_VAR_NAME),
        description="The base URL to use for the Azure OpenAI API requests i.e. the Azure OpenAI endpoint.",
    )
    api_key: Optional[RuntimeParameter[SecretStr]] = Field(
        default_factory=lambda: os.getenv(_AZURE_OPENAI_API_KEY_ENV_VAR_NAME),
        description="The API key to authenticate the requests to the Azure OpenAI API.",
    )

    api_version: Optional[RuntimeParameter[str]] = Field(
        default_factory=lambda: os.getenv("OPENAI_API_VERSION"),
        description="The API version to use for the Azure OpenAI API.",
    )

    _base_url_env_var: str = PrivateAttr(_AZURE_OPENAI_ENDPOINT_ENV_VAR_NAME)
    _api_key_env_var: str = PrivateAttr(_AZURE_OPENAI_API_KEY_ENV_VAR_NAME)
    _aclient: Optional["AsyncAzureOpenAI"] = PrivateAttr(...)  # type: ignore

    @override
    def load(self) -> None:
        """Loads the `AsyncAzureOpenAI` client to benefit from async requests."""
        # This is a workaround to avoid the `OpenAILLM` calling the _prepare_structured_output
        # in the load method before we have the proper client.
        with patch(
            "distilabel.models.openai.OpenAILLM._prepare_structured_output", lambda x: x
        ):
            super().load()

        try:
            from openai import AsyncAzureOpenAI
        except ImportError as ie:
            raise ImportError(
                "OpenAI Python client is not installed. Please install it using"
                " `pip install 'distilabel[openai]'`."
            ) from ie

        if self.api_key is None:
            raise ValueError(
                f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
                f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
            )

        # TODO: May be worth adding the AD auth too? Also the `organization`?
        self._aclient = AsyncAzureOpenAI(  # type: ignore
            azure_endpoint=self.base_url,  # type: ignore
            azure_deployment=self.model,
            api_version=self.api_version,
            api_key=self.api_key.get_secret_value(),
            max_retries=self.max_retries,  # type: ignore
            timeout=self.timeout,
        )

        if self.structured_output:
            self._prepare_structured_output(self.structured_output)
load()

加载 AsyncAzureOpenAI 客户端以从异步请求中受益。

源代码位于 src/distilabel/models/llms/azure.py
@override
def load(self) -> None:
    """Loads the `AsyncAzureOpenAI` client to benefit from async requests."""
    # This is a workaround to avoid the `OpenAILLM` calling the _prepare_structured_output
    # in the load method before we have the proper client.
    with patch(
        "distilabel.models.openai.OpenAILLM._prepare_structured_output", lambda x: x
    ):
        super().load()

    try:
        from openai import AsyncAzureOpenAI
    except ImportError as ie:
        raise ImportError(
            "OpenAI Python client is not installed. Please install it using"
            " `pip install 'distilabel[openai]'`."
        ) from ie

    if self.api_key is None:
        raise ValueError(
            f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
            f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
        )

    # TODO: May be worth adding the AD auth too? Also the `organization`?
    self._aclient = AsyncAzureOpenAI(  # type: ignore
        azure_endpoint=self.base_url,  # type: ignore
        azure_deployment=self.model,
        api_version=self.api_version,
        api_key=self.api_key.get_secret_value(),
        max_retries=self.max_retries,  # type: ignore
        timeout=self.timeout,
    )

    if self.structured_output:
        self._prepare_structured_output(self.structured_output)

CohereLLM

基类: AsyncLLM

Cohere API 实现,使用异步客户端进行并发文本生成。

属性

名称 类型 描述
model str

来自 Cohere API 的模型名称,用于生成。

base_url 可选[RuntimeParameter[str]]

用于 Cohere API 请求的基本 URL。默认为 "https://api.cohere.ai/v1"

api_key 可选[RuntimeParameter[SecretStr]]

用于验证对 Cohere API 请求的 API 密钥。默认为 COHERE_API_KEY 环境变量的值。

timeout RuntimeParameter[int]

等待来自 API 响应的最长秒数。默认为 120

client_name RuntimeParameter[str]

用于 API 请求的客户端名称。默认为 "distilabel"

structured_output 可选[RuntimeParameter[InstructorStructuredOutputType]]

一个字典,包含使用 instructor 的结构化输出配置。您可以在 distilabel.steps.tasks.structured_outputs.instructorInstructorStructuredOutputType 中查看字典结构。

_ChatMessage Type[ChatMessage]

来自 cohere 包的 ChatMessage 类。

_aclient AsyncClient

来自 cohere 包的 AsyncClient 客户端。

运行时参数
  • base_url:用于 Cohere API 请求的基本 URL。默认为 "https://api.cohere.ai/v1"
  • api_key:用于验证对 Cohere API 请求的 API 密钥。默认为 COHERE_API_KEY 环境变量的值。
  • timeout:等待来自 API 响应的最长秒数。默认为 120
  • client_name:用于 API 请求的客户端名称。默认为 "distilabel"

示例

生成文本

from distilabel.models.llms import CohereLLM

llm = CohereLLM(model="CohereForAI/c4ai-command-r-plus")

llm.load()

# Call the model
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

Generate structured data:

```python
from pydantic import BaseModel
from distilabel.models.llms import CohereLLM

class User(BaseModel):
    name: str
    last_name: str
    id: int

llm = CohereLLM(
    model="CohereForAI/c4ai-command-r-plus",
    api_key="api.key",
    structured_output={"schema": User}
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
源代码位于 src/distilabel/models/llms/cohere.py
class CohereLLM(AsyncLLM):
    """Cohere API implementation using the async client for concurrent text generation.

    Attributes:
        model: the name of the model from the Cohere API to use for the generation.
        base_url: the base URL to use for the Cohere API requests. Defaults to
            `"https://api.cohere.ai/v1"`.
        api_key: the API key to authenticate the requests to the Cohere API. Defaults to
            the value of the `COHERE_API_KEY` environment variable.
        timeout: the maximum time in seconds to wait for a response from the API. Defaults
            to `120`.
        client_name: the name of the client to use for the API requests. Defaults to
            `"distilabel"`.
        structured_output: a dictionary containing the structured output configuration configuration
            using `instructor`. You can take a look at the dictionary structure in
            `InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`.
        _ChatMessage: the `ChatMessage` class from the `cohere` package.
        _aclient: the `AsyncClient` client from the `cohere` package.

    Runtime parameters:
        - `base_url`: the base URL to use for the Cohere API requests. Defaults to
            `"https://api.cohere.ai/v1"`.
        - `api_key`: the API key to authenticate the requests to the Cohere API. Defaults
            to the value of the `COHERE_API_KEY` environment variable.
        - `timeout`: the maximum time in seconds to wait for a response from the API. Defaults
            to `120`.
        - `client_name`: the name of the client to use for the API requests. Defaults to
            `"distilabel"`.

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import CohereLLM

        llm = CohereLLM(model="CohereForAI/c4ai-command-r-plus")

        llm.load()

        # Call the model
        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

        Generate structured data:

        ```python
        from pydantic import BaseModel
        from distilabel.models.llms import CohereLLM

        class User(BaseModel):
            name: str
            last_name: str
            id: int

        llm = CohereLLM(
            model="CohereForAI/c4ai-command-r-plus",
            api_key="api.key",
            structured_output={"schema": User}
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
        ```
    """

    model: str
    base_url: Optional[RuntimeParameter[str]] = Field(
        default_factory=lambda: os.getenv(
            "COHERE_BASE_URL", "https://api.cohere.ai/v1"
        ),
        description="The base URL to use for the Cohere API requests.",
    )
    api_key: Optional[RuntimeParameter[SecretStr]] = Field(
        default_factory=lambda: os.getenv(_COHERE_API_KEY_ENV_VAR_NAME),
        description="The API key to authenticate the requests to the Cohere API.",
    )
    timeout: RuntimeParameter[int] = Field(
        default=120,
        description="The maximum time in seconds to wait for a response from the API.",
    )
    client_name: RuntimeParameter[str] = Field(
        default="distilabel",
        description="The name of the client to use for the API requests.",
    )
    structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = (
        Field(
            default=None,
            description="The structured output format to use across all the generations.",
        )
    )

    _num_generations_param_supported = False

    _ChatMessage: Type["ChatMessage"] = PrivateAttr(...)
    _aclient: "AsyncClient" = PrivateAttr(...)
    _tokenizer: "Tokenizer" = PrivateAttr(...)

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    def load(self) -> None:
        """Loads the `AsyncClient` client from the `cohere` package."""

        super().load()

        try:
            from cohere import AsyncClient, ChatMessage
        except ImportError as ie:
            raise ImportError(
                "The `cohere` package is required to use the `CohereLLM` class."
            ) from ie

        self._ChatMessage = ChatMessage

        self._aclient = AsyncClient(
            api_key=self.api_key.get_secret_value(),  # type: ignore
            client_name=self.client_name,
            base_url=self.base_url,
            timeout=self.timeout,
        )

        if self.structured_output:
            result = self._prepare_structured_output(
                structured_output=self.structured_output,
                client=self._aclient,
                framework="cohere",
            )
            self._aclient = result.get("client")  # type: ignore
            if structured_output := result.get("structured_output"):
                self.structured_output = structured_output

        from cohere.manually_maintained.tokenizers import get_hf_tokenizer

        self._tokenizer: "Tokenizer" = get_hf_tokenizer(self._aclient, self.model)

    def _format_chat_to_cohere(
        self, input: "FormattedInput"
    ) -> Tuple[Union[str, None], List["ChatMessage"], str]:
        """Formats the chat input to the Cohere Chat API conversational format.

        Args:
            input: The chat input to format.

        Returns:
            A tuple containing the system, chat history, and message.
        """
        system = None
        message = None
        chat_history = []
        for item in input:
            role = item["role"]
            content = item["content"]
            if role == "system":
                system = content
            elif role == "user":
                message = content
            elif role == "assistant":
                if message is None:
                    raise ValueError(
                        "An assistant message but be preceded by a user message."
                    )
                chat_history.append(self._ChatMessage(role="USER", message=message))  # type: ignore
                chat_history.append(self._ChatMessage(role="CHATBOT", message=content))  # type: ignore
                message = None

        if message is None:
            raise ValueError("The chat input must end with a user message.")

        return system, chat_history, message

    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: FormattedInput,
        temperature: Optional[float] = None,
        max_tokens: Optional[int] = None,
        k: Optional[int] = None,
        p: Optional[float] = None,
        seed: Optional[float] = None,
        stop_sequences: Optional[Sequence[str]] = None,
        frequency_penalty: Optional[float] = None,
        presence_penalty: Optional[float] = None,
        raw_prompting: Optional[bool] = None,
    ) -> GenerateOutput:
        """Generates a response from the LLM given an input.

        Args:
            input: a single input in chat format to generate responses for.
            temperature: the temperature to use for the generation. Defaults to `None`.
            max_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `None`.
            k: the number of highest probability vocabulary tokens to keep for the generation.
                Defaults to `None`.
            p: the nucleus sampling probability to use for the generation. Defaults to
                `None`.
            seed: the seed to use for the generation. Defaults to `None`.
            stop_sequences: a list of sequences to use as stopping criteria for the generation.
                Defaults to `None`.
            frequency_penalty: the frequency penalty to use for the generation. Defaults
                to `None`.
            presence_penalty: the presence penalty to use for the generation. Defaults to
                `None`.
            raw_prompting: a flag to use raw prompting for the generation. Defaults to
                `None`.

        Returns:
            The generated response from the Cohere API model.
        """
        structured_output = None
        if isinstance(input, tuple):
            input, structured_output = input
            result = self._prepare_structured_output(
                structured_output=structured_output,  # type: ignore
                client=self._aclient,
                framework="cohere",
            )
            self._aclient = result.get("client")  # type: ignore

        if structured_output is None and self.structured_output is not None:
            structured_output = self.structured_output

        system, chat_history, message = self._format_chat_to_cohere(input)

        kwargs = {
            "message": message,
            "model": self.model,
            "preamble": system,
            "chat_history": chat_history,
            "temperature": temperature,
            "max_tokens": max_tokens,
            "k": k,
            "p": p,
            "seed": seed,
            "stop_sequences": stop_sequences,
            "frequency_penalty": frequency_penalty,
            "presence_penalty": presence_penalty,
            "raw_prompting": raw_prompting,
        }
        if structured_output:
            kwargs = self._prepare_kwargs(kwargs, structured_output)  # type: ignore

        response: Union["Message", "BaseModel"] = await self._aclient.chat(**kwargs)  # type: ignore

        if structured_output:
            return prepare_output(
                [response.model_dump_json()],
                **self._get_llm_statistics(
                    input, orjson.dumps(response.model_dump_json()).decode("utf-8")
                ),  # type: ignore
            )

        if (text := response.text) == "":
            self._logger.warning(  # type: ignore
                f"Received no response using Cohere client (model: '{self.model}')."
                f" Finish reason was: {response.finish_reason}"
            )
            return prepare_output(
                [None],
                **self._get_llm_statistics(input, ""),
            )

        return prepare_output(
            [text],
            **self._get_llm_statistics(input, text),
        )

    def _get_llm_statistics(
        self, input: FormattedInput, output: str
    ) -> "LLMStatistics":
        return {
            "input_tokens": [compute_tokens(input, self._tokenizer.encode)],
            "output_tokens": [compute_tokens(output, self._tokenizer.encode)],
        }
model_name property

返回用于 LLM 的模型名称。

load()

加载来自 cohere 包的 AsyncClient 客户端。

源代码位于 src/distilabel/models/llms/cohere.py
def load(self) -> None:
    """Loads the `AsyncClient` client from the `cohere` package."""

    super().load()

    try:
        from cohere import AsyncClient, ChatMessage
    except ImportError as ie:
        raise ImportError(
            "The `cohere` package is required to use the `CohereLLM` class."
        ) from ie

    self._ChatMessage = ChatMessage

    self._aclient = AsyncClient(
        api_key=self.api_key.get_secret_value(),  # type: ignore
        client_name=self.client_name,
        base_url=self.base_url,
        timeout=self.timeout,
    )

    if self.structured_output:
        result = self._prepare_structured_output(
            structured_output=self.structured_output,
            client=self._aclient,
            framework="cohere",
        )
        self._aclient = result.get("client")  # type: ignore
        if structured_output := result.get("structured_output"):
            self.structured_output = structured_output

    from cohere.manually_maintained.tokenizers import get_hf_tokenizer

    self._tokenizer: "Tokenizer" = get_hf_tokenizer(self._aclient, self.model)
_format_chat_to_cohere(input)

将聊天输入格式化为 Cohere Chat API 会话格式。

参数

名称 类型 描述 默认值
input FormattedInput

要格式化的聊天输入。

必需

返回

类型 描述
Tuple[Union[str, None], List[ChatMessage], str]

包含系统、聊天记录和消息的元组。

源代码位于 src/distilabel/models/llms/cohere.py
def _format_chat_to_cohere(
    self, input: "FormattedInput"
) -> Tuple[Union[str, None], List["ChatMessage"], str]:
    """Formats the chat input to the Cohere Chat API conversational format.

    Args:
        input: The chat input to format.

    Returns:
        A tuple containing the system, chat history, and message.
    """
    system = None
    message = None
    chat_history = []
    for item in input:
        role = item["role"]
        content = item["content"]
        if role == "system":
            system = content
        elif role == "user":
            message = content
        elif role == "assistant":
            if message is None:
                raise ValueError(
                    "An assistant message but be preceded by a user message."
                )
            chat_history.append(self._ChatMessage(role="USER", message=message))  # type: ignore
            chat_history.append(self._ChatMessage(role="CHATBOT", message=content))  # type: ignore
            message = None

    if message is None:
        raise ValueError("The chat input must end with a user message.")

    return system, chat_history, message
agenerate(input, temperature=None, max_tokens=None, k=None, p=None, seed=None, stop_sequences=None, frequency_penalty=None, presence_penalty=None, raw_prompting=None) async

从 LLM 生成给定输入的响应。

参数

名称 类型 描述 默认值
input FormattedInput

以聊天格式的单个输入,用于生成响应。

必需
temperature Optional[float]

用于生成的温度。默认为 None

None
max_tokens Optional[int]

模型将生成的最大新 token 数。默认为 None

None
k Optional[int]

为生成保留的最高概率词汇 token 数。默认为 None

None
p Optional[float]

用于生成的 nucleus 采样概率。默认为 None

None
seed Optional[float]

用于生成的种子。默认为 None

None
stop_sequences Optional[Sequence[str]]

用作生成停止标准的序列列表。默认为 None

None
frequency_penalty Optional[float]

用于生成的频率惩罚。默认为 None

None
presence_penalty Optional[float]

用于存在的惩罚。默认为 None

None
raw_prompting Optional[bool]

用于原始提示生成的标志。默认为 None

None

返回

类型 描述
GenerateOutput

来自 Cohere API 模型的生成响应。

源代码位于 src/distilabel/models/llms/cohere.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: FormattedInput,
    temperature: Optional[float] = None,
    max_tokens: Optional[int] = None,
    k: Optional[int] = None,
    p: Optional[float] = None,
    seed: Optional[float] = None,
    stop_sequences: Optional[Sequence[str]] = None,
    frequency_penalty: Optional[float] = None,
    presence_penalty: Optional[float] = None,
    raw_prompting: Optional[bool] = None,
) -> GenerateOutput:
    """Generates a response from the LLM given an input.

    Args:
        input: a single input in chat format to generate responses for.
        temperature: the temperature to use for the generation. Defaults to `None`.
        max_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `None`.
        k: the number of highest probability vocabulary tokens to keep for the generation.
            Defaults to `None`.
        p: the nucleus sampling probability to use for the generation. Defaults to
            `None`.
        seed: the seed to use for the generation. Defaults to `None`.
        stop_sequences: a list of sequences to use as stopping criteria for the generation.
            Defaults to `None`.
        frequency_penalty: the frequency penalty to use for the generation. Defaults
            to `None`.
        presence_penalty: the presence penalty to use for the generation. Defaults to
            `None`.
        raw_prompting: a flag to use raw prompting for the generation. Defaults to
            `None`.

    Returns:
        The generated response from the Cohere API model.
    """
    structured_output = None
    if isinstance(input, tuple):
        input, structured_output = input
        result = self._prepare_structured_output(
            structured_output=structured_output,  # type: ignore
            client=self._aclient,
            framework="cohere",
        )
        self._aclient = result.get("client")  # type: ignore

    if structured_output is None and self.structured_output is not None:
        structured_output = self.structured_output

    system, chat_history, message = self._format_chat_to_cohere(input)

    kwargs = {
        "message": message,
        "model": self.model,
        "preamble": system,
        "chat_history": chat_history,
        "temperature": temperature,
        "max_tokens": max_tokens,
        "k": k,
        "p": p,
        "seed": seed,
        "stop_sequences": stop_sequences,
        "frequency_penalty": frequency_penalty,
        "presence_penalty": presence_penalty,
        "raw_prompting": raw_prompting,
    }
    if structured_output:
        kwargs = self._prepare_kwargs(kwargs, structured_output)  # type: ignore

    response: Union["Message", "BaseModel"] = await self._aclient.chat(**kwargs)  # type: ignore

    if structured_output:
        return prepare_output(
            [response.model_dump_json()],
            **self._get_llm_statistics(
                input, orjson.dumps(response.model_dump_json()).decode("utf-8")
            ),  # type: ignore
        )

    if (text := response.text) == "":
        self._logger.warning(  # type: ignore
            f"Received no response using Cohere client (model: '{self.model}')."
            f" Finish reason was: {response.finish_reason}"
        )
        return prepare_output(
            [None],
            **self._get_llm_statistics(input, ""),
        )

    return prepare_output(
        [text],
        **self._get_llm_statistics(input, text),
    )

GroqLLM

基类: AsyncLLM

Groq API 实现,使用异步客户端进行并发文本生成。

属性

名称 类型 描述
model str

来自 Groq API 的模型名称,用于生成。

base_url Optional[RuntimeParameter[str]]

用于 Groq API 请求的基本 URL。默认为 "https://api.groq.com"

api_key Optional[RuntimeParameter[SecretStr]]

用于验证对 Groq API 请求的 API 密钥。默认为 GROQ_API_KEY 环境变量的值。

max_retries RuntimeParameter[int]

在失败之前重试 API 请求的最大次数。默认为 2

timeout RuntimeParameter[int]

等待来自 API 响应的最长秒数。默认为 120

structured_output Optional[RuntimeParameter[InstructorStructuredOutputType]]

一个字典,包含使用 instructor 的结构化输出配置。您可以在 distilabel.steps.tasks.structured_outputs.instructorInstructorStructuredOutputType 中查看字典结构。

_api_key_env_var str

用于 API 密钥的环境变量名称。

_aclient Optional[AsyncGroq]

来自 groq 包的 AsyncGroq 客户端。

运行时参数
  • base_url:用于 Groq API 请求的基本 URL。默认为 "https://api.groq.com"
  • api_key:用于验证对 Groq API 请求的 API 密钥。默认为 GROQ_API_KEY 环境变量的值。
  • max_retries:在失败之前重试 API 请求的最大次数。默认为 2
  • timeout:等待来自 API 响应的最长秒数。默认为 120

示例

生成文本

from distilabel.models.llms import GroqLLM

llm = GroqLLM(model="llama3-70b-8192")

llm.load()

# Call the model
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

Generate structured data:

```python
from pydantic import BaseModel
from distilabel.models.llms import GroqLLM

class User(BaseModel):
    name: str
    last_name: str
    id: int

llm = GroqLLM(
    model="llama3-70b-8192",
    api_key="api.key",
    structured_output={"schema": User}
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
源代码位于 src/distilabel/models/llms/groq.py
class GroqLLM(AsyncLLM):
    """Groq API implementation using the async client for concurrent text generation.

    Attributes:
        model: the name of the model from the Groq API to use for the generation.
        base_url: the base URL to use for the Groq API requests. Defaults to
            `"https://api.groq.com"`.
        api_key: the API key to authenticate the requests to the Groq API. Defaults to
            the value of the `GROQ_API_KEY` environment variable.
        max_retries: the maximum number of times to retry the request to the API before
            failing. Defaults to `2`.
        timeout: the maximum time in seconds to wait for a response from the API. Defaults
            to `120`.
        structured_output: a dictionary containing the structured output configuration configuration
            using `instructor`. You can take a look at the dictionary structure in
            `InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`.
        _api_key_env_var: the name of the environment variable to use for the API key.
        _aclient: the `AsyncGroq` client from the `groq` package.

    Runtime parameters:
        - `base_url`: the base URL to use for the Groq API requests. Defaults to
            `"https://api.groq.com"`.
        - `api_key`: the API key to authenticate the requests to the Groq API. Defaults to
            the value of the `GROQ_API_KEY` environment variable.
        - `max_retries`: the maximum number of times to retry the request to the API before
            failing. Defaults to `2`.
        - `timeout`: the maximum time in seconds to wait for a response from the API. Defaults
            to `120`.

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import GroqLLM

        llm = GroqLLM(model="llama3-70b-8192")

        llm.load()

        # Call the model
        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

        Generate structured data:

        ```python
        from pydantic import BaseModel
        from distilabel.models.llms import GroqLLM

        class User(BaseModel):
            name: str
            last_name: str
            id: int

        llm = GroqLLM(
            model="llama3-70b-8192",
            api_key="api.key",
            structured_output={"schema": User}
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
        ```
    """

    model: str

    base_url: Optional[RuntimeParameter[str]] = Field(
        default_factory=lambda: os.getenv(
            _GROQ_API_BASE_URL_ENV_VAR_NAME, "https://api.groq.com"
        ),
        description="The base URL to use for the Groq API requests.",
    )
    api_key: Optional[RuntimeParameter[SecretStr]] = Field(
        default_factory=lambda: os.getenv(_GROQ_API_KEY_ENV_VAR_NAME),
        description="The API key to authenticate the requests to the Groq API.",
    )
    max_retries: RuntimeParameter[int] = Field(
        default=2,
        description="The maximum number of times to retry the request to the API before"
        " failing.",
    )
    timeout: RuntimeParameter[int] = Field(
        default=120,
        description="The maximum time in seconds to wait for a response from the API.",
    )
    structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = (
        Field(
            default=None,
            description="The structured output format to use across all the generations.",
        )
    )

    _num_generations_param_supported = False

    _api_key_env_var: str = PrivateAttr(_GROQ_API_KEY_ENV_VAR_NAME)
    _aclient: Optional["AsyncGroq"] = PrivateAttr(...)

    def load(self) -> None:
        """Loads the `AsyncGroq` client to benefit from async requests."""
        super().load()

        try:
            from groq import AsyncGroq
        except ImportError as ie:
            raise ImportError(
                "Groq Python client is not installed. Please install it using"
                ' `pip install "distilabel[groq]"`.'
            ) from ie

        if self.api_key is None:
            raise ValueError(
                f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
                f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
            )

        self._aclient = AsyncGroq(
            base_url=self.base_url,
            api_key=self.api_key.get_secret_value(),
            max_retries=self.max_retries,  # type: ignore
            timeout=self.timeout,
        )

        if self.structured_output:
            result = self._prepare_structured_output(
                structured_output=self.structured_output,
                client=self._aclient,
                framework="groq",
            )
            self._aclient = result.get("client")  # type: ignore
            if structured_output := result.get("structured_output"):
                self.structured_output = structured_output

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: FormattedInput,
        seed: Optional[int] = None,
        max_new_tokens: int = 128,
        temperature: float = 1.0,
        top_p: float = 1.0,
        stop: Optional[str] = None,
    ) -> "GenerateOutput":
        """Generates `num_generations` responses for the given input using the Groq async
        client.

        Args:
            input: a single input in chat format to generate responses for.
            seed: the seed to use for the generation. Defaults to `None`.
            max_new_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `128`.
            temperature: the temperature to use for the generation. Defaults to `0.1`.
            top_p: the top-p value to use for the generation. Defaults to `1.0`.
            stop: the stop sequence to use for the generation. Defaults to `None`.

        Returns:
            A list of lists of strings containing the generated responses for each input.

        References:
            - https://console.groq.com/docs/text-chat
        """
        structured_output = None
        if isinstance(input, tuple):
            input, structured_output = input
            result = self._prepare_structured_output(
                structured_output=structured_output,
                client=self._aclient,
                framework="groq",
            )
            self._aclient = result.get("client")

        if structured_output is None and self.structured_output is not None:
            structured_output = self.structured_output

        kwargs = {
            "messages": input,  # type: ignore
            "model": self.model,
            "seed": seed,
            "temperature": temperature,
            "max_tokens": max_new_tokens,
            "top_p": top_p,
            "stream": False,
            "stop": stop,
        }
        if structured_output:
            kwargs = self._prepare_kwargs(kwargs, structured_output)

        completion = await self._aclient.chat.completions.create(**kwargs)  # type: ignore
        if structured_output:
            return prepare_output(
                [completion.model_dump_json()],
                **self._get_llm_statistics(completion._raw_response),
            )

        generations = []
        for choice in completion.choices:
            if (content := choice.message.content) is None:
                self._logger.warning(  # type: ignore
                    f"Received no response using the Groq client (model: '{self.model}')."
                    f" Finish reason was: {choice.finish_reason}"
                )
            generations.append(content)
        return prepare_output(generations, **self._get_llm_statistics(completion))

    @staticmethod
    def _get_llm_statistics(completion: "ChatCompletion") -> "LLMStatistics":
        return {
            "input_tokens": [completion.usage.prompt_tokens if completion else 0],
            "output_tokens": [completion.usage.completion_tokens if completion else 0],
        }
model_name property

返回用于 LLM 的模型名称。

load()

加载 AsyncGroq 客户端以从异步请求中受益。

源代码位于 src/distilabel/models/llms/groq.py
def load(self) -> None:
    """Loads the `AsyncGroq` client to benefit from async requests."""
    super().load()

    try:
        from groq import AsyncGroq
    except ImportError as ie:
        raise ImportError(
            "Groq Python client is not installed. Please install it using"
            ' `pip install "distilabel[groq]"`.'
        ) from ie

    if self.api_key is None:
        raise ValueError(
            f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
            f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
        )

    self._aclient = AsyncGroq(
        base_url=self.base_url,
        api_key=self.api_key.get_secret_value(),
        max_retries=self.max_retries,  # type: ignore
        timeout=self.timeout,
    )

    if self.structured_output:
        result = self._prepare_structured_output(
            structured_output=self.structured_output,
            client=self._aclient,
            framework="groq",
        )
        self._aclient = result.get("client")  # type: ignore
        if structured_output := result.get("structured_output"):
            self.structured_output = structured_output
agenerate(input, seed=None, max_new_tokens=128, temperature=1.0, top_p=1.0, stop=None) async

使用 Groq 异步客户端为给定输入生成 num_generations 响应。

参数

名称 类型 描述 默认值
input FormattedInput

以聊天格式的单个输入,用于生成响应。

必需
seed Optional[int]

用于生成的种子。默认为 None

None
max_new_tokens int

模型将生成的最大新 token 数。默认为 128

128
temperature float

用于生成的温度。默认为 0.1

1.0
top_p float

用于生成的 top-p 值。默认为 1.0

1.0
stop Optional[str]

用于生成的停止序列。默认为 None

None

返回

类型 描述
GenerateOutput

包含每个输入的生成响应的字符串列表的列表。

参考
  • https://console.groq.com/docs/text-chat
源代码位于 src/distilabel/models/llms/groq.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: FormattedInput,
    seed: Optional[int] = None,
    max_new_tokens: int = 128,
    temperature: float = 1.0,
    top_p: float = 1.0,
    stop: Optional[str] = None,
) -> "GenerateOutput":
    """Generates `num_generations` responses for the given input using the Groq async
    client.

    Args:
        input: a single input in chat format to generate responses for.
        seed: the seed to use for the generation. Defaults to `None`.
        max_new_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `128`.
        temperature: the temperature to use for the generation. Defaults to `0.1`.
        top_p: the top-p value to use for the generation. Defaults to `1.0`.
        stop: the stop sequence to use for the generation. Defaults to `None`.

    Returns:
        A list of lists of strings containing the generated responses for each input.

    References:
        - https://console.groq.com/docs/text-chat
    """
    structured_output = None
    if isinstance(input, tuple):
        input, structured_output = input
        result = self._prepare_structured_output(
            structured_output=structured_output,
            client=self._aclient,
            framework="groq",
        )
        self._aclient = result.get("client")

    if structured_output is None and self.structured_output is not None:
        structured_output = self.structured_output

    kwargs = {
        "messages": input,  # type: ignore
        "model": self.model,
        "seed": seed,
        "temperature": temperature,
        "max_tokens": max_new_tokens,
        "top_p": top_p,
        "stream": False,
        "stop": stop,
    }
    if structured_output:
        kwargs = self._prepare_kwargs(kwargs, structured_output)

    completion = await self._aclient.chat.completions.create(**kwargs)  # type: ignore
    if structured_output:
        return prepare_output(
            [completion.model_dump_json()],
            **self._get_llm_statistics(completion._raw_response),
        )

    generations = []
    for choice in completion.choices:
        if (content := choice.message.content) is None:
            self._logger.warning(  # type: ignore
                f"Received no response using the Groq client (model: '{self.model}')."
                f" Finish reason was: {choice.finish_reason}"
            )
        generations.append(content)
    return prepare_output(generations, **self._get_llm_statistics(completion))

InferenceEndpointsLLM

基类: InferenceEndpointsBaseClientAsyncLLMMagpieChatTemplateMixin

InferenceEndpoints LLM 实现,运行异步 API 客户端。

此 LLM 将在内部使用 huggingface_hub.AsyncInferenceClient

属性

名称 类型 描述
model_id Optional[str]

用于 LLM 的模型 ID,可在 Hugging Face Hub 中找到,这将用于解析无服务器 Inference Endpoints API 请求的基本 URL。默认为 None

endpoint_name 可选[RuntimeParameter[str]]

用于 LLM 的 Inference Endpoint 名称。默认为 None

endpoint_namespace 可选[RuntimeParameter[str]]

用于 LLM 的 Inference Endpoint 命名空间。默认为 None

base_url 可选[RuntimeParameter[str]]

用于 Inference Endpoints API 请求的基本 URL。

api_key 可选[RuntimeParameter[SecretStr]]

用于验证对 Inference Endpoints API 请求的 API 密钥。

tokenizer_id Optional[str]

用于 LLM 的 tokenizer ID,可在 Hugging Face Hub 中找到。默认为 None,但建议定义一个以正确格式化提示。

model_display_name Optional[str]

用于 LLM 的模型显示名称。默认为 None

use_magpie_template bool

用于启用/禁用应用 Magpie 预查询模板的标志。默认为 False

magpie_pre_query_template Union[MagpieAvailablePreQueryTemplates, str, None]

要应用于提示或发送到 LLM 以生成指令或后续用户消息的预查询模板。有效值为 "llama3"、"qwen2" 或提供的另一个预查询模板。默认为 None

structured_output Optional[RuntimeParameter[StructuredOutputType]]

一个字典,包含结构化输出配置;如果需要更细粒度的控制,则包含 OutlinesStructuredOutput 的实例。默认为 None。

图标

:hugging

示例

免费的无服务器推理 API,设置使用此 API 的 Task 的 input_batch_size 以避免模型过载

from distilabel.models.llms.huggingface import InferenceEndpointsLLM

llm = InferenceEndpointsLLM(
    model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

专用推理端点

from distilabel.models.llms.huggingface import InferenceEndpointsLLM

llm = InferenceEndpointsLLM(
    endpoint_name="<ENDPOINT_NAME>",
    api_key="<HF_API_KEY>",
    endpoint_namespace="<USER|ORG>",
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

专用推理端点或 TGI

from distilabel.models.llms.huggingface import InferenceEndpointsLLM

llm = InferenceEndpointsLLM(
    api_key="<HF_API_KEY>",
    base_url="<BASE_URL>",
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

生成结构化数据

from pydantic import BaseModel
from distilabel.models.llms import InferenceEndpointsLLM

class User(BaseModel):
    name: str
    last_name: str
    id: int

llm = InferenceEndpointsLLM(
    model_id="meta-llama/Meta-Llama-3-70B-Instruct",
    tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
    api_key="api.key",
    structured_output={"format": "json", "schema": User.model_json_schema()}
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the Tour De France"}]])
源代码位于 src/distilabel/models/llms/huggingface/inference_endpoints.py
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
class InferenceEndpointsLLM(
    InferenceEndpointsBaseClient, AsyncLLM, MagpieChatTemplateMixin
):
    """InferenceEndpoints LLM implementation running the async API client.

    This LLM will internally use `huggingface_hub.AsyncInferenceClient`.

    Attributes:
        model_id: the model ID to use for the LLM as available in the Hugging Face Hub, which
            will be used to resolve the base URL for the serverless Inference Endpoints API requests.
            Defaults to `None`.
        endpoint_name: the name of the Inference Endpoint to use for the LLM. Defaults to `None`.
        endpoint_namespace: the namespace of the Inference Endpoint to use for the LLM. Defaults to `None`.
        base_url: the base URL to use for the Inference Endpoints API requests.
        api_key: the API key to authenticate the requests to the Inference Endpoints API.
        tokenizer_id: the tokenizer ID to use for the LLM as available in the Hugging Face Hub.
            Defaults to `None`, but defining one is recommended to properly format the prompt.
        model_display_name: the model display name to use for the LLM. Defaults to `None`.
        use_magpie_template: a flag used to enable/disable applying the Magpie pre-query
            template. Defaults to `False`.
        magpie_pre_query_template: the pre-query template to be applied to the prompt or
            sent to the LLM to generate an instruction or a follow up user message. Valid
            values are "llama3", "qwen2" or another pre-query template provided. Defaults
            to `None`.
        structured_output: a dictionary containing the structured output configuration or
            if more fine-grained control is needed, an instance of `OutlinesStructuredOutput`.
            Defaults to None.

    Icon:
        `:hugging:`

    Examples:
        Free serverless Inference API, set the input_batch_size of the Task that uses this to avoid Model is overloaded:

        ```python
        from distilabel.models.llms.huggingface import InferenceEndpointsLLM

        llm = InferenceEndpointsLLM(
            model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```

        Dedicated Inference Endpoints:

        ```python
        from distilabel.models.llms.huggingface import InferenceEndpointsLLM

        llm = InferenceEndpointsLLM(
            endpoint_name="<ENDPOINT_NAME>",
            api_key="<HF_API_KEY>",
            endpoint_namespace="<USER|ORG>",
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```

        Dedicated Inference Endpoints or TGI:

        ```python
        from distilabel.models.llms.huggingface import InferenceEndpointsLLM

        llm = InferenceEndpointsLLM(
            api_key="<HF_API_KEY>",
            base_url="<BASE_URL>",
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```

        Generate structured data:

        ```python
        from pydantic import BaseModel
        from distilabel.models.llms import InferenceEndpointsLLM

        class User(BaseModel):
            name: str
            last_name: str
            id: int

        llm = InferenceEndpointsLLM(
            model_id="meta-llama/Meta-Llama-3-70B-Instruct",
            tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
            api_key="api.key",
            structured_output={"format": "json", "schema": User.model_json_schema()}
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the Tour De France"}]])
        ```
    """

    def load(self) -> None:
        # Sets the logger and calls the load method of the BaseClient
        self._num_generations_param_supported = False
        AsyncLLM.load(self)
        InferenceEndpointsBaseClient.load(self)

    @model_validator(mode="after")  # type: ignore
    def only_one_of_model_id_endpoint_name_or_base_url_provided(
        self,
    ) -> "InferenceEndpointsLLM":
        """Validates that only one of `model_id` or `endpoint_name` is provided; and if `base_url` is also
        provided, a warning will be shown informing the user that the provided `base_url` will be ignored in
        favour of the dynamically calculated one.."""

        if self.base_url and (self.model_id or self.endpoint_name):
            self._logger.warning(  # type: ignore
                f"Since the `base_url={self.base_url}` is available and either one of `model_id`"
                " or `endpoint_name` is also provided, the `base_url` will either be ignored"
                " or overwritten with the one generated from either of those args, for serverless"
                " or dedicated inference endpoints, respectively."
            )

        if self.use_magpie_template and self.tokenizer_id is None:
            raise ValueError(
                "`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please,"
                " set a `tokenizer_id` and try again."
            )

        if (
            self.model_id
            and self.tokenizer_id is None
            and self.structured_output is not None
        ):
            self.tokenizer_id = self.model_id

        if self.base_url and not (self.model_id or self.endpoint_name):
            return self

        if self.model_id and not self.endpoint_name:
            return self

        if self.endpoint_name and not self.model_id:
            return self

        raise ValidationError(
            f"Only one of `model_id` or `endpoint_name` must be provided. If `base_url` is"
            f" provided too, it will be overwritten instead. Found `model_id`={self.model_id},"
            f" `endpoint_name`={self.endpoint_name}, and `base_url`={self.base_url}."
        )

    def prepare_input(self, input: "StandardInput") -> str:
        """Prepares the input (applying the chat template and tokenization) for the provided
        input.

        Args:
            input: the input list containing chat items.

        Returns:
            The prompt to send to the LLM.
        """
        prompt: str = (
            self._tokenizer.apply_chat_template(  # type: ignore
                conversation=input,  # type: ignore
                tokenize=False,
                add_generation_prompt=True,
            )
            if input
            else ""
        )
        return super().apply_magpie_pre_query_template(prompt, input)

    def _get_structured_output(
        self, input: FormattedInput
    ) -> Tuple["StandardInput", Union[Dict[str, Any], None]]:
        """Gets the structured output (if any) for the given input.

        Args:
            input: a single input in chat format to generate responses for.

        Returns:
            The input and the structured output that will be passed as `grammar` to the
            inference endpoint or `None` if not required.
        """
        structured_output = None

        # Specific structured output per input
        if isinstance(input, tuple):
            input, structured_output = input
            structured_output = {
                "type": structured_output["format"],  # type: ignore
                "value": structured_output["schema"],  # type: ignore
            }

        # Same structured output for all the inputs
        if structured_output is None and self.structured_output is not None:
            try:
                structured_output = {
                    "type": self.structured_output["format"],  # type: ignore
                    "value": self.structured_output["schema"],  # type: ignore
                }
            except KeyError as e:
                raise ValueError(
                    "To use the structured output you have to inform the `format` and `schema` in "
                    "the `structured_output` attribute."
                ) from e

        if structured_output:
            if isinstance(structured_output["value"], ModelMetaclass):
                structured_output["value"] = structured_output[
                    "value"
                ].model_json_schema()

        return input, structured_output

    async def _generate_with_text_generation(
        self,
        input: str,
        max_new_tokens: int = 128,
        repetition_penalty: Optional[float] = None,
        frequency_penalty: Optional[float] = None,
        temperature: float = 1.0,
        do_sample: bool = False,
        top_n_tokens: Optional[int] = None,
        top_p: Optional[float] = None,
        top_k: Optional[int] = None,
        typical_p: Optional[float] = None,
        stop_sequences: Union[List[str], None] = None,
        return_full_text: bool = False,
        seed: Optional[int] = None,
        watermark: bool = False,
        structured_output: Union[Dict[str, Any], None] = None,
    ) -> GenerateOutput:
        generation: Union["TextGenerationOutput", None] = None
        try:
            generation = await self._aclient.text_generation(  # type: ignore
                prompt=input,
                max_new_tokens=max_new_tokens,
                do_sample=do_sample,
                typical_p=typical_p,
                repetition_penalty=repetition_penalty,
                frequency_penalty=frequency_penalty,
                temperature=temperature,
                top_n_tokens=top_n_tokens,
                top_p=top_p,
                top_k=top_k,
                stop_sequences=stop_sequences,
                return_full_text=return_full_text,
                # NOTE: here to ensure that the cache is not used and a different response is
                # generated every time
                seed=seed or random.randint(0, sys.maxsize),
                watermark=watermark,
                grammar=structured_output,  # type: ignore
                details=True,
            )
        except Exception as e:
            self._logger.warning(  # type: ignore
                f"⚠️ Received no response using Inference Client (model: '{self.model_name}')."
                f" Finish reason was: {e}"
            )
        return prepare_output(
            generations=[generation.generated_text] if generation else [None],
            input_tokens=[
                compute_tokens(input, self._tokenizer.encode) if self._tokenizer else -1
            ],
            output_tokens=[
                generation.details.generated_tokens
                if generation and generation.details
                else 0
            ],
            logprobs=self._get_logprobs_from_text_generation(generation)
            if generation
            else None,  # type: ignore
        )

    def _get_logprobs_from_text_generation(
        self, generation: "TextGenerationOutput"
    ) -> Union[List[List[List["Logprob"]]], None]:
        if generation.details is None or generation.details.top_tokens is None:
            return None

        return [
            [
                [
                    {"token": top_logprob["text"], "logprob": top_logprob["logprob"]}
                    for top_logprob in token_logprobs
                ]
                for token_logprobs in generation.details.top_tokens
            ]
        ]

    async def _generate_with_chat_completion(
        self,
        input: "StandardInput",
        max_new_tokens: int = 128,
        frequency_penalty: Optional[float] = None,
        logit_bias: Optional[List[float]] = None,
        logprobs: bool = False,
        presence_penalty: Optional[float] = None,
        seed: Optional[int] = None,
        stop_sequences: Optional[List[str]] = None,
        temperature: float = 1.0,
        tool_choice: Optional[Union[Dict[str, str], Literal["auto"]]] = None,
        tool_prompt: Optional[str] = None,
        tools: Optional[List[Dict[str, Any]]] = None,
        top_logprobs: Optional[PositiveInt] = None,
        top_p: Optional[float] = None,
    ) -> GenerateOutput:
        message = None
        completion: Union["ChatCompletionOutput", None] = None
        output_logprobs = None
        try:
            completion = await self._aclient.chat_completion(  # type: ignore
                messages=input,  # type: ignore
                max_tokens=max_new_tokens,
                frequency_penalty=frequency_penalty,
                logit_bias=logit_bias,
                logprobs=logprobs,
                presence_penalty=presence_penalty,
                # NOTE: here to ensure that the cache is not used and a different response is
                # generated every time
                seed=seed or random.randint(0, sys.maxsize),
                stop=stop_sequences,
                temperature=temperature,
                tool_choice=tool_choice,  # type: ignore
                tool_prompt=tool_prompt,
                tools=tools,  # type: ignore
                top_logprobs=top_logprobs,
                top_p=top_p,
            )
            choice = completion.choices[0]  # type: ignore
            if (message := choice.message.content) is None:
                self._logger.warning(  # type: ignore
                    f"⚠️ Received no response using Inference Client (model: '{self.model_name}')."
                    f" Finish reason was: {choice.finish_reason}"
                )
            if choice_logprobs := self._get_logprobs_from_choice(choice):
                output_logprobs = [choice_logprobs]
        except Exception as e:
            self._logger.warning(  # type: ignore
                f"⚠️ Received no response using Inference Client (model: '{self.model_name}')."
                f" Finish reason was: {e}"
            )
        return prepare_output(
            generations=[message],
            input_tokens=[completion.usage.prompt_tokens] if completion else None,
            output_tokens=[completion.usage.completion_tokens] if completion else None,
            logprobs=output_logprobs,
        )

    def _get_logprobs_from_choice(
        self, choice: "ChatCompletionOutputComplete"
    ) -> Union[List[List["Logprob"]], None]:
        if choice.logprobs is None:
            return None

        return [
            [
                {"token": top_logprob.token, "logprob": top_logprob.logprob}
                for top_logprob in token_logprobs.top_logprobs
            ]
            for token_logprobs in choice.logprobs.content
        ]

    def _check_stop_sequences(
        self,
        stop_sequences: Optional[Union[str, List[str]]] = None,
    ) -> Union[List[str], None]:
        """Checks that no more than 4 stop sequences are provided.

        Args:
            stop_sequences: the stop sequences to be checked.

        Returns:
            The stop sequences.
        """
        if stop_sequences is not None:
            if isinstance(stop_sequences, str):
                stop_sequences = [stop_sequences]
            if len(stop_sequences) > 4:
                warnings.warn(
                    "Only up to 4 stop sequences are allowed, so keeping the first 4 items only.",
                    UserWarning,
                    stacklevel=2,
                )
                stop_sequences = stop_sequences[:4]
        return stop_sequences

    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: FormattedInput,
        max_new_tokens: int = 128,
        frequency_penalty: Optional[Annotated[float, Field(ge=-2.0, le=2.0)]] = None,
        logit_bias: Optional[List[float]] = None,
        logprobs: bool = False,
        presence_penalty: Optional[Annotated[float, Field(ge=-2.0, le=2.0)]] = None,
        seed: Optional[int] = None,
        stop_sequences: Optional[List[str]] = None,
        temperature: float = 1.0,
        tool_choice: Optional[Union[Dict[str, str], Literal["auto"]]] = None,
        tool_prompt: Optional[str] = None,
        tools: Optional[List[Dict[str, Any]]] = None,
        top_logprobs: Optional[PositiveInt] = None,
        top_n_tokens: Optional[PositiveInt] = None,
        top_p: Optional[float] = None,
        do_sample: bool = False,
        repetition_penalty: Optional[float] = None,
        return_full_text: bool = False,
        top_k: Optional[int] = None,
        typical_p: Optional[float] = None,
        watermark: bool = False,
        num_generations: int = 1,
    ) -> GenerateOutput:
        """Generates completions for the given input using the async client. This method
        uses two methods of the `huggingface_hub.AsyncClient`: `chat_completion` and `text_generation`.
        `chat_completion` method will be used only if no `tokenizer_id` has been specified.
        Some arguments of this function are specific to the `text_generation` method, while
        some others are specific to the `chat_completion` method.

        Args:
            input: a single input in chat format to generate responses for.
            max_new_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `128`.
            frequency_penalty: a value between `-2.0` and `2.0`. Positive values penalize
                new tokens based on their existing frequency in the text so far, decreasing
                model's likelihood to repeat the same line verbatim. Defauls to `None`.
            logit_bias: modify the likelihood of specified tokens appearing in the completion.
                This argument is exclusive to the `chat_completion` method and will be used
                only if `tokenizer_id` is `None`.
                Defaults to `None`.
            logprobs: whether to return the log probabilities or not. This argument is exclusive
                to the `chat_completion` method and will be used only if `tokenizer_id`
                is `None`. Defaults to `False`.
            presence_penalty: a value between `-2.0` and `2.0`. Positive values penalize
                new tokens based on whether they appear in the text so far, increasing the
                model likelihood to talk about new topics. This argument is exclusive to
                the `chat_completion` method and will be used only if `tokenizer_id` is
                `None`. Defauls to `None`.
            seed: the seed to use for the generation. Defaults to `None`.
            stop_sequences: either a single string or a list of strings containing the sequences
                to stop the generation at. Defaults to `None`, but will be set to the
                `tokenizer.eos_token` if available.
            temperature: the temperature to use for the generation. Defaults to `1.0`.
            tool_choice: the name of the tool the model should call. It can be a dictionary
                like `{"function_name": "my_tool"}` or "auto". If not provided, then the
                model won't use any tool. This argument is exclusive to the `chat_completion`
                method and will be used only if `tokenizer_id` is `None`. Defaults to `None`.
            tool_prompt: A prompt to be appended before the tools. This argument is exclusive
                to the `chat_completion` method and will be used only if `tokenizer_id`
                is `None`. Defauls to `None`.
            tools: a list of tools definitions that the LLM can use.
                This argument is exclusive to the `chat_completion` method and will be used
                only if `tokenizer_id` is `None`. Defaults to `None`.
            top_logprobs: the number of top log probabilities to return per output token
                generated. This argument is exclusive to the `chat_completion` method and
                will be used only if `tokenizer_id` is `None`. Defaults to `None`.
            top_n_tokens: the number of top log probabilities to return per output token
                generated. This argument is exclusive of the `text_generation` method and
                will be only used if `tokenizer_id` is not `None`. Defaults to `None`.
            top_p: the top-p value to use for the generation. Defaults to `1.0`.
            do_sample: whether to use sampling for the generation. This argument is exclusive
                of the `text_generation` method and will be only used if `tokenizer_id` is not
                `None`. Defaults to `False`.
            repetition_penalty: the repetition penalty to use for the generation. This argument
                is exclusive of the `text_generation` method and will be only used if `tokenizer_id`
                is not `None`. Defaults to `None`.
            return_full_text: whether to return the full text of the completion or just
                the generated text. Defaults to `False`, meaning that only the generated
                text will be returned. This argument is exclusive of the `text_generation`
                method and will be only used if `tokenizer_id` is not `None`.
            top_k: the top-k value to use for the generation. This argument is exclusive
                of the `text_generation` method and will be only used if `tokenizer_id`
                is not `None`. Defaults to `0.8`, since neither `0.0` nor `1.0` are valid
                values in TGI.
            typical_p: the typical-p value to use for the generation. This argument is exclusive
                of the `text_generation` method and will be only used if `tokenizer_id`
                is not `None`. Defaults to `None`.
            watermark: whether to add the watermark to the generated text. This argument
                is exclusive of the `text_generation` method and will be only used if `tokenizer_id`
                is not `None`. Defaults to `None`.
            num_generations: the number of generations to generate. Defaults to `1`. It's here to ensure
                the validation succeds.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """
        stop_sequences = self._check_stop_sequences(stop_sequences)

        if isinstance(input, str) or self.tokenizer_id is not None:
            structured_output = None
            if not isinstance(input, str):
                input, structured_output = self._get_structured_output(input)
                input = self.prepare_input(input)

            return await self._generate_with_text_generation(
                input=input,
                max_new_tokens=max_new_tokens,
                do_sample=do_sample,
                typical_p=typical_p,
                repetition_penalty=repetition_penalty,
                frequency_penalty=frequency_penalty,
                temperature=temperature,
                top_n_tokens=top_n_tokens,
                top_p=top_p,
                top_k=top_k,
                stop_sequences=stop_sequences,
                return_full_text=return_full_text,
                seed=seed,
                watermark=watermark,
                structured_output=structured_output,
            )

        return await self._generate_with_chat_completion(
            input=input,  # type: ignore
            max_new_tokens=max_new_tokens,
            frequency_penalty=frequency_penalty,
            logit_bias=logit_bias,
            logprobs=logprobs,
            presence_penalty=presence_penalty,
            seed=seed,
            stop_sequences=stop_sequences,
            temperature=temperature,
            tool_choice=tool_choice,
            tool_prompt=tool_prompt,
            tools=tools,
            top_logprobs=top_logprobs,
            top_p=top_p,
        )
only_one_of_model_id_endpoint_name_or_base_url_provided()

验证是否仅提供了 model_idendpoint_name 中的一个;并且如果也提供了 base_url,则会显示警告,告知用户提供的 base_url 将被忽略,而优先使用动态计算的 URL。

源代码位于 src/distilabel/models/llms/huggingface/inference_endpoints.py
@model_validator(mode="after")  # type: ignore
def only_one_of_model_id_endpoint_name_or_base_url_provided(
    self,
) -> "InferenceEndpointsLLM":
    """Validates that only one of `model_id` or `endpoint_name` is provided; and if `base_url` is also
    provided, a warning will be shown informing the user that the provided `base_url` will be ignored in
    favour of the dynamically calculated one.."""

    if self.base_url and (self.model_id or self.endpoint_name):
        self._logger.warning(  # type: ignore
            f"Since the `base_url={self.base_url}` is available and either one of `model_id`"
            " or `endpoint_name` is also provided, the `base_url` will either be ignored"
            " or overwritten with the one generated from either of those args, for serverless"
            " or dedicated inference endpoints, respectively."
        )

    if self.use_magpie_template and self.tokenizer_id is None:
        raise ValueError(
            "`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please,"
            " set a `tokenizer_id` and try again."
        )

    if (
        self.model_id
        and self.tokenizer_id is None
        and self.structured_output is not None
    ):
        self.tokenizer_id = self.model_id

    if self.base_url and not (self.model_id or self.endpoint_name):
        return self

    if self.model_id and not self.endpoint_name:
        return self

    if self.endpoint_name and not self.model_id:
        return self

    raise ValidationError(
        f"Only one of `model_id` or `endpoint_name` must be provided. If `base_url` is"
        f" provided too, it will be overwritten instead. Found `model_id`={self.model_id},"
        f" `endpoint_name`={self.endpoint_name}, and `base_url`={self.base_url}."
    )
prepare_input(input)

为提供的输入准备输入(应用聊天模板和分词)。

参数

名称 类型 描述 默认值
input StandardInput

包含聊天项的输入列表。

必需

返回

类型 描述
str

要发送给 LLM 的提示。

源代码位于 src/distilabel/models/llms/huggingface/inference_endpoints.py
def prepare_input(self, input: "StandardInput") -> str:
    """Prepares the input (applying the chat template and tokenization) for the provided
    input.

    Args:
        input: the input list containing chat items.

    Returns:
        The prompt to send to the LLM.
    """
    prompt: str = (
        self._tokenizer.apply_chat_template(  # type: ignore
            conversation=input,  # type: ignore
            tokenize=False,
            add_generation_prompt=True,
        )
        if input
        else ""
    )
    return super().apply_magpie_pre_query_template(prompt, input)
_get_structured_output(input)

获取给定输入的结构化输出(如果有)。

参数

名称 类型 描述 默认值
input FormattedInput

以聊天格式的单个输入,用于生成响应。

必需

返回

类型 描述
StandardInput

输入和结构化输出,将作为 grammar 传递给

Union[Dict[str, Any], None]

推理端点,如果不需要则为 None

源代码位于 src/distilabel/models/llms/huggingface/inference_endpoints.py
def _get_structured_output(
    self, input: FormattedInput
) -> Tuple["StandardInput", Union[Dict[str, Any], None]]:
    """Gets the structured output (if any) for the given input.

    Args:
        input: a single input in chat format to generate responses for.

    Returns:
        The input and the structured output that will be passed as `grammar` to the
        inference endpoint or `None` if not required.
    """
    structured_output = None

    # Specific structured output per input
    if isinstance(input, tuple):
        input, structured_output = input
        structured_output = {
            "type": structured_output["format"],  # type: ignore
            "value": structured_output["schema"],  # type: ignore
        }

    # Same structured output for all the inputs
    if structured_output is None and self.structured_output is not None:
        try:
            structured_output = {
                "type": self.structured_output["format"],  # type: ignore
                "value": self.structured_output["schema"],  # type: ignore
            }
        except KeyError as e:
            raise ValueError(
                "To use the structured output you have to inform the `format` and `schema` in "
                "the `structured_output` attribute."
            ) from e

    if structured_output:
        if isinstance(structured_output["value"], ModelMetaclass):
            structured_output["value"] = structured_output[
                "value"
            ].model_json_schema()

    return input, structured_output
_check_stop_sequences(stop_sequences=None)

检查提供的停止序列是否不超过 4 个。

参数

名称 类型 描述 默认值
stop_sequences Optional[Union[str, List[str]]]

要检查的停止序列。

None

返回

类型 描述
Union[List[str], None]

停止序列。

源代码位于 src/distilabel/models/llms/huggingface/inference_endpoints.py
def _check_stop_sequences(
    self,
    stop_sequences: Optional[Union[str, List[str]]] = None,
) -> Union[List[str], None]:
    """Checks that no more than 4 stop sequences are provided.

    Args:
        stop_sequences: the stop sequences to be checked.

    Returns:
        The stop sequences.
    """
    if stop_sequences is not None:
        if isinstance(stop_sequences, str):
            stop_sequences = [stop_sequences]
        if len(stop_sequences) > 4:
            warnings.warn(
                "Only up to 4 stop sequences are allowed, so keeping the first 4 items only.",
                UserWarning,
                stacklevel=2,
            )
            stop_sequences = stop_sequences[:4]
    return stop_sequences
agenerate(input, max_new_tokens=128, frequency_penalty=None, logit_bias=None, logprobs=False, presence_penalty=None, seed=None, stop_sequences=None, temperature=1.0, tool_choice=None, tool_prompt=None, tools=None, top_logprobs=None, top_n_tokens=None, top_p=None, do_sample=False, repetition_penalty=None, return_full_text=False, top_k=None, typical_p=None, watermark=False, num_generations=1) async

使用异步客户端为给定输入生成补全。此方法使用 huggingface_hub.AsyncClient 的两个方法:chat_completiontext_generation。 仅当未指定 tokenizer_id 时,才会使用 chat_completion 方法。此函数的某些参数特定于 text_generation 方法,而另一些参数特定于 chat_completion 方法。

参数

名称 类型 描述 默认值
input FormattedInput

以聊天格式的单个输入,用于生成响应。

必需
max_new_tokens int

模型将生成的最大新 token 数。默认为 128

128
frequency_penalty Optional[Annotated[float, Field(ge=-2.0, le=2.0)]]

介于 -2.02.0 之间的值。正值会根据新 token 在目前文本中已有的频率来惩罚新 token,从而降低模型逐字重复同一行的可能性。默认为 None

None
logit_bias Optional[List[float]]

修改补全中指定 token 出现的可能性。此参数是 chat_completion 方法独有的,仅当 tokenizer_idNone 时才会使用。默认为 None

None
logprobs bool

是否返回对数概率。此参数是 chat_completion 方法独有的,仅当 tokenizer_idNone 时才会使用。默认为 False

False
presence_penalty Optional[Annotated[float, Field(ge=-2.0, le=2.0)]]

介于 -2.02.0 之间的值。正值会根据新 token 是否在目前文本中出现来惩罚新 token,从而提高模型谈论新话题的可能性。此参数是 chat_completion 方法独有的,仅当 tokenizer_idNone 时才会使用。默认为 None

None
seed Optional[int]

用于生成的种子。默认为 None

None
stop_sequences Optional[List[str]]

单个字符串或字符串列表,包含在其中停止生成的序列。默认为 None,但如果可用,将设置为 tokenizer.eos_token

None
temperature float

用于生成的温度。默认为 1.0

1.0
tool_choice Optional[Union[Dict[str, str], Literal['auto']]]

模型应调用的工具的名称。它可以是像 {"function_name": "my_tool"} 这样的字典或 "auto"。如果未提供,则模型不会使用任何工具。此参数是 chat_completion 方法独有的,仅当 tokenizer_idNone 时才会使用。默认为 None

None
tool_prompt Optional[str]

在工具之前附加的提示。此参数是 chat_completion 方法独有的,仅当 tokenizer_idNone 时才会使用。默认为 None

None
tools Optional[List[Dict[str, Any]]]

LLM 可以使用的工具定义列表。此参数是 chat_completion 方法独有的,仅当 tokenizer_idNone 时才会使用。默认为 None

None
top_logprobs Optional[PositiveInt]

每个生成的输出 token 要返回的顶部对数概率的数量。此参数是 chat_completion 方法独有的,仅当 tokenizer_idNone 时才会使用。默认为 None

None
top_n_tokens Optional[PositiveInt]

每个生成的输出 token 要返回的顶部对数概率的数量。此参数是 text_generation 方法独有的,仅当 tokenizer_id 不为 None 时才会使用。默认为 None

None
top_p Optional[float]

用于生成的 top-p 值。默认为 1.0

None
do_sample bool

是否使用采样进行生成。此参数是 text_generation 方法独有的,仅当 tokenizer_id 不为 None 时才会使用。默认为 False

False
repetition_penalty Optional[float]

用于生成的重复惩罚。此参数是 text_generation 方法独有的,仅当 tokenizer_id 不为 None 时才会使用。默认为 None

None
return_full_text bool

是否返回补全的完整文本,还是仅返回生成的文本。默认为 False,表示仅返回生成的文本。此参数是 text_generation 方法独有的,仅当 tokenizer_id 不为 None 时才会使用。

False
top_k Optional[int]

用于生成的 top-k 值。此参数是 text_generation 方法独有的,仅当 tokenizer_id 不为 None 时才会使用。默认为 0.8,因为 0.01.0 在 TGI 中都不是有效值。

None
typical_p Optional[float]

用于生成的 typical-p 值。此参数是 text_generation 方法独有的,仅当 tokenizer_id 不为 None 时才会使用。默认为 None

None
watermark bool

是否将水印添加到生成的文本。此参数是 text_generation 方法独有的,仅当 tokenizer_id 不为 None 时才会使用。默认为 None

False
num_generations int

要生成的代数。默认为 1。此处是为了确保验证成功。

1

返回

类型 描述
GenerateOutput

包含每个输入的生成响应的字符串列表的列表。

源代码位于 src/distilabel/models/llms/huggingface/inference_endpoints.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: FormattedInput,
    max_new_tokens: int = 128,
    frequency_penalty: Optional[Annotated[float, Field(ge=-2.0, le=2.0)]] = None,
    logit_bias: Optional[List[float]] = None,
    logprobs: bool = False,
    presence_penalty: Optional[Annotated[float, Field(ge=-2.0, le=2.0)]] = None,
    seed: Optional[int] = None,
    stop_sequences: Optional[List[str]] = None,
    temperature: float = 1.0,
    tool_choice: Optional[Union[Dict[str, str], Literal["auto"]]] = None,
    tool_prompt: Optional[str] = None,
    tools: Optional[List[Dict[str, Any]]] = None,
    top_logprobs: Optional[PositiveInt] = None,
    top_n_tokens: Optional[PositiveInt] = None,
    top_p: Optional[float] = None,
    do_sample: bool = False,
    repetition_penalty: Optional[float] = None,
    return_full_text: bool = False,
    top_k: Optional[int] = None,
    typical_p: Optional[float] = None,
    watermark: bool = False,
    num_generations: int = 1,
) -> GenerateOutput:
    """Generates completions for the given input using the async client. This method
    uses two methods of the `huggingface_hub.AsyncClient`: `chat_completion` and `text_generation`.
    `chat_completion` method will be used only if no `tokenizer_id` has been specified.
    Some arguments of this function are specific to the `text_generation` method, while
    some others are specific to the `chat_completion` method.

    Args:
        input: a single input in chat format to generate responses for.
        max_new_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `128`.
        frequency_penalty: a value between `-2.0` and `2.0`. Positive values penalize
            new tokens based on their existing frequency in the text so far, decreasing
            model's likelihood to repeat the same line verbatim. Defauls to `None`.
        logit_bias: modify the likelihood of specified tokens appearing in the completion.
            This argument is exclusive to the `chat_completion` method and will be used
            only if `tokenizer_id` is `None`.
            Defaults to `None`.
        logprobs: whether to return the log probabilities or not. This argument is exclusive
            to the `chat_completion` method and will be used only if `tokenizer_id`
            is `None`. Defaults to `False`.
        presence_penalty: a value between `-2.0` and `2.0`. Positive values penalize
            new tokens based on whether they appear in the text so far, increasing the
            model likelihood to talk about new topics. This argument is exclusive to
            the `chat_completion` method and will be used only if `tokenizer_id` is
            `None`. Defauls to `None`.
        seed: the seed to use for the generation. Defaults to `None`.
        stop_sequences: either a single string or a list of strings containing the sequences
            to stop the generation at. Defaults to `None`, but will be set to the
            `tokenizer.eos_token` if available.
        temperature: the temperature to use for the generation. Defaults to `1.0`.
        tool_choice: the name of the tool the model should call. It can be a dictionary
            like `{"function_name": "my_tool"}` or "auto". If not provided, then the
            model won't use any tool. This argument is exclusive to the `chat_completion`
            method and will be used only if `tokenizer_id` is `None`. Defaults to `None`.
        tool_prompt: A prompt to be appended before the tools. This argument is exclusive
            to the `chat_completion` method and will be used only if `tokenizer_id`
            is `None`. Defauls to `None`.
        tools: a list of tools definitions that the LLM can use.
            This argument is exclusive to the `chat_completion` method and will be used
            only if `tokenizer_id` is `None`. Defaults to `None`.
        top_logprobs: the number of top log probabilities to return per output token
            generated. This argument is exclusive to the `chat_completion` method and
            will be used only if `tokenizer_id` is `None`. Defaults to `None`.
        top_n_tokens: the number of top log probabilities to return per output token
            generated. This argument is exclusive of the `text_generation` method and
            will be only used if `tokenizer_id` is not `None`. Defaults to `None`.
        top_p: the top-p value to use for the generation. Defaults to `1.0`.
        do_sample: whether to use sampling for the generation. This argument is exclusive
            of the `text_generation` method and will be only used if `tokenizer_id` is not
            `None`. Defaults to `False`.
        repetition_penalty: the repetition penalty to use for the generation. This argument
            is exclusive of the `text_generation` method and will be only used if `tokenizer_id`
            is not `None`. Defaults to `None`.
        return_full_text: whether to return the full text of the completion or just
            the generated text. Defaults to `False`, meaning that only the generated
            text will be returned. This argument is exclusive of the `text_generation`
            method and will be only used if `tokenizer_id` is not `None`.
        top_k: the top-k value to use for the generation. This argument is exclusive
            of the `text_generation` method and will be only used if `tokenizer_id`
            is not `None`. Defaults to `0.8`, since neither `0.0` nor `1.0` are valid
            values in TGI.
        typical_p: the typical-p value to use for the generation. This argument is exclusive
            of the `text_generation` method and will be only used if `tokenizer_id`
            is not `None`. Defaults to `None`.
        watermark: whether to add the watermark to the generated text. This argument
            is exclusive of the `text_generation` method and will be only used if `tokenizer_id`
            is not `None`. Defaults to `None`.
        num_generations: the number of generations to generate. Defaults to `1`. It's here to ensure
            the validation succeds.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """
    stop_sequences = self._check_stop_sequences(stop_sequences)

    if isinstance(input, str) or self.tokenizer_id is not None:
        structured_output = None
        if not isinstance(input, str):
            input, structured_output = self._get_structured_output(input)
            input = self.prepare_input(input)

        return await self._generate_with_text_generation(
            input=input,
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
            typical_p=typical_p,
            repetition_penalty=repetition_penalty,
            frequency_penalty=frequency_penalty,
            temperature=temperature,
            top_n_tokens=top_n_tokens,
            top_p=top_p,
            top_k=top_k,
            stop_sequences=stop_sequences,
            return_full_text=return_full_text,
            seed=seed,
            watermark=watermark,
            structured_output=structured_output,
        )

    return await self._generate_with_chat_completion(
        input=input,  # type: ignore
        max_new_tokens=max_new_tokens,
        frequency_penalty=frequency_penalty,
        logit_bias=logit_bias,
        logprobs=logprobs,
        presence_penalty=presence_penalty,
        seed=seed,
        stop_sequences=stop_sequences,
        temperature=temperature,
        tool_choice=tool_choice,
        tool_prompt=tool_prompt,
        tools=tools,
        top_logprobs=top_logprobs,
        top_p=top_p,
    )

TransformersLLM

基类: LLMMagpieChatTemplateMixinCudaDevicePlacementMixin

使用文本生成 pipeline 的 Hugging Face transformers 库 LLM 实现。

属性

名称 类型 描述
model str

模型 Hugging Face Hub repo id 或包含模型权重和配置文件的目录路径。

revision str

如果 model 指的是 Hugging Face Hub 仓库,则要使用的修订版本(例如分支名称或提交 id)。默认为 "main"

torch_dtype str

模型要使用的 torch dtype,例如 "float16"、"float32" 等。默认为 "auto"

trust_remote_code bool

是否允许获取和执行从 Hub 仓库获取的远程代码。默认为 False

model_kwargs Optional[Dict[str, Any]]

将传递给模型的 from_pretrained 方法的其他关键字参数字典。

tokenizer Optional[str]

tokenizer Hugging Face Hub repo id 或包含 tokenizer 配置文件目录的路径。如果未提供,将使用与 model 关联的 tokenizer。默认为 None

use_fast bool

是否使用快速 tokenizer。默认为 True

chat_template Optional[str]

将用于构建提示的聊天模板,然后再将其发送到模型。 如果未提供,将使用 tokenizer 配置中定义的聊天模板。 如果未提供且 tokenizer 没有聊天模板,则将使用 ChatML 模板。默认为 None

device Optional[Union[str, int]]

模型将加载到的设备的名称或索引。默认为 None

device_map Optional[Union[str, Dict[str, Any]]]

将模型的每一层映射到设备的字典,或像 "sequential""auto" 这样的模式。默认为 None

token Optional[SecretStr]

将用于向 Hugging Face Hub 验证身份的 Hugging Face Hub token。 如果未提供,将使用 HF_TOKEN 环境变量或 huggingface_hub 包本地配置。 默认为 None

structured_output Optional[RuntimeParameter[OutlinesStructuredOutputType]]

一个字典,包含结构化输出配置;如果需要更细粒度的控制,则包含 OutlinesStructuredOutput 的实例。默认为 None。

use_magpie_template bool

用于启用/禁用应用 Magpie 预查询模板的标志。默认为 False

magpie_pre_query_template Union[MagpieAvailablePreQueryTemplates, str, None]

要应用于提示或发送到 LLM 以生成指令或后续用户消息的预查询模板。有效值为 "llama3"、"qwen2" 或提供的另一个预查询模板。默认为 None

图标

:hugging

示例

生成文本

from distilabel.models.llms import TransformersLLM

llm = TransformersLLM(model="microsoft/Phi-3-mini-4k-instruct")

llm.load()

# Call the model
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
源代码位于 src/distilabel/models/llms/huggingface/transformers.py
class TransformersLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin):
    """Hugging Face `transformers` library LLM implementation using the text generation
    pipeline.

    Attributes:
        model: the model Hugging Face Hub repo id or a path to a directory containing the
            model weights and configuration files.
        revision: if `model` refers to a Hugging Face Hub repository, then the revision
            (e.g. a branch name or a commit id) to use. Defaults to `"main"`.
        torch_dtype: the torch dtype to use for the model e.g. "float16", "float32", etc.
            Defaults to `"auto"`.
        trust_remote_code: whether to allow fetching and executing remote code fetched
            from the repository in the Hub. Defaults to `False`.
        model_kwargs: additional dictionary of keyword arguments that will be passed to
            the `from_pretrained` method of the model.
        tokenizer: the tokenizer Hugging Face Hub repo id or a path to a directory containing
            the tokenizer config files. If not provided, the one associated to the `model`
            will be used. Defaults to `None`.
        use_fast: whether to use a fast tokenizer or not. Defaults to `True`.
        chat_template: a chat template that will be used to build the prompts before
            sending them to the model. If not provided, the chat template defined in the
            tokenizer config will be used. If not provided and the tokenizer doesn't have
            a chat template, then ChatML template will be used. Defaults to `None`.
        device: the name or index of the device where the model will be loaded. Defaults
            to `None`.
        device_map: a dictionary mapping each layer of the model to a device, or a mode
            like `"sequential"` or `"auto"`. Defaults to `None`.
        token: the Hugging Face Hub token that will be used to authenticate to the Hugging
            Face Hub. If not provided, the `HF_TOKEN` environment or `huggingface_hub` package
            local configuration will be used. Defaults to `None`.
        structured_output: a dictionary containing the structured output configuration or if more
            fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None.
        use_magpie_template: a flag used to enable/disable applying the Magpie pre-query
            template. Defaults to `False`.
        magpie_pre_query_template: the pre-query template to be applied to the prompt or
            sent to the LLM to generate an instruction or a follow up user message. Valid
            values are "llama3", "qwen2" or another pre-query template provided. Defaults
            to `None`.

    Icon:
        `:hugging:`

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import TransformersLLM

        llm = TransformersLLM(model="microsoft/Phi-3-mini-4k-instruct")

        llm.load()

        # Call the model
        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```
    """

    model: str
    revision: str = "main"
    torch_dtype: str = "auto"
    trust_remote_code: bool = False
    model_kwargs: Optional[Dict[str, Any]] = None
    tokenizer: Optional[str] = None
    use_fast: bool = True
    chat_template: Optional[str] = None
    device: Optional[Union[str, int]] = None
    device_map: Optional[Union[str, Dict[str, Any]]] = None
    token: Optional[SecretStr] = Field(
        default_factory=lambda: os.getenv(HF_TOKEN_ENV_VAR)
    )
    structured_output: Optional[RuntimeParameter[OutlinesStructuredOutputType]] = Field(
        default=None,
        description="The structured output format to use across all the generations.",
    )

    _pipeline: Optional["Pipeline"] = PrivateAttr(...)
    _prefix_allowed_tokens_fn: Union[Callable, None] = PrivateAttr(default=None)
    _logits_processor: Union[Callable, None] = PrivateAttr(default=None)

    def load(self) -> None:
        """Loads the model and tokenizer and creates the text generation pipeline. In addition,
        it will configure the tokenizer chat template."""
        if self.device == "cuda":
            CudaDevicePlacementMixin.load(self)

        try:
            from transformers import pipeline
        except ImportError as ie:
            raise ImportError(
                "Transformers is not installed. Please install it using `pip install 'distilabel[hf-transformers]'`."
            ) from ie

        token = self.token.get_secret_value() if self.token is not None else self.token

        self._pipeline = pipeline(
            "text-generation",
            model=self.model,
            revision=self.revision,
            torch_dtype=self.torch_dtype,
            trust_remote_code=self.trust_remote_code,
            model_kwargs=self.model_kwargs or {},
            tokenizer=self.tokenizer or self.model,
            use_fast=self.use_fast,
            device=self.device,
            device_map=self.device_map,
            token=token,
            return_full_text=False,
        )

        if self.chat_template is not None:
            self._pipeline.tokenizer.chat_template = self.chat_template  # type: ignore

        if self._pipeline.tokenizer.pad_token is None:  # type: ignore
            self._pipeline.tokenizer.pad_token = self._pipeline.tokenizer.eos_token  # type: ignore

        if self.structured_output:
            processor = self._prepare_structured_output(self.structured_output)
            if _is_outlines_version_below_0_1_0():
                self._prefix_allowed_tokens_fn = processor
            else:
                self._logits_processor = [processor]

        super().load()

    def unload(self) -> None:
        """Unloads the `vLLM` model."""
        CudaDevicePlacementMixin.unload(self)
        super().unload()

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    def prepare_input(self, input: "StandardInput") -> str:
        """Prepares the input (applying the chat template and tokenization) for the provided
        input.

        Args:
            input: the input list containing chat items.

        Returns:
            The prompt to send to the LLM.
        """
        if self._pipeline.tokenizer.chat_template is None:  # type: ignore
            return input[0]["content"]

        prompt: str = (
            self._pipeline.tokenizer.apply_chat_template(  # type: ignore
                input,  # type: ignore
                tokenize=False,
                add_generation_prompt=True,
            )
            if input
            else ""
        )
        return super().apply_magpie_pre_query_template(prompt, input)

    @validate_call
    def generate(  # type: ignore
        self,
        inputs: List[StandardInput],
        num_generations: int = 1,
        max_new_tokens: int = 128,
        temperature: float = 0.1,
        repetition_penalty: float = 1.1,
        top_p: float = 1.0,
        top_k: int = 0,
        do_sample: bool = True,
    ) -> List[GenerateOutput]:
        """Generates `num_generations` responses for each input using the text generation
        pipeline.

        Args:
            inputs: a list of inputs in chat format to generate responses for.
            num_generations: the number of generations to create per input. Defaults to
                `1`.
            max_new_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `128`.
            temperature: the temperature to use for the generation. Defaults to `0.1`.
            repetition_penalty: the repetition penalty to use for the generation. Defaults
                to `1.1`.
            top_p: the top-p value to use for the generation. Defaults to `1.0`.
            top_k: the top-k value to use for the generation. Defaults to `0`.
            do_sample: whether to use sampling or not. Defaults to `True`.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """
        prepared_inputs = [self.prepare_input(input=input) for input in inputs]

        outputs: List[List[Dict[str, str]]] = self._pipeline(  # type: ignore
            prepared_inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            repetition_penalty=repetition_penalty,
            top_p=top_p,
            top_k=top_k,
            do_sample=do_sample,
            num_return_sequences=num_generations,
            prefix_allowed_tokens_fn=self._prefix_allowed_tokens_fn,
            pad_token_id=self._pipeline.tokenizer.eos_token_id,
            logits_processor=self._logits_processor,
        )
        llm_output = [
            [generation["generated_text"] for generation in output]
            for output in outputs
        ]

        result = []
        for input, output in zip(inputs, llm_output):
            result.append(
                prepare_output(
                    output,
                    input_tokens=[
                        compute_tokens(input, self._pipeline.tokenizer.encode)
                    ],
                    output_tokens=[
                        compute_tokens(row, self._pipeline.tokenizer.encode)
                        for row in output
                    ],
                )
            )

        return result

    def get_last_hidden_states(
        self, inputs: List["StandardInput"]
    ) -> List["HiddenState"]:
        """Gets the last `hidden_states` of the model for the given inputs. It doesn't
        execute the task head.

        Args:
            inputs: a list of inputs in chat format to generate the embeddings for.

        Returns:
            A list containing the last hidden state for each sequence using a NumPy array
            with shape [num_tokens, hidden_size].
        """
        model: "PreTrainedModel" = (
            self._pipeline.model.model  # type: ignore
            if hasattr(self._pipeline.model, "model")  # type: ignore
            else next(self._pipeline.model.children())  # type: ignore
        )
        tokenizer: "PreTrainedTokenizer" = self._pipeline.tokenizer  # type: ignore
        input_ids = tokenizer(
            [self.prepare_input(input) for input in inputs],  # type: ignore
            return_tensors="pt",
            padding=True,
        ).to(model.device)
        last_hidden_states = model(**input_ids)["last_hidden_state"]

        return [
            seq_last_hidden_state[attention_mask.bool(), :].detach().cpu().numpy()
            for seq_last_hidden_state, attention_mask in zip(
                last_hidden_states,
                input_ids["attention_mask"],  # type: ignore
            )
        ]

    def _prepare_structured_output(
        self, structured_output: Optional[OutlinesStructuredOutputType] = None
    ) -> Union[Callable, List[Callable]]:
        """Creates the appropriate function to filter tokens to generate structured outputs.

        Args:
            structured_output: the configuration dict to prepare the structured output.

        Returns:
            The callable that will be used to guide the generation of the model.
        """
        from distilabel.steps.tasks.structured_outputs.outlines import (
            prepare_guided_output,
        )

        result = prepare_guided_output(
            structured_output, "transformers", self._pipeline
        )
        if schema := result.get("schema"):
            self.structured_output["schema"] = schema
        return result["processor"]
model_name property

返回用于 LLM 的模型名称。

load()

加载模型和 tokenizer,并创建文本生成 pipeline。 此外,它还将配置 tokenizer 聊天模板。

源代码位于 src/distilabel/models/llms/huggingface/transformers.py
def load(self) -> None:
    """Loads the model and tokenizer and creates the text generation pipeline. In addition,
    it will configure the tokenizer chat template."""
    if self.device == "cuda":
        CudaDevicePlacementMixin.load(self)

    try:
        from transformers import pipeline
    except ImportError as ie:
        raise ImportError(
            "Transformers is not installed. Please install it using `pip install 'distilabel[hf-transformers]'`."
        ) from ie

    token = self.token.get_secret_value() if self.token is not None else self.token

    self._pipeline = pipeline(
        "text-generation",
        model=self.model,
        revision=self.revision,
        torch_dtype=self.torch_dtype,
        trust_remote_code=self.trust_remote_code,
        model_kwargs=self.model_kwargs or {},
        tokenizer=self.tokenizer or self.model,
        use_fast=self.use_fast,
        device=self.device,
        device_map=self.device_map,
        token=token,
        return_full_text=False,
    )

    if self.chat_template is not None:
        self._pipeline.tokenizer.chat_template = self.chat_template  # type: ignore

    if self._pipeline.tokenizer.pad_token is None:  # type: ignore
        self._pipeline.tokenizer.pad_token = self._pipeline.tokenizer.eos_token  # type: ignore

    if self.structured_output:
        processor = self._prepare_structured_output(self.structured_output)
        if _is_outlines_version_below_0_1_0():
            self._prefix_allowed_tokens_fn = processor
        else:
            self._logits_processor = [processor]

    super().load()
unload()

卸载 vLLM 模型。

源代码位于 src/distilabel/models/llms/huggingface/transformers.py
def unload(self) -> None:
    """Unloads the `vLLM` model."""
    CudaDevicePlacementMixin.unload(self)
    super().unload()
prepare_input(input)

为提供的输入准备输入(应用聊天模板和分词)。

参数

名称 类型 描述 默认值
input StandardInput

包含聊天项的输入列表。

必需

返回

类型 描述
str

要发送给 LLM 的提示。

源代码位于 src/distilabel/models/llms/huggingface/transformers.py
def prepare_input(self, input: "StandardInput") -> str:
    """Prepares the input (applying the chat template and tokenization) for the provided
    input.

    Args:
        input: the input list containing chat items.

    Returns:
        The prompt to send to the LLM.
    """
    if self._pipeline.tokenizer.chat_template is None:  # type: ignore
        return input[0]["content"]

    prompt: str = (
        self._pipeline.tokenizer.apply_chat_template(  # type: ignore
            input,  # type: ignore
            tokenize=False,
            add_generation_prompt=True,
        )
        if input
        else ""
    )
    return super().apply_magpie_pre_query_template(prompt, input)
generate(inputs, num_generations=1, max_new_tokens=128, temperature=0.1, repetition_penalty=1.1, top_p=1.0, top_k=0, do_sample=True)

使用文本生成 pipeline 为每个输入生成 num_generations 个响应。

参数

名称 类型 描述 默认值
inputs List[StandardInput]

聊天格式的输入列表,用于生成响应。

必需
num_generations int

每个输入要创建的代数。默认为 1

1
max_new_tokens int

模型将生成的最大新 token 数。默认为 128

128
temperature float

用于生成的温度。默认为 0.1

0.1
repetition_penalty float

用于生成的重复惩罚。默认为 1.1

1.1
top_p float

用于生成的 top-p 值。默认为 1.0

1.0
top_k int

用于生成的 top-k 值。默认为 0

0
do_sample bool

是否使用采样。默认为 True

True

返回

类型 描述
List[GenerateOutput]

包含每个输入的生成响应的字符串列表的列表。

源代码位于 src/distilabel/models/llms/huggingface/transformers.py
@validate_call
def generate(  # type: ignore
    self,
    inputs: List[StandardInput],
    num_generations: int = 1,
    max_new_tokens: int = 128,
    temperature: float = 0.1,
    repetition_penalty: float = 1.1,
    top_p: float = 1.0,
    top_k: int = 0,
    do_sample: bool = True,
) -> List[GenerateOutput]:
    """Generates `num_generations` responses for each input using the text generation
    pipeline.

    Args:
        inputs: a list of inputs in chat format to generate responses for.
        num_generations: the number of generations to create per input. Defaults to
            `1`.
        max_new_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `128`.
        temperature: the temperature to use for the generation. Defaults to `0.1`.
        repetition_penalty: the repetition penalty to use for the generation. Defaults
            to `1.1`.
        top_p: the top-p value to use for the generation. Defaults to `1.0`.
        top_k: the top-k value to use for the generation. Defaults to `0`.
        do_sample: whether to use sampling or not. Defaults to `True`.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """
    prepared_inputs = [self.prepare_input(input=input) for input in inputs]

    outputs: List[List[Dict[str, str]]] = self._pipeline(  # type: ignore
        prepared_inputs,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        repetition_penalty=repetition_penalty,
        top_p=top_p,
        top_k=top_k,
        do_sample=do_sample,
        num_return_sequences=num_generations,
        prefix_allowed_tokens_fn=self._prefix_allowed_tokens_fn,
        pad_token_id=self._pipeline.tokenizer.eos_token_id,
        logits_processor=self._logits_processor,
    )
    llm_output = [
        [generation["generated_text"] for generation in output]
        for output in outputs
    ]

    result = []
    for input, output in zip(inputs, llm_output):
        result.append(
            prepare_output(
                output,
                input_tokens=[
                    compute_tokens(input, self._pipeline.tokenizer.encode)
                ],
                output_tokens=[
                    compute_tokens(row, self._pipeline.tokenizer.encode)
                    for row in output
                ],
            )
        )

    return result
get_last_hidden_states(inputs)

获取给定输入的模型的最后 hidden_states。它不执行任务头。

参数

名称 类型 描述 默认值
inputs List[StandardInput]

聊天格式的输入列表,用于生成嵌入。

必需

返回

类型 描述
List[HiddenState]

一个列表,其中包含每个序列的最后一个隐藏状态,使用 NumPy 数组

List[HiddenState]

形状为 [num_tokens, hidden_size]。

源代码位于 src/distilabel/models/llms/huggingface/transformers.py
def get_last_hidden_states(
    self, inputs: List["StandardInput"]
) -> List["HiddenState"]:
    """Gets the last `hidden_states` of the model for the given inputs. It doesn't
    execute the task head.

    Args:
        inputs: a list of inputs in chat format to generate the embeddings for.

    Returns:
        A list containing the last hidden state for each sequence using a NumPy array
        with shape [num_tokens, hidden_size].
    """
    model: "PreTrainedModel" = (
        self._pipeline.model.model  # type: ignore
        if hasattr(self._pipeline.model, "model")  # type: ignore
        else next(self._pipeline.model.children())  # type: ignore
    )
    tokenizer: "PreTrainedTokenizer" = self._pipeline.tokenizer  # type: ignore
    input_ids = tokenizer(
        [self.prepare_input(input) for input in inputs],  # type: ignore
        return_tensors="pt",
        padding=True,
    ).to(model.device)
    last_hidden_states = model(**input_ids)["last_hidden_state"]

    return [
        seq_last_hidden_state[attention_mask.bool(), :].detach().cpu().numpy()
        for seq_last_hidden_state, attention_mask in zip(
            last_hidden_states,
            input_ids["attention_mask"],  # type: ignore
        )
    ]
_prepare_structured_output(structured_output=None)

创建适当的函数来过滤 token,以生成结构化输出。

参数

名称 类型 描述 默认值
structured_output Optional[OutlinesStructuredOutputType]

配置字典,用于准备结构化输出。

None

返回

类型 描述
Union[Callable, List[Callable]]

将用于指导模型生成的 callable。

源代码位于 src/distilabel/models/llms/huggingface/transformers.py
def _prepare_structured_output(
    self, structured_output: Optional[OutlinesStructuredOutputType] = None
) -> Union[Callable, List[Callable]]:
    """Creates the appropriate function to filter tokens to generate structured outputs.

    Args:
        structured_output: the configuration dict to prepare the structured output.

    Returns:
        The callable that will be used to guide the generation of the model.
    """
    from distilabel.steps.tasks.structured_outputs.outlines import (
        prepare_guided_output,
    )

    result = prepare_guided_output(
        structured_output, "transformers", self._pipeline
    )
    if schema := result.get("schema"):
        self.structured_output["schema"] = schema
    return result["processor"]

LiteLLM

基类: AsyncLLM

运行异步 API 客户端的 LiteLLM 实现。

属性

名称 类型 描述
model str

用于 LLM 的模型名称,例如 "gpt-3.5-turbo" 或 "mistral/mistral-large" 等。

verbose RuntimeParameter[bool]

是否记录 LiteLLM 客户端的日志。默认为 False

structured_output 可选[RuntimeParameter[InstructorStructuredOutputType]]

一个字典,包含使用 instructor 的结构化输出配置。您可以在 distilabel.steps.tasks.structured_outputs.instructorInstructorStructuredOutputType 中查看字典结构。

运行时参数
  • verbose:是否记录 LiteLLM 客户端的日志。默认为 False

示例

生成文本

from distilabel.models.llms import LiteLLM

llm = LiteLLM(model="gpt-3.5-turbo")

llm.load()

# Call the model
output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])

Generate structured data:

```python
from pydantic import BaseModel
from distilabel.models.llms import LiteLLM

class User(BaseModel):
    name: str
    last_name: str
    id: int

llm = LiteLLM(
    model="gpt-3.5-turbo",
    api_key="api.key",
    structured_output={"schema": User}
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
源代码位于 src/distilabel/models/llms/litellm.py
class LiteLLM(AsyncLLM):
    """LiteLLM implementation running the async API client.

    Attributes:
        model: the model name to use for the LLM e.g. "gpt-3.5-turbo" or "mistral/mistral-large",
            etc.
        verbose: whether to log the LiteLLM client's logs. Defaults to `False`.
        structured_output: a dictionary containing the structured output configuration configuration
            using `instructor`. You can take a look at the dictionary structure in
            `InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`.

    Runtime parameters:
        - `verbose`: whether to log the LiteLLM client's logs. Defaults to `False`.

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import LiteLLM

        llm = LiteLLM(model="gpt-3.5-turbo")

        llm.load()

        # Call the model
        output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])

        Generate structured data:

        ```python
        from pydantic import BaseModel
        from distilabel.models.llms import LiteLLM

        class User(BaseModel):
            name: str
            last_name: str
            id: int

        llm = LiteLLM(
            model="gpt-3.5-turbo",
            api_key="api.key",
            structured_output={"schema": User}
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
        ```
    """

    model: str
    verbose: RuntimeParameter[bool] = Field(
        default=False, description="Whether to log the LiteLLM client's logs."
    )
    structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = (
        Field(
            default=None,
            description="The structured output format to use across all the generations.",
        )
    )

    _aclient: Optional[Callable] = PrivateAttr(...)

    def load(self) -> None:
        """
        Loads the `acompletion` LiteLLM client to benefit from async requests.
        """
        super().load()

        try:
            import litellm

            litellm.telemetry = False
        except ImportError as e:
            raise ImportError(
                "LiteLLM Python client is not installed. Please install it using"
                " `pip install 'distilabel[litellm]'`."
            ) from e
        self._aclient = litellm.acompletion

        if not self.verbose:
            litellm.suppress_debug_info = True
            for key in logging.Logger.manager.loggerDict.keys():
                if "litellm" not in key.lower():
                    continue
                logging.getLogger(key).setLevel(logging.CRITICAL)

        if self.structured_output:
            result = self._prepare_structured_output(
                structured_output=self.structured_output,
                client=self._aclient,
                framework="litellm",
            )
            self._aclient = result.get("client").messages.create
            if structured_output := result.get("structured_output"):
                self.structured_output = structured_output

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    @validate_call
    async def agenerate(  # type: ignore # noqa: C901
        self,
        input: FormattedInput,
        num_generations: int = 1,
        functions: Optional[List] = None,
        function_call: Optional[str] = None,
        temperature: Optional[float] = 1.0,
        top_p: Optional[float] = 1.0,
        stop: Optional[Union[str, list]] = None,
        max_tokens: Optional[int] = None,
        presence_penalty: Optional[float] = None,
        frequency_penalty: Optional[float] = None,
        logit_bias: Optional[dict] = None,
        user: Optional[str] = None,
        metadata: Optional[dict] = None,
        api_base: Optional[str] = None,
        api_version: Optional[str] = None,
        api_key: Optional[str] = None,
        model_list: Optional[list] = None,
        mock_response: Optional[str] = None,
        force_timeout: Optional[int] = 600,
        custom_llm_provider: Optional[str] = None,
    ) -> GenerateOutput:
        """Generates `num_generations` responses for the given input using the [LiteLLM async client](https://github.com/BerriAI/litellm).

        Args:
            input: a single input in chat format to generate responses for.
            num_generations: the number of generations to create per input. Defaults to
                `1`.
            functions: a list of functions to apply to the conversation messages. Defaults to
                `None`.
            function_call: the name of the function to call within the conversation. Defaults
                to `None`.
            temperature: the temperature to use for the generation. Defaults to `1.0`.
            top_p: the top-p value to use for the generation. Defaults to `1.0`.
            stop: Up to 4 sequences where the LLM API will stop generating further tokens.
                Defaults to `None`.
            max_tokens: The maximum number of tokens in the generated completion. Defaults to
                `None`.
            presence_penalty: It is used to penalize new tokens based on their existence in the
                text so far. Defaults to `None`.
            frequency_penalty: It is used to penalize new tokens based on their frequency in the
                text so far. Defaults to `None`.
            logit_bias: Used to modify the probability of specific tokens appearing in the
                completion. Defaults to `None`.
            user: A unique identifier representing your end-user. This can help the LLM provider
                to monitor and detect abuse. Defaults to `None`.
            metadata: Pass in additional metadata to tag your completion calls - eg. prompt
                version, details, etc. Defaults to `None`.
            api_base: Base URL for the API. Defaults to `None`.
            api_version: API version. Defaults to `None`.
            api_key: API key. Defaults to `None`.
            model_list: List of api base, version, keys. Defaults to `None`.
            mock_response: If provided, return a mock completion response for testing or debugging
                purposes. Defaults to `None`.
            force_timeout: The maximum execution time in seconds for the completion request.
                Defaults to `600`.
            custom_llm_provider: Used for Non-OpenAI LLMs, Example usage for bedrock, set(iterable)
                model="amazon.titan-tg1-large" and custom_llm_provider="bedrock". Defaults to
                `None`.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """
        import litellm
        from litellm import token_counter

        structured_output = None
        if isinstance(input, tuple):
            input, structured_output = input
            result = self._prepare_structured_output(
                structured_output=structured_output,
                client=self._aclient,
                framework="litellm",
            )
            self._aclient = result.get("client").messages.create

        if structured_output is None and self.structured_output is not None:
            structured_output = self.structured_output

        kwargs = {
            "model": self.model,
            "messages": input,
            "n": num_generations,
            "functions": functions,
            "function_call": function_call,
            "temperature": temperature,
            "top_p": top_p,
            "stream": False,
            "stop": stop,
            "max_tokens": max_tokens,
            "presence_penalty": presence_penalty,
            "frequency_penalty": frequency_penalty,
            "logit_bias": logit_bias,
            "user": user,
            "metadata": metadata,
            "api_base": api_base,
            "api_version": api_version,
            "api_key": api_key,
            "model_list": model_list,
            "mock_response": mock_response,
            "force_timeout": force_timeout,
            "custom_llm_provider": custom_llm_provider,
        }
        if structured_output:
            kwargs = self._prepare_kwargs(kwargs, structured_output)

        async def _call_aclient_until_n_choices() -> List["Choices"]:
            choices = []
            while len(choices) < num_generations:
                completion: Union["ModelResponse", "BaseModel"] = await self._aclient(
                    **kwargs
                )  # type: ignore
                if self.structured_output:
                    # Prevent pydantic model from being cast to list during list extension
                    completion = [completion]
                else:
                    completion = completion.choices
                choices.extend(completion)
            return choices

        # litellm.drop_params is used to en/disable sending **kwargs parameters to the API if they cannot be used
        try:
            litellm.drop_params = False
            choices = await _call_aclient_until_n_choices()
        except litellm.exceptions.APIError as e:
            if "does not support parameters" in str(e):
                litellm.drop_params = True
                choices = await _call_aclient_until_n_choices()
            else:
                raise e

        generations = []
        input_tokens = [
            token_counter(model=self.model, messages=input)
        ] * num_generations
        output_tokens = []

        if self.structured_output:
            for choice in choices:
                generations.append(choice.model_dump_json())
                output_tokens.append(
                    token_counter(
                        model=self.model,
                        text=orjson.dumps(choice.model_dump_json()).decode("utf-8"),
                    )
                )
            return prepare_output(
                generations,
                input_tokens=input_tokens,
                output_tokens=output_tokens,
            )

        for choice in choices:
            if (content := choice.message.content) is None:
                self._logger.warning(  # type: ignore
                    f"Received no response using LiteLLM client (model: '{self.model}')."
                    f" Finish reason was: {choice.finish_reason}"
                )
            generations.append(content)
            output_tokens.append(token_counter(model=self.model, text=content))

        return prepare_output(
            generations, input_tokens=input_tokens, output_tokens=output_tokens
        )
model_name property

返回用于 LLM 的模型名称。

load()

加载 acompletion LiteLLM 客户端以受益于异步请求。

源代码位于 src/distilabel/models/llms/litellm.py
def load(self) -> None:
    """
    Loads the `acompletion` LiteLLM client to benefit from async requests.
    """
    super().load()

    try:
        import litellm

        litellm.telemetry = False
    except ImportError as e:
        raise ImportError(
            "LiteLLM Python client is not installed. Please install it using"
            " `pip install 'distilabel[litellm]'`."
        ) from e
    self._aclient = litellm.acompletion

    if not self.verbose:
        litellm.suppress_debug_info = True
        for key in logging.Logger.manager.loggerDict.keys():
            if "litellm" not in key.lower():
                continue
            logging.getLogger(key).setLevel(logging.CRITICAL)

    if self.structured_output:
        result = self._prepare_structured_output(
            structured_output=self.structured_output,
            client=self._aclient,
            framework="litellm",
        )
        self._aclient = result.get("client").messages.create
        if structured_output := result.get("structured_output"):
            self.structured_output = structured_output
agenerate(input, num_generations=1, functions=None, function_call=None, temperature=1.0, top_p=1.0, stop=None, max_tokens=None, presence_penalty=None, frequency_penalty=None, logit_bias=None, user=None, metadata=None, api_base=None, api_version=None, api_key=None, model_list=None, mock_response=None, force_timeout=600, custom_llm_provider=None) async

使用 LiteLLM 异步客户端 为给定输入生成 num_generations 个响应。

参数

名称 类型 描述 默认值
input FormattedInput

以聊天格式的单个输入,用于生成响应。

必需
num_generations int

每个输入要创建的代数。默认为 1

1
functions Optional[List]

要应用于对话消息的函数列表。默认为 None

None
function_call Optional[str]

要在对话中调用的函数的名称。默认为 None

None
temperature Optional[float]

用于生成的温度。默认为 1.0

1.0
top_p Optional[float]

用于生成的 top-p 值。默认为 1.0

1.0
stop Optional[Union[str, list]]

最多 4 个序列,LLM API 将在此处停止生成更多 token。默认为 None

None
max_tokens Optional[int]

生成补全中的最大 token 数。默认为 None

None
presence_penalty Optional[float]

用于根据新 token 在目前文本中是否存在来惩罚新 token。默认为 None

None
frequency_penalty Optional[float]

用于根据新 token 在目前文本中的频率来惩罚新 token。默认为 None

None
logit_bias Optional[dict]

用于修改特定 token 在补全中出现的概率。默认为 None

None
user Optional[str]

代表最终用户的唯一标识符。这可以帮助 LLM 提供商监控和检测滥用行为。默认为 None

None
metadata Optional[dict]

传入额外的元数据以标记您的补全调用 - 例如提示版本、详细信息等。默认为 None

None
api_base Optional[str]

API 的基本 URL。默认为 None

None
api_version Optional[str]

API 版本。默认为 None

None
api_key Optional[str]

API 密钥。默认为 None

None
model_list Optional[list]

api 基本 URL、版本、密钥的列表。默认为 None

None
mock_response Optional[str]

如果提供,则返回模拟补全响应,用于测试或调试目的。默认为 None

None
force_timeout Optional[int]

补全请求的最大执行时间(秒)。默认为 600

600
custom_llm_provider Optional[str]

用于非 OpenAI LLM,Bedrock 的示例用法,设置(iterable) model="amazon.titan-tg1-large" 和 custom_llm_provider="bedrock"。默认为 None

None

返回

类型 描述
GenerateOutput

包含每个输入的生成响应的字符串列表的列表。

源代码位于 src/distilabel/models/llms/litellm.py
@validate_call
async def agenerate(  # type: ignore # noqa: C901
    self,
    input: FormattedInput,
    num_generations: int = 1,
    functions: Optional[List] = None,
    function_call: Optional[str] = None,
    temperature: Optional[float] = 1.0,
    top_p: Optional[float] = 1.0,
    stop: Optional[Union[str, list]] = None,
    max_tokens: Optional[int] = None,
    presence_penalty: Optional[float] = None,
    frequency_penalty: Optional[float] = None,
    logit_bias: Optional[dict] = None,
    user: Optional[str] = None,
    metadata: Optional[dict] = None,
    api_base: Optional[str] = None,
    api_version: Optional[str] = None,
    api_key: Optional[str] = None,
    model_list: Optional[list] = None,
    mock_response: Optional[str] = None,
    force_timeout: Optional[int] = 600,
    custom_llm_provider: Optional[str] = None,
) -> GenerateOutput:
    """Generates `num_generations` responses for the given input using the [LiteLLM async client](https://github.com/BerriAI/litellm).

    Args:
        input: a single input in chat format to generate responses for.
        num_generations: the number of generations to create per input. Defaults to
            `1`.
        functions: a list of functions to apply to the conversation messages. Defaults to
            `None`.
        function_call: the name of the function to call within the conversation. Defaults
            to `None`.
        temperature: the temperature to use for the generation. Defaults to `1.0`.
        top_p: the top-p value to use for the generation. Defaults to `1.0`.
        stop: Up to 4 sequences where the LLM API will stop generating further tokens.
            Defaults to `None`.
        max_tokens: The maximum number of tokens in the generated completion. Defaults to
            `None`.
        presence_penalty: It is used to penalize new tokens based on their existence in the
            text so far. Defaults to `None`.
        frequency_penalty: It is used to penalize new tokens based on their frequency in the
            text so far. Defaults to `None`.
        logit_bias: Used to modify the probability of specific tokens appearing in the
            completion. Defaults to `None`.
        user: A unique identifier representing your end-user. This can help the LLM provider
            to monitor and detect abuse. Defaults to `None`.
        metadata: Pass in additional metadata to tag your completion calls - eg. prompt
            version, details, etc. Defaults to `None`.
        api_base: Base URL for the API. Defaults to `None`.
        api_version: API version. Defaults to `None`.
        api_key: API key. Defaults to `None`.
        model_list: List of api base, version, keys. Defaults to `None`.
        mock_response: If provided, return a mock completion response for testing or debugging
            purposes. Defaults to `None`.
        force_timeout: The maximum execution time in seconds for the completion request.
            Defaults to `600`.
        custom_llm_provider: Used for Non-OpenAI LLMs, Example usage for bedrock, set(iterable)
            model="amazon.titan-tg1-large" and custom_llm_provider="bedrock". Defaults to
            `None`.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """
    import litellm
    from litellm import token_counter

    structured_output = None
    if isinstance(input, tuple):
        input, structured_output = input
        result = self._prepare_structured_output(
            structured_output=structured_output,
            client=self._aclient,
            framework="litellm",
        )
        self._aclient = result.get("client").messages.create

    if structured_output is None and self.structured_output is not None:
        structured_output = self.structured_output

    kwargs = {
        "model": self.model,
        "messages": input,
        "n": num_generations,
        "functions": functions,
        "function_call": function_call,
        "temperature": temperature,
        "top_p": top_p,
        "stream": False,
        "stop": stop,
        "max_tokens": max_tokens,
        "presence_penalty": presence_penalty,
        "frequency_penalty": frequency_penalty,
        "logit_bias": logit_bias,
        "user": user,
        "metadata": metadata,
        "api_base": api_base,
        "api_version": api_version,
        "api_key": api_key,
        "model_list": model_list,
        "mock_response": mock_response,
        "force_timeout": force_timeout,
        "custom_llm_provider": custom_llm_provider,
    }
    if structured_output:
        kwargs = self._prepare_kwargs(kwargs, structured_output)

    async def _call_aclient_until_n_choices() -> List["Choices"]:
        choices = []
        while len(choices) < num_generations:
            completion: Union["ModelResponse", "BaseModel"] = await self._aclient(
                **kwargs
            )  # type: ignore
            if self.structured_output:
                # Prevent pydantic model from being cast to list during list extension
                completion = [completion]
            else:
                completion = completion.choices
            choices.extend(completion)
        return choices

    # litellm.drop_params is used to en/disable sending **kwargs parameters to the API if they cannot be used
    try:
        litellm.drop_params = False
        choices = await _call_aclient_until_n_choices()
    except litellm.exceptions.APIError as e:
        if "does not support parameters" in str(e):
            litellm.drop_params = True
            choices = await _call_aclient_until_n_choices()
        else:
            raise e

    generations = []
    input_tokens = [
        token_counter(model=self.model, messages=input)
    ] * num_generations
    output_tokens = []

    if self.structured_output:
        for choice in choices:
            generations.append(choice.model_dump_json())
            output_tokens.append(
                token_counter(
                    model=self.model,
                    text=orjson.dumps(choice.model_dump_json()).decode("utf-8"),
                )
            )
        return prepare_output(
            generations,
            input_tokens=input_tokens,
            output_tokens=output_tokens,
        )

    for choice in choices:
        if (content := choice.message.content) is None:
            self._logger.warning(  # type: ignore
                f"Received no response using LiteLLM client (model: '{self.model}')."
                f" Finish reason was: {choice.finish_reason}"
            )
        generations.append(content)
        output_tokens.append(token_counter(model=self.model, text=content))

    return prepare_output(
        generations, input_tokens=input_tokens, output_tokens=output_tokens
    )

LlamaCppLLM

基类: LLMMagpieChatTemplateMixin

llama.cpp LLM 实现,运行 C++ 代码的 Python 绑定。

属性

名称 类型 描述
model_path RuntimeParameter[FilePath]

包含 GGUF 量化模型路径,与已安装的 llama.cpp Python 绑定版本兼容。

n_gpu_layers RuntimeParameter[int]

用于 GPU 的层数。默认为 -1,表示将使用可用的 GPU 设备。

chat_format 可选[RuntimeParameter[str]]

模型要使用的聊天格式。默认为 None,表示将使用 Llama 格式。

n_ctx int

模型要使用的上下文大小。默认为 512

n_batch int

模型要使用的提示处理最大批次大小。默认为 512

seed int

用于生成的随机种子。默认为 4294967295

verbose RuntimeParameter[bool]

是否打印详细输出。默认为 False

structured_output Optional[RuntimeParameter[OutlinesStructuredOutputType]]

一个字典,包含结构化输出配置;如果需要更细粒度的控制,则包含 OutlinesStructuredOutput 的实例。默认为 None。

extra_kwargs Optional[RuntimeParameter[Dict[str, Any]]]

将传递给 llama_cpp 库的 Llama 类的其他关键字参数字典。默认为 {}

tokenizer_id 可选[RuntimeParameter[str]]

tokenizer Hugging Face Hub repo id 或包含 tokenizer 配置文件目录的路径。如果未提供,将使用与 model 关联的 tokenizer。默认为 None

use_magpie_template bool

用于启用/禁用应用 Magpie 预查询模板的标志。默认为 False

magpie_pre_query_template Union[MagpieAvailablePreQueryTemplates, str, None]

要应用于提示或发送到 LLM 以生成指令或后续用户消息的预查询模板。有效值为 "llama3"、"qwen2" 或提供的另一个预查询模板。默认为 None

_model Optional[Llama]

Llama 模型实例。此属性供内部使用,不应直接访问。它将在 load 方法中设置。

运行时参数
  • model_path:GGUF 量化模型的路径。
  • n_gpu_layers:用于 GPU 的层数。默认为 -1
  • chat_format:模型要使用的聊天格式。默认为 None
  • verbose:是否打印详细输出。默认为 False
  • extra_kwargs:将传递给 llama_cpp 库的 Llama 类的其他关键字参数字典。默认为 {}
参考

示例

生成文本

from pathlib import Path
from distilabel.models.llms import LlamaCppLLM

# You can follow along this example downloading the following model running the following
# command in the terminal, that will download the model to the `Downloads` folder:
# curl -L -o ~/Downloads/openhermes-2.5-mistral-7b.Q4_K_M.gguf https://hugging-face.cn/TheBloke/OpenHermes-2.5-Mistral-7B-GGUF/resolve/main/openhermes-2.5-mistral-7b.Q4_K_M.gguf

model_path = "Downloads/openhermes-2.5-mistral-7b.Q4_K_M.gguf"

llm = LlamaCppLLM(
    model_path=str(Path.home() / model_path),
    n_gpu_layers=-1,  # To use the GPU if available
    n_ctx=1024,       # Set the context size
)

llm.load()

# Call the model
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

生成结构化数据

from pathlib import Path
from distilabel.models.llms import LlamaCppLLM

model_path = "Downloads/openhermes-2.5-mistral-7b.Q4_K_M.gguf"

class User(BaseModel):
    name: str
    last_name: str
    id: int

llm = LlamaCppLLM(
    model_path=str(Path.home() / model_path),  # type: ignore
    n_gpu_layers=-1,
    n_ctx=1024,
    structured_output={"format": "json", "schema": Character},
)

llm.load()

# Call the model
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
源代码位于 src/distilabel/models/llms/llamacpp.py
class LlamaCppLLM(LLM, MagpieChatTemplateMixin):
    """llama.cpp LLM implementation running the Python bindings for the C++ code.

    Attributes:
        model_path: contains the path to the GGUF quantized model, compatible with the
            installed version of the `llama.cpp` Python bindings.
        n_gpu_layers: the number of layers to use for the GPU. Defaults to `-1`, meaning that
            the available GPU device will be used.
        chat_format: the chat format to use for the model. Defaults to `None`, which means the
            Llama format will be used.
        n_ctx: the context size to use for the model. Defaults to `512`.
        n_batch: the prompt processing maximum batch size to use for the model. Defaults to `512`.
        seed: random seed to use for the generation. Defaults to `4294967295`.
        verbose: whether to print verbose output. Defaults to `False`.
        structured_output: a dictionary containing the structured output configuration or if more
            fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None.
        extra_kwargs: additional dictionary of keyword arguments that will be passed to the
            `Llama` class of `llama_cpp` library. Defaults to `{}`.
        tokenizer_id: the tokenizer Hugging Face Hub repo id or a path to a directory containing
            the tokenizer config files. If not provided, the one associated to the `model`
            will be used. Defaults to `None`.
        use_magpie_template: a flag used to enable/disable applying the Magpie pre-query
            template. Defaults to `False`.
        magpie_pre_query_template: the pre-query template to be applied to the prompt or
            sent to the LLM to generate an instruction or a follow up user message. Valid
            values are "llama3", "qwen2" or another pre-query template provided. Defaults
            to `None`.
        _model: the Llama model instance. This attribute is meant to be used internally and
            should not be accessed directly. It will be set in the `load` method.

    Runtime parameters:
        - `model_path`: the path to the GGUF quantized model.
        - `n_gpu_layers`: the number of layers to use for the GPU. Defaults to `-1`.
        - `chat_format`: the chat format to use for the model. Defaults to `None`.
        - `verbose`: whether to print verbose output. Defaults to `False`.
        - `extra_kwargs`: additional dictionary of keyword arguments that will be passed to the
            `Llama` class of `llama_cpp` library. Defaults to `{}`.

    References:
        - [`llama.cpp`](https://github.com/ggerganov/llama.cpp)
        - [`llama-cpp-python`](https://github.com/abetlen/llama-cpp-python)

    Examples:
        Generate text:

        ```python
        from pathlib import Path
        from distilabel.models.llms import LlamaCppLLM

        # You can follow along this example downloading the following model running the following
        # command in the terminal, that will download the model to the `Downloads` folder:
        # curl -L -o ~/Downloads/openhermes-2.5-mistral-7b.Q4_K_M.gguf https://hugging-face.cn/TheBloke/OpenHermes-2.5-Mistral-7B-GGUF/resolve/main/openhermes-2.5-mistral-7b.Q4_K_M.gguf

        model_path = "Downloads/openhermes-2.5-mistral-7b.Q4_K_M.gguf"

        llm = LlamaCppLLM(
            model_path=str(Path.home() / model_path),
            n_gpu_layers=-1,  # To use the GPU if available
            n_ctx=1024,       # Set the context size
        )

        llm.load()

        # Call the model
        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```

        Generate structured data:

        ```python
        from pathlib import Path
        from distilabel.models.llms import LlamaCppLLM

        model_path = "Downloads/openhermes-2.5-mistral-7b.Q4_K_M.gguf"

        class User(BaseModel):
            name: str
            last_name: str
            id: int

        llm = LlamaCppLLM(
            model_path=str(Path.home() / model_path),  # type: ignore
            n_gpu_layers=-1,
            n_ctx=1024,
            structured_output={"format": "json", "schema": Character},
        )

        llm.load()

        # Call the model
        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
        ```
    """

    model_path: RuntimeParameter[FilePath] = Field(
        default=None, description="The path to the GGUF quantized model.", exclude=True
    )
    n_gpu_layers: RuntimeParameter[int] = Field(
        default=-1,
        description="The number of layers that will be loaded in the GPU.",
    )
    chat_format: Optional[RuntimeParameter[str]] = Field(
        default=None,
        description="The chat format to use for the model. Defaults to `None`, which means the Llama format will be used.",
    )

    n_ctx: int = 512
    n_batch: int = 512
    seed: int = 4294967295
    verbose: RuntimeParameter[bool] = Field(
        default=False,
        description="Whether to print verbose output from llama.cpp library.",
    )
    extra_kwargs: Optional[RuntimeParameter[Dict[str, Any]]] = Field(
        default_factory=dict,
        description="Additional dictionary of keyword arguments that will be passed to the"
        " `Llama` class of `llama_cpp` library. See all the supported arguments at: "
        "https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__init__",
    )
    structured_output: Optional[RuntimeParameter[OutlinesStructuredOutputType]] = Field(
        default=None,
        description="The structured output format to use across all the generations.",
    )
    tokenizer_id: Optional[RuntimeParameter[str]] = Field(
        default=None,
        description="The Hugging Face Hub repo id or a path to a directory containing"
        " the tokenizer config files. If not provided, the one associated to the `model`"
        " will be used.",
    )
    _logits_processor: Optional["LogitsProcessorList"] = PrivateAttr(default=None)
    _model: Optional["Llama"] = PrivateAttr(...)

    @model_validator(mode="after")
    def validate_magpie_usage(
        self,
    ) -> "LlamaCppLLM":
        """Validates that magpie usage is valid."""

        if self.use_magpie_template and self.tokenizer_id is None:
            raise ValueError(
                "`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please,"
                " set a `tokenizer_id` and try again."
            )

    def load(self) -> None:
        """Loads the `Llama` model from the `model_path`."""
        try:
            from llama_cpp import Llama
        except ImportError as ie:
            raise ImportError(
                "The `llama_cpp` package is required to use the `LlamaCppLLM` class."
            ) from ie

        self._model = Llama(
            model_path=self.model_path.as_posix(),
            seed=self.seed,
            n_ctx=self.n_ctx,
            n_batch=self.n_batch,
            chat_format=self.chat_format,
            n_gpu_layers=self.n_gpu_layers,
            verbose=self.verbose,
            **self.extra_kwargs,
        )

        if self.structured_output:
            self._logits_processor = self._prepare_structured_output(
                self.structured_output
            )

        if self.use_magpie_template or self.magpie_pre_query_template:
            if not self.tokenizer_id:
                raise ValueError(
                    "The Hugging Face Hub repo id or a path to a directory containing"
                    " the tokenizer config files is required when using the `use_magpie_template`"
                    " or `magpie_pre_query_template` runtime parameters."
                )

        if self.tokenizer_id:
            try:
                from transformers import AutoTokenizer
            except ImportError as ie:
                raise ImportError(
                    "Transformers is not installed. Please install it using `pip install 'distilabel[hf-transformers]'`."
                ) from ie
            self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id)
            if self._tokenizer.chat_template is None:
                raise ValueError(
                    "The tokenizer does not have a chat template. Please use a tokenizer with a chat template."
                )

        # NOTE: Here because of the custom `logging` interface used, since it will create the logging name
        # out of the model name, which won't be available until the `Llama` instance is created.
        super().load()

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self._model.model_path  # type: ignore

    def _generate_chat_completion(
        self,
        input: FormattedInput,
        max_new_tokens: int = 128,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
        temperature: float = 1.0,
        top_p: float = 1.0,
        extra_generation_kwargs: Optional[Dict[str, Any]] = None,
    ) -> "CreateChatCompletionResponse":
        return self._model.create_chat_completion(  # type: ignore
            messages=input,  # type: ignore
            max_tokens=max_new_tokens,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty,
            temperature=temperature,
            top_p=top_p,
            logits_processor=self._logits_processor,
            **(extra_generation_kwargs or {}),
        )

    def prepare_input(self, input: "StandardInput") -> str:
        """Prepares the input (applying the chat template and tokenization) for the provided
        input.

        Args:
            input: the input list containing chat items.

        Returns:
            The prompt to send to the LLM.
        """
        prompt: str = (
            self._tokenizer.apply_chat_template(  # type: ignore
                conversation=input,  # type: ignore
                tokenize=False,
                add_generation_prompt=True,
            )
            if input
            else ""
        )
        return super().apply_magpie_pre_query_template(prompt, input)

    def _generate_with_text_generation(
        self,
        input: FormattedInput,
        max_new_tokens: int = 128,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
        temperature: float = 1.0,
        top_p: float = 1.0,
        extra_generation_kwargs: Optional[Dict[str, Any]] = None,
    ) -> "CreateChatCompletionResponse":
        prompt = self.prepare_input(input)
        return self._model.create_completion(
            prompt=prompt,
            max_tokens=max_new_tokens,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty,
            temperature=temperature,
            top_p=top_p,
            logits_processor=self._logits_processor,
            **(extra_generation_kwargs or {}),
        )

    @validate_call
    def generate(  # type: ignore
        self,
        inputs: List[FormattedInput],
        num_generations: int = 1,
        max_new_tokens: int = 128,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
        temperature: float = 1.0,
        top_p: float = 1.0,
        extra_generation_kwargs: Optional[Dict[str, Any]] = None,
    ) -> List[GenerateOutput]:
        """Generates `num_generations` responses for the given input using the Llama model.

        Args:
            inputs: a list of inputs in chat format to generate responses for.
            num_generations: the number of generations to create per input. Defaults to
                `1`.
            max_new_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `128`.
            frequency_penalty: the repetition penalty to use for the generation. Defaults
                to `0.0`.
            presence_penalty: the presence penalty to use for the generation. Defaults to
                `0.0`.
            temperature: the temperature to use for the generation. Defaults to `0.1`.
            top_p: the top-p value to use for the generation. Defaults to `1.0`.
            extra_generation_kwargs: dictionary with additional arguments to be passed to
                the `create_chat_completion` method. Reference at
                https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """
        structured_output = None
        batch_outputs = []
        for input in inputs:
            if isinstance(input, tuple):
                input, structured_output = input
            elif self.structured_output:
                structured_output = self.structured_output

            outputs = []
            output_tokens = []
            for _ in range(num_generations):
                # NOTE(plaguss): There seems to be a bug in how the logits processor
                # is used. Basically it consumes the FSM internally, and it isn't reinitialized
                # after each generation, so subsequent calls yield nothing. This is a workaround
                # until is fixed in the `llama_cpp` or `outlines` libraries.
                if structured_output:
                    self._logits_processor = self._prepare_structured_output(
                        structured_output
                    )
                if self.tokenizer_id is None:
                    completion = self._generate_chat_completion(
                        input,
                        max_new_tokens,
                        frequency_penalty,
                        presence_penalty,
                        temperature,
                        top_p,
                        extra_generation_kwargs,
                    )
                    outputs.append(completion["choices"][0]["message"]["content"])
                    output_tokens.append(completion["usage"]["completion_tokens"])
                else:
                    completion: "CreateChatCompletionResponse" = (
                        self._generate_with_text_generation(  # type: ignore
                            input,
                            max_new_tokens,
                            frequency_penalty,
                            presence_penalty,
                            temperature,
                            top_p,
                            extra_generation_kwargs,
                        )
                    )
                    outputs.append(completion["choices"][0]["text"])
                    output_tokens.append(completion["usage"]["completion_tokens"])
            batch_outputs.append(
                prepare_output(
                    outputs,
                    input_tokens=[completion["usage"]["prompt_tokens"]]
                    * num_generations,
                    output_tokens=output_tokens,
                )
            )

        return batch_outputs

    def _prepare_structured_output(
        self, structured_output: Optional[OutlinesStructuredOutputType] = None
    ) -> Union["LogitsProcessorList", "LogitsProcessor"]:
        """Creates the appropriate function to filter tokens to generate structured outputs.

        Args:
            structured_output: the configuration dict to prepare the structured output.

        Returns:
            The callable that will be used to guide the generation of the model.
        """
        from distilabel.steps.tasks.structured_outputs.outlines import (
            prepare_guided_output,
        )

        result = prepare_guided_output(structured_output, "llamacpp", self._model)
        if (schema := result.get("schema")) and self.structured_output:
            self.structured_output["schema"] = schema
        return [result["processor"]]
model_name property

返回用于 LLM 的模型名称。

validate_magpie_usage()

验证 magpie 用法是否有效。

源代码位于 src/distilabel/models/llms/llamacpp.py
@model_validator(mode="after")
def validate_magpie_usage(
    self,
) -> "LlamaCppLLM":
    """Validates that magpie usage is valid."""

    if self.use_magpie_template and self.tokenizer_id is None:
        raise ValueError(
            "`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please,"
            " set a `tokenizer_id` and try again."
        )
load()

model_path 加载 Llama 模型。

源代码位于 src/distilabel/models/llms/llamacpp.py
def load(self) -> None:
    """Loads the `Llama` model from the `model_path`."""
    try:
        from llama_cpp import Llama
    except ImportError as ie:
        raise ImportError(
            "The `llama_cpp` package is required to use the `LlamaCppLLM` class."
        ) from ie

    self._model = Llama(
        model_path=self.model_path.as_posix(),
        seed=self.seed,
        n_ctx=self.n_ctx,
        n_batch=self.n_batch,
        chat_format=self.chat_format,
        n_gpu_layers=self.n_gpu_layers,
        verbose=self.verbose,
        **self.extra_kwargs,
    )

    if self.structured_output:
        self._logits_processor = self._prepare_structured_output(
            self.structured_output
        )

    if self.use_magpie_template or self.magpie_pre_query_template:
        if not self.tokenizer_id:
            raise ValueError(
                "The Hugging Face Hub repo id or a path to a directory containing"
                " the tokenizer config files is required when using the `use_magpie_template`"
                " or `magpie_pre_query_template` runtime parameters."
            )

    if self.tokenizer_id:
        try:
            from transformers import AutoTokenizer
        except ImportError as ie:
            raise ImportError(
                "Transformers is not installed. Please install it using `pip install 'distilabel[hf-transformers]'`."
            ) from ie
        self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id)
        if self._tokenizer.chat_template is None:
            raise ValueError(
                "The tokenizer does not have a chat template. Please use a tokenizer with a chat template."
            )

    # NOTE: Here because of the custom `logging` interface used, since it will create the logging name
    # out of the model name, which won't be available until the `Llama` instance is created.
    super().load()
prepare_input(input)

为提供的输入准备输入(应用聊天模板和分词)。

参数

名称 类型 描述 默认值
input StandardInput

包含聊天项的输入列表。

必需

返回

类型 描述
str

要发送给 LLM 的提示。

源代码位于 src/distilabel/models/llms/llamacpp.py
def prepare_input(self, input: "StandardInput") -> str:
    """Prepares the input (applying the chat template and tokenization) for the provided
    input.

    Args:
        input: the input list containing chat items.

    Returns:
        The prompt to send to the LLM.
    """
    prompt: str = (
        self._tokenizer.apply_chat_template(  # type: ignore
            conversation=input,  # type: ignore
            tokenize=False,
            add_generation_prompt=True,
        )
        if input
        else ""
    )
    return super().apply_magpie_pre_query_template(prompt, input)
generate(inputs, num_generations=1, max_new_tokens=128, frequency_penalty=0.0, presence_penalty=0.0, temperature=1.0, top_p=1.0, extra_generation_kwargs=None)

使用 Llama 模型为给定输入生成 num_generations 个响应。

参数

名称 类型 描述 默认值
inputs List[FormattedInput]

聊天格式的输入列表,用于生成响应。

必需
num_generations int

每个输入要创建的代数。默认为 1

1
max_new_tokens int

模型将生成的最大新 token 数。默认为 128

128
frequency_penalty float

用于生成的重复惩罚。默认为 0.0

0.0
presence_penalty float

用于存在的惩罚。默认为 0.0

0.0
temperature float

用于生成的温度。默认为 0.1

1.0
top_p float

用于生成的 top-p 值。默认为 1.0

1.0
extra_generation_kwargs Optional[Dict[str, Any]]

包含要传递给 create_chat_completion 方法的其他参数的字典。参考 https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion

None

返回

类型 描述
List[GenerateOutput]

包含每个输入的生成响应的字符串列表的列表。

源代码位于 src/distilabel/models/llms/llamacpp.py
@validate_call
def generate(  # type: ignore
    self,
    inputs: List[FormattedInput],
    num_generations: int = 1,
    max_new_tokens: int = 128,
    frequency_penalty: float = 0.0,
    presence_penalty: float = 0.0,
    temperature: float = 1.0,
    top_p: float = 1.0,
    extra_generation_kwargs: Optional[Dict[str, Any]] = None,
) -> List[GenerateOutput]:
    """Generates `num_generations` responses for the given input using the Llama model.

    Args:
        inputs: a list of inputs in chat format to generate responses for.
        num_generations: the number of generations to create per input. Defaults to
            `1`.
        max_new_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `128`.
        frequency_penalty: the repetition penalty to use for the generation. Defaults
            to `0.0`.
        presence_penalty: the presence penalty to use for the generation. Defaults to
            `0.0`.
        temperature: the temperature to use for the generation. Defaults to `0.1`.
        top_p: the top-p value to use for the generation. Defaults to `1.0`.
        extra_generation_kwargs: dictionary with additional arguments to be passed to
            the `create_chat_completion` method. Reference at
            https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """
    structured_output = None
    batch_outputs = []
    for input in inputs:
        if isinstance(input, tuple):
            input, structured_output = input
        elif self.structured_output:
            structured_output = self.structured_output

        outputs = []
        output_tokens = []
        for _ in range(num_generations):
            # NOTE(plaguss): There seems to be a bug in how the logits processor
            # is used. Basically it consumes the FSM internally, and it isn't reinitialized
            # after each generation, so subsequent calls yield nothing. This is a workaround
            # until is fixed in the `llama_cpp` or `outlines` libraries.
            if structured_output:
                self._logits_processor = self._prepare_structured_output(
                    structured_output
                )
            if self.tokenizer_id is None:
                completion = self._generate_chat_completion(
                    input,
                    max_new_tokens,
                    frequency_penalty,
                    presence_penalty,
                    temperature,
                    top_p,
                    extra_generation_kwargs,
                )
                outputs.append(completion["choices"][0]["message"]["content"])
                output_tokens.append(completion["usage"]["completion_tokens"])
            else:
                completion: "CreateChatCompletionResponse" = (
                    self._generate_with_text_generation(  # type: ignore
                        input,
                        max_new_tokens,
                        frequency_penalty,
                        presence_penalty,
                        temperature,
                        top_p,
                        extra_generation_kwargs,
                    )
                )
                outputs.append(completion["choices"][0]["text"])
                output_tokens.append(completion["usage"]["completion_tokens"])
        batch_outputs.append(
            prepare_output(
                outputs,
                input_tokens=[completion["usage"]["prompt_tokens"]]
                * num_generations,
                output_tokens=output_tokens,
            )
        )

    return batch_outputs
_prepare_structured_output(structured_output=None)

创建适当的函数来过滤 token,以生成结构化输出。

参数

名称 类型 描述 默认值
structured_output Optional[OutlinesStructuredOutputType]

配置字典,用于准备结构化输出。

None

返回

类型 描述
Union[LogitsProcessorList, LogitsProcessor]

将用于指导模型生成的 callable。

源代码位于 src/distilabel/models/llms/llamacpp.py
def _prepare_structured_output(
    self, structured_output: Optional[OutlinesStructuredOutputType] = None
) -> Union["LogitsProcessorList", "LogitsProcessor"]:
    """Creates the appropriate function to filter tokens to generate structured outputs.

    Args:
        structured_output: the configuration dict to prepare the structured output.

    Returns:
        The callable that will be used to guide the generation of the model.
    """
    from distilabel.steps.tasks.structured_outputs.outlines import (
        prepare_guided_output,
    )

    result = prepare_guided_output(structured_output, "llamacpp", self._model)
    if (schema := result.get("schema")) and self.structured_output:
        self.structured_output["schema"] = schema
    return [result["processor"]]

MistralLLM

基类: AsyncLLM

运行异步 API 客户端的 Mistral LLM 实现。

属性

名称 类型 描述
model str

用于 LLM 的模型名称,例如 "mistral-tiny"、"mistral-large" 等。

endpoint str

用于 Mistral API 的端点。默认为 "https://api.mistral.ai"。

api_key 可选[RuntimeParameter[SecretStr]]

用于验证对 Mistral API 请求的 API 密钥。默认为 None,表示将使用为环境变量 OPENAI_API_KEY 设置的值;如果未设置,则为 None

max_retries RuntimeParameter[int]

请求失败时尝试的最大重试次数。默认为 5

timeout RuntimeParameter[int]

等待响应的最长时间(秒)。默认为 120

max_concurrent_requests RuntimeParameter[int]

要发送的最大并发请求数。默认为 64

structured_output 可选[RuntimeParameter[InstructorStructuredOutputType]]

一个字典,包含使用 instructor 的结构化输出配置。您可以在 distilabel.steps.tasks.structured_outputs.instructorInstructorStructuredOutputType 中查看字典结构。

_api_key_env_var str

用于 API 密钥的环境变量名称。它旨在内部使用。

_aclient Optional[Mistral]

用于 Mistral API 的 Mistral。它供内部使用。在 load 方法中设置。

运行时参数
  • api_key:用于验证对 Mistral API 请求的 API 密钥。
  • max_retries:请求失败时尝试的最大重试次数。默认为 5
  • timeout: 等待响应的最大时间(秒)。默认为 120
  • max_concurrent_requests: 要发送的最大并发请求数。默认为 64

示例

生成文本

from distilabel.models.llms import MistralLLM

llm = MistralLLM(model="open-mixtral-8x22b")

llm.load()

# Call the model
output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])

Generate structured data:

```python
from pydantic import BaseModel
from distilabel.models.llms import MistralLLM

class User(BaseModel):
    name: str
    last_name: str
    id: int

llm = MistralLLM(
    model="open-mixtral-8x22b",
    api_key="api.key",
    structured_output={"schema": User}
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
源代码位于 src/distilabel/models/llms/mistral.py
class MistralLLM(AsyncLLM):
    """Mistral LLM implementation running the async API client.

    Attributes:
        model: the model name to use for the LLM e.g. "mistral-tiny", "mistral-large", etc.
        endpoint: the endpoint to use for the Mistral API. Defaults to "https://api.mistral.ai".
        api_key: the API key to authenticate the requests to the Mistral API. Defaults to `None` which
            means that the value set for the environment variable `OPENAI_API_KEY` will be used, or
            `None` if not set.
        max_retries: the maximum number of retries to attempt when a request fails. Defaults to `5`.
        timeout: the maximum time in seconds to wait for a response. Defaults to `120`.
        max_concurrent_requests: the maximum number of concurrent requests to send. Defaults
            to `64`.
        structured_output: a dictionary containing the structured output configuration configuration
            using `instructor`. You can take a look at the dictionary structure in
            `InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`.
        _api_key_env_var: the name of the environment variable to use for the API key. It is meant to
            be used internally.
        _aclient: the `Mistral` to use for the Mistral API. It is meant to be used internally.
            Set in the `load` method.

    Runtime parameters:
        - `api_key`: the API key to authenticate the requests to the Mistral API.
        - `max_retries`: the maximum number of retries to attempt when a request fails.
            Defaults to `5`.
        - `timeout`: the maximum time in seconds to wait for a response. Defaults to `120`.
        - `max_concurrent_requests`: the maximum number of concurrent requests to send.
            Defaults to `64`.

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import MistralLLM

        llm = MistralLLM(model="open-mixtral-8x22b")

        llm.load()

        # Call the model
        output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])

        Generate structured data:

        ```python
        from pydantic import BaseModel
        from distilabel.models.llms import MistralLLM

        class User(BaseModel):
            name: str
            last_name: str
            id: int

        llm = MistralLLM(
            model="open-mixtral-8x22b",
            api_key="api.key",
            structured_output={"schema": User}
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
        ```
    """

    model: str
    endpoint: str = "https://api.mistral.ai"
    api_key: Optional[RuntimeParameter[SecretStr]] = Field(
        default_factory=lambda: os.getenv(_MISTRALAI_API_KEY_ENV_VAR_NAME),
        description="The API key to authenticate the requests to the Mistral API.",
    )
    max_retries: RuntimeParameter[int] = Field(
        default=6,
        description="The maximum number of times to retry the request to the API before"
        " failing.",
    )
    timeout: RuntimeParameter[int] = Field(
        default=120,
        description="The maximum time in seconds to wait for a response from the API.",
    )
    max_concurrent_requests: RuntimeParameter[int] = Field(
        default=64, description="The maximum number of concurrent requests to send."
    )
    structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = (
        Field(
            default=None,
            description="The structured output format to use across all the generations.",
        )
    )

    _num_generations_param_supported = False

    _api_key_env_var: str = PrivateAttr(_MISTRALAI_API_KEY_ENV_VAR_NAME)
    _aclient: Optional["Mistral"] = PrivateAttr(...)

    def load(self) -> None:
        """Loads the `Mistral` client to benefit from async requests."""
        super().load()

        try:
            from mistralai import Mistral
        except ImportError as ie:
            raise ImportError(
                "MistralAI Python client is not installed. Please install it using"
                " `pip install 'distilabel[mistralai]'`."
            ) from ie

        if self.api_key is None:
            raise ValueError(
                f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
                f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
            )

        self._aclient = Mistral(
            api_key=self.api_key.get_secret_value(),
            endpoint=self.endpoint,
            max_retries=self.max_retries,  # type: ignore
            timeout=self.timeout,  # type: ignore
            max_concurrent_requests=self.max_concurrent_requests,  # type: ignore
        )

        if self.structured_output:
            result = self._prepare_structured_output(
                structured_output=self.structured_output,
                client=self._aclient,
                framework="mistral",
            )
            self._aclient = result.get("client")  # type: ignore
            if structured_output := result.get("structured_output"):
                self.structured_output = structured_output

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    # TODO: add `num_generations` parameter once Mistral client allows `n` parameter
    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: FormattedInput,
        max_new_tokens: Optional[int] = None,
        temperature: Optional[float] = None,
        top_p: Optional[float] = None,
    ) -> GenerateOutput:
        """Generates `num_generations` responses for the given input using the MistralAI async
        client.

        Args:
            input: a single input in chat format to generate responses for.
            max_new_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `128`.
            temperature: the temperature to use for the generation. Defaults to `0.1`.
            top_p: the top-p value to use for the generation. Defaults to `1.0`.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """
        structured_output = None
        if isinstance(input, tuple):
            input, structured_output = input
            result = self._prepare_structured_output(
                structured_output=structured_output,
                client=self._aclient,
                framework="mistral",
            )
            self._aclient = result.get("client")

        if structured_output is None and self.structured_output is not None:
            structured_output = self.structured_output

        kwargs = {
            "messages": input,  # type: ignore
            "model": self.model,
            "max_tokens": max_new_tokens,
            "temperature": temperature,
            "top_p": top_p,
        }
        generations = []
        if structured_output:
            kwargs = self._prepare_kwargs(kwargs, structured_output)
            # TODO: This should work just with the _aclient.chat method, but it's not working.
            # We need to check instructor and see if we can create a PR.
            completion = await self._aclient.chat.completions.create(**kwargs)  # type: ignore
        else:
            # completion = await self._aclient.chat(**kwargs)  # type: ignore
            completion = await self._aclient.chat.complete_async(**kwargs)  # type: ignore

        if structured_output:
            return prepare_output(
                [completion.model_dump_json()],
                **self._get_llm_statistics(completion._raw_response),
            )

        for choice in completion.choices:
            if (content := choice.message.content) is None:
                self._logger.warning(  # type: ignore
                    f"Received no response using MistralAI client (model: '{self.model}')."
                    f" Finish reason was: {choice.finish_reason}"
                )
            generations.append(content)

        return prepare_output(generations, **self._get_llm_statistics(completion))

    @staticmethod
    def _get_llm_statistics(completion: "ChatCompletionResponse") -> "LLMStatistics":
        return {
            "input_tokens": [completion.usage.prompt_tokens],
            "output_tokens": [completion.usage.completion_tokens],
        }
model_name 属性

返回用于 LLM 的模型名称。

load()

加载 Mistral 客户端以利用异步请求的优势。

源代码位于 src/distilabel/models/llms/mistral.py
def load(self) -> None:
    """Loads the `Mistral` client to benefit from async requests."""
    super().load()

    try:
        from mistralai import Mistral
    except ImportError as ie:
        raise ImportError(
            "MistralAI Python client is not installed. Please install it using"
            " `pip install 'distilabel[mistralai]'`."
        ) from ie

    if self.api_key is None:
        raise ValueError(
            f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
            f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
        )

    self._aclient = Mistral(
        api_key=self.api_key.get_secret_value(),
        endpoint=self.endpoint,
        max_retries=self.max_retries,  # type: ignore
        timeout=self.timeout,  # type: ignore
        max_concurrent_requests=self.max_concurrent_requests,  # type: ignore
    )

    if self.structured_output:
        result = self._prepare_structured_output(
            structured_output=self.structured_output,
            client=self._aclient,
            framework="mistral",
        )
        self._aclient = result.get("client")  # type: ignore
        if structured_output := result.get("structured_output"):
            self.structured_output = structured_output
agenerate(input, max_new_tokens=None, temperature=None, top_p=None) 异步

为给定输入生成 num_generations 个响应,使用 MistralAI 异步客户端。

参数

名称 类型 描述 默认值
input FormattedInput

以聊天格式的单个输入,用于生成响应。

必需
max_new_tokens Optional[int]

模型将生成的最大新 token 数。默认为 128

None
temperature Optional[float]

用于生成的温度。默认为 0.1

None
top_p Optional[float]

用于生成的 top-p 值。默认为 1.0

None

返回

类型 描述
GenerateOutput

包含每个输入的生成响应的字符串列表的列表。

源代码位于 src/distilabel/models/llms/mistral.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: FormattedInput,
    max_new_tokens: Optional[int] = None,
    temperature: Optional[float] = None,
    top_p: Optional[float] = None,
) -> GenerateOutput:
    """Generates `num_generations` responses for the given input using the MistralAI async
    client.

    Args:
        input: a single input in chat format to generate responses for.
        max_new_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `128`.
        temperature: the temperature to use for the generation. Defaults to `0.1`.
        top_p: the top-p value to use for the generation. Defaults to `1.0`.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """
    structured_output = None
    if isinstance(input, tuple):
        input, structured_output = input
        result = self._prepare_structured_output(
            structured_output=structured_output,
            client=self._aclient,
            framework="mistral",
        )
        self._aclient = result.get("client")

    if structured_output is None and self.structured_output is not None:
        structured_output = self.structured_output

    kwargs = {
        "messages": input,  # type: ignore
        "model": self.model,
        "max_tokens": max_new_tokens,
        "temperature": temperature,
        "top_p": top_p,
    }
    generations = []
    if structured_output:
        kwargs = self._prepare_kwargs(kwargs, structured_output)
        # TODO: This should work just with the _aclient.chat method, but it's not working.
        # We need to check instructor and see if we can create a PR.
        completion = await self._aclient.chat.completions.create(**kwargs)  # type: ignore
    else:
        # completion = await self._aclient.chat(**kwargs)  # type: ignore
        completion = await self._aclient.chat.complete_async(**kwargs)  # type: ignore

    if structured_output:
        return prepare_output(
            [completion.model_dump_json()],
            **self._get_llm_statistics(completion._raw_response),
        )

    for choice in completion.choices:
        if (content := choice.message.content) is None:
            self._logger.warning(  # type: ignore
                f"Received no response using MistralAI client (model: '{self.model}')."
                f" Finish reason was: {choice.finish_reason}"
            )
        generations.append(content)

    return prepare_output(generations, **self._get_llm_statistics(completion))

MlxLLM

基类: LLMMagpieChatTemplateMixin

Apple MLX LLM 实现。

属性

名称 类型 描述
path_or_hf_repo str

模型路径或 Hugging Face Hub 仓库 ID。

tokenizer_config Dict[str, Any]

tokenizer 配置。

mlx_model_config Dict[str, Any]

MLX 模型配置。

adapter_path Optional[str]

adapter 的路径。

use_magpie_template bool

用于启用/禁用应用 Magpie 预查询模板的标志。默认为 False

magpie_pre_query_template Union[MagpieAvailablePreQueryTemplates, str, None]

要应用于提示或发送到 LLM 以生成指令或后续用户消息的预查询模板。有效值为 "llama3"、"qwen2" 或提供的另一个预查询模板。默认为 None

图标

:apple

示例

生成文本

from distilabel.models.llms import MlxLLM

llm = MlxLLM(path_or_hf_repo="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit")

llm.load()

# Call the model
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
源代码位于 src/distilabel/models/llms/mlx.py
class MlxLLM(LLM, MagpieChatTemplateMixin):
    """Apple MLX LLM implementation.

    Attributes:
        path_or_hf_repo: the path to the model or the Hugging Face Hub repo id.
        tokenizer_config: the tokenizer configuration.
        mlx_model_config: the MLX model configuration.
        adapter_path: the path to the adapter.
        use_magpie_template: a flag used to enable/disable applying the Magpie pre-query
            template. Defaults to `False`.
        magpie_pre_query_template: the pre-query template to be applied to the prompt or
            sent to the LLM to generate an instruction or a follow up user message. Valid
            values are "llama3", "qwen2" or another pre-query template provided. Defaults
            to `None`.

    Icon:
        `:apple:`

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import MlxLLM

        llm = MlxLLM(path_or_hf_repo="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit")

        llm.load()

        # Call the model
        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```
    """

    path_or_hf_repo: str
    tokenizer_config: Dict[str, Any] = Field(default_factory=dict)
    mlx_model_config: Dict[str, Any] = Field(default_factory=dict)
    adapter_path: Optional[str] = None

    _model: Optional["nn.Module"] = PrivateAttr(None)
    _tokenizer: Optional["TokenizerWrapper"] = PrivateAttr(None)
    _mlx_generate: Optional[Callable] = PrivateAttr(None)
    _make_sampler: Optional[Callable] = PrivateAttr(None)

    def load(self) -> None:
        """Loads the model and tokenizer and creates the text generation pipeline. In addition,
        it will configure the tokenizer chat template."""
        try:
            import mlx  # noqa
            from mlx_lm.utils import generate, load
            from mlx_lm.sample_utils import make_sampler
        except ImportError as ie:
            raise ImportError(
                "MLX is not installed. Please install it using `pip install 'distilabel[mlx]'`."
            ) from ie

        self._model, self._tokenizer = load(
            self.path_or_hf_repo,
            tokenizer_config=self.tokenizer_config,
            model_config=self.mlx_model_config,
            adapter_path=self.adapter_path,
        )

        if self._tokenizer.pad_token is None:
            self._tokenizer.pad_token = self._tokenizer.eos_token

        self._mlx_generate = generate
        self._make_sampler = make_sampler
        super().load()

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.path_or_hf_repo

    def prepare_input(self, input: Union["StandardInput", str]) -> str:
        """Prepares the input (applying the chat template and tokenization) for the provided
        input.

        Args:
            input: the input list containing chat items.

        Returns:
            The prompt to send to the LLM.
        """
        if isinstance(input, str):
            return input

        prompt: str = (
            self._tokenizer.apply_chat_template(  # type: ignore
                input,
                tokenize=False,
                add_generation_prompt=True,
            )
            if input
            else ""
        )
        return super().apply_magpie_pre_query_template(prompt, input)

    @validate_call
    def generate(  # type: ignore
        self,
        inputs: List[Union[StandardInput, str]],
        num_generations: int = 1,
        max_tokens: int = 256,
        logits_processors: Optional[List[Callable]] = None,
        max_kv_size: Optional[int] = None,
        prompt_cache: Optional[Any] = None,
        prefill_step_size: int = 512,
        kv_bits: Optional[int] = None,
        kv_group_size: int = 64,
        quantized_kv_start: int = 0,
        prompt_progress_callback: Optional[Callable[[int, int], None]] = None,
        temp: float = 0.0,
        top_p: float = 0.0,
        min_p: float = 0.0,
        min_tokens_to_keep: int = 1,
        top_k: int = -1,
    ) -> List[GenerateOutput]:
        """Generates `num_generations` responses for each input using the text generation
        pipeline.

        Args:
            inputs: the inputs to generate responses for.
            num_generations: the number of generations to create per input. Defaults to
                `1`.
            max_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `128`.
            logits_processors: the logits processors to use for the generation. Defaults to
                `None`.
            max_kv_size: the maximum size of the key-value cache. Defaults to `None`.
            prompt_cache: the prompt cache to use for the generation. Defaults to `None`.
            prefill_step_size: the prefill step size. Defaults to `512`.
            kv_bits: the number of bits to use for the key-value cache. Defaults to `None`.
            kv_group_size: the group size for the key-value cache. Defaults to `64`.
            quantized_kv_start: the start of the quantized key-value cache. Defaults to `0`.
            prompt_progress_callback: the callback to use for the generation. Defaults to
                `None`.
            temp: The temperature for text generation. Defaults to `0.0`.
            top_p: The top-p value used for the generation. Defaults to `0.0`.
            min_p: The min-p value used for the generation. Defaults to `0.0`.
            min_tokens_to_keep: Minimum number of tokens to keep for sampling after
                filtering. Must be at least 1. Defaults to `1`.
            top_k: The top-k value used for the generation. Defaults to `-1`.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """

        sampler = self._make_sampler(  # type: ignore
            temp=temp,
            top_p=top_p,
            min_p=min_p,
            min_tokens_to_keep=min_tokens_to_keep,
            top_k=top_k,
        )
        structured_output = None
        result = []
        for input in inputs:
            if isinstance(input, tuple):
                input, structured_output = input

            output: List[str] = []
            for _ in range(num_generations):
                if structured_output:  # will raise a NotImplementedError
                    self._prepare_structured_output(structured_output)
                prompt = self.prepare_input(input)
                generation = self._mlx_generate(  # type: ignore
                    prompt=prompt,
                    model=self._model,
                    tokenizer=self._tokenizer,
                    logits_processors=logits_processors,
                    max_tokens=max_tokens,
                    sampler=sampler,
                    max_kv_size=max_kv_size,
                    prompt_cache=prompt_cache,
                    prefill_step_size=prefill_step_size,
                    kv_bits=kv_bits,
                    kv_group_size=kv_group_size,
                    quantized_kv_start=quantized_kv_start,
                    prompt_progress_callback=prompt_progress_callback,
                )

                output.append(generation)

            result.append(
                prepare_output(
                    generations=output,
                    input_tokens=[compute_tokens(input, self._tokenizer.encode)],  # type: ignore
                    output_tokens=[
                        compute_tokens(
                            text_or_messages=generation,
                            tokenizer=self._tokenizer.encode,  # type: ignore
                        )
                        for generation in output
                    ],
                )
            )
        return result
model_name 属性

返回用于 LLM 的模型名称。

load()

加载模型和 tokenizer,并创建文本生成 pipeline。 此外,它还将配置 tokenizer 聊天模板。

源代码位于 src/distilabel/models/llms/mlx.py
def load(self) -> None:
    """Loads the model and tokenizer and creates the text generation pipeline. In addition,
    it will configure the tokenizer chat template."""
    try:
        import mlx  # noqa
        from mlx_lm.utils import generate, load
        from mlx_lm.sample_utils import make_sampler
    except ImportError as ie:
        raise ImportError(
            "MLX is not installed. Please install it using `pip install 'distilabel[mlx]'`."
        ) from ie

    self._model, self._tokenizer = load(
        self.path_or_hf_repo,
        tokenizer_config=self.tokenizer_config,
        model_config=self.mlx_model_config,
        adapter_path=self.adapter_path,
    )

    if self._tokenizer.pad_token is None:
        self._tokenizer.pad_token = self._tokenizer.eos_token

    self._mlx_generate = generate
    self._make_sampler = make_sampler
    super().load()
prepare_input(input)

为提供的输入准备输入(应用聊天模板和分词)。

参数

名称 类型 描述 默认值
input Union[StandardInput, str]

包含聊天项的输入列表。

必需

返回

类型 描述
str

要发送给 LLM 的提示。

源代码位于 src/distilabel/models/llms/mlx.py
def prepare_input(self, input: Union["StandardInput", str]) -> str:
    """Prepares the input (applying the chat template and tokenization) for the provided
    input.

    Args:
        input: the input list containing chat items.

    Returns:
        The prompt to send to the LLM.
    """
    if isinstance(input, str):
        return input

    prompt: str = (
        self._tokenizer.apply_chat_template(  # type: ignore
            input,
            tokenize=False,
            add_generation_prompt=True,
        )
        if input
        else ""
    )
    return super().apply_magpie_pre_query_template(prompt, input)
generate(inputs, num_generations=1, max_tokens=256, logits_processors=None, max_kv_size=None, prompt_cache=None, prefill_step_size=512, kv_bits=None, kv_group_size=64, quantized_kv_start=0, prompt_progress_callback=None, temp=0.0, top_p=0.0, min_p=0.0, min_tokens_to_keep=1, top_k=-1)

使用文本生成 pipeline 为每个输入生成 num_generations 个响应。

参数

名称 类型 描述 默认值
inputs List[Union[StandardInput, str]]

要为其生成响应的输入。

必需
num_generations int

每个输入要创建的代数。默认为 1

1
max_tokens int

模型将生成的最大新 token 数。默认为 128

256
logits_processors Optional[List[Callable]]

用于生成的 logits 处理器。默认为 `None`。

None
max_kv_size Optional[int]

键值缓存的最大大小。默认为 `None`。

None
prompt_cache Optional[Any]

用于生成的 prompt 缓存。默认为 `None`。

None
prefill_step_size int

预填充步长。默认为 `512`。

512
kv_bits Optional[int]

用于键值缓存的位数。默认为 `None`。

None
kv_group_size int

键值缓存的组大小。默认为 `64`。

64
quantized_kv_start int

量化键值缓存的起始位置。默认为 `0`。

0
prompt_progress_callback Optional[Callable[[int, int], None]]

用于生成的回调。默认为 `None`。

None
temp float

文本生成的温度。默认为 `0.0`。

0.0
top_p float

用于生成的 top-p 值。默认为 `0.0`。

0.0
min_p float

用于生成的 min-p 值。默认为 `0.0`。

0.0
min_tokens_to_keep int

过滤后用于采样的最小 token 数。必须至少为 1。默认为 `1`。

1
top_k int

用于生成的 top-k 值。默认为 `-1`。

-1

返回

类型 描述
List[GenerateOutput]

包含每个输入的生成响应的字符串列表的列表。

源代码位于 src/distilabel/models/llms/mlx.py
@validate_call
def generate(  # type: ignore
    self,
    inputs: List[Union[StandardInput, str]],
    num_generations: int = 1,
    max_tokens: int = 256,
    logits_processors: Optional[List[Callable]] = None,
    max_kv_size: Optional[int] = None,
    prompt_cache: Optional[Any] = None,
    prefill_step_size: int = 512,
    kv_bits: Optional[int] = None,
    kv_group_size: int = 64,
    quantized_kv_start: int = 0,
    prompt_progress_callback: Optional[Callable[[int, int], None]] = None,
    temp: float = 0.0,
    top_p: float = 0.0,
    min_p: float = 0.0,
    min_tokens_to_keep: int = 1,
    top_k: int = -1,
) -> List[GenerateOutput]:
    """Generates `num_generations` responses for each input using the text generation
    pipeline.

    Args:
        inputs: the inputs to generate responses for.
        num_generations: the number of generations to create per input. Defaults to
            `1`.
        max_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `128`.
        logits_processors: the logits processors to use for the generation. Defaults to
            `None`.
        max_kv_size: the maximum size of the key-value cache. Defaults to `None`.
        prompt_cache: the prompt cache to use for the generation. Defaults to `None`.
        prefill_step_size: the prefill step size. Defaults to `512`.
        kv_bits: the number of bits to use for the key-value cache. Defaults to `None`.
        kv_group_size: the group size for the key-value cache. Defaults to `64`.
        quantized_kv_start: the start of the quantized key-value cache. Defaults to `0`.
        prompt_progress_callback: the callback to use for the generation. Defaults to
            `None`.
        temp: The temperature for text generation. Defaults to `0.0`.
        top_p: The top-p value used for the generation. Defaults to `0.0`.
        min_p: The min-p value used for the generation. Defaults to `0.0`.
        min_tokens_to_keep: Minimum number of tokens to keep for sampling after
            filtering. Must be at least 1. Defaults to `1`.
        top_k: The top-k value used for the generation. Defaults to `-1`.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """

    sampler = self._make_sampler(  # type: ignore
        temp=temp,
        top_p=top_p,
        min_p=min_p,
        min_tokens_to_keep=min_tokens_to_keep,
        top_k=top_k,
    )
    structured_output = None
    result = []
    for input in inputs:
        if isinstance(input, tuple):
            input, structured_output = input

        output: List[str] = []
        for _ in range(num_generations):
            if structured_output:  # will raise a NotImplementedError
                self._prepare_structured_output(structured_output)
            prompt = self.prepare_input(input)
            generation = self._mlx_generate(  # type: ignore
                prompt=prompt,
                model=self._model,
                tokenizer=self._tokenizer,
                logits_processors=logits_processors,
                max_tokens=max_tokens,
                sampler=sampler,
                max_kv_size=max_kv_size,
                prompt_cache=prompt_cache,
                prefill_step_size=prefill_step_size,
                kv_bits=kv_bits,
                kv_group_size=kv_group_size,
                quantized_kv_start=quantized_kv_start,
                prompt_progress_callback=prompt_progress_callback,
            )

            output.append(generation)

        result.append(
            prepare_output(
                generations=output,
                input_tokens=[compute_tokens(input, self._tokenizer.encode)],  # type: ignore
                output_tokens=[
                    compute_tokens(
                        text_or_messages=generation,
                        tokenizer=self._tokenizer.encode,  # type: ignore
                    )
                    for generation in output
                ],
            )
        )
    return result

MixtureOfAgentsLLM

基类: AsyncLLM

Mixture-of-Agents 实现。

一个 LLM 类,它利用 LLM 的集体优势来生成响应,如 “Mixture-of-Agents Enhances Large Language model Capabilities” 论文中所述。其中有一系列 LLM 提出/生成输出,下一轮/层的 LLM 可以将其用作辅助信息。最后,有一个 LLM 聚合输出以生成最终响应。

属性

名称 类型 描述
aggregator_llm LLM

聚合提议者 LLM 输出的 LLM

proposers_llms List[AsyncLLM]

提出要聚合的输出的 LLM 列表。

rounds int

`proposers_llms` 生成输出的层数或轮数。默认为 `1`。

参考

示例

生成文本

from distilabel.models.llms import MixtureOfAgentsLLM, InferenceEndpointsLLM

llm = MixtureOfAgentsLLM(
    aggregator_llm=InferenceEndpointsLLM(
        model_id="meta-llama/Meta-Llama-3-70B-Instruct",
        tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
    ),
    proposers_llms=[
        InferenceEndpointsLLM(
            model_id="meta-llama/Meta-Llama-3-70B-Instruct",
            tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
        ),
        InferenceEndpointsLLM(
            model_id="NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
            tokenizer_id="NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
        ),
        InferenceEndpointsLLM(
            model_id="HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1",
            tokenizer_id="HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1",
        ),
    ],
    rounds=2,
)

llm.load()

output = llm.generate_outputs(
    inputs=[
        [
            {
                "role": "user",
                "content": "My favorite witty review of The Rings of Power series is this: Input:",
            }
        ]
    ]
)
源代码位于 src/distilabel/models/llms/moa.py
class MixtureOfAgentsLLM(AsyncLLM):
    """`Mixture-of-Agents` implementation.

    An `LLM` class that leverages `LLM`s collective strenghts to generate a response,
    as described in the "Mixture-of-Agents Enhances Large Language model Capabilities"
    paper. There is a list of `LLM`s proposing/generating outputs that `LLM`s from the next
    round/layer can use as auxiliary information. Finally, there is an `LLM` that aggregates
    the outputs to generate the final response.

    Attributes:
        aggregator_llm: The `LLM` that aggregates the outputs of the proposer `LLM`s.
        proposers_llms: The list of `LLM`s that propose outputs to be aggregated.
        rounds: The number of layers or rounds that the `proposers_llms` will generate
            outputs. Defaults to `1`.

    References:
        - [Mixture-of-Agents Enhances Large Language Model Capabilities](https://arxiv.org/abs/2406.04692)

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import MixtureOfAgentsLLM, InferenceEndpointsLLM

        llm = MixtureOfAgentsLLM(
            aggregator_llm=InferenceEndpointsLLM(
                model_id="meta-llama/Meta-Llama-3-70B-Instruct",
                tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
            ),
            proposers_llms=[
                InferenceEndpointsLLM(
                    model_id="meta-llama/Meta-Llama-3-70B-Instruct",
                    tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
                ),
                InferenceEndpointsLLM(
                    model_id="NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
                    tokenizer_id="NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
                ),
                InferenceEndpointsLLM(
                    model_id="HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1",
                    tokenizer_id="HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1",
                ),
            ],
            rounds=2,
        )

        llm.load()

        output = llm.generate_outputs(
            inputs=[
                [
                    {
                        "role": "user",
                        "content": "My favorite witty review of The Rings of Power series is this: Input:",
                    }
                ]
            ]
        )
        ```
    """

    aggregator_llm: LLM
    proposers_llms: List[AsyncLLM] = Field(default_factory=list)
    rounds: int = 1

    @property
    def runtime_parameters_names(self) -> "RuntimeParametersNames":
        """Returns the runtime parameters of the `LLM`, which are a combination of the
        `RuntimeParameter`s of the `LLM`, the `aggregator_llm` and the `proposers_llms`.

        Returns:
            The runtime parameters of the `LLM`.
        """
        runtime_parameters_names = super().runtime_parameters_names
        del runtime_parameters_names["generation_kwargs"]
        return runtime_parameters_names

    def load(self) -> None:
        """Loads all the `LLM`s in the `MixtureOfAgents`."""
        super().load()

        for llm in self.proposers_llms:
            self._logger.debug(f"Loading proposer LLM in MoA: {llm}")  # type: ignore
            llm.load()

        self._logger.debug(f"Loading aggregator LLM in MoA: {self.aggregator_llm}")  # type: ignore
        self.aggregator_llm.load()

    @property
    def model_name(self) -> str:
        """Returns the aggregated model name."""
        return f"moa-{self.aggregator_llm.model_name}-{'-'.join([llm.model_name for llm in self.proposers_llms])}"

    def get_generation_kwargs(self) -> Dict[str, Any]:
        """Returns the generation kwargs of the `MixtureOfAgents` as a dictionary.

        Returns:
            The generation kwargs of the `MixtureOfAgents`.
        """
        return {
            "aggregator_llm": self.aggregator_llm.get_generation_kwargs(),
            "proposers_llms": [
                llm.get_generation_kwargs() for llm in self.proposers_llms
            ],
        }

    # `abstractmethod`, had to be implemented but not used
    async def agenerate(
        self, input: "FormattedInput", num_generations: int = 1, **kwargs: Any
    ) -> List[Union[str, None]]:
        raise NotImplementedError(
            "`agenerate` method is not implemented for `MixtureOfAgents`"
        )

    def _build_moa_system_prompt(self, prev_outputs: List[str]) -> str:
        """Builds the Mixture-of-Agents system prompt.

        Args:
            prev_outputs: The list of previous outputs to use as references.

        Returns:
            The Mixture-of-Agents system prompt.
        """
        moa_system_prompt = MOA_SYSTEM_PROMPT
        for i, prev_output in enumerate(prev_outputs):
            if prev_output is not None:
                moa_system_prompt += f"\n{i + 1}. {prev_output}"
        return moa_system_prompt

    def _inject_moa_system_prompt(
        self, input: "StandardInput", prev_outputs: List[str]
    ) -> "StandardInput":
        """Injects the Mixture-of-Agents system prompt into the input.

        Args:
            input: The input to inject the system prompt into.
            prev_outputs: The list of previous outputs to use as references.

        Returns:
            The input with the Mixture-of-Agents system prompt injected.
        """
        if len(prev_outputs) == 0:
            return input

        moa_system_prompt = self._build_moa_system_prompt(prev_outputs)

        system = next((item for item in input if item["role"] == "system"), None)
        if system:
            original_system_prompt = system["content"]
            system["content"] = f"{moa_system_prompt}\n\n{original_system_prompt}"
        else:
            input.insert(0, {"role": "system", "content": moa_system_prompt})

        return input

    async def _agenerate(
        self,
        inputs: List["FormattedInput"],
        num_generations: int = 1,
        **kwargs: Any,
    ) -> List["GenerateOutput"]:
        """Internal function to concurrently generate responses for a list of inputs.

        Args:
            inputs: the list of inputs to generate responses for.
            num_generations: the number of generations to generate per input.
            **kwargs: the additional kwargs to be used for the generation.

        Returns:
            A list containing the generations for each input.
        """
        aggregator_llm_kwargs: Dict[str, Any] = kwargs.get("aggregator_llm", {})
        proposers_llms_kwargs: List[Dict[str, Any]] = kwargs.get(
            "proposers_llms", [{}] * len(self.proposers_llms)
        )

        prev_outputs = []
        for round in range(self.rounds):
            self._logger.debug(f"Generating round {round + 1}/{self.rounds} in MoA")  # type: ignore

            # Generate `num_generations` with each proposer LLM for each input
            tasks = [
                asyncio.create_task(
                    llm._agenerate(
                        inputs=[
                            self._inject_moa_system_prompt(
                                cast("StandardInput", input), prev_input_outputs
                            )
                            for input, prev_input_outputs in itertools.zip_longest(
                                inputs, prev_outputs, fillvalue=[]
                            )
                        ],
                        num_generations=1,
                        **generation_kwargs,
                    )
                )
                for llm, generation_kwargs in zip(
                    self.proposers_llms, proposers_llms_kwargs
                )
            ]

            # Group generations per input
            outputs: List[List["GenerateOutput"]] = await asyncio.gather(*tasks)
            prev_outputs = [
                list(itertools.chain(*input_outputs)) for input_outputs in zip(*outputs)
            ]

        self._logger.debug("Aggregating outputs in MoA")  # type: ignore
        if isinstance(self.aggregator_llm, AsyncLLM):
            return await self.aggregator_llm._agenerate(
                inputs=[
                    self._inject_moa_system_prompt(
                        cast("StandardInput", input), prev_input_outputs
                    )
                    for input, prev_input_outputs in zip(inputs, prev_outputs)
                ],
                num_generations=num_generations,
                **aggregator_llm_kwargs,
            )

        return self.aggregator_llm.generate(
            inputs=[
                self._inject_moa_system_prompt(
                    cast("StandardInput", input), prev_input_outputs
                )
                for input, prev_input_outputs in zip(inputs, prev_outputs)
            ],
            num_generations=num_generations,
            **aggregator_llm_kwargs,
        )
runtime_parameters_names 属性

返回 `LLM` 的运行时参数,这些参数是 `LLM`、`aggregator_llm` 和 `proposers_llms` 的 `RuntimeParameter` 的组合。

返回

类型 描述
RuntimeParametersNames

`LLM` 的运行时参数。

model_name 属性

返回聚合的模型名称。

load()

加载 `MixtureOfAgents` 中的所有 `LLM`。

源代码位于 src/distilabel/models/llms/moa.py
def load(self) -> None:
    """Loads all the `LLM`s in the `MixtureOfAgents`."""
    super().load()

    for llm in self.proposers_llms:
        self._logger.debug(f"Loading proposer LLM in MoA: {llm}")  # type: ignore
        llm.load()

    self._logger.debug(f"Loading aggregator LLM in MoA: {self.aggregator_llm}")  # type: ignore
    self.aggregator_llm.load()
get_generation_kwargs()

以字典形式返回 `MixtureOfAgents` 的 generation kwargs。

返回

类型 描述
Dict[str, Any]

`MixtureOfAgents` 的 generation kwargs。

源代码位于 src/distilabel/models/llms/moa.py
def get_generation_kwargs(self) -> Dict[str, Any]:
    """Returns the generation kwargs of the `MixtureOfAgents` as a dictionary.

    Returns:
        The generation kwargs of the `MixtureOfAgents`.
    """
    return {
        "aggregator_llm": self.aggregator_llm.get_generation_kwargs(),
        "proposers_llms": [
            llm.get_generation_kwargs() for llm in self.proposers_llms
        ],
    }
_build_moa_system_prompt(prev_outputs)

构建 Mixture-of-Agents 系统提示。

参数

名称 类型 描述 默认值
prev_outputs List[str]

用作参考的先前输出的列表。

必需

返回

类型 描述
str

Mixture-of-Agents 系统提示。

源代码位于 src/distilabel/models/llms/moa.py
def _build_moa_system_prompt(self, prev_outputs: List[str]) -> str:
    """Builds the Mixture-of-Agents system prompt.

    Args:
        prev_outputs: The list of previous outputs to use as references.

    Returns:
        The Mixture-of-Agents system prompt.
    """
    moa_system_prompt = MOA_SYSTEM_PROMPT
    for i, prev_output in enumerate(prev_outputs):
        if prev_output is not None:
            moa_system_prompt += f"\n{i + 1}. {prev_output}"
    return moa_system_prompt
_inject_moa_system_prompt(input, prev_outputs)

将 Mixture-of-Agents 系统提示注入到输入中。

参数

名称 类型 描述 默认值
input StandardInput

要将系统提示注入到的输入。

必需
prev_outputs List[str]

用作参考的先前输出的列表。

必需

返回

类型 描述
StandardInput

注入了 Mixture-of-Agents 系统提示的输入。

源代码位于 src/distilabel/models/llms/moa.py
def _inject_moa_system_prompt(
    self, input: "StandardInput", prev_outputs: List[str]
) -> "StandardInput":
    """Injects the Mixture-of-Agents system prompt into the input.

    Args:
        input: The input to inject the system prompt into.
        prev_outputs: The list of previous outputs to use as references.

    Returns:
        The input with the Mixture-of-Agents system prompt injected.
    """
    if len(prev_outputs) == 0:
        return input

    moa_system_prompt = self._build_moa_system_prompt(prev_outputs)

    system = next((item for item in input if item["role"] == "system"), None)
    if system:
        original_system_prompt = system["content"]
        system["content"] = f"{moa_system_prompt}\n\n{original_system_prompt}"
    else:
        input.insert(0, {"role": "system", "content": moa_system_prompt})

    return input
_agenerate(inputs, num_generations=1, **kwargs) 异步

内部函数,用于并发地为输入列表生成响应。

参数

名称 类型 描述 默认值
inputs List[FormattedInput]

要为其生成响应的输入列表。

必需
num_generations int

每个输入要生成的代数。

1
**kwargs Any

用于生成的附加 kwargs。

{}

返回

类型 描述
List[GenerateOutput]

包含每个输入的代数的列表。

源代码位于 src/distilabel/models/llms/moa.py
async def _agenerate(
    self,
    inputs: List["FormattedInput"],
    num_generations: int = 1,
    **kwargs: Any,
) -> List["GenerateOutput"]:
    """Internal function to concurrently generate responses for a list of inputs.

    Args:
        inputs: the list of inputs to generate responses for.
        num_generations: the number of generations to generate per input.
        **kwargs: the additional kwargs to be used for the generation.

    Returns:
        A list containing the generations for each input.
    """
    aggregator_llm_kwargs: Dict[str, Any] = kwargs.get("aggregator_llm", {})
    proposers_llms_kwargs: List[Dict[str, Any]] = kwargs.get(
        "proposers_llms", [{}] * len(self.proposers_llms)
    )

    prev_outputs = []
    for round in range(self.rounds):
        self._logger.debug(f"Generating round {round + 1}/{self.rounds} in MoA")  # type: ignore

        # Generate `num_generations` with each proposer LLM for each input
        tasks = [
            asyncio.create_task(
                llm._agenerate(
                    inputs=[
                        self._inject_moa_system_prompt(
                            cast("StandardInput", input), prev_input_outputs
                        )
                        for input, prev_input_outputs in itertools.zip_longest(
                            inputs, prev_outputs, fillvalue=[]
                        )
                    ],
                    num_generations=1,
                    **generation_kwargs,
                )
            )
            for llm, generation_kwargs in zip(
                self.proposers_llms, proposers_llms_kwargs
            )
        ]

        # Group generations per input
        outputs: List[List["GenerateOutput"]] = await asyncio.gather(*tasks)
        prev_outputs = [
            list(itertools.chain(*input_outputs)) for input_outputs in zip(*outputs)
        ]

    self._logger.debug("Aggregating outputs in MoA")  # type: ignore
    if isinstance(self.aggregator_llm, AsyncLLM):
        return await self.aggregator_llm._agenerate(
            inputs=[
                self._inject_moa_system_prompt(
                    cast("StandardInput", input), prev_input_outputs
                )
                for input, prev_input_outputs in zip(inputs, prev_outputs)
            ],
            num_generations=num_generations,
            **aggregator_llm_kwargs,
        )

    return self.aggregator_llm.generate(
        inputs=[
            self._inject_moa_system_prompt(
                cast("StandardInput", input), prev_input_outputs
            )
            for input, prev_input_outputs in zip(inputs, prev_outputs)
        ],
        num_generations=num_generations,
        **aggregator_llm_kwargs,
    )

OllamaLLM

Bases: AsyncLLM, MagpieChatTemplateMixin

运行 Async API 客户端的 Ollama LLM 实现。

属性

名称 类型 描述
model str

用于 LLM 的模型名称,例如 “notus”。

host 可选[RuntimeParameter[str]]

Ollama 服务器主机。

timeout RuntimeParameter[int]

LLM 的超时。默认为 `120`。

follow_redirects bool

是否跟随重定向。默认为 `True`。

structured_output 可选[RuntimeParameter[InstructorStructuredOutputType]]

一个字典,包含结构化输出配置;如果需要更细粒度的控制,则包含 OutlinesStructuredOutput 的实例。默认为 None。

tokenizer_id 可选[RuntimeParameter[str]]

tokenizer Hugging Face Hub repo id 或包含 tokenizer 配置文件目录的路径。如果未提供,将使用与 model 关联的 tokenizer。默认为 None

use_magpie_template bool

用于启用/禁用应用 Magpie 预查询模板的标志。默认为 False

magpie_pre_query_template Union[MagpieAvailablePreQueryTemplates, str, None]

要应用于提示或发送到 LLM 以生成指令或后续用户消息的预查询模板。有效值为 "llama3"、"qwen2" 或提供的另一个预查询模板。默认为 None

_aclient Optional[AsyncClient]

用于 Ollama API 的 `AsyncClient`。它旨在内部使用。在 `load` 方法中设置。

运行时参数
  • host: Ollama 服务器主机。
  • timeout: Ollama API 的客户端超时。默认为 `120`。

示例

生成文本

from distilabel.models.llms import OllamaLLM

llm = OllamaLLM(model="llama3")

llm.load()

# Call the model
output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])
源代码位于 src/distilabel/models/llms/ollama.py
class OllamaLLM(AsyncLLM, MagpieChatTemplateMixin):
    """Ollama LLM implementation running the Async API client.

    Attributes:
        model: the model name to use for the LLM e.g. "notus".
        host: the Ollama server host.
        timeout: the timeout for the LLM. Defaults to `120`.
        follow_redirects: whether to follow redirects. Defaults to `True`.
        structured_output: a dictionary containing the structured output configuration or if more
            fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None.
        tokenizer_id: the tokenizer Hugging Face Hub repo id or a path to a directory containing
            the tokenizer config files. If not provided, the one associated to the `model`
            will be used. Defaults to `None`.
        use_magpie_template: a flag used to enable/disable applying the Magpie pre-query
            template. Defaults to `False`.
        magpie_pre_query_template: the pre-query template to be applied to the prompt or
            sent to the LLM to generate an instruction or a follow up user message. Valid
            values are "llama3", "qwen2" or another pre-query template provided. Defaults
            to `None`.
        _aclient: the `AsyncClient` to use for the Ollama API. It is meant to be used internally.
            Set in the `load` method.

    Runtime parameters:
        - `host`: the Ollama server host.
        - `timeout`: the client timeout for the Ollama API. Defaults to `120`.

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import OllamaLLM

        llm = OllamaLLM(model="llama3")

        llm.load()

        # Call the model
        output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```
    """

    model: str
    host: Optional[RuntimeParameter[str]] = Field(
        default=None, description="The host of the Ollama API."
    )
    timeout: RuntimeParameter[int] = Field(
        default=120, description="The timeout for the Ollama API."
    )
    follow_redirects: bool = True
    structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = (
        Field(
            default=None,
            description="The structured output format to use across all the generations.",
        )
    )
    tokenizer_id: Optional[RuntimeParameter[str]] = Field(
        default=None,
        description="The Hugging Face Hub repo id or a path to a directory containing"
        " the tokenizer config files. If not provided, the one associated to the `model`"
        " will be used.",
    )
    _num_generations_param_supported = False
    _aclient: Optional["AsyncClient"] = PrivateAttr(...)  # type: ignore

    @model_validator(mode="after")  # type: ignore
    def validate_magpie_usage(
        self,
    ) -> "OllamaLLM":
        """Validates that magpie usage is valid."""

        if self.use_magpie_template and self.tokenizer_id is None:
            raise ValueError(
                "`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please,"
                " set a `tokenizer_id` and try again."
            )

    def load(self) -> None:
        """Loads the `AsyncClient` to use Ollama async API."""
        super().load()

        try:
            from ollama import AsyncClient

            self._aclient = AsyncClient(
                host=self.host,
                timeout=self.timeout,
                follow_redirects=self.follow_redirects,
            )
        except ImportError as e:
            raise ImportError(
                "Ollama Python client is not installed. Please install it using"
                " `pip install 'distilabel[ollama]'`."
            ) from e

        if self.tokenizer_id:
            try:
                from transformers import AutoTokenizer
            except ImportError as ie:
                raise ImportError(
                    "Transformers is not installed. Please install it using `pip install 'distilabel[hf-transformers]'`."
                ) from ie
            self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id)
            if self._tokenizer.chat_template is None:
                raise ValueError(
                    "The tokenizer does not have a chat template. Please use a tokenizer with a chat template."
                )

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    async def _generate_chat_completion(
        self,
        input: "StandardInput",
        format: Literal["", "json"] = "",
        options: Union[Options, None] = None,
        keep_alive: Union[bool, None] = None,
    ) -> "ChatResponse":
        return await self._aclient.chat(
            model=self.model,
            messages=input,
            stream=False,
            format=format,
            options=options,
            keep_alive=keep_alive,
        )

    def prepare_input(self, input: "StandardInput") -> str:
        """Prepares the input (applying the chat template and tokenization) for the provided
        input.

        Args:
            input: the input list containing chat items.

        Returns:
            The prompt to send to the LLM.
        """
        prompt: str = (
            self._tokenizer.apply_chat_template(
                conversation=input,
                tokenize=False,
                add_generation_prompt=True,
            )
            if input
            else ""
        )
        return super().apply_magpie_pre_query_template(prompt, input)

    async def _generate_with_text_generation(
        self,
        input: "StandardInput",
        format: Literal["", "json"] = None,
        options: Union[Options, None] = None,
        keep_alive: Union[bool, None] = None,
    ) -> "GenerateResponse":
        input = self.prepare_input(input)
        return await self._aclient.generate(
            model=self.model,
            prompt=input,
            format=format,
            options=options,
            keep_alive=keep_alive,
            raw=True,
        )

    @validate_call
    async def agenerate(
        self,
        input: StandardInput,
        format: Literal["", "json"] = "",
        # TODO: include relevant options from `Options` in `agenerate` method.
        options: Union[Options, None] = None,
        keep_alive: Union[bool, None] = None,
    ) -> GenerateOutput:
        """
        Generates a response asynchronously, using the [Ollama Async API definition](https://github.com/ollama/ollama-python).

        Args:
            input: the input to use for the generation.
            format: the format to use for the generation. Defaults to `""`.
            options: the options to use for the generation. Defaults to `None`.
            keep_alive: whether to keep the connection alive. Defaults to `None`.

        Returns:
            A list of strings as completion for the given input.
        """
        text = None
        try:
            if not format:
                format = None
            if self.tokenizer_id is None:
                completion = await self._generate_chat_completion(
                    input, format, options, keep_alive
                )
                text = completion["message"]["content"]
            else:
                completion = await self._generate_with_text_generation(
                    input, format, options, keep_alive
                )
                text = completion.response
        except Exception as e:
            self._logger.warning(  # type: ignore
                f"⚠️ Received no response using Ollama client (model: '{self.model_name}')."
                f" Finish reason was: {e}"
            )

        return prepare_output([text], **self._get_llm_statistics(completion))

    @staticmethod
    def _get_llm_statistics(completion: Dict[str, Any]) -> "LLMStatistics":
        return {
            "input_tokens": [completion["prompt_eval_count"]],
            "output_tokens": [completion["eval_count"]],
        }
model_name 属性

返回用于 LLM 的模型名称。

validate_magpie_usage()

验证 magpie 用法是否有效。

源代码位于 src/distilabel/models/llms/ollama.py
@model_validator(mode="after")  # type: ignore
def validate_magpie_usage(
    self,
) -> "OllamaLLM":
    """Validates that magpie usage is valid."""

    if self.use_magpie_template and self.tokenizer_id is None:
        raise ValueError(
            "`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please,"
            " set a `tokenizer_id` and try again."
        )
load()

加载 `AsyncClient` 以使用 Ollama 异步 API。

源代码位于 src/distilabel/models/llms/ollama.py
def load(self) -> None:
    """Loads the `AsyncClient` to use Ollama async API."""
    super().load()

    try:
        from ollama import AsyncClient

        self._aclient = AsyncClient(
            host=self.host,
            timeout=self.timeout,
            follow_redirects=self.follow_redirects,
        )
    except ImportError as e:
        raise ImportError(
            "Ollama Python client is not installed. Please install it using"
            " `pip install 'distilabel[ollama]'`."
        ) from e

    if self.tokenizer_id:
        try:
            from transformers import AutoTokenizer
        except ImportError as ie:
            raise ImportError(
                "Transformers is not installed. Please install it using `pip install 'distilabel[hf-transformers]'`."
            ) from ie
        self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id)
        if self._tokenizer.chat_template is None:
            raise ValueError(
                "The tokenizer does not have a chat template. Please use a tokenizer with a chat template."
            )
prepare_input(input)

为提供的输入准备输入(应用聊天模板和分词)。

参数

名称 类型 描述 默认值
input StandardInput

包含聊天项的输入列表。

必需

返回

类型 描述
str

要发送给 LLM 的提示。

源代码位于 src/distilabel/models/llms/ollama.py
def prepare_input(self, input: "StandardInput") -> str:
    """Prepares the input (applying the chat template and tokenization) for the provided
    input.

    Args:
        input: the input list containing chat items.

    Returns:
        The prompt to send to the LLM.
    """
    prompt: str = (
        self._tokenizer.apply_chat_template(
            conversation=input,
            tokenize=False,
            add_generation_prompt=True,
        )
        if input
        else ""
    )
    return super().apply_magpie_pre_query_template(prompt, input)
agenerate(input, format='', options=None, keep_alive=None) 异步

异步生成响应,使用 Ollama Async API 定义

参数

名称 类型 描述 默认值
input StandardInput

用于生成的输入。

必需
format Literal['', 'json']

用于生成的格式。默认为 `""`。

''
options Union[Options, None]

用于生成的选项。默认为 `None`。

None
keep_alive Union[bool, None]

是否保持连接活动。默认为 `None`。

None

返回

类型 描述
GenerateOutput

作为给定输入的补全的字符串列表。

源代码位于 src/distilabel/models/llms/ollama.py
@validate_call
async def agenerate(
    self,
    input: StandardInput,
    format: Literal["", "json"] = "",
    # TODO: include relevant options from `Options` in `agenerate` method.
    options: Union[Options, None] = None,
    keep_alive: Union[bool, None] = None,
) -> GenerateOutput:
    """
    Generates a response asynchronously, using the [Ollama Async API definition](https://github.com/ollama/ollama-python).

    Args:
        input: the input to use for the generation.
        format: the format to use for the generation. Defaults to `""`.
        options: the options to use for the generation. Defaults to `None`.
        keep_alive: whether to keep the connection alive. Defaults to `None`.

    Returns:
        A list of strings as completion for the given input.
    """
    text = None
    try:
        if not format:
            format = None
        if self.tokenizer_id is None:
            completion = await self._generate_chat_completion(
                input, format, options, keep_alive
            )
            text = completion["message"]["content"]
        else:
            completion = await self._generate_with_text_generation(
                input, format, options, keep_alive
            )
            text = completion.response
    except Exception as e:
        self._logger.warning(  # type: ignore
            f"⚠️ Received no response using Ollama client (model: '{self.model_name}')."
            f" Finish reason was: {e}"
        )

    return prepare_output([text], **self._get_llm_statistics(completion))

OpenAILLM

Bases: OpenAIBaseClient, AsyncLLM

运行异步 API 客户端的 OpenAI LLM 实现。

属性

名称 类型 描述
model str

用于 LLM 的模型名称,例如 “gpt-3.5-turbo”、“gpt-4” 等。支持的模型可以在 此处 找到。

base_url 可选[RuntimeParameter[str]]

用于 OpenAI API 请求的基本 URL。默认为 `None`,这意味着将使用为环境变量 `OPENAI_BASE_URL` 设置的值,如果未设置,则使用 “https://api.openai.com/v1”。

api_key 可选[RuntimeParameter[SecretStr]]

用于验证 OpenAI API 请求的 API 密钥。默认为 `None`,这意味着将使用为环境变量 `OPENAI_API_KEY` 设置的值,如果未设置,则为 `None`。

default_headers Optional[RuntimeParameter[Dict[str, str]]]

用于 OpenAI API 请求的默认标头。

max_retries RuntimeParameter[int]

在失败之前重试 API 请求的最大次数。默认为 `6`。

timeout RuntimeParameter[int]

等待来自 API 响应的最长秒数。默认为 120

structured_output 可选[RuntimeParameter[InstructorStructuredOutputType]]

一个字典,包含使用 instructor 的结构化输出配置。您可以在 distilabel.steps.tasks.structured_outputs.instructorInstructorStructuredOutputType 中查看字典结构。

运行时参数
  • base_url: 用于 OpenAI API 请求的基本 URL。默认为 `None`。
  • api_key: 用于验证 OpenAI API 请求的 API 密钥。默认为 `None`。
  • max_retries: 在失败之前重试 API 请求的最大次数。默认为 `6`。
  • timeout:等待来自 API 响应的最长秒数。默认为 120
图标

:simple-openai

示例

生成文本

from distilabel.models.llms import OpenAILLM

llm = OpenAILLM(model="gpt-4-turbo", api_key="api.key")

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

从遵循 OpenAI API 的自定义端点生成文本

from distilabel.models.llms import OpenAILLM

llm = OpenAILLM(
    model="prometheus-eval/prometheus-7b-v2.0",
    base_url=r"http://localhost:8080/v1"
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

生成结构化数据

from pydantic import BaseModel
from distilabel.models.llms import OpenAILLM

class User(BaseModel):
    name: str
    last_name: str
    id: int

llm = OpenAILLM(
    model="gpt-4-turbo",
    api_key="api.key",
    structured_output={"schema": User}
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])

使用 Batch API 生成(离线批量生成)

from distilabel.models.llms import OpenAILLM

load = llm = OpenAILLM(
    model="gpt-3.5-turbo",
    use_offline_batch_generation=True,
    offline_batch_generation_block_until_done=5,  # poll for results every 5 seconds
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
# [['Hello! How can I assist you today?']]
源代码位于 src/distilabel/models/llms/openai.py
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
class OpenAILLM(OpenAIBaseClient, AsyncLLM):
    """OpenAI LLM implementation running the async API client.

    Attributes:
        model: the model name to use for the LLM e.g. "gpt-3.5-turbo", "gpt-4", etc.
            Supported models can be found [here](https://platform.openai.com/docs/guides/text-generation).
        base_url: the base URL to use for the OpenAI API requests. Defaults to `None`, which
            means that the value set for the environment variable `OPENAI_BASE_URL` will
            be used, or "https://api.openai.com/v1" if not set.
        api_key: the API key to authenticate the requests to the OpenAI API. Defaults to
            `None` which means that the value set for the environment variable `OPENAI_API_KEY`
            will be used, or `None` if not set.
        default_headers: the default headers to use for the OpenAI API requests.
        max_retries: the maximum number of times to retry the request to the API before
            failing. Defaults to `6`.
        timeout: the maximum time in seconds to wait for a response from the API. Defaults
            to `120`.
        structured_output: a dictionary containing the structured output configuration configuration
            using `instructor`. You can take a look at the dictionary structure in
            `InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`.

    Runtime parameters:
        - `base_url`: the base URL to use for the OpenAI API requests. Defaults to `None`.
        - `api_key`: the API key to authenticate the requests to the OpenAI API. Defaults
            to `None`.
        - `max_retries`: the maximum number of times to retry the request to the API before
            failing. Defaults to `6`.
        - `timeout`: the maximum time in seconds to wait for a response from the API. Defaults
            to `120`.

    Icon:
        `:simple-openai:`

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import OpenAILLM

        llm = OpenAILLM(model="gpt-4-turbo", api_key="api.key")

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```

        Generate text from a custom endpoint following the OpenAI API:

        ```python
        from distilabel.models.llms import OpenAILLM

        llm = OpenAILLM(
            model="prometheus-eval/prometheus-7b-v2.0",
            base_url=r"http://localhost:8080/v1"
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```

        Generate structured data:

        ```python
        from pydantic import BaseModel
        from distilabel.models.llms import OpenAILLM

        class User(BaseModel):
            name: str
            last_name: str
            id: int

        llm = OpenAILLM(
            model="gpt-4-turbo",
            api_key="api.key",
            structured_output={"schema": User}
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
        ```

        Generate with Batch API (offline batch generation):

        ```python
        from distilabel.models.llms import OpenAILLM

        load = llm = OpenAILLM(
            model="gpt-3.5-turbo",
            use_offline_batch_generation=True,
            offline_batch_generation_block_until_done=5,  # poll for results every 5 seconds
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        # [['Hello! How can I assist you today?']]
        ```
    """

    def load(self) -> None:
        AsyncLLM.load(self)
        OpenAIBaseClient.load(self)

    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: FormattedInput,
        num_generations: int = 1,
        max_new_tokens: NonNegativeInt = 128,
        logprobs: bool = False,
        top_logprobs: Optional[PositiveInt] = None,
        echo: bool = False,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
        temperature: float = 1.0,
        top_p: float = 1.0,
        stop: Optional[Union[str, List[str]]] = None,
        response_format: Optional[Dict[str, str]] = None,
        extra_body: Optional[Dict[str, Any]] = None,
    ) -> GenerateOutput:
        """Generates `num_generations` responses for the given input using the OpenAI async
        client.

        Args:
            input: a single input in chat format to generate responses for.
            num_generations: the number of generations to create per input. Defaults to
                `1`.
            max_new_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `128`.
            logprobs: whether to return the log probabilities or not. Defaults to `False`.
            top_logprobs: the number of top log probabilities to return per output token
                generated. Defaults to `None`.
            echo: whether to echo the input in the response or not. It's only used if the
                `input` argument is an `str`. Defaults to `False`.
            frequency_penalty: the repetition penalty to use for the generation. Defaults
                to `0.0`.
            presence_penalty: the presence penalty to use for the generation. Defaults to
                `0.0`.
            temperature: the temperature to use for the generation. Defaults to `0.1`.
            top_p: the top-p value to use for the generation. Defaults to `1.0`.
            stop: a string or a list of strings to use as a stop sequence for the generation.
                Defaults to `None`.
            response_format: the format of the response to return. Must be one of
                "text" or "json". Read the documentation [here](https://platform.openai.com/docs/guides/text-generation/json-mode)
                for more information on how to use the JSON model from OpenAI. Defaults to None
                which returns text. To return JSON, use {"type": "json_object"}.
            extra_body: an optional dictionary containing extra body parameters that will
                be sent to the OpenAI API endpoint. Defaults to `None`.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """

        if isinstance(input, str):
            return await self._generate_completion(
                input=input,
                num_generations=num_generations,
                max_new_tokens=max_new_tokens,
                echo=echo,
                top_logprobs=top_logprobs,
                frequency_penalty=frequency_penalty,
                presence_penalty=presence_penalty,
                temperature=temperature,
                top_p=top_p,
                extra_body=extra_body,
            )

        return await self._generate_chat_completion(
            input=input,
            num_generations=num_generations,
            max_new_tokens=max_new_tokens,
            logprobs=logprobs,
            top_logprobs=top_logprobs,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty,
            temperature=temperature,
            top_p=top_p,
            stop=stop,
            response_format=response_format,
            extra_body=extra_body,
        )

    async def _generate_completion(
        self,
        input: str,
        num_generations: int = 1,
        max_new_tokens: int = 128,
        echo: bool = False,
        top_logprobs: Optional[PositiveInt] = None,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
        temperature: float = 1.0,
        top_p: float = 1.0,
        extra_body: Optional[Dict[str, Any]] = None,
    ) -> GenerateOutput:
        completion = await self._aclient.completions.create(
            prompt=input,
            echo=echo,
            model=self.model,
            n=num_generations,
            max_tokens=max_new_tokens,
            logprobs=top_logprobs,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty,
            temperature=temperature,
            top_p=top_p,
            extra_body=extra_body,
        )

        generations = []
        logprobs = []
        for choice in completion.choices:
            generations.append(choice.text)
            if choice_logprobs := self._get_logprobs_from_completion_choice(choice):
                logprobs.append(choice_logprobs)

        statistics = self._get_llm_statistics(completion)
        return prepare_output(
            generations=generations,
            input_tokens=statistics["input_tokens"],
            output_tokens=statistics["output_tokens"],
            logprobs=logprobs,
        )

    def _get_logprobs_from_completion_choice(
        self, choice: "OpenAICompletionChoice"
    ) -> Union[List[Union[List["Logprob"], None]], None]:
        if choice.logprobs is None or choice.logprobs.top_logprobs is None:
            return None

        return [
            [
                {"token": token, "logprob": token_logprob}
                for token, token_logprob in logprobs.items()
            ]
            if logprobs is not None
            else None
            for logprobs in choice.logprobs.top_logprobs
        ]

    async def _generate_chat_completion(
        self,
        input: Union["StandardInput", "StructuredInput"],
        num_generations: int = 1,
        max_new_tokens: int = 128,
        logprobs: bool = False,
        top_logprobs: Optional[PositiveInt] = None,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
        temperature: float = 1.0,
        top_p: float = 1.0,
        stop: Optional[Union[str, List[str]]] = None,
        response_format: Optional[Dict[str, str]] = None,
        extra_body: Optional[Dict[str, Any]] = None,
    ) -> GenerateOutput:
        structured_output = None
        if isinstance(input, tuple):
            input, structured_output = input
            result = self._prepare_structured_output(
                structured_output=structured_output,  # type: ignore
                client=self._aclient,
                framework="openai",
            )
            self._aclient = result.get("client")  # type: ignore

        if structured_output is None and self.structured_output is not None:
            structured_output = self.structured_output

        kwargs = {
            "messages": input,  # type: ignore
            "model": self.model,
            "logprobs": logprobs,
            "top_logprobs": top_logprobs,
            "max_tokens": max_new_tokens,
            "n": num_generations,
            "frequency_penalty": frequency_penalty,
            "presence_penalty": presence_penalty,
            "temperature": temperature,
            "top_p": top_p,
            "stop": stop,
            "extra_body": extra_body,
        }

        # Checks if any message contains an image, in that case "stop" cannot be used or
        # raises an error in the API.
        if isinstance(
            [row for row in input if row["role"] == "user"][0]["content"], list
        ):
            kwargs.pop("stop")

        if response_format is not None:
            kwargs["response_format"] = response_format

        if structured_output:
            kwargs = self._prepare_kwargs(kwargs, structured_output)  # type: ignore

        completion = await self._aclient.chat.completions.create(**kwargs)  # type: ignore

        if structured_output:
            # NOTE: `instructor` doesn't work with `n` parameter, so it will always return
            # only 1 choice.
            statistics = self._get_llm_statistics(completion._raw_response)
            if choice_logprobs := self._get_logprobs_from_chat_completion_choice(
                completion._raw_response.choices[0]
            ):
                output_logprobs = [choice_logprobs]
            else:
                output_logprobs = None
            return prepare_output(
                generations=[completion.model_dump_json()],
                input_tokens=statistics["input_tokens"],
                output_tokens=statistics["output_tokens"],
                logprobs=output_logprobs,
            )

        return self._generations_from_openai_completion(completion)

    def _generations_from_openai_completion(
        self, completion: "OpenAIChatCompletion"
    ) -> "GenerateOutput":
        """Get the generations from the OpenAI Chat Completion object.

        Args:
            completion: the completion object to get the generations from.

        Returns:
            A list of strings containing the generated responses for the input.
        """
        generations = []
        logprobs = []
        for choice in completion.choices:
            if (content := choice.message.content) is None:
                self._logger.warning(  # type: ignore
                    f"Received no response using OpenAI client (model: '{self.model}')."
                    f" Finish reason was: {choice.finish_reason}"
                )
            generations.append(content)
            if choice_logprobs := self._get_logprobs_from_chat_completion_choice(
                choice
            ):
                logprobs.append(choice_logprobs)

        statistics = self._get_llm_statistics(completion)
        return prepare_output(
            generations=generations,
            input_tokens=statistics["input_tokens"],
            output_tokens=statistics["output_tokens"],
            logprobs=logprobs,
        )

    def _get_logprobs_from_chat_completion_choice(
        self, choice: "OpenAIChatCompletionChoice"
    ) -> Union[List[List["Logprob"]], None]:
        if choice.logprobs is None or choice.logprobs.content is None:
            return None

        return [
            [
                {"token": top_logprob.token, "logprob": top_logprob.logprob}
                for top_logprob in token_logprobs.top_logprobs
            ]
            for token_logprobs in choice.logprobs.content
        ]

    def offline_batch_generate(
        self,
        inputs: Union[List["FormattedInput"], None] = None,
        num_generations: int = 1,
        max_new_tokens: int = 128,
        logprobs: bool = False,
        top_logprobs: Optional[PositiveInt] = None,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
        temperature: float = 1.0,
        top_p: float = 1.0,
        stop: Optional[Union[str, List[str]]] = None,
        response_format: Optional[str] = None,
        **kwargs: Any,
    ) -> List["GenerateOutput"]:
        """Uses the OpenAI batch API to generate `num_generations` responses for the given
        inputs.

        Args:
            inputs: a list of inputs in chat format to generate responses for.
            num_generations: the number of generations to create per input. Defaults to
                `1`.
            max_new_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `128`.
            logprobs: whether to return the log probabilities or not. Defaults to `False`.
            top_logprobs: the number of top log probabilities to return per output token
                generated. Defaults to `None`.
            frequency_penalty: the repetition penalty to use for the generation. Defaults
                to `0.0`.
            presence_penalty: the presence penalty to use for the generation. Defaults to
                `0.0`.
            temperature: the temperature to use for the generation. Defaults to `0.1`.
            top_p: the top-p value to use for the generation. Defaults to `1.0`.
            stop: a string or a list of strings to use as a stop sequence for the generation.
                Defaults to `None`.
            response_format: the format of the response to return. Must be one of
                "text" or "json". Read the documentation [here](https://platform.openai.com/docs/guides/text-generation/json-mode)
                for more information on how to use the JSON model from OpenAI. Defaults to `text`.

        Returns:
            A list of lists of strings containing the generated responses for each input
            in `inputs`.

        Raises:
            DistilabelOfflineBatchGenerationNotFinishedException: if the batch generation
                is not finished yet.
            ValueError: if no job IDs were found to retrieve the results from.
        """
        if self.jobs_ids:
            return self._check_and_get_batch_results()

        if inputs:
            self.jobs_ids = self._create_jobs(
                inputs=inputs,
                **{
                    "model": self.model,
                    "logprobs": logprobs,
                    "top_logprobs": top_logprobs,
                    "max_tokens": max_new_tokens,
                    "n": num_generations,
                    "frequency_penalty": frequency_penalty,
                    "presence_penalty": presence_penalty,
                    "temperature": temperature,
                    "top_p": top_p,
                    "stop": stop,
                    "response_format": response_format,
                },
            )
            raise DistilabelOfflineBatchGenerationNotFinishedException(
                jobs_ids=self.jobs_ids
            )

        raise ValueError("No `inputs` were provided and no `jobs_ids` were found.")

    def _check_and_get_batch_results(self) -> List["GenerateOutput"]:
        """Checks the status of the batch jobs and retrieves the results from the OpenAI
        Batch API.

        Returns:
            A list of lists of strings containing the generated responses for each input.

        Raises:
            ValueError: if no job IDs were found to retrieve the results from.
            DistilabelOfflineBatchGenerationNotFinishedException: if the batch generation
                is not finished yet.
            RuntimeError: if the only batch job found failed.
        """
        if not self.jobs_ids:
            raise ValueError("No job IDs were found to retrieve the results from.")

        outputs = []
        for batch_id in self.jobs_ids:
            batch = self._get_openai_batch(batch_id)

            if batch.status in ("validating", "in_progress", "finalizing"):
                raise DistilabelOfflineBatchGenerationNotFinishedException(
                    jobs_ids=self.jobs_ids
                )

            if batch.status in ("failed", "expired", "cancelled", "cancelling"):
                self._logger.error(  # type: ignore
                    f"OpenAI API batch with ID '{batch_id}' failed with status '{batch.status}'."
                )
                if len(self.jobs_ids) == 1:
                    self.jobs_ids = None
                    raise RuntimeError(
                        f"The only OpenAI API Batch that was created with ID '{batch_id}'"
                        f" failed with status '{batch.status}'."
                    )

                continue

            outputs.extend(self._retrieve_batch_results(batch))

        # sort by `custom_id` to return the results in the same order as the inputs
        outputs = sorted(outputs, key=lambda x: int(x["custom_id"]))
        return [self._parse_output(output) for output in outputs]

    def _parse_output(self, output: Dict[str, Any]) -> "GenerateOutput":
        """Parses the output from the OpenAI Batch API into a list of strings.

        Args:
            output: the output to parse.

        Returns:
            A list of strings containing the generated responses for the input.
        """
        from openai.types.chat import ChatCompletion as OpenAIChatCompletion

        if "response" not in output:
            return []

        if output["response"]["status_code"] != 200:
            return []

        return self._generations_from_openai_completion(
            OpenAIChatCompletion(**output["response"]["body"])
        )

    def _get_openai_batch(self, batch_id: str) -> "OpenAIBatch":
        """Gets a batch from the OpenAI Batch API.

        Args:
            batch_id: the ID of the batch to retrieve.

        Returns:
            The batch retrieved from the OpenAI Batch API.

        Raises:
            openai.OpenAIError: if there was an error while retrieving the batch from the
                OpenAI Batch API.
        """
        import openai

        try:
            return self._client.batches.retrieve(batch_id)
        except openai.OpenAIError as e:
            self._logger.error(  # type: ignore
                f"Error while retrieving batch '{batch_id}' from OpenAI: {e}"
            )
            raise e

    def _retrieve_batch_results(self, batch: "OpenAIBatch") -> List[Dict[str, Any]]:
        """Retrieves the results of a batch from its output file, parsing the JSONL content
        into a list of dictionaries.

        Args:
            batch: the batch to retrieve the results from.

        Returns:
            A list of dictionaries containing the results of the batch.

        Raises:
            AssertionError: if no output file ID was found in the batch.
        """
        import openai

        assert batch.output_file_id, "No output file ID was found in the batch."

        try:
            file_response = self._client.files.content(batch.output_file_id)
            return [orjson.loads(line) for line in file_response.text.splitlines()]
        except openai.OpenAIError as e:
            self._logger.error(  # type: ignore
                f"Error while retrieving batch results from file '{batch.output_file_id}': {e}"
            )
            return []

    def _create_jobs(
        self, inputs: List["FormattedInput"], **kwargs: Any
    ) -> Tuple[str, ...]:
        """Creates jobs in the OpenAI Batch API to generate responses for the given inputs.

        Args:
            inputs: a list of inputs in chat format to generate responses for.
            kwargs: the keyword arguments to use for the generation.

        Returns:
            A list of job IDs created in the OpenAI Batch API.
        """
        batch_input_files = self._create_batch_files(inputs=inputs, **kwargs)
        jobs = []
        for batch_input_file in batch_input_files:
            if batch := self._create_batch_api_job(batch_input_file):
                jobs.append(batch.id)
        return tuple(jobs)

    def _create_batch_api_job(
        self, batch_input_file: "OpenAIFileObject"
    ) -> Union["OpenAIBatch", None]:
        """Creates a job in the OpenAI Batch API to generate responses for the given input
        file.

        Args:
            batch_input_file: the input file to generate responses for.

        Returns:
            The batch job created in the OpenAI Batch API.
        """
        import openai

        metadata = {"description": "distilabel"}

        if distilabel_pipeline_name := envs.DISTILABEL_PIPELINE_NAME:
            metadata["distilabel_pipeline_name"] = distilabel_pipeline_name

        if distilabel_pipeline_cache_id := envs.DISTILABEL_PIPELINE_CACHE_ID:
            metadata["distilabel_pipeline_cache_id"] = distilabel_pipeline_cache_id

        batch = None
        try:
            batch = self._client.batches.create(
                completion_window="24h",
                endpoint="/v1/chat/completions",
                input_file_id=batch_input_file.id,
                metadata=metadata,
            )
        except openai.OpenAIError as e:
            self._logger.error(  # type: ignore
                f"Error while creating OpenAI Batch API job for file with ID"
                f" '{batch_input_file.id}': {e}."
            )
            raise e
        return batch

    def _create_batch_files(
        self, inputs: List["FormattedInput"], **kwargs: Any
    ) -> List["OpenAIFileObject"]:
        """Creates the necessary input files for the batch API to generate responses. The
        maximum size of each file so the OpenAI Batch API can process it is 100MB, so we
        need to split the inputs into multiple files if necessary.

        More information: https://platform.openai.com/docs/api-reference/files/create

        Args:
            inputs: a list of inputs in chat format to generate responses for, optionally
                including structured output.
            kwargs: the keyword arguments to use for the generation.

        Returns:
            The list of file objects created for the OpenAI Batch API.

        Raises:
            openai.OpenAIError: if there was an error while creating the batch input file
                in the OpenAI Batch API.
        """
        import openai

        files = []
        for file_no, buffer in enumerate(
            self._create_jsonl_buffers(inputs=inputs, **kwargs)
        ):
            try:
                # TODO: add distilabel pipeline name and id
                batch_input_file = self._client.files.create(
                    file=(self._name_for_openai_files(file_no), buffer),
                    purpose="batch",
                )
                files.append(batch_input_file)
            except openai.OpenAIError as e:
                self._logger.error(  # type: ignore
                    f"Error while creating OpenAI batch input file: {e}"
                )
                raise e
        return files

    def _create_jsonl_buffers(
        self, inputs: List["FormattedInput"], **kwargs: Any
    ) -> Generator[io.BytesIO, None, None]:
        """Creates a generator of buffers containing the JSONL formatted inputs to be
        used by the OpenAI Batch API. The buffers created are of size 100MB or less.

        Args:
            inputs: a list of inputs in chat format to generate responses for, optionally
                including structured output.
            kwargs: the keyword arguments to use for the generation.

        Yields:
            A buffer containing the JSONL formatted inputs to be used by the OpenAI Batch
            API.
        """
        buffer = io.BytesIO()
        buffer_current_size = 0
        for i, input in enumerate(inputs):
            # We create the smallest `custom_id` so we don't  increase the size of the file
            # to much, but we can still sort the results with the order of the inputs.
            row = self._create_jsonl_row(input=input, custom_id=str(i), **kwargs)
            row_size = len(row)
            if row_size + buffer_current_size > _OPENAI_BATCH_API_MAX_FILE_SIZE:
                buffer.seek(0)
                yield buffer
                buffer = io.BytesIO()
                buffer_current_size = 0
            buffer.write(row)
            buffer_current_size += row_size

        if buffer_current_size > 0:
            buffer.seek(0)
            yield buffer

    def _create_jsonl_row(
        self, input: "FormattedInput", custom_id: str, **kwargs: Any
    ) -> bytes:
        """Creates a JSONL formatted row to be used by the OpenAI Batch API.

        Args:
            input: a list of inputs in chat format to generate responses for, optionally
                including structured output.
            custom_id: a custom ID to use for the row.
            kwargs: the keyword arguments to use for the generation.

        Returns:
            A JSONL formatted row to be used by the OpenAI Batch API.
        """
        # TODO: depending on the format of the input, add `response_format` to the kwargs
        row = {
            "custom_id": custom_id,
            "method": "POST",
            "url": "/v1/chat/completions",
            "body": {"messages": input, **kwargs},
        }
        json_row = orjson.dumps(row)
        return json_row + b"\n"

    def _name_for_openai_files(self, file_no: int) -> str:
        if (
            envs.DISTILABEL_PIPELINE_NAME is None
            or envs.DISTILABEL_PIPELINE_CACHE_ID is None
        ):
            return f"distilabel-pipeline-fileno-{file_no}.jsonl"

        return f"distilabel-pipeline-{envs.DISTILABEL_PIPELINE_NAME}-{envs.DISTILABEL_PIPELINE_CACHE_ID}-fileno-{file_no}.jsonl"

    @staticmethod
    def _get_llm_statistics(
        completion: Union["OpenAIChatCompletion", "OpenAICompletion"],
    ) -> "LLMStatistics":
        return {
            "output_tokens": [
                completion.usage.completion_tokens if completion.usage else 0
            ],
            "input_tokens": [completion.usage.prompt_tokens if completion.usage else 0],
        }
agenerate(input, num_generations=1, max_new_tokens=128, logprobs=False, top_logprobs=None, echo=False, frequency_penalty=0.0, presence_penalty=0.0, temperature=1.0, top_p=1.0, stop=None, response_format=None, extra_body=None) 异步

为给定输入生成 num_generations 个响应,使用 OpenAI 异步客户端。

参数

名称 类型 描述 默认值
input FormattedInput

以聊天格式的单个输入,用于生成响应。

必需
num_generations int

每个输入要创建的代数。默认为 1

1
max_new_tokens NonNegativeInt

模型将生成的最大新 token 数。默认为 128

128
logprobs bool

是否返回对数概率。默认为 `False`。

False
top_logprobs Optional[PositiveInt]

每个生成的输出 token 要返回的 top 对数概率数。默认为 `None`。

None
echo bool

是否在响应中回显输入。仅当 `input` 参数为 `str` 时才使用。默认为 `False`。

False
frequency_penalty float

用于生成的重复惩罚。默认为 0.0

0.0
presence_penalty float

用于存在的惩罚。默认为 0.0

0.0
temperature float

用于生成的温度。默认为 0.1

1.0
top_p float

用于生成的 top-p 值。默认为 1.0

1.0
stop Optional[Union[str, List[str]]]

用作生成停止序列的字符串或字符串列表。默认为 `None`。

None
response_format Optional[Dict[str, str]]

要返回的响应格式。必须是 “text” 或 “json” 之一。阅读 此处 的文档,以获取有关如何使用 OpenAI 的 JSON 模型的更多信息。默认为 None,返回文本。要返回 JSON,请使用 {“type”: “json_object”}。

None
extra_body Optional[Dict[str, Any]]

一个可选字典,其中包含将发送到 OpenAI API 端点的额外 body 参数。默认为 `None`。

None

返回

类型 描述
GenerateOutput

包含每个输入的生成响应的字符串列表的列表。

源代码位于 src/distilabel/models/llms/openai.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: FormattedInput,
    num_generations: int = 1,
    max_new_tokens: NonNegativeInt = 128,
    logprobs: bool = False,
    top_logprobs: Optional[PositiveInt] = None,
    echo: bool = False,
    frequency_penalty: float = 0.0,
    presence_penalty: float = 0.0,
    temperature: float = 1.0,
    top_p: float = 1.0,
    stop: Optional[Union[str, List[str]]] = None,
    response_format: Optional[Dict[str, str]] = None,
    extra_body: Optional[Dict[str, Any]] = None,
) -> GenerateOutput:
    """Generates `num_generations` responses for the given input using the OpenAI async
    client.

    Args:
        input: a single input in chat format to generate responses for.
        num_generations: the number of generations to create per input. Defaults to
            `1`.
        max_new_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `128`.
        logprobs: whether to return the log probabilities or not. Defaults to `False`.
        top_logprobs: the number of top log probabilities to return per output token
            generated. Defaults to `None`.
        echo: whether to echo the input in the response or not. It's only used if the
            `input` argument is an `str`. Defaults to `False`.
        frequency_penalty: the repetition penalty to use for the generation. Defaults
            to `0.0`.
        presence_penalty: the presence penalty to use for the generation. Defaults to
            `0.0`.
        temperature: the temperature to use for the generation. Defaults to `0.1`.
        top_p: the top-p value to use for the generation. Defaults to `1.0`.
        stop: a string or a list of strings to use as a stop sequence for the generation.
            Defaults to `None`.
        response_format: the format of the response to return. Must be one of
            "text" or "json". Read the documentation [here](https://platform.openai.com/docs/guides/text-generation/json-mode)
            for more information on how to use the JSON model from OpenAI. Defaults to None
            which returns text. To return JSON, use {"type": "json_object"}.
        extra_body: an optional dictionary containing extra body parameters that will
            be sent to the OpenAI API endpoint. Defaults to `None`.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """

    if isinstance(input, str):
        return await self._generate_completion(
            input=input,
            num_generations=num_generations,
            max_new_tokens=max_new_tokens,
            echo=echo,
            top_logprobs=top_logprobs,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty,
            temperature=temperature,
            top_p=top_p,
            extra_body=extra_body,
        )

    return await self._generate_chat_completion(
        input=input,
        num_generations=num_generations,
        max_new_tokens=max_new_tokens,
        logprobs=logprobs,
        top_logprobs=top_logprobs,
        frequency_penalty=frequency_penalty,
        presence_penalty=presence_penalty,
        temperature=temperature,
        top_p=top_p,
        stop=stop,
        response_format=response_format,
        extra_body=extra_body,
    )
_generations_from_openai_completion(completion)

从 OpenAI Chat Completion 对象获取代。

参数

名称 类型 描述 默认值
completion ChatCompletion

从中获取代的 completion 对象。

必需

返回

类型 描述
GenerateOutput

包含输入生成的响应的字符串列表。

源代码位于 src/distilabel/models/llms/openai.py
def _generations_from_openai_completion(
    self, completion: "OpenAIChatCompletion"
) -> "GenerateOutput":
    """Get the generations from the OpenAI Chat Completion object.

    Args:
        completion: the completion object to get the generations from.

    Returns:
        A list of strings containing the generated responses for the input.
    """
    generations = []
    logprobs = []
    for choice in completion.choices:
        if (content := choice.message.content) is None:
            self._logger.warning(  # type: ignore
                f"Received no response using OpenAI client (model: '{self.model}')."
                f" Finish reason was: {choice.finish_reason}"
            )
        generations.append(content)
        if choice_logprobs := self._get_logprobs_from_chat_completion_choice(
            choice
        ):
            logprobs.append(choice_logprobs)

    statistics = self._get_llm_statistics(completion)
    return prepare_output(
        generations=generations,
        input_tokens=statistics["input_tokens"],
        output_tokens=statistics["output_tokens"],
        logprobs=logprobs,
    )
offline_batch_generate(inputs=None, num_generations=1, max_new_tokens=128, logprobs=False, top_logprobs=None, frequency_penalty=0.0, presence_penalty=0.0, temperature=1.0, top_p=1.0, stop=None, response_format=None, **kwargs)

使用 OpenAI 批量 API 为给定输入生成 num_generations 个响应。

参数

名称 类型 描述 默认值
inputs Union[List[FormattedInput], None]

聊天格式的输入列表,用于生成响应。

None
num_generations int

每个输入要创建的代数。默认为 1

1
max_new_tokens int

模型将生成的最大新 token 数。默认为 128

128
logprobs bool

是否返回对数概率。默认为 `False`。

False
top_logprobs Optional[PositiveInt]

每个生成的输出 token 要返回的 top 对数概率数。默认为 `None`。

None
frequency_penalty float

用于生成的重复惩罚。默认为 0.0

0.0
presence_penalty float

用于存在的惩罚。默认为 0.0

0.0
temperature float

用于生成的温度。默认为 0.1

1.0
top_p float

用于生成的 top-p 值。默认为 1.0

1.0
stop Optional[Union[str, List[str]]]

用作生成停止序列的字符串或字符串列表。默认为 `None`。

None
response_format Optional[str]

要返回的响应格式。必须是 “text” 或 “json” 之一。阅读 此处 的文档,以获取有关如何使用 OpenAI 的 JSON 模型的更多信息。默认为 `text`。

None

返回

类型 描述
List[GenerateOutput]

包含每个输入生成的响应的字符串列表的列表

List[GenerateOutput]

inputs 中。

引发

类型 描述
DistilabelOfflineBatchGenerationNotFinishedException

如果批量生成尚未完成。

ValueError

如果未找到任何作业 ID 来检索结果。

源代码位于 src/distilabel/models/llms/openai.py
def offline_batch_generate(
    self,
    inputs: Union[List["FormattedInput"], None] = None,
    num_generations: int = 1,
    max_new_tokens: int = 128,
    logprobs: bool = False,
    top_logprobs: Optional[PositiveInt] = None,
    frequency_penalty: float = 0.0,
    presence_penalty: float = 0.0,
    temperature: float = 1.0,
    top_p: float = 1.0,
    stop: Optional[Union[str, List[str]]] = None,
    response_format: Optional[str] = None,
    **kwargs: Any,
) -> List["GenerateOutput"]:
    """Uses the OpenAI batch API to generate `num_generations` responses for the given
    inputs.

    Args:
        inputs: a list of inputs in chat format to generate responses for.
        num_generations: the number of generations to create per input. Defaults to
            `1`.
        max_new_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `128`.
        logprobs: whether to return the log probabilities or not. Defaults to `False`.
        top_logprobs: the number of top log probabilities to return per output token
            generated. Defaults to `None`.
        frequency_penalty: the repetition penalty to use for the generation. Defaults
            to `0.0`.
        presence_penalty: the presence penalty to use for the generation. Defaults to
            `0.0`.
        temperature: the temperature to use for the generation. Defaults to `0.1`.
        top_p: the top-p value to use for the generation. Defaults to `1.0`.
        stop: a string or a list of strings to use as a stop sequence for the generation.
            Defaults to `None`.
        response_format: the format of the response to return. Must be one of
            "text" or "json". Read the documentation [here](https://platform.openai.com/docs/guides/text-generation/json-mode)
            for more information on how to use the JSON model from OpenAI. Defaults to `text`.

    Returns:
        A list of lists of strings containing the generated responses for each input
        in `inputs`.

    Raises:
        DistilabelOfflineBatchGenerationNotFinishedException: if the batch generation
            is not finished yet.
        ValueError: if no job IDs were found to retrieve the results from.
    """
    if self.jobs_ids:
        return self._check_and_get_batch_results()

    if inputs:
        self.jobs_ids = self._create_jobs(
            inputs=inputs,
            **{
                "model": self.model,
                "logprobs": logprobs,
                "top_logprobs": top_logprobs,
                "max_tokens": max_new_tokens,
                "n": num_generations,
                "frequency_penalty": frequency_penalty,
                "presence_penalty": presence_penalty,
                "temperature": temperature,
                "top_p": top_p,
                "stop": stop,
                "response_format": response_format,
            },
        )
        raise DistilabelOfflineBatchGenerationNotFinishedException(
            jobs_ids=self.jobs_ids
        )

    raise ValueError("No `inputs` were provided and no `jobs_ids` were found.")
_check_and_get_batch_results()

检查批量作业的状态并从 OpenAI 批量 API 检索结果。

返回

类型 描述
List[GenerateOutput]

包含每个输入的生成响应的字符串列表的列表。

引发

类型 描述
ValueError

如果未找到任何作业 ID 来检索结果。

DistilabelOfflineBatchGenerationNotFinishedException

如果批量生成尚未完成。

RuntimeError

如果找到的唯一批量作业失败。

源代码位于 src/distilabel/models/llms/openai.py
def _check_and_get_batch_results(self) -> List["GenerateOutput"]:
    """Checks the status of the batch jobs and retrieves the results from the OpenAI
    Batch API.

    Returns:
        A list of lists of strings containing the generated responses for each input.

    Raises:
        ValueError: if no job IDs were found to retrieve the results from.
        DistilabelOfflineBatchGenerationNotFinishedException: if the batch generation
            is not finished yet.
        RuntimeError: if the only batch job found failed.
    """
    if not self.jobs_ids:
        raise ValueError("No job IDs were found to retrieve the results from.")

    outputs = []
    for batch_id in self.jobs_ids:
        batch = self._get_openai_batch(batch_id)

        if batch.status in ("validating", "in_progress", "finalizing"):
            raise DistilabelOfflineBatchGenerationNotFinishedException(
                jobs_ids=self.jobs_ids
            )

        if batch.status in ("failed", "expired", "cancelled", "cancelling"):
            self._logger.error(  # type: ignore
                f"OpenAI API batch with ID '{batch_id}' failed with status '{batch.status}'."
            )
            if len(self.jobs_ids) == 1:
                self.jobs_ids = None
                raise RuntimeError(
                    f"The only OpenAI API Batch that was created with ID '{batch_id}'"
                    f" failed with status '{batch.status}'."
                )

            continue

        outputs.extend(self._retrieve_batch_results(batch))

    # sort by `custom_id` to return the results in the same order as the inputs
    outputs = sorted(outputs, key=lambda x: int(x["custom_id"]))
    return [self._parse_output(output) for output in outputs]
_parse_output(output)

将来自 OpenAI 批量 API 的输出解析为字符串列表。

参数

名称 类型 描述 默认值
output Dict[str, Any]

要解析的输出。

必需

返回

类型 描述
GenerateOutput

包含输入生成的响应的字符串列表。

源代码位于 src/distilabel/models/llms/openai.py
def _parse_output(self, output: Dict[str, Any]) -> "GenerateOutput":
    """Parses the output from the OpenAI Batch API into a list of strings.

    Args:
        output: the output to parse.

    Returns:
        A list of strings containing the generated responses for the input.
    """
    from openai.types.chat import ChatCompletion as OpenAIChatCompletion

    if "response" not in output:
        return []

    if output["response"]["status_code"] != 200:
        return []

    return self._generations_from_openai_completion(
        OpenAIChatCompletion(**output["response"]["body"])
    )
_get_openai_batch(batch_id)

从 OpenAI 批量 API 获取一个批量。

参数

名称 类型 描述 默认值
batch_id str

要检索的批量的 ID。

必需

返回

类型 描述
Batch

从 OpenAI 批量 API 检索的批量。

引发

类型 描述
OpenAIError

如果在从 OpenAI 批量 API 检索批量时发生错误。

源代码位于 src/distilabel/models/llms/openai.py
def _get_openai_batch(self, batch_id: str) -> "OpenAIBatch":
    """Gets a batch from the OpenAI Batch API.

    Args:
        batch_id: the ID of the batch to retrieve.

    Returns:
        The batch retrieved from the OpenAI Batch API.

    Raises:
        openai.OpenAIError: if there was an error while retrieving the batch from the
            OpenAI Batch API.
    """
    import openai

    try:
        return self._client.batches.retrieve(batch_id)
    except openai.OpenAIError as e:
        self._logger.error(  # type: ignore
            f"Error while retrieving batch '{batch_id}' from OpenAI: {e}"
        )
        raise e
_retrieve_batch_results(batch)

从其输出文件中检索批量的结果,并将 JSONL 内容解析为字典列表。

参数

名称 类型 描述 默认值
batch Batch

从中检索结果的批量。

必需

返回

类型 描述
List[Dict[str, Any]]

包含批量结果的字典列表。

引发

类型 描述
AssertionError

如果在批量中未找到输出文件 ID。

源代码位于 src/distilabel/models/llms/openai.py
def _retrieve_batch_results(self, batch: "OpenAIBatch") -> List[Dict[str, Any]]:
    """Retrieves the results of a batch from its output file, parsing the JSONL content
    into a list of dictionaries.

    Args:
        batch: the batch to retrieve the results from.

    Returns:
        A list of dictionaries containing the results of the batch.

    Raises:
        AssertionError: if no output file ID was found in the batch.
    """
    import openai

    assert batch.output_file_id, "No output file ID was found in the batch."

    try:
        file_response = self._client.files.content(batch.output_file_id)
        return [orjson.loads(line) for line in file_response.text.splitlines()]
    except openai.OpenAIError as e:
        self._logger.error(  # type: ignore
            f"Error while retrieving batch results from file '{batch.output_file_id}': {e}"
        )
        return []
_create_jobs(inputs, **kwargs)

在 OpenAI 批量 API 中创建作业,以生成给定输入的响应。

参数

名称 类型 描述 默认值
inputs List[FormattedInput]

聊天格式的输入列表,用于生成响应。

必需
kwargs Any

用于生成的关键字参数。

{}

返回

类型 描述
Tuple[str, ...]

在 OpenAI 批量 API 中创建的作业 ID 列表。

源代码位于 src/distilabel/models/llms/openai.py
def _create_jobs(
    self, inputs: List["FormattedInput"], **kwargs: Any
) -> Tuple[str, ...]:
    """Creates jobs in the OpenAI Batch API to generate responses for the given inputs.

    Args:
        inputs: a list of inputs in chat format to generate responses for.
        kwargs: the keyword arguments to use for the generation.

    Returns:
        A list of job IDs created in the OpenAI Batch API.
    """
    batch_input_files = self._create_batch_files(inputs=inputs, **kwargs)
    jobs = []
    for batch_input_file in batch_input_files:
        if batch := self._create_batch_api_job(batch_input_file):
            jobs.append(batch.id)
    return tuple(jobs)
_create_batch_api_job(batch_input_file)

在 OpenAI 批量 API 中创建一个作业,以生成给定输入文件的响应。

参数

名称 类型 描述 默认值
batch_input_file FileObject

要为其生成响应的输入文件。

必需

返回

类型 描述
Union[Batch, None]

在 OpenAI 批量 API 中创建的批量作业。

源代码位于 src/distilabel/models/llms/openai.py
def _create_batch_api_job(
    self, batch_input_file: "OpenAIFileObject"
) -> Union["OpenAIBatch", None]:
    """Creates a job in the OpenAI Batch API to generate responses for the given input
    file.

    Args:
        batch_input_file: the input file to generate responses for.

    Returns:
        The batch job created in the OpenAI Batch API.
    """
    import openai

    metadata = {"description": "distilabel"}

    if distilabel_pipeline_name := envs.DISTILABEL_PIPELINE_NAME:
        metadata["distilabel_pipeline_name"] = distilabel_pipeline_name

    if distilabel_pipeline_cache_id := envs.DISTILABEL_PIPELINE_CACHE_ID:
        metadata["distilabel_pipeline_cache_id"] = distilabel_pipeline_cache_id

    batch = None
    try:
        batch = self._client.batches.create(
            completion_window="24h",
            endpoint="/v1/chat/completions",
            input_file_id=batch_input_file.id,
            metadata=metadata,
        )
    except openai.OpenAIError as e:
        self._logger.error(  # type: ignore
            f"Error while creating OpenAI Batch API job for file with ID"
            f" '{batch_input_file.id}': {e}."
        )
        raise e
    return batch
_create_batch_files(inputs, **kwargs)

为批量 API 创建必要的输入文件,以生成响应。为了让 OpenAI 批量 API 可以处理,每个文件的最大大小为 100MB,因此如果需要,我们需要将输入拆分为多个文件。

更多信息:https://platform.openai.com/docs/api-reference/files/create

参数

名称 类型 描述 默认值
inputs List[FormattedInput]

聊天格式的输入列表,用于生成响应,可以选择包括结构化输出。

必需
kwargs Any

用于生成的关键字参数。

{}

返回

类型 描述
List[FileObject]

为 OpenAI 批量 API 创建的文件对象列表。

引发

类型 描述
OpenAIError

如果在 OpenAI 批量 API 中创建批量输入文件时发生错误。

源代码位于 src/distilabel/models/llms/openai.py
def _create_batch_files(
    self, inputs: List["FormattedInput"], **kwargs: Any
) -> List["OpenAIFileObject"]:
    """Creates the necessary input files for the batch API to generate responses. The
    maximum size of each file so the OpenAI Batch API can process it is 100MB, so we
    need to split the inputs into multiple files if necessary.

    More information: https://platform.openai.com/docs/api-reference/files/create

    Args:
        inputs: a list of inputs in chat format to generate responses for, optionally
            including structured output.
        kwargs: the keyword arguments to use for the generation.

    Returns:
        The list of file objects created for the OpenAI Batch API.

    Raises:
        openai.OpenAIError: if there was an error while creating the batch input file
            in the OpenAI Batch API.
    """
    import openai

    files = []
    for file_no, buffer in enumerate(
        self._create_jsonl_buffers(inputs=inputs, **kwargs)
    ):
        try:
            # TODO: add distilabel pipeline name and id
            batch_input_file = self._client.files.create(
                file=(self._name_for_openai_files(file_no), buffer),
                purpose="batch",
            )
            files.append(batch_input_file)
        except openai.OpenAIError as e:
            self._logger.error(  # type: ignore
                f"Error while creating OpenAI batch input file: {e}"
            )
            raise e
    return files
_create_jsonl_buffers(inputs, **kwargs)

创建一个缓冲区生成器,其中包含要由 OpenAI 批量 API 使用的 JSONL 格式的输入。创建的缓冲区大小为 100MB 或更小。

参数

名称 类型 描述 默认值
inputs List[FormattedInput]

聊天格式的输入列表,用于生成响应,可以选择包括结构化输出。

必需
kwargs Any

用于生成的关键字参数。

{}

产生

类型 描述
BytesIO

包含要由 OpenAI 批量使用的 JSONL 格式输入的缓冲区

BytesIO

API。

源代码位于 src/distilabel/models/llms/openai.py
def _create_jsonl_buffers(
    self, inputs: List["FormattedInput"], **kwargs: Any
) -> Generator[io.BytesIO, None, None]:
    """Creates a generator of buffers containing the JSONL formatted inputs to be
    used by the OpenAI Batch API. The buffers created are of size 100MB or less.

    Args:
        inputs: a list of inputs in chat format to generate responses for, optionally
            including structured output.
        kwargs: the keyword arguments to use for the generation.

    Yields:
        A buffer containing the JSONL formatted inputs to be used by the OpenAI Batch
        API.
    """
    buffer = io.BytesIO()
    buffer_current_size = 0
    for i, input in enumerate(inputs):
        # We create the smallest `custom_id` so we don't  increase the size of the file
        # to much, but we can still sort the results with the order of the inputs.
        row = self._create_jsonl_row(input=input, custom_id=str(i), **kwargs)
        row_size = len(row)
        if row_size + buffer_current_size > _OPENAI_BATCH_API_MAX_FILE_SIZE:
            buffer.seek(0)
            yield buffer
            buffer = io.BytesIO()
            buffer_current_size = 0
        buffer.write(row)
        buffer_current_size += row_size

    if buffer_current_size > 0:
        buffer.seek(0)
        yield buffer
_create_jsonl_row(input, custom_id, **kwargs)

创建一个要由 OpenAI 批量 API 使用的 JSONL 格式的行。

参数

名称 类型 描述 默认值
input FormattedInput

聊天格式的输入列表,用于生成响应,可以选择包括结构化输出。

必需
custom_id str

用于行的自定义 ID。

必需
kwargs Any

用于生成的关键字参数。

{}

返回

类型 描述
bytes

要由 OpenAI 批量 API 使用的 JSONL 格式的行。

源代码位于 src/distilabel/models/llms/openai.py
def _create_jsonl_row(
    self, input: "FormattedInput", custom_id: str, **kwargs: Any
) -> bytes:
    """Creates a JSONL formatted row to be used by the OpenAI Batch API.

    Args:
        input: a list of inputs in chat format to generate responses for, optionally
            including structured output.
        custom_id: a custom ID to use for the row.
        kwargs: the keyword arguments to use for the generation.

    Returns:
        A JSONL formatted row to be used by the OpenAI Batch API.
    """
    # TODO: depending on the format of the input, add `response_format` to the kwargs
    row = {
        "custom_id": custom_id,
        "method": "POST",
        "url": "/v1/chat/completions",
        "body": {"messages": input, **kwargs},
    }
    json_row = orjson.dumps(row)
    return json_row + b"\n"

TogetherLLM

基类: OpenAILLM

运行 OpenAI 异步 API 客户端的 TogetherLLM LLM 实现。

属性

名称 类型 描述
model str

用于 LLM 的模型名称,例如 “mistralai/Mixtral-8x7B-Instruct-v0.1”。支持的模型可以在 此处 找到。

base_url 可选[RuntimeParameter[str]]

用于 Together API 的基本 URL 可以使用 TOGETHER_BASE_URL 设置。默认为 `None`,这意味着将使用为环境变量 `TOGETHER_BASE_URL` 设置的值,如果未设置,则使用 “https://api.together.xyz/v1”。

api_key 可选[RuntimeParameter[SecretStr]]

用于验证 Together API 请求的 API 密钥。默认为 `None`,这意味着将使用为环境变量 `TOGETHER_API_KEY` 设置的值,如果未设置,则为 `None`。

_api_key_env_var str

用于 API 密钥的环境变量名称。它旨在内部使用。

示例

生成文本

from distilabel.models.llms import AnyscaleLLM

llm = TogetherLLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", api_key="api.key")

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
源代码位于 src/distilabel/models/llms/together.py
class TogetherLLM(OpenAILLM):
    """TogetherLLM LLM implementation running the async API client of OpenAI.

    Attributes:
        model: the model name to use for the LLM e.g. "mistralai/Mixtral-8x7B-Instruct-v0.1".
            Supported models can be found [here](https://api.together.xyz/models).
        base_url: the base URL to use for the Together API can be set with `TOGETHER_BASE_URL`.
            Defaults to `None` which means that the value set for the environment variable
            `TOGETHER_BASE_URL` will be used, or "https://api.together.xyz/v1" if not set.
        api_key: the API key to authenticate the requests to the Together API. Defaults to `None`
            which means that the value set for the environment variable `TOGETHER_API_KEY` will be
            used, or `None` if not set.
        _api_key_env_var: the name of the environment variable to use for the API key. It
            is meant to be used internally.

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import AnyscaleLLM

        llm = TogetherLLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", api_key="api.key")

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```
    """

    base_url: Optional[RuntimeParameter[str]] = Field(
        default_factory=lambda: os.getenv(
            "TOGETHER_BASE_URL", "https://api.together.xyz/v1"
        ),
        description="The base URL to use for the Together API requests.",
    )
    api_key: Optional[RuntimeParameter[SecretStr]] = Field(
        default_factory=lambda: os.getenv(_TOGETHER_API_KEY_ENV_VAR_NAME),
        description="The API key to authenticate the requests to the Together API.",
    )

    _api_key_env_var: str = PrivateAttr(_TOGETHER_API_KEY_ENV_VAR_NAME)

VertexAILLM

基类: AsyncLLM

运行 Gemini 异步 API 客户端的 VertexAI LLM 实现。

  • Gemini API: https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini

要使用 VertexAILLM,必须使用以下方法之一配置 Google Cloud 身份验证

  • 设置 GOOGLE_CLOUD_CREDENTIALS 环境变量
  • 使用 gcloud auth application-default login 命令
  • 使用来自 google-cloud-aiplatform 库的 vertexai.init 函数

属性

名称 类型 描述
model str

用于 LLM 的模型名称,例如 “gemini-1.0-pro”。支持的模型

_aclient Optional[GenerativeModel]

用于 Vertex AI Gemini API 的 GenerativeModel。它旨在内部使用。在 `load` 方法中设置。

图标

:simple-googlecloud

示例

生成文本

from distilabel.models.llms import VertexAILLM

llm = VertexAILLM(model="gemini-1.5-pro")

llm.load()

# Call the model
output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])
源代码位于 src/distilabel/models/llms/vertexai.py
class VertexAILLM(AsyncLLM):
    """VertexAI LLM implementation running the async API clients for Gemini.

    - Gemini API: https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini

    To use the `VertexAILLM` is necessary to have configured the Google Cloud authentication
    using one of these methods:

    - Setting `GOOGLE_CLOUD_CREDENTIALS` environment variable
    - Using `gcloud auth application-default login` command
    - Using `vertexai.init` function from the `google-cloud-aiplatform` library

    Attributes:
        model: the model name to use for the LLM e.g. "gemini-1.0-pro". [Supported models](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models).
        _aclient: the `GenerativeModel` to use for the Vertex AI Gemini API. It is meant
            to be used internally. Set in the `load` method.

    Icon:
        `:simple-googlecloud:`

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import VertexAILLM

        llm = VertexAILLM(model="gemini-1.5-pro")

        llm.load()

        # Call the model
        output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```
    """

    model: str

    _num_generations_param_supported = False

    _aclient: Optional["GenerativeModel"] = PrivateAttr(...)

    def load(self) -> None:
        """Loads the `GenerativeModel` class which has access to `generate_content_async` to benefit from async requests."""
        super().load()

        try:
            from vertexai.generative_models import GenerationConfig, GenerativeModel

            self._generation_config_class = GenerationConfig
        except ImportError as e:
            raise ImportError(
                "vertexai is not installed. Please install it using"
                " `pip install 'distilabel[vertexai]'`."
            ) from e

        if _is_gemini_model(self.model):
            self._aclient = GenerativeModel(model_name=self.model)
        else:
            raise NotImplementedError(
                "`VertexAILLM` is only implemented for `gemini` models that allow for `ChatType` data."
            )

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    def _chattype_to_content(self, input: "StandardInput") -> List["Content"]:
        """Converts a chat type to a list of content items expected by the API.

        Args:
            input: the chat type to be converted.

        Returns:
            List[str]: a list of content items expected by the API.
        """
        from vertexai.generative_models import Content, Part

        contents = []
        for message in input:
            if message["role"] not in ["user", "model"]:
                raise ValueError(
                    "`VertexAILLM only supports the roles 'user' or 'model'."
                )
            contents.append(
                Content(
                    role=message["role"], parts=[Part.from_text(message["content"])]
                )
            )
        return contents

    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: VertexChatType,
        temperature: Optional[float] = None,
        top_p: Optional[float] = None,
        top_k: Optional[int] = None,
        max_output_tokens: Optional[int] = None,
        stop_sequences: Optional[List[str]] = None,
        safety_settings: Optional[Dict[str, Any]] = None,
        tools: Optional[List[Dict[str, Any]]] = None,
    ) -> GenerateOutput:
        """Generates `num_generations` responses for the given input using the [VertexAI async client definition](https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini).

        Args:
            input: a single input in chat format to generate responses for.
            temperature: Controls the randomness of predictions. Range: [0.0, 1.0]. Defaults to `None`.
            top_p: If specified, nucleus sampling will be used. Range: (0.0, 1.0]. Defaults to `None`.
            top_k: If specified, top-k sampling will be used. Defaults to `None`.
            max_output_tokens: The maximum number of output tokens to generate per message. Defaults to `None`.
            stop_sequences: A list of stop sequences. Defaults to `None`.
            safety_settings: Safety configuration for returned content from the API. Defaults to `None`.
            tools: A potential list of tools that can be used by the API. Defaults to `None`.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """
        from vertexai.generative_models import GenerationConfig

        content: "GenerationResponse" = await self._aclient.generate_content_async(  # type: ignore
            contents=self._chattype_to_content(input),
            generation_config=GenerationConfig(
                candidate_count=1,  # only one candidate allowed per call
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                max_output_tokens=max_output_tokens,
                stop_sequences=stop_sequences,
            ),
            safety_settings=safety_settings,  # type: ignore
            tools=tools,  # type: ignore
            stream=False,
        )

        text = None
        try:
            text = content.candidates[0].text
        except ValueError:
            self._logger.warning(  # type: ignore
                f"Received no response using VertexAI client (model: '{self.model}')."
                f" Finish reason was: '{content.candidates[0].finish_reason}'."
            )
        return prepare_output([text], **self._get_llm_statistics(content))

    @staticmethod
    def _get_llm_statistics(content: "GenerationResponse") -> "LLMStatistics":
        return {
            "input_tokens": [content.usage_metadata.prompt_token_count],
            "output_tokens": [content.usage_metadata.candidates_token_count],
        }
model_name 属性

返回用于 LLM 的模型名称。

load()

加载 GenerativeModel 类,该类有权访问 generate_content_async 以利用异步请求的优势。

源代码位于 src/distilabel/models/llms/vertexai.py
def load(self) -> None:
    """Loads the `GenerativeModel` class which has access to `generate_content_async` to benefit from async requests."""
    super().load()

    try:
        from vertexai.generative_models import GenerationConfig, GenerativeModel

        self._generation_config_class = GenerationConfig
    except ImportError as e:
        raise ImportError(
            "vertexai is not installed. Please install it using"
            " `pip install 'distilabel[vertexai]'`."
        ) from e

    if _is_gemini_model(self.model):
        self._aclient = GenerativeModel(model_name=self.model)
    else:
        raise NotImplementedError(
            "`VertexAILLM` is only implemented for `gemini` models that allow for `ChatType` data."
        )
_chattype_to_content(input)

将聊天类型转换为 API 期望的内容项列表。

参数

名称 类型 描述 默认值
input StandardInput

要转换的聊天类型。

必需

返回

类型 描述
List[Content]

List[str]:API 期望的内容项列表。

源代码位于 src/distilabel/models/llms/vertexai.py
def _chattype_to_content(self, input: "StandardInput") -> List["Content"]:
    """Converts a chat type to a list of content items expected by the API.

    Args:
        input: the chat type to be converted.

    Returns:
        List[str]: a list of content items expected by the API.
    """
    from vertexai.generative_models import Content, Part

    contents = []
    for message in input:
        if message["role"] not in ["user", "model"]:
            raise ValueError(
                "`VertexAILLM only supports the roles 'user' or 'model'."
            )
        contents.append(
            Content(
                role=message["role"], parts=[Part.from_text(message["content"])]
            )
        )
    return contents
agenerate(input, temperature=None, top_p=None, top_k=None, max_output_tokens=None, stop_sequences=None, safety_settings=None, tools=None) 异步

为给定输入生成 num_generations 个响应,使用 VertexAI 异步客户端定义

参数

名称 类型 描述 默认值
input VertexChatType

以聊天格式的单个输入,用于生成响应。

必需
temperature Optional[float]

控制预测的随机性。范围:[0.0, 1.0]。默认为 `None`。

None
top_p Optional[float]

如果指定,将使用 nucleus 采样。范围:(0.0, 1.0]。默认为 `None`。

None
top_k Optional[int]

如果指定,将使用 top-k 采样。默认为 `None`。

None
max_output_tokens Optional[int]

每个消息要生成的最大输出 token 数。默认为 `None`。

None
stop_sequences Optional[List[str]]

停止序列列表。默认为 `None`。

None
safety_settings Optional[Dict[str, Any]]

从 API 返回的内容的安全配置。默认为 `None`。

None
tools Optional[List[Dict[str, Any]]]

API 可以使用的一系列潜在工具。默认为 `None`。

None

返回

类型 描述
GenerateOutput

包含每个输入的生成响应的字符串列表的列表。

源代码位于 src/distilabel/models/llms/vertexai.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: VertexChatType,
    temperature: Optional[float] = None,
    top_p: Optional[float] = None,
    top_k: Optional[int] = None,
    max_output_tokens: Optional[int] = None,
    stop_sequences: Optional[List[str]] = None,
    safety_settings: Optional[Dict[str, Any]] = None,
    tools: Optional[List[Dict[str, Any]]] = None,
) -> GenerateOutput:
    """Generates `num_generations` responses for the given input using the [VertexAI async client definition](https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini).

    Args:
        input: a single input in chat format to generate responses for.
        temperature: Controls the randomness of predictions. Range: [0.0, 1.0]. Defaults to `None`.
        top_p: If specified, nucleus sampling will be used. Range: (0.0, 1.0]. Defaults to `None`.
        top_k: If specified, top-k sampling will be used. Defaults to `None`.
        max_output_tokens: The maximum number of output tokens to generate per message. Defaults to `None`.
        stop_sequences: A list of stop sequences. Defaults to `None`.
        safety_settings: Safety configuration for returned content from the API. Defaults to `None`.
        tools: A potential list of tools that can be used by the API. Defaults to `None`.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """
    from vertexai.generative_models import GenerationConfig

    content: "GenerationResponse" = await self._aclient.generate_content_async(  # type: ignore
        contents=self._chattype_to_content(input),
        generation_config=GenerationConfig(
            candidate_count=1,  # only one candidate allowed per call
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            max_output_tokens=max_output_tokens,
            stop_sequences=stop_sequences,
        ),
        safety_settings=safety_settings,  # type: ignore
        tools=tools,  # type: ignore
        stream=False,
    )

    text = None
    try:
        text = content.candidates[0].text
    except ValueError:
        self._logger.warning(  # type: ignore
            f"Received no response using VertexAI client (model: '{self.model}')."
            f" Finish reason was: '{content.candidates[0].finish_reason}'."
        )
    return prepare_output([text], **self._get_llm_statistics(content))

ClientvLLM

Bases: OpenAILLM, MagpieChatTemplateMixin

用于实现 OpenAI API 规范的 vLLM 服务器的客户端。

属性

名称 类型 描述
base_url 可选[RuntimeParameter[str]]

`vLLM` 服务器的基本 URL。默认为 `"http://localhost:8000"`。

max_retries RuntimeParameter[int]

在失败之前重试 API 请求的最大次数。默认为 `6`。

timeout RuntimeParameter[int]

等待来自 API 响应的最长秒数。默认为 120

httpx_client_kwargs RuntimeParameter[int]

将传递给为与 `vLLM` 服务器通信而创建的 httpx.AsyncClient 的额外 kwargs。默认为 `None`。

tokenizer Optional[str]

Hugging Face Hub 仓库 ID 或 tokenizer 的路径,该 tokenizer 将用于应用聊天模板并在将其发送到服务器之前对输入进行 token 化。默认为 `None`。

tokenizer_revision Optional[str]

要加载的 tokenizer 的修订版本。默认为 `None`。

_aclient AsyncOpenAI

用于与 `vLLM` 服务器通信的 httpx.AsyncClient。默认为 `None`。

运行时参数
  • base_url: `vLLM` 服务器的基本 url。默认为 `"http://localhost:8000"`。
  • max_retries: 在失败之前重试 API 请求的最大次数。默认为 `6`。
  • timeout:等待来自 API 响应的最长秒数。默认为 120
  • httpx_client_kwargs: 将传递给为与 `vLLM` 服务器通信而创建的 httpx.AsyncClient 的额外 kwargs。默认为 `None`。

示例

生成文本

from distilabel.models.llms import ClientvLLM

llm = ClientvLLM(
    base_url="http://localhost:8000/v1",
    tokenizer="meta-llama/Meta-Llama-3.1-8B-Instruct"
)

llm.load()

results = llm.generate_outputs(
    inputs=[[{"role": "user", "content": "Hello, how are you?"}]],
    temperature=0.7,
    top_p=1.0,
    max_new_tokens=256,
)
# [
#     [
#         "I'm functioning properly, thank you for asking. How can I assist you today?",
#         "I'm doing well, thank you for asking. I'm a large language model, so I don't have feelings or emotions like humans do, but I'm here to help answer any questions or provide information you might need. How can I assist you today?",
#         "I'm just a computer program, so I don't have feelings like humans do, but I'm functioning properly and ready to help you with any questions or tasks you have. What's on your mind?"
#     ]
# ]
源代码位于 src/distilabel/models/llms/vllm.py
class ClientvLLM(OpenAILLM, MagpieChatTemplateMixin):
    """A client for the `vLLM` server implementing the OpenAI API specification.

    Attributes:
        base_url: the base URL of the `vLLM` server. Defaults to `"http://localhost:8000"`.
        max_retries: the maximum number of times to retry the request to the API before
            failing. Defaults to `6`.
        timeout: the maximum time in seconds to wait for a response from the API. Defaults
            to `120`.
        httpx_client_kwargs: extra kwargs that will be passed to the `httpx.AsyncClient`
            created to comunicate with the `vLLM` server. Defaults to `None`.
        tokenizer: the Hugging Face Hub repo id or path of the tokenizer that will be used
            to apply the chat template and tokenize the inputs before sending it to the
            server. Defaults to `None`.
        tokenizer_revision: the revision of the tokenizer to load. Defaults to `None`.
        _aclient: the `httpx.AsyncClient` used to comunicate with the `vLLM` server. Defaults
            to `None`.

    Runtime parameters:
        - `base_url`: the base url of the `vLLM` server. Defaults to `"http://localhost:8000"`.
        - `max_retries`: the maximum number of times to retry the request to the API before
            failing. Defaults to `6`.
        - `timeout`: the maximum time in seconds to wait for a response from the API. Defaults
            to `120`.
        - `httpx_client_kwargs`: extra kwargs that will be passed to the `httpx.AsyncClient`
            created to comunicate with the `vLLM` server. Defaults to `None`.

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import ClientvLLM

        llm = ClientvLLM(
            base_url="http://localhost:8000/v1",
            tokenizer="meta-llama/Meta-Llama-3.1-8B-Instruct"
        )

        llm.load()

        results = llm.generate_outputs(
            inputs=[[{"role": "user", "content": "Hello, how are you?"}]],
            temperature=0.7,
            top_p=1.0,
            max_new_tokens=256,
        )
        # [
        #     [
        #         "I'm functioning properly, thank you for asking. How can I assist you today?",
        #         "I'm doing well, thank you for asking. I'm a large language model, so I don't have feelings or emotions like humans do, but I'm here to help answer any questions or provide information you might need. How can I assist you today?",
        #         "I'm just a computer program, so I don't have feelings like humans do, but I'm functioning properly and ready to help you with any questions or tasks you have. What's on your mind?"
        #     ]
        # ]
        ```
    """

    model: str = ""  # Default value so it's not needed to `ClientvLLM(model="...")`
    tokenizer: Optional[str] = None
    tokenizer_revision: Optional[str] = None

    # We need the sync client to get the list of models
    _client: "OpenAI" = PrivateAttr(None)
    _tokenizer: "PreTrainedTokenizer" = PrivateAttr(None)

    def load(self) -> None:
        """Creates an `httpx.AsyncClient` to connect to the vLLM server and a tokenizer
        optionally."""

        self.api_key = SecretStr("EMPTY")

        # We need to first create the sync client to get the model name that will be used
        # in the `super().load()` when creating the logger.
        try:
            from openai import OpenAI
        except ImportError as ie:
            raise ImportError(
                "OpenAI Python client is not installed. Please install it using"
                " `pip install 'distilabel[openai]'`."
            ) from ie

        self._client = OpenAI(
            base_url=self.base_url,
            api_key=self.api_key.get_secret_value(),  # type: ignore
            max_retries=self.max_retries,  # type: ignore
            timeout=self.timeout,
        )

        super().load()

        try:
            from transformers import AutoTokenizer
        except ImportError as ie:
            raise ImportError(
                "To use `ClientvLLM` you need to install `transformers`."
                "Please install it using `pip install 'distilabel[hf-transformers]'`."
            ) from ie

        self._tokenizer = AutoTokenizer.from_pretrained(
            self.tokenizer, revision=self.tokenizer_revision
        )

    @cached_property
    def model_name(self) -> str:  # type: ignore
        """Returns the name of the model served with vLLM server."""
        models = self._client.models.list()
        return models.data[0].id

    def _prepare_input(self, input: "StandardInput") -> str:
        """Prepares the input (applying the chat template and tokenization) for the provided
        input.

        Args:
            input: the input list containing chat items.

        Returns:
            The prompt to send to the LLM.
        """
        prompt: str = (
            self._tokenizer.apply_chat_template(  # type: ignore
                input,  # type: ignore
                tokenize=False,
                add_generation_prompt=True,  # type: ignore
            )
            if input
            else ""
        )
        return super().apply_magpie_pre_query_template(prompt, input)

    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: FormattedInput,
        num_generations: int = 1,
        max_new_tokens: int = 128,
        frequency_penalty: float = 0.0,
        logit_bias: Optional[Dict[str, int]] = None,
        presence_penalty: float = 0.0,
        temperature: float = 1.0,
        top_p: float = 1.0,
    ) -> GenerateOutput:
        """Generates `num_generations` responses for each input.

        Args:
            input: a single input in chat format to generate responses for.
            num_generations: the number of generations to create per input. Defaults to
                `1`.
            max_new_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `128`.
            frequency_penalty: the repetition penalty to use for the generation. Defaults
                to `0.0`.
            logit_bias: modify the likelihood of specified tokens appearing in the completion.
                Defaults to ``
            presence_penalty: the presence penalty to use for the generation. Defaults to
                `0.0`.
            temperature: the temperature to use for the generation. Defaults to `0.1`.
            top_p: nucleus sampling. The value refers to the top-p tokens that should be
                considered for sampling. Defaults to `1.0`.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """

        completion = await self._aclient.completions.create(
            model=self.model_name,
            prompt=self._prepare_input(input),  # type: ignore
            n=num_generations,
            max_tokens=max_new_tokens,
            frequency_penalty=frequency_penalty,
            logit_bias=logit_bias,
            presence_penalty=presence_penalty,
            temperature=temperature,
            top_p=top_p,
        )

        generations = []
        for choice in completion.choices:
            text = choice.text
            if text == "":
                self._logger.warning(  # type: ignore
                    f"Received no response from vLLM server (model: '{self.model_name}')."
                    f" Finish reason was: {choice.finish_reason}"
                )
            generations.append(text)

        return prepare_output(generations, **self._get_llm_statistics(completion))
model_name 已缓存 属性

返回 vLLM 服务器提供的模型名称。

load()

创建一个 httpx.AsyncClient 以连接到 vLLM 服务器,并可选择性地连接 tokenizer。

源代码位于 src/distilabel/models/llms/vllm.py
def load(self) -> None:
    """Creates an `httpx.AsyncClient` to connect to the vLLM server and a tokenizer
    optionally."""

    self.api_key = SecretStr("EMPTY")

    # We need to first create the sync client to get the model name that will be used
    # in the `super().load()` when creating the logger.
    try:
        from openai import OpenAI
    except ImportError as ie:
        raise ImportError(
            "OpenAI Python client is not installed. Please install it using"
            " `pip install 'distilabel[openai]'`."
        ) from ie

    self._client = OpenAI(
        base_url=self.base_url,
        api_key=self.api_key.get_secret_value(),  # type: ignore
        max_retries=self.max_retries,  # type: ignore
        timeout=self.timeout,
    )

    super().load()

    try:
        from transformers import AutoTokenizer
    except ImportError as ie:
        raise ImportError(
            "To use `ClientvLLM` you need to install `transformers`."
            "Please install it using `pip install 'distilabel[hf-transformers]'`."
        ) from ie

    self._tokenizer = AutoTokenizer.from_pretrained(
        self.tokenizer, revision=self.tokenizer_revision
    )
_prepare_input(input)

为提供的输入准备输入(应用聊天模板和分词)。

参数

名称 类型 描述 默认值
input StandardInput

包含聊天项的输入列表。

必需

返回

类型 描述
str

要发送给 LLM 的提示。

源代码位于 src/distilabel/models/llms/vllm.py
def _prepare_input(self, input: "StandardInput") -> str:
    """Prepares the input (applying the chat template and tokenization) for the provided
    input.

    Args:
        input: the input list containing chat items.

    Returns:
        The prompt to send to the LLM.
    """
    prompt: str = (
        self._tokenizer.apply_chat_template(  # type: ignore
            input,  # type: ignore
            tokenize=False,
            add_generation_prompt=True,  # type: ignore
        )
        if input
        else ""
    )
    return super().apply_magpie_pre_query_template(prompt, input)
agenerate(input, num_generations=1, max_new_tokens=128, frequency_penalty=0.0, logit_bias=None, presence_penalty=0.0, temperature=1.0, top_p=1.0) async

为每个输入生成 num_generations 个响应。

参数

名称 类型 描述 默认值
input FormattedInput

以聊天格式的单个输入,用于生成响应。

必需
num_generations int

每个输入要创建的代数。默认为 1

1
max_new_tokens int

模型将生成的最大新 token 数。默认为 128

128
frequency_penalty float

用于生成的重复惩罚。默认为 0.0

0.0
logit_bias Optional[Dict[str, int]]

修改指定 token 出现在补全中的可能性。默认为 ``

None
presence_penalty float

用于存在的惩罚。默认为 0.0

0.0
temperature float

用于生成的温度。默认为 0.1

1.0
top_p float

nucleus 采样。该值指的是应该被考虑用于采样的 top-p token。默认为 1.0

1.0

返回

类型 描述
GenerateOutput

包含每个输入的生成响应的字符串列表的列表。

源代码位于 src/distilabel/models/llms/vllm.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: FormattedInput,
    num_generations: int = 1,
    max_new_tokens: int = 128,
    frequency_penalty: float = 0.0,
    logit_bias: Optional[Dict[str, int]] = None,
    presence_penalty: float = 0.0,
    temperature: float = 1.0,
    top_p: float = 1.0,
) -> GenerateOutput:
    """Generates `num_generations` responses for each input.

    Args:
        input: a single input in chat format to generate responses for.
        num_generations: the number of generations to create per input. Defaults to
            `1`.
        max_new_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `128`.
        frequency_penalty: the repetition penalty to use for the generation. Defaults
            to `0.0`.
        logit_bias: modify the likelihood of specified tokens appearing in the completion.
            Defaults to ``
        presence_penalty: the presence penalty to use for the generation. Defaults to
            `0.0`.
        temperature: the temperature to use for the generation. Defaults to `0.1`.
        top_p: nucleus sampling. The value refers to the top-p tokens that should be
            considered for sampling. Defaults to `1.0`.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """

    completion = await self._aclient.completions.create(
        model=self.model_name,
        prompt=self._prepare_input(input),  # type: ignore
        n=num_generations,
        max_tokens=max_new_tokens,
        frequency_penalty=frequency_penalty,
        logit_bias=logit_bias,
        presence_penalty=presence_penalty,
        temperature=temperature,
        top_p=top_p,
    )

    generations = []
    for choice in completion.choices:
        text = choice.text
        if text == "":
            self._logger.warning(  # type: ignore
                f"Received no response from vLLM server (model: '{self.model_name}')."
                f" Finish reason was: {choice.finish_reason}"
            )
        generations.append(text)

    return prepare_output(generations, **self._get_llm_statistics(completion))

vLLM

基类: LLMMagpieChatTemplateMixinCudaDevicePlacementMixin

vLLM 库 LLM 实现。

属性

名称 类型 描述
model str

模型 Hugging Face Hub repo id 或包含模型权重和配置文件的目录路径。

dtype str

模型使用的数据类型。默认为 auto

trust_remote_code bool

加载模型时是否信任远程代码。默认为 False

quantization Optional[str]

模型使用的量化模式。默认为 None

revision Optional[str]

要加载的模型修订版本。默认为 None

tokenizer Optional[str]

tokenizer Hugging Face Hub 仓库 ID 或包含 tokenizer 文件的目录路径。如果未提供,tokenizer 将从模型目录加载。默认为 None

tokenizer_mode Literal['auto', 'slow']

tokenizer 使用的模式。默认为 auto

tokenizer_revision Optional[str]

要加载的 tokenizer 的修订版本。默认为 `None`。

skip_tokenizer_init bool

是否跳过 tokenizer 的初始化。默认为 False

chat_template Optional[str]

将用于构建提示的聊天模板,然后再将其发送到模型。 如果未提供,将使用 tokenizer 配置中定义的聊天模板。 如果未提供且 tokenizer 没有聊天模板,则将使用 ChatML 模板。默认为 None

structured_output Optional[RuntimeParameter[OutlinesStructuredOutputType]]

一个字典,包含结构化输出配置;如果需要更细粒度的控制,则包含 OutlinesStructuredOutput 的实例。默认为 None。

seed int

用于随机数生成器的种子。默认为 0

extra_kwargs Optional[RuntimeParameter[Dict[str, Any]]]

将传递给 vllm 库的 LLM 类的关键字参数的附加字典。默认为 {}

_model LLM

vLLM 模型实例。此属性旨在在内部使用,不应直接访问。它将在 load 方法中设置。

_tokenizer PreTrainedTokenizer

用于在将 prompt 传递给 LLM 之前格式化 prompt 的 tokenizer 实例。此属性旨在在内部使用,不应直接访问。它将在 load 方法中设置。

use_magpie_template bool

用于启用/禁用应用 Magpie 预查询模板的标志。默认为 False

magpie_pre_query_template Union[MagpieAvailablePreQueryTemplates, str, None]

要应用于提示或发送到 LLM 以生成指令或后续用户消息的预查询模板。有效值为 "llama3"、"qwen2" 或提供的另一个预查询模板。默认为 None

参考
  • https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py
运行时参数
  • extra_kwargs:将传递给 vllm 库的 LLM 类的关键字参数的附加字典。

示例

生成文本

from distilabel.models.llms import vLLM

# You can pass a custom chat_template to the model
llm = vLLM(
    model="prometheus-eval/prometheus-7b-v2.0",
    chat_template="[INST] {{ messages[0]"content" }}\n{{ messages[1]"content" }}[/INST]",
)

llm.load()

# Call the model
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

生成结构化数据

from pathlib import Path
from distilabel.models.llms import vLLM

class User(BaseModel):
    name: str
    last_name: str
    id: int

llm = vLLM(
    model="prometheus-eval/prometheus-7b-v2.0"
    structured_output={"format": "json", "schema": Character},
)

llm.load()

# Call the model
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
源代码位于 src/distilabel/models/llms/vllm.py
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
class vLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin):
    """`vLLM` library LLM implementation.

    Attributes:
        model: the model Hugging Face Hub repo id or a path to a directory containing the
            model weights and configuration files.
        dtype: the data type to use for the model. Defaults to `auto`.
        trust_remote_code: whether to trust the remote code when loading the model. Defaults
            to `False`.
        quantization: the quantization mode to use for the model. Defaults to `None`.
        revision: the revision of the model to load. Defaults to `None`.
        tokenizer: the tokenizer Hugging Face Hub repo id or a path to a directory containing
            the tokenizer files. If not provided, the tokenizer will be loaded from the
            model directory. Defaults to `None`.
        tokenizer_mode: the mode to use for the tokenizer. Defaults to `auto`.
        tokenizer_revision: the revision of the tokenizer to load. Defaults to `None`.
        skip_tokenizer_init: whether to skip the initialization of the tokenizer. Defaults
            to `False`.
        chat_template: a chat template that will be used to build the prompts before
            sending them to the model. If not provided, the chat template defined in the
            tokenizer config will be used. If not provided and the tokenizer doesn't have
            a chat template, then ChatML template will be used. Defaults to `None`.
        structured_output: a dictionary containing the structured output configuration or if more
            fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None.
        seed: the seed to use for the random number generator. Defaults to `0`.
        extra_kwargs: additional dictionary of keyword arguments that will be passed to the
            `LLM` class of `vllm` library. Defaults to `{}`.
        _model: the `vLLM` model instance. This attribute is meant to be used internally
            and should not be accessed directly. It will be set in the `load` method.
        _tokenizer: the tokenizer instance used to format the prompt before passing it to
            the `LLM`. This attribute is meant to be used internally and should not be
            accessed directly. It will be set in the `load` method.
        use_magpie_template: a flag used to enable/disable applying the Magpie pre-query
            template. Defaults to `False`.
        magpie_pre_query_template: the pre-query template to be applied to the prompt or
            sent to the LLM to generate an instruction or a follow up user message. Valid
            values are "llama3", "qwen2" or another pre-query template provided. Defaults
            to `None`.

    References:
        - https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py

    Runtime parameters:
        - `extra_kwargs`: additional dictionary of keyword arguments that will be passed to
            the `LLM` class of `vllm` library.

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import vLLM

        # You can pass a custom chat_template to the model
        llm = vLLM(
            model="prometheus-eval/prometheus-7b-v2.0",
            chat_template="[INST] {{ messages[0]\"content\" }}\\n{{ messages[1]\"content\" }}[/INST]",
        )

        llm.load()

        # Call the model
        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```

        Generate structured data:

        ```python
        from pathlib import Path
        from distilabel.models.llms import vLLM

        class User(BaseModel):
            name: str
            last_name: str
            id: int

        llm = vLLM(
            model="prometheus-eval/prometheus-7b-v2.0"
            structured_output={"format": "json", "schema": Character},
        )

        llm.load()

        # Call the model
        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
        ```
    """

    model: str
    dtype: str = "auto"
    trust_remote_code: bool = False
    quantization: Optional[str] = None
    revision: Optional[str] = None

    tokenizer: Optional[str] = None
    tokenizer_mode: Literal["auto", "slow"] = "auto"
    tokenizer_revision: Optional[str] = None
    skip_tokenizer_init: bool = False
    chat_template: Optional[str] = None

    seed: int = 0

    extra_kwargs: Optional[RuntimeParameter[Dict[str, Any]]] = Field(
        default_factory=dict,
        description="Additional dictionary of keyword arguments that will be passed to the"
        " `vLLM` class of `vllm` library. See all the supported arguments at: "
        "https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py",
    )
    structured_output: Optional[RuntimeParameter[OutlinesStructuredOutputType]] = Field(
        default=None,
        description="The structured output format to use across all the generations.",
    )

    _model: "_vLLM" = PrivateAttr(None)
    _tokenizer: "PreTrainedTokenizer" = PrivateAttr(None)
    _structured_output_logits_processor: Optional[Callable] = PrivateAttr(default=None)

    def load(self) -> None:
        """Loads the `vLLM` model using either the path or the Hugging Face Hub repository id.
        Additionally, this method also sets the `chat_template` for the tokenizer, so as to properly
        parse the list of OpenAI formatted inputs using the expected format by the model, otherwise, the
        default value is ChatML format, unless explicitly provided.
        """
        super().load()

        CudaDevicePlacementMixin.load(self)

        try:
            from vllm import LLM as _vLLM
        except ImportError as ie:
            raise ImportError(
                "vLLM is not installed. Please install it using `pip install 'distilabel[vllm]'`."
            ) from ie

        self._model = _vLLM(
            self.model,
            dtype=self.dtype,
            trust_remote_code=self.trust_remote_code,
            quantization=self.quantization,
            revision=self.revision,
            tokenizer=self.tokenizer,
            tokenizer_mode=self.tokenizer_mode,
            tokenizer_revision=self.tokenizer_revision,
            skip_tokenizer_init=self.skip_tokenizer_init,
            seed=self.seed,
            **self.extra_kwargs,  # type: ignore
        )

        self._tokenizer = self._model.get_tokenizer()  # type: ignore
        if self.chat_template is not None:
            self._tokenizer.chat_template = self.chat_template  # type: ignore

        if self.structured_output:
            self._structured_output_logits_processor = self._prepare_structured_output(
                self.structured_output
            )

    def unload(self) -> None:
        """Unloads the `vLLM` model."""
        self._cleanup_vllm_model()
        self._model = None  # type: ignore
        self._tokenizer = None  # type: ignore
        CudaDevicePlacementMixin.unload(self)
        super().unload()

    def _cleanup_vllm_model(self) -> None:
        if self._model is None:
            return

        import torch  # noqa
        from vllm.distributed.parallel_state import (
            destroy_distributed_environment,
            destroy_model_parallel,
        )

        destroy_model_parallel()
        destroy_distributed_environment()
        del self._model.llm_engine.model_executor
        del self._model
        with contextlib.suppress(AssertionError):
            torch.distributed.destroy_process_group()
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    def prepare_input(self, input: Union["StandardInput", str]) -> str:
        """Prepares the input (applying the chat template and tokenization) for the provided
        input.

        Args:
            input: the input list containing chat items.

        Returns:
            The prompt to send to the LLM.
        """
        if isinstance(input, str):
            return input

        prompt: str = (
            self._tokenizer.apply_chat_template(
                input,  # type: ignore
                tokenize=False,
                add_generation_prompt=True,  # type: ignore
            )
            if input
            else ""
        )
        return super().apply_magpie_pre_query_template(prompt, input)

    def _prepare_batches(
        self, inputs: List["StructuredInput"]
    ) -> Tuple[List[Tuple[List[str], "OutlinesStructuredOutputType"]], List[int]]:
        """Prepares the inputs by grouping them by the structured output.

        When we generate structured outputs with schemas obtained from a dataset, we need to
        prepare the data to try to send batches of inputs instead of single inputs to the model
        to take advante of the engine. So we group the inputs by the structured output to be
        passed in the `generate` method.

        Args:
            inputs: The batch of inputs passed to the generate method. As we expect to be generating
                structured outputs, each element will be a tuple containing the instruction and the
                structured output.

        Returns:
            The prepared batches (sub-batches let's say) to be passed to the `generate` method.
            Each new tuple will contain instead of the single instruction, a list of instructions
        """
        instruction_order = {}
        batches: Dict[str, List[str]] = {}
        for i, (instruction, structured_output) in enumerate(inputs):
            instruction = self.prepare_input(instruction)
            instruction_order[instruction] = i

            structured_output = json.dumps(structured_output)
            if structured_output not in batches:
                batches[structured_output] = [instruction]
            else:
                batches[structured_output].append(instruction)

        # Built a list with instructions sorted by structured output
        flat_instructions = [
            instruction for _, group in batches.items() for instruction in group
        ]

        # Generate the list of indices based on the original order
        sorted_indices = [
            instruction_order[instruction] for instruction in flat_instructions
        ]

        return [
            (batch, json.loads(schema)) for schema, batch in batches.items()
        ], sorted_indices

    @validate_call
    def generate(  # noqa: C901 # type: ignore
        self,
        inputs: List[FormattedInput],
        num_generations: int = 1,
        max_new_tokens: int = 128,
        presence_penalty: float = 0.0,
        frequency_penalty: float = 0.0,
        repetition_penalty: float = 1.0,
        temperature: float = 1.0,
        top_p: float = 1.0,
        top_k: int = -1,
        min_p: float = 0.0,
        logprobs: Optional[PositiveInt] = None,
        stop: Optional[List[str]] = None,
        stop_token_ids: Optional[List[int]] = None,
        include_stop_str_in_output: bool = False,
        skip_special_tokens: bool = True,
        logits_processors: Optional[LogitsProcessors] = None,
        extra_sampling_params: Optional[Dict[str, Any]] = None,
        echo: bool = False,
    ) -> List[GenerateOutput]:
        """Generates `num_generations` responses for each input.

        Args:
            inputs: a list of inputs in chat format to generate responses for.
            num_generations: the number of generations to create per input. Defaults to
                `1`.
            max_new_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `128`.
            presence_penalty: the presence penalty to use for the generation. Defaults to
                `0.0`.
            frequency_penalty: the repetition penalty to use for the generation. Defaults
                to `0.0`.
            repetition_penalty: the repetition penalty to use for the generation Defaults to
                `1.0`.
            temperature: the temperature to use for the generation. Defaults to `0.1`.
            top_p: the top-p value to use for the generation. Defaults to `1.0`.
            top_k: the top-k value to use for the generation. Defaults to `0`.
            min_p: the minimum probability to use for the generation. Defaults to `0.0`.
            logprobs: number of log probabilities to return per output token. If `None`,
                then no log probability won't be returned. Defaults to `None`.
            stop: a list of strings that will be used to stop the generation when found.
                Defaults to `None`.
            stop_token_ids: a list of token ids that will be used to stop the generation
                when found. Defaults to `None`.
            include_stop_str_in_output: whether to include the stop string in the output.
                Defaults to `False`.
            skip_special_tokens: whether to exclude special tokens from the output. Defaults
                to `False`.
            logits_processors: a list of functions to process the logits before sampling.
                Defaults to `None`.
            extra_sampling_params: dictionary with additional arguments to be passed to
                the `SamplingParams` class from `vllm`.
            echo: whether to echo the include the prompt in the response or not. Defaults
                to `False`.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """
        from vllm import SamplingParams

        if not logits_processors:
            logits_processors = []

        if extra_sampling_params is None:
            extra_sampling_params = {}

        structured_output = None

        if isinstance(inputs[0], tuple):
            # Prepare the batches for structured generation
            prepared_batches, sorted_indices = self._prepare_batches(inputs)  # type: ignore
        else:
            # Simulate a batch without the structured output content
            prepared_batches = [([self.prepare_input(input) for input in inputs], None)]  # type: ignore
            sorted_indices = None

        # Case in which we have a single structured output for the dataset
        if self._structured_output_logits_processor:
            logits_processors.append(self._structured_output_logits_processor)

        batched_outputs: List["LLMOutput"] = []
        generations = []

        for prepared_inputs, structured_output in prepared_batches:
            if self.structured_output is not None and structured_output is not None:
                self._logger.warning(
                    "An `structured_output` was provided in the model configuration, but"
                    " one was also provided in the input. The input structured output will"
                    " be used."
                )

            if structured_output is not None:
                logits_processors.append(
                    self._prepare_structured_output(structured_output)  # type: ignore
                )

            sampling_params = SamplingParams(  # type: ignore
                n=num_generations,
                presence_penalty=presence_penalty,
                frequency_penalty=frequency_penalty,
                repetition_penalty=repetition_penalty,
                temperature=temperature,
                top_p=top_p,
                top_k=top_k,
                min_p=min_p,
                max_tokens=max_new_tokens,
                prompt_logprobs=logprobs if echo else None,
                logprobs=logprobs,
                stop=stop,
                stop_token_ids=stop_token_ids,
                include_stop_str_in_output=include_stop_str_in_output,
                skip_special_tokens=skip_special_tokens,
                logits_processors=logits_processors,
                **extra_sampling_params,
            )

            batch_outputs: List["RequestOutput"] = self._model.generate(
                prompts=prepared_inputs,
                sampling_params=sampling_params,
                use_tqdm=False,
            )

            # Remove structured output logit processor to avoid stacking structured output
            # logits processors that leads to non-sense generations
            if structured_output is not None:
                logits_processors.pop(-1)

            for input, outputs in zip(prepared_inputs, batch_outputs):
                processed_prompt_logprobs = []
                if outputs.prompt_logprobs is not None:
                    processed_prompt_logprobs = self._get_llm_logprobs(
                        outputs.prompt_logprobs
                    )
                texts, statistics, outputs_logprobs = self._process_outputs(
                    input=input,
                    outputs=outputs,
                    echo=echo,
                    prompt_logprobs=processed_prompt_logprobs,
                )
                batched_outputs.append(texts)
                generation = prepare_output(
                    generations=texts,
                    input_tokens=statistics["input_tokens"],
                    output_tokens=statistics["output_tokens"],
                    logprobs=outputs_logprobs,
                )

                generations.append(generation)

        if sorted_indices is not None:
            pairs = list(enumerate(sorted_indices))
            pairs.sort(key=lambda x: x[1])
            generations = [generations[original_idx] for original_idx, _ in pairs]

        return generations

    def _process_outputs(
        self,
        input: str,
        outputs: "RequestOutput",
        prompt_logprobs: List[List["Logprob"]],
        echo: bool = False,
    ) -> Tuple["LLMOutput", "LLMStatistics", "LLMLogprobs"]:
        texts = []
        outputs_logprobs = []
        statistics = {
            "input_tokens": [compute_tokens(input, self._tokenizer.encode)]
            * len(outputs.outputs),
            "output_tokens": [],
        }
        for output in outputs.outputs:
            text = output.text
            if echo:
                text = input + text
            texts.append(text)
            statistics["output_tokens"].append(len(output.token_ids))
            if output.logprobs is not None:
                processed_output_logprobs = self._get_llm_logprobs(output.logprobs)
                outputs_logprobs.append(prompt_logprobs + processed_output_logprobs)
        return texts, statistics, outputs_logprobs

    def _prepare_structured_output(  # type: ignore
        self, structured_output: "OutlinesStructuredOutputType"
    ) -> Union[Callable, None]:
        """Creates the appropriate function to filter tokens to generate structured outputs.

        Args:
            structured_output: the configuration dict to prepare the structured output.

        Returns:
            The callable that will be used to guide the generation of the model.
        """
        from distilabel.steps.tasks.structured_outputs.outlines import (
            prepare_guided_output,
        )

        assert structured_output is not None, "`structured_output` cannot be `None`"

        result = prepare_guided_output(structured_output, "vllm", self._model)
        if (schema := result.get("schema")) and self.structured_output:
            self.structured_output["schema"] = schema
        return result["processor"]

    def _get_llm_logprobs(
        self, logprobs: Union["PromptLogprobs", "SampleLogprobs"]
    ) -> List[List["Logprob"]]:
        processed_logprobs = []
        for token_logprob in logprobs:  # type: ignore
            token_logprobs = []
            if token_logprob is None:
                processed_logprobs.append(None)
                continue
            for logprob in token_logprob.values():
                token_logprobs.append(
                    {"token": logprob.decoded_token, "logprob": logprob.logprob}
                )
            processed_logprobs.append(token_logprobs)
        return processed_logprobs
model_name property

返回用于 LLM 的模型名称。

load()

使用路径或 Hugging Face Hub 仓库 ID 加载 vLLM 模型。此外,此方法还为 tokenizer 设置 chat_template,以便正确解析 OpenAI 格式的输入列表,使用模型期望的格式,否则,默认值为 ChatML 格式,除非显式提供。

源代码位于 src/distilabel/models/llms/vllm.py
def load(self) -> None:
    """Loads the `vLLM` model using either the path or the Hugging Face Hub repository id.
    Additionally, this method also sets the `chat_template` for the tokenizer, so as to properly
    parse the list of OpenAI formatted inputs using the expected format by the model, otherwise, the
    default value is ChatML format, unless explicitly provided.
    """
    super().load()

    CudaDevicePlacementMixin.load(self)

    try:
        from vllm import LLM as _vLLM
    except ImportError as ie:
        raise ImportError(
            "vLLM is not installed. Please install it using `pip install 'distilabel[vllm]'`."
        ) from ie

    self._model = _vLLM(
        self.model,
        dtype=self.dtype,
        trust_remote_code=self.trust_remote_code,
        quantization=self.quantization,
        revision=self.revision,
        tokenizer=self.tokenizer,
        tokenizer_mode=self.tokenizer_mode,
        tokenizer_revision=self.tokenizer_revision,
        skip_tokenizer_init=self.skip_tokenizer_init,
        seed=self.seed,
        **self.extra_kwargs,  # type: ignore
    )

    self._tokenizer = self._model.get_tokenizer()  # type: ignore
    if self.chat_template is not None:
        self._tokenizer.chat_template = self.chat_template  # type: ignore

    if self.structured_output:
        self._structured_output_logits_processor = self._prepare_structured_output(
            self.structured_output
        )
unload()

卸载 vLLM 模型。

源代码位于 src/distilabel/models/llms/vllm.py
def unload(self) -> None:
    """Unloads the `vLLM` model."""
    self._cleanup_vllm_model()
    self._model = None  # type: ignore
    self._tokenizer = None  # type: ignore
    CudaDevicePlacementMixin.unload(self)
    super().unload()
prepare_input(input)

为提供的输入准备输入(应用聊天模板和分词)。

参数

名称 类型 描述 默认值
input Union[StandardInput, str]

包含聊天项的输入列表。

必需

返回

类型 描述
str

要发送给 LLM 的提示。

源代码位于 src/distilabel/models/llms/vllm.py
def prepare_input(self, input: Union["StandardInput", str]) -> str:
    """Prepares the input (applying the chat template and tokenization) for the provided
    input.

    Args:
        input: the input list containing chat items.

    Returns:
        The prompt to send to the LLM.
    """
    if isinstance(input, str):
        return input

    prompt: str = (
        self._tokenizer.apply_chat_template(
            input,  # type: ignore
            tokenize=False,
            add_generation_prompt=True,  # type: ignore
        )
        if input
        else ""
    )
    return super().apply_magpie_pre_query_template(prompt, input)
_prepare_batches(inputs)

通过按结构化输出对输入进行分组来准备输入。

当我们使用从数据集获得的模式生成结构化输出时,我们需要准备数据,尝试将批量输入而不是单个输入发送到模型,以利用引擎的优势。因此,我们将输入按结构化输出分组,以便在 generate 方法中传递。

参数

名称 类型 描述 默认值
inputs List[StructuredInput]

传递给 generate 方法的输入批次。由于我们期望生成结构化输出,因此每个元素都将是一个元组,其中包含指令和结构化输出。

必需

返回

类型 描述
List[Tuple[List[str], OutlinesStructuredOutputType]]

准备好的批次(子批次)将传递给 generate 方法。

List[int]

每个新元组将包含指令列表,而不是单个指令

源代码位于 src/distilabel/models/llms/vllm.py
def _prepare_batches(
    self, inputs: List["StructuredInput"]
) -> Tuple[List[Tuple[List[str], "OutlinesStructuredOutputType"]], List[int]]:
    """Prepares the inputs by grouping them by the structured output.

    When we generate structured outputs with schemas obtained from a dataset, we need to
    prepare the data to try to send batches of inputs instead of single inputs to the model
    to take advante of the engine. So we group the inputs by the structured output to be
    passed in the `generate` method.

    Args:
        inputs: The batch of inputs passed to the generate method. As we expect to be generating
            structured outputs, each element will be a tuple containing the instruction and the
            structured output.

    Returns:
        The prepared batches (sub-batches let's say) to be passed to the `generate` method.
        Each new tuple will contain instead of the single instruction, a list of instructions
    """
    instruction_order = {}
    batches: Dict[str, List[str]] = {}
    for i, (instruction, structured_output) in enumerate(inputs):
        instruction = self.prepare_input(instruction)
        instruction_order[instruction] = i

        structured_output = json.dumps(structured_output)
        if structured_output not in batches:
            batches[structured_output] = [instruction]
        else:
            batches[structured_output].append(instruction)

    # Built a list with instructions sorted by structured output
    flat_instructions = [
        instruction for _, group in batches.items() for instruction in group
    ]

    # Generate the list of indices based on the original order
    sorted_indices = [
        instruction_order[instruction] for instruction in flat_instructions
    ]

    return [
        (batch, json.loads(schema)) for schema, batch in batches.items()
    ], sorted_indices
generate(inputs, num_generations=1, max_new_tokens=128, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=1.0, top_p=1.0, top_k=-1, min_p=0.0, logprobs=None, stop=None, stop_token_ids=None, include_stop_str_in_output=False, skip_special_tokens=True, logits_processors=None, extra_sampling_params=None, echo=False)

为每个输入生成 num_generations 个响应。

参数

名称 类型 描述 默认值
inputs List[FormattedInput]

聊天格式的输入列表,用于生成响应。

必需
num_generations int

每个输入要创建的代数。默认为 1

1
max_new_tokens int

模型将生成的最大新 token 数。默认为 128

128
presence_penalty float

用于存在的惩罚。默认为 0.0

0.0
frequency_penalty float

用于生成的重复惩罚。默认为 0.0

0.0
repetition_penalty float

用于生成的重复惩罚。默认为 1.0

1.0
temperature float

用于生成的温度。默认为 0.1

1.0
top_p float

用于生成的 top-p 值。默认为 1.0

1.0
top_k int

用于生成的 top-k 值。默认为 0

-1
min_p float

用于生成的最小概率。默认为 0.0

0.0
logprobs Optional[PositiveInt]

每个输出 token 返回的对数概率数。如果为 None,则不会返回对数概率。默认为 None

None
stop Optional[List[str]]

一个字符串列表,当找到这些字符串时,将用于停止生成。默认为 None

None
stop_token_ids Optional[List[int]]

一个 token ID 列表,当找到这些 token ID 时,将用于停止生成。默认为 None

None
include_stop_str_in_output bool

是否在输出中包含停止字符串。默认为 False

False
skip_special_tokens bool

是否从输出中排除特殊 token。默认为 False

True
logits_processors Optional[LogitsProcessors]

一个函数列表,用于在采样之前处理 logits。默认为 None

None
extra_sampling_params Optional[Dict[str, Any]]

包含要传递给 vllm 中的 SamplingParams 类的附加参数的字典。

None
echo bool

是否在响应中回显包含 prompt。默认为 False

False

返回

类型 描述
List[GenerateOutput]

包含每个输入的生成响应的字符串列表的列表。

源代码位于 src/distilabel/models/llms/vllm.py
@validate_call
def generate(  # noqa: C901 # type: ignore
    self,
    inputs: List[FormattedInput],
    num_generations: int = 1,
    max_new_tokens: int = 128,
    presence_penalty: float = 0.0,
    frequency_penalty: float = 0.0,
    repetition_penalty: float = 1.0,
    temperature: float = 1.0,
    top_p: float = 1.0,
    top_k: int = -1,
    min_p: float = 0.0,
    logprobs: Optional[PositiveInt] = None,
    stop: Optional[List[str]] = None,
    stop_token_ids: Optional[List[int]] = None,
    include_stop_str_in_output: bool = False,
    skip_special_tokens: bool = True,
    logits_processors: Optional[LogitsProcessors] = None,
    extra_sampling_params: Optional[Dict[str, Any]] = None,
    echo: bool = False,
) -> List[GenerateOutput]:
    """Generates `num_generations` responses for each input.

    Args:
        inputs: a list of inputs in chat format to generate responses for.
        num_generations: the number of generations to create per input. Defaults to
            `1`.
        max_new_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `128`.
        presence_penalty: the presence penalty to use for the generation. Defaults to
            `0.0`.
        frequency_penalty: the repetition penalty to use for the generation. Defaults
            to `0.0`.
        repetition_penalty: the repetition penalty to use for the generation Defaults to
            `1.0`.
        temperature: the temperature to use for the generation. Defaults to `0.1`.
        top_p: the top-p value to use for the generation. Defaults to `1.0`.
        top_k: the top-k value to use for the generation. Defaults to `0`.
        min_p: the minimum probability to use for the generation. Defaults to `0.0`.
        logprobs: number of log probabilities to return per output token. If `None`,
            then no log probability won't be returned. Defaults to `None`.
        stop: a list of strings that will be used to stop the generation when found.
            Defaults to `None`.
        stop_token_ids: a list of token ids that will be used to stop the generation
            when found. Defaults to `None`.
        include_stop_str_in_output: whether to include the stop string in the output.
            Defaults to `False`.
        skip_special_tokens: whether to exclude special tokens from the output. Defaults
            to `False`.
        logits_processors: a list of functions to process the logits before sampling.
            Defaults to `None`.
        extra_sampling_params: dictionary with additional arguments to be passed to
            the `SamplingParams` class from `vllm`.
        echo: whether to echo the include the prompt in the response or not. Defaults
            to `False`.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """
    from vllm import SamplingParams

    if not logits_processors:
        logits_processors = []

    if extra_sampling_params is None:
        extra_sampling_params = {}

    structured_output = None

    if isinstance(inputs[0], tuple):
        # Prepare the batches for structured generation
        prepared_batches, sorted_indices = self._prepare_batches(inputs)  # type: ignore
    else:
        # Simulate a batch without the structured output content
        prepared_batches = [([self.prepare_input(input) for input in inputs], None)]  # type: ignore
        sorted_indices = None

    # Case in which we have a single structured output for the dataset
    if self._structured_output_logits_processor:
        logits_processors.append(self._structured_output_logits_processor)

    batched_outputs: List["LLMOutput"] = []
    generations = []

    for prepared_inputs, structured_output in prepared_batches:
        if self.structured_output is not None and structured_output is not None:
            self._logger.warning(
                "An `structured_output` was provided in the model configuration, but"
                " one was also provided in the input. The input structured output will"
                " be used."
            )

        if structured_output is not None:
            logits_processors.append(
                self._prepare_structured_output(structured_output)  # type: ignore
            )

        sampling_params = SamplingParams(  # type: ignore
            n=num_generations,
            presence_penalty=presence_penalty,
            frequency_penalty=frequency_penalty,
            repetition_penalty=repetition_penalty,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
            max_tokens=max_new_tokens,
            prompt_logprobs=logprobs if echo else None,
            logprobs=logprobs,
            stop=stop,
            stop_token_ids=stop_token_ids,
            include_stop_str_in_output=include_stop_str_in_output,
            skip_special_tokens=skip_special_tokens,
            logits_processors=logits_processors,
            **extra_sampling_params,
        )

        batch_outputs: List["RequestOutput"] = self._model.generate(
            prompts=prepared_inputs,
            sampling_params=sampling_params,
            use_tqdm=False,
        )

        # Remove structured output logit processor to avoid stacking structured output
        # logits processors that leads to non-sense generations
        if structured_output is not None:
            logits_processors.pop(-1)

        for input, outputs in zip(prepared_inputs, batch_outputs):
            processed_prompt_logprobs = []
            if outputs.prompt_logprobs is not None:
                processed_prompt_logprobs = self._get_llm_logprobs(
                    outputs.prompt_logprobs
                )
            texts, statistics, outputs_logprobs = self._process_outputs(
                input=input,
                outputs=outputs,
                echo=echo,
                prompt_logprobs=processed_prompt_logprobs,
            )
            batched_outputs.append(texts)
            generation = prepare_output(
                generations=texts,
                input_tokens=statistics["input_tokens"],
                output_tokens=statistics["output_tokens"],
                logprobs=outputs_logprobs,
            )

            generations.append(generation)

    if sorted_indices is not None:
        pairs = list(enumerate(sorted_indices))
        pairs.sort(key=lambda x: x[1])
        generations = [generations[original_idx] for original_idx, _ in pairs]

    return generations
_prepare_structured_output(structured_output)

创建适当的函数来过滤 token,以生成结构化输出。

参数

名称 类型 描述 默认值
structured_output OutlinesStructuredOutputType

配置字典,用于准备结构化输出。

必需

返回

类型 描述
Union[Callable, None]

将用于指导模型生成的 callable。

源代码位于 src/distilabel/models/llms/vllm.py
def _prepare_structured_output(  # type: ignore
    self, structured_output: "OutlinesStructuredOutputType"
) -> Union[Callable, None]:
    """Creates the appropriate function to filter tokens to generate structured outputs.

    Args:
        structured_output: the configuration dict to prepare the structured output.

    Returns:
        The callable that will be used to guide the generation of the model.
    """
    from distilabel.steps.tasks.structured_outputs.outlines import (
        prepare_guided_output,
    )

    assert structured_output is not None, "`structured_output` cannot be `None`"

    result = prepare_guided_output(structured_output, "vllm", self._model)
    if (schema := result.get("schema")) and self.structured_output:
        self.structured_output["schema"] = schema
    return result["processor"]

CudaDevicePlacementMixin

基类:BaseModel

Mixin 类,用于根据 cuda_devices 属性和 _device_llm_placement_map 中提供的设备放置信息,将 CUDA 设备分配给 LLM。提供设备放置信息是可选的,但如果提供,它将用于将 CUDA 设备分配给 LLM,尽量避免为不同的 LLM 使用相同的设备。

属性

名称 类型 描述
cuda_devices RuntimeParameter[Union[List[int], Literal['auto']]]

一个列表,包含 LLM 要使用的 CUDA 设备的 ID。如果设置为“auto”,设备将根据 _device_llm_placement_map 中提供的设备放置信息自动分配。如果设置为设备列表,将检查这些设备是否可供 LLM 使用。如果不可用,将记录警告。

disable_cuda_device_placement RuntimeParameter[bool]

是否禁用 CUDA 设备放置逻辑。默认为 False

_llm_identifier Union[str, None]

LLM 的标识符,用作 _device_llm_placement_map 中的键。

_device_llm_placement_map Generator[Dict[str, List[int]], None, None]

一个字典,包含每个 LLM 的设备放置信息。

源代码位于 src/distilabel/models/mixins/cuda_device_placement.py
class CudaDevicePlacementMixin(BaseModel):
    """Mixin class to assign CUDA devices to the `LLM` based on the `cuda_devices` attribute
    and the device placement information provided in `_device_llm_placement_map`. Providing
    the device placement information is optional, but if it is provided, it will be used to
    assign CUDA devices to the `LLM`s, trying to avoid using the same device for different
    `LLM`s.

    Attributes:
        cuda_devices: a list with the ID of the CUDA devices to be used by the `LLM`. If set
            to "auto", the devices will be automatically assigned based on the device
            placement information provided in `_device_llm_placement_map`. If set to a list
            of devices, it will be checked if the devices are available to be used by the
            `LLM`. If not, a warning will be logged.
        disable_cuda_device_placement: Whether to disable the CUDA device placement logic
            or not. Defaults to `False`.
        _llm_identifier: the identifier of the `LLM` to be used as key in `_device_llm_placement_map`.
        _device_llm_placement_map: a dictionary with the device placement information for each
            `LLM`.
    """

    cuda_devices: RuntimeParameter[Union[List[int], Literal["auto"]]] = Field(
        default="auto", description="A list with the ID of the CUDA devices to be used."
    )
    disable_cuda_device_placement: RuntimeParameter[bool] = Field(
        default=False,
        description="Whether to disable the CUDA device placement logic or not.",
    )

    _llm_identifier: Union[str, None] = PrivateAttr(default=None)
    _desired_num_gpus: PositiveInt = PrivateAttr(default=1)
    _available_cuda_devices: List[int] = PrivateAttr(default_factory=list)
    _can_check_cuda_devices: bool = PrivateAttr(default=False)

    _logger: "Logger" = PrivateAttr(None)

    def load(self) -> None:
        """Assign CUDA devices to the LLM based on the device placement information provided
        in `_device_llm_placement_map`."""

        if self.disable_cuda_device_placement:
            return

        try:
            import pynvml

            pynvml.nvmlInit()
            device_count = pynvml.nvmlDeviceGetCount()
            self._available_cuda_devices = list(range(device_count))
            self._can_check_cuda_devices = True
        except ImportError as ie:
            if self.cuda_devices == "auto":
                raise ImportError(
                    "The 'pynvml' library is not installed. It is required to automatically"
                    " assign CUDA devices to the `LLM`s. Please, install it and try again."
                ) from ie

            if self.cuda_devices:
                self._logger.warning(  # type: ignore
                    "The 'pynvml' library is not installed. It is recommended to install it"
                    " to check if the CUDA devices assigned to the LLM are available."
                )

        self._assign_cuda_devices()

    def unload(self) -> None:
        """Unloads the LLM and removes the CUDA devices assigned to it from the device
        placement information provided in `_device_llm_placement_map`."""
        if self.disable_cuda_device_placement:
            return

        with self._device_llm_placement_map() as device_map:
            if self._llm_identifier in device_map:
                self._logger.debug(  # type: ignore
                    f"Removing '{self._llm_identifier}' from the CUDA device map file"
                    f" '{_CUDA_DEVICE_PLACEMENT_MIXIN_FILE}'."
                )
                del device_map[self._llm_identifier]

    @contextmanager
    def _device_llm_placement_map(self) -> Generator[Dict[str, List[int]], None, None]:
        """Reads the content of the device placement file of the node with a lock, yields
        the content, and writes the content back to the file after the context manager is
        closed. If the file doesn't exist, an empty dictionary will be yielded.

        Yields:
            The content of the device placement file.
        """
        _CUDA_DEVICE_PLACEMENT_MIXIN_FILE.parent.mkdir(parents=True, exist_ok=True)
        _CUDA_DEVICE_PLACEMENT_MIXIN_FILE.touch()
        with portalocker.Lock(
            _CUDA_DEVICE_PLACEMENT_MIXIN_FILE,
            "r+",
            flags=portalocker.LockFlags.EXCLUSIVE,
        ) as f:
            try:
                content = json.load(f)
            except json.JSONDecodeError:
                content = {}
            yield content
            f.seek(0)
            f.truncate()
            f.write(json.dumps(content))

    def _assign_cuda_devices(self) -> None:
        """Assigns CUDA devices to the LLM based on the device placement information provided
        in `_device_llm_placement_map`. If the `cuda_devices` attribute is set to "auto", it
        will be set to the first available CUDA device that is not going to be used by any
        other LLM. If the `cuda_devices` attribute is set to a list of devices, it will be
        checked if the devices are available to be used by the LLM. If not, a warning will be
        logged."""

        # Take the lock and read the device placement information for each LLM.
        with self._device_llm_placement_map() as device_map:
            if self.cuda_devices == "auto":
                self.cuda_devices = []
                for _ in range(self._desired_num_gpus):
                    if (device_id := self._get_cuda_device(device_map)) is not None:
                        self.cuda_devices.append(device_id)
                        device_map[self._llm_identifier] = self.cuda_devices  # type: ignore
                if len(self.cuda_devices) != self._desired_num_gpus:
                    self._logger.warning(  # type: ignore
                        f"Could not assign the desired number of GPUs {self._desired_num_gpus}"
                        f" for LLM with identifier '{self._llm_identifier}'."
                    )
            else:
                self._check_cuda_devices(device_map)

            device_map[self._llm_identifier] = self.cuda_devices  # type: ignore

        # `_device_llm_placement_map` was not provided and user didn't set the `cuda_devices`
        # attribute. In this case, the `cuda_devices` attribute will be set to an empty list.
        if self.cuda_devices == "auto":
            self.cuda_devices = []

        self._set_cuda_visible_devices()

    def _check_cuda_devices(self, device_map: Dict[str, List[int]]) -> None:
        """Checks if the CUDA devices assigned to the LLM are also assigned to other LLMs.

        Args:
            device_map: a dictionary with the device placement information for each LLM.
        """
        for device in self.cuda_devices:  # type: ignore
            for llm, devices in device_map.items():
                if device in devices:
                    self._logger.warning(  # type: ignore
                        f"LLM with identifier '{llm}' is also going to use CUDA device "
                        f"'{device}'. This may lead to performance issues or running out"
                        " of memory depending on the device capabilities and the loaded"
                        " models."
                    )

    def _get_cuda_device(self, device_map: Dict[str, List[int]]) -> Union[int, None]:
        """Returns the first available CUDA device to be used by the LLM that is not going
        to be used by any other LLM.

        Args:
            device_map: a dictionary with the device placement information for each LLM.

        Returns:
            The first available CUDA device to be used by the LLM.

        Raises:
            RuntimeError: if there is no available CUDA device to be used by the LLM.
        """
        for device in self._available_cuda_devices:
            if all(device not in devices for devices in device_map.values()):
                return device

        return None

    def _set_cuda_visible_devices(self) -> None:
        """Sets the `CUDA_VISIBLE_DEVICES` environment variable to the list of CUDA devices
        to be used by the LLM.
        """
        if not self.cuda_devices:
            return

        if self._can_check_cuda_devices and not all(
            device in self._available_cuda_devices for device in self.cuda_devices
        ):
            raise RuntimeError(
                f"Invalid CUDA devices for LLM '{self._llm_identifier}': {self.cuda_devices}."
                f" The available devices are: {self._available_cuda_devices}. Please, review"
                " the 'cuda_devices' attribute and try again."
            )

        cuda_devices = ",".join([str(device) for device in self.cuda_devices])
        self._logger.info(  # type: ignore
            f"🎮 LLM '{self._llm_identifier}' is going to use the following CUDA devices:"
            f" {self.cuda_devices}."
        )
        os.environ["CUDA_VISIBLE_DEVICES"] = cuda_devices
load()

根据 _device_llm_placement_map 中提供的设备放置信息,将 CUDA 设备分配给 LLM。

源代码位于 src/distilabel/models/mixins/cuda_device_placement.py
def load(self) -> None:
    """Assign CUDA devices to the LLM based on the device placement information provided
    in `_device_llm_placement_map`."""

    if self.disable_cuda_device_placement:
        return

    try:
        import pynvml

        pynvml.nvmlInit()
        device_count = pynvml.nvmlDeviceGetCount()
        self._available_cuda_devices = list(range(device_count))
        self._can_check_cuda_devices = True
    except ImportError as ie:
        if self.cuda_devices == "auto":
            raise ImportError(
                "The 'pynvml' library is not installed. It is required to automatically"
                " assign CUDA devices to the `LLM`s. Please, install it and try again."
            ) from ie

        if self.cuda_devices:
            self._logger.warning(  # type: ignore
                "The 'pynvml' library is not installed. It is recommended to install it"
                " to check if the CUDA devices assigned to the LLM are available."
            )

    self._assign_cuda_devices()
unload()

卸载 LLM 并从 _device_llm_placement_map 中提供的设备放置信息中删除分配给它的 CUDA 设备。

源代码位于 src/distilabel/models/mixins/cuda_device_placement.py
def unload(self) -> None:
    """Unloads the LLM and removes the CUDA devices assigned to it from the device
    placement information provided in `_device_llm_placement_map`."""
    if self.disable_cuda_device_placement:
        return

    with self._device_llm_placement_map() as device_map:
        if self._llm_identifier in device_map:
            self._logger.debug(  # type: ignore
                f"Removing '{self._llm_identifier}' from the CUDA device map file"
                f" '{_CUDA_DEVICE_PLACEMENT_MIXIN_FILE}'."
            )
            del device_map[self._llm_identifier]
_device_llm_placement_map()

使用锁读取节点的设备放置文件的内容,生成内容,并在上下文管理器关闭后将内容写回文件。如果文件不存在,将生成一个空字典。

产生

类型 描述
Dict[str, List[int]]

设备放置文件的内容。

源代码位于 src/distilabel/models/mixins/cuda_device_placement.py
@contextmanager
def _device_llm_placement_map(self) -> Generator[Dict[str, List[int]], None, None]:
    """Reads the content of the device placement file of the node with a lock, yields
    the content, and writes the content back to the file after the context manager is
    closed. If the file doesn't exist, an empty dictionary will be yielded.

    Yields:
        The content of the device placement file.
    """
    _CUDA_DEVICE_PLACEMENT_MIXIN_FILE.parent.mkdir(parents=True, exist_ok=True)
    _CUDA_DEVICE_PLACEMENT_MIXIN_FILE.touch()
    with portalocker.Lock(
        _CUDA_DEVICE_PLACEMENT_MIXIN_FILE,
        "r+",
        flags=portalocker.LockFlags.EXCLUSIVE,
    ) as f:
        try:
            content = json.load(f)
        except json.JSONDecodeError:
            content = {}
        yield content
        f.seek(0)
        f.truncate()
        f.write(json.dumps(content))
_assign_cuda_devices()

根据 _device_llm_placement_map 中提供的设备放置信息,将 CUDA 设备分配给 LLM。如果 cuda_devices 属性设置为“auto”,它将被设置为第一个可用的 CUDA 设备,该设备不会被任何其他 LLM 使用。如果 cuda_devices 属性设置为设备列表,将检查这些设备是否可供 LLM 使用。如果不可用,将记录警告。

源代码位于 src/distilabel/models/mixins/cuda_device_placement.py
def _assign_cuda_devices(self) -> None:
    """Assigns CUDA devices to the LLM based on the device placement information provided
    in `_device_llm_placement_map`. If the `cuda_devices` attribute is set to "auto", it
    will be set to the first available CUDA device that is not going to be used by any
    other LLM. If the `cuda_devices` attribute is set to a list of devices, it will be
    checked if the devices are available to be used by the LLM. If not, a warning will be
    logged."""

    # Take the lock and read the device placement information for each LLM.
    with self._device_llm_placement_map() as device_map:
        if self.cuda_devices == "auto":
            self.cuda_devices = []
            for _ in range(self._desired_num_gpus):
                if (device_id := self._get_cuda_device(device_map)) is not None:
                    self.cuda_devices.append(device_id)
                    device_map[self._llm_identifier] = self.cuda_devices  # type: ignore
            if len(self.cuda_devices) != self._desired_num_gpus:
                self._logger.warning(  # type: ignore
                    f"Could not assign the desired number of GPUs {self._desired_num_gpus}"
                    f" for LLM with identifier '{self._llm_identifier}'."
                )
        else:
            self._check_cuda_devices(device_map)

        device_map[self._llm_identifier] = self.cuda_devices  # type: ignore

    # `_device_llm_placement_map` was not provided and user didn't set the `cuda_devices`
    # attribute. In this case, the `cuda_devices` attribute will be set to an empty list.
    if self.cuda_devices == "auto":
        self.cuda_devices = []

    self._set_cuda_visible_devices()
_check_cuda_devices(device_map)

检查分配给 LLM 的 CUDA 设备是否也分配给了其他 LLM。

参数

名称 类型 描述 默认值
device_map Dict[str, List[int]]

一个字典,包含每个 LLM 的设备放置信息。

必需
源代码位于 src/distilabel/models/mixins/cuda_device_placement.py
def _check_cuda_devices(self, device_map: Dict[str, List[int]]) -> None:
    """Checks if the CUDA devices assigned to the LLM are also assigned to other LLMs.

    Args:
        device_map: a dictionary with the device placement information for each LLM.
    """
    for device in self.cuda_devices:  # type: ignore
        for llm, devices in device_map.items():
            if device in devices:
                self._logger.warning(  # type: ignore
                    f"LLM with identifier '{llm}' is also going to use CUDA device "
                    f"'{device}'. This may lead to performance issues or running out"
                    " of memory depending on the device capabilities and the loaded"
                    " models."
                )
_get_cuda_device(device_map)

返回第一个可供 LLM 使用的 CUDA 设备,该设备不会被任何其他 LLM 使用。

参数

名称 类型 描述 默认值
device_map Dict[str, List[int]]

一个字典,包含每个 LLM 的设备放置信息。

必需

返回

类型 描述
Union[int, None]

第一个可供 LLM 使用的 CUDA 设备。

引发

类型 描述
RuntimeError

如果没有可供 LLM 使用的 CUDA 设备。

源代码位于 src/distilabel/models/mixins/cuda_device_placement.py
def _get_cuda_device(self, device_map: Dict[str, List[int]]) -> Union[int, None]:
    """Returns the first available CUDA device to be used by the LLM that is not going
    to be used by any other LLM.

    Args:
        device_map: a dictionary with the device placement information for each LLM.

    Returns:
        The first available CUDA device to be used by the LLM.

    Raises:
        RuntimeError: if there is no available CUDA device to be used by the LLM.
    """
    for device in self._available_cuda_devices:
        if all(device not in devices for devices in device_map.values()):
            return device

    return None
_set_cuda_visible_devices()

CUDA_VISIBLE_DEVICES 环境变量设置为 LLM 要使用的 CUDA 设备列表。

源代码位于 src/distilabel/models/mixins/cuda_device_placement.py
def _set_cuda_visible_devices(self) -> None:
    """Sets the `CUDA_VISIBLE_DEVICES` environment variable to the list of CUDA devices
    to be used by the LLM.
    """
    if not self.cuda_devices:
        return

    if self._can_check_cuda_devices and not all(
        device in self._available_cuda_devices for device in self.cuda_devices
    ):
        raise RuntimeError(
            f"Invalid CUDA devices for LLM '{self._llm_identifier}': {self.cuda_devices}."
            f" The available devices are: {self._available_cuda_devices}. Please, review"
            " the 'cuda_devices' attribute and try again."
        )

    cuda_devices = ",".join([str(device) for device in self.cuda_devices])
    self._logger.info(  # type: ignore
        f"🎮 LLM '{self._llm_identifier}' is going to use the following CUDA devices:"
        f" {self.cuda_devices}."
    )
    os.environ["CUDA_VISIBLE_DEVICES"] = cuda_devices