跳到内容

Embedding

本节包含 distilabel embedding 的 API 参考。

有关 Embeddings 如何工作以及一些示例的更多信息。

base

Embeddings

基类: RuntimeParametersMixin, BaseModel, _Serializable, ABC

Embeddings 模型的基础类。

要实现 Embeddings 子类,你需要继承此类并实现: - load 方法来加载 Embeddings 模型。不要忘记调用 super().load(),以便初始化 _logger 属性。 - model_name 属性来返回用于 Embeddings 的模型名称。 - encode 方法来生成句子 embedding。

属性

名称 类型 描述
_logger Logger

用于 Embeddings 模型的 logger。它将在调用 load 方法时初始化。

源代码位于 src/distilabel/models/embeddings/base.py
class Embeddings(RuntimeParametersMixin, BaseModel, _Serializable, ABC):
    """Base class for `Embeddings` models.

    To implement an `Embeddings` subclass, you need to subclass this class and implement:
        - `load` method to load the `Embeddings` model. Don't forget to call `super().load()`,
            so the `_logger` attribute is initialized.
        - `model_name` property to return the model name used for the `Embeddings`.
        - `encode` method to generate the sentence embeddings.

    Attributes:
        _logger: the logger to be used for the `Embeddings` model. It will be initialized
            when the `load` method is called.
    """

    model_config = ConfigDict(
        arbitrary_types_allowed=True,
        protected_namespaces=(),
        validate_default=True,
        validate_assignment=True,
        extra="forbid",
    )
    _logger: "Logger" = PrivateAttr(None)

    def load(self) -> None:
        """Method to be called to initialize the `Embeddings`"""
        self._logger = logging.getLogger(
            f"distilabel.models.embeddings.{self.model_name}"
        )

    def unload(self) -> None:
        """Method to be called to unload the `Embeddings` and release any resources."""
        pass

    @property
    @abstractmethod
    def model_name(self) -> str:
        """Returns the model name used for the `Embeddings`."""
        pass

    @abstractmethod
    def encode(self, inputs: List[str]) -> List[List[Union[int, float]]]:
        """Generates embeddings for the provided inputs.

        Args:
            inputs: a list of texts for which an embedding has to be generated.

        Returns:
            The generated embeddings.
        """
        pass
model_name abstractmethod property

返回用于 Embeddings 的模型名称。

load()

要调用的方法以初始化 Embeddings

源代码位于 src/distilabel/models/embeddings/base.py
def load(self) -> None:
    """Method to be called to initialize the `Embeddings`"""
    self._logger = logging.getLogger(
        f"distilabel.models.embeddings.{self.model_name}"
    )
unload()

要调用的方法以卸载 Embeddings 并释放任何资源。

源代码位于 src/distilabel/models/embeddings/base.py
def unload(self) -> None:
    """Method to be called to unload the `Embeddings` and release any resources."""
    pass
encode(inputs) abstractmethod

为提供的输入生成 embedding。

参数

名称 类型 描述 默认
inputs List[str]

要为其生成 embedding 的文本列表。

必需

返回

类型 描述
List[List[Union[int, float]]]

生成的 embedding。

源代码位于 src/distilabel/models/embeddings/base.py
@abstractmethod
def encode(self, inputs: List[str]) -> List[List[Union[int, float]]]:
    """Generates embeddings for the provided inputs.

    Args:
        inputs: a list of texts for which an embedding has to be generated.

    Returns:
        The generated embeddings.
    """
    pass