基类: RuntimeParametersMixin
, BaseModel
, _Serializable
, ABC
Embeddings
模型的基础类。
要实现 Embeddings
子类,你需要继承此类并实现: - load
方法来加载 Embeddings
模型。不要忘记调用 super().load()
,以便初始化 _logger
属性。 - model_name
属性来返回用于 Embeddings
的模型名称。 - encode
方法来生成句子 embedding。
属性
名称 |
类型 |
描述 |
_logger |
Logger
|
用于 Embeddings 模型的 logger。它将在调用 load 方法时初始化。
|
源代码位于 src/distilabel/models/embeddings/base.py
| class Embeddings(RuntimeParametersMixin, BaseModel, _Serializable, ABC):
"""Base class for `Embeddings` models.
To implement an `Embeddings` subclass, you need to subclass this class and implement:
- `load` method to load the `Embeddings` model. Don't forget to call `super().load()`,
so the `_logger` attribute is initialized.
- `model_name` property to return the model name used for the `Embeddings`.
- `encode` method to generate the sentence embeddings.
Attributes:
_logger: the logger to be used for the `Embeddings` model. It will be initialized
when the `load` method is called.
"""
model_config = ConfigDict(
arbitrary_types_allowed=True,
protected_namespaces=(),
validate_default=True,
validate_assignment=True,
extra="forbid",
)
_logger: "Logger" = PrivateAttr(None)
def load(self) -> None:
"""Method to be called to initialize the `Embeddings`"""
self._logger = logging.getLogger(
f"distilabel.models.embeddings.{self.model_name}"
)
def unload(self) -> None:
"""Method to be called to unload the `Embeddings` and release any resources."""
pass
@property
@abstractmethod
def model_name(self) -> str:
"""Returns the model name used for the `Embeddings`."""
pass
@abstractmethod
def encode(self, inputs: List[str]) -> List[List[Union[int, float]]]:
"""Generates embeddings for the provided inputs.
Args:
inputs: a list of texts for which an embedding has to be generated.
Returns:
The generated embeddings.
"""
pass
|
model_name
abstractmethod
property
load()
要调用的方法以初始化 Embeddings
源代码位于 src/distilabel/models/embeddings/base.py
| def load(self) -> None:
"""Method to be called to initialize the `Embeddings`"""
self._logger = logging.getLogger(
f"distilabel.models.embeddings.{self.model_name}"
)
|
unload()
要调用的方法以卸载 Embeddings
并释放任何资源。
源代码位于 src/distilabel/models/embeddings/base.py
| def unload(self) -> None:
"""Method to be called to unload the `Embeddings` and release any resources."""
pass
|
encode(inputs)
abstractmethod
为提供的输入生成 embedding。
参数
名称 |
类型 |
描述 |
默认 |
inputs
|
List[str]
|
|
必需
|
返回
类型 |
描述 |
List[List[Union[int, float]]]
|
|
源代码位于 src/distilabel/models/embeddings/base.py
| @abstractmethod
def encode(self, inputs: List[str]) -> List[List[Union[int, float]]]:
"""Generates embeddings for the provided inputs.
Args:
inputs: a list of texts for which an embedding has to be generated.
Returns:
The generated embeddings.
"""
pass
|