跳到内容

SentenceTransformerEmbeddings

用于嵌入生成的 sentence-transformers 库实现。

属性

  • model: 模型 Hugging Face Hub 仓库 ID 或包含模型权重和配置文件的目录路径。

  • device: 用于加载模型的设备名称,例如 "cuda"、"mps" 等。默认为 None

  • prompts: 包含模型要使用的提示的字典。默认为 None

  • default_prompt_name: 将应用于输入的默认提示(在 prompts 中)。如果未提供,则不使用任何提示。默认为 None

  • trust_remote_code: 是否允许获取和执行从 Hub 仓库获取的远程代码。默认为 False

  • revision: 如果 model 指的是 Hugging Face Hub 仓库,则要使用的修订版本(例如分支名称或提交 ID)。默认为 "main"

  • token: 将用于向 Hugging Face Hub 进行身份验证的 Hugging Face Hub 令牌。如果未提供,将使用 HF_TOKEN 环境变量或 huggingface_hub 包本地配置。默认为 None

  • truncate_dim: 截断句子嵌入的维度。默认为 None

  • model_kwargs: 将传递给 Hugging Face transformers 模型类的额外 kwargs。默认为 None

  • tokenizer_kwargs: 将传递给 Hugging Face transformers 分词器类的额外 kwargs。默认为 None

  • config_kwargs: 将传递给 Hugging Face transformers 配置类的额外 kwargs。默认为 None

  • precision: 结果嵌入将具有的 dtype。默认为 "float32"

  • normalize_embeddings: 是否标准化嵌入,使其长度为 1。默认为 None

示例

生成句子嵌入

from distilabel.models import SentenceTransformerEmbeddings

embeddings = SentenceTransformerEmbeddings(model="mixedbread-ai/mxbai-embed-large-v1")

embeddings.load()

results = embeddings.encode(inputs=["distilabel is awesome!", "and Argilla!"])
# [
#   [-0.05447685346007347, -0.01623094454407692, ...],
#   [4.4889533455716446e-05, 0.044016145169734955, ...],
# ]