FaissNearestNeighbour¶
创建 faiss
索引以获取最近邻。
FaissNearestNeighbour
是一个 GlobalStep
,它使用 Hugging Face datasets
库集成创建一个 faiss
索引,然后获取每个输入行的最近邻和最近邻的分数或距离。
属性¶
-
device:要使用的 CUDA 设备 ID 或 ID 列表。如果为负整数,将使用所有可用的 GPU。默认为
None
。 -
string_factory:用于构建
faiss
索引的工厂名称。可用的字符串工厂可以在这里查看:https://github.com/facebookresearch/faiss/wiki/Faiss-indexes。默认为None
。 -
metric_type:用于测量点之间距离的度量类型。它是一个整数,推荐的传递方式是导入
faiss
,然后传递faiss.METRIC_x
变量之一。默认为None
。 -
k:为每个输入行搜索的最近邻数量。默认为
1
。 -
search_batch_size:搜索批次中包含的行数。可以调整该值以最大化资源使用或避免 OOM 问题。默认为
50
。 -
train_size:如果索引需要训练步骤,则指定将使用多少向量来训练索引。
运行时参数¶
-
device:要使用的 CUDA 设备 ID 或 ID 列表。如果为负整数,将使用所有可用的 GPU。默认为
None
。 -
string_factory:用于构建
faiss
索引的工厂名称。可用的字符串工厂可以在这里查看:https://github.com/facebookresearch/faiss/wiki/Faiss-indexes。默认为None
。 -
metric_type:用于测量点之间距离的度量类型。它是一个整数,推荐的传递方式是导入
faiss
,然后传递faiss.METRIC_x
变量之一。默认为None
。 -
k:为每个输入行搜索的最近邻数量。默认为
1
。 -
search_batch_size:搜索批次中包含的行数。可以调整该值以最大化资源使用或避免 OOM 问题。默认为
50
。 -
train_size:如果索引需要训练步骤,则指定将使用多少向量来训练索引。
输入和输出列¶
graph TD
subgraph Dataset
subgraph Columns
ICOL0[embedding]
end
subgraph New columns
OCOL0[nn_indices]
OCOL1[nn_scores]
end
end
subgraph FaissNearestNeighbour
StepInput[Input Columns: embedding]
StepOutput[Output Columns: nn_indices, nn_scores]
end
ICOL0 --> StepInput
StepOutput --> OCOL0
StepOutput --> OCOL1
StepInput --> StepOutput
输入¶
- embedding (
List[Union[float, int]]
): 句子嵌入。
输出¶
-
nn_indices (
List[int]
): 包含该行输入中k
个最近邻索引的列表。 -
nn_scores (
List[float]
): 包含到输入中每个k
个最近邻的分数或距离的列表。
示例¶
生成嵌入并获取最近邻¶
from distilabel.models import SentenceTransformerEmbeddings
from distilabel.pipeline import Pipeline
from distilabel.steps import EmbeddingGeneration, FaissNearestNeighbour, LoadDataFromHub
with Pipeline(name="hello") as pipeline:
load_data = LoadDataFromHub(output_mappings={"prompt": "text"})
embeddings = EmbeddingGeneration(
embeddings=SentenceTransformerEmbeddings(
model="mixedbread-ai/mxbai-embed-large-v1"
)
)
nearest_neighbours = FaissNearestNeighbour()
load_data >> embeddings >> nearest_neighbours
if __name__ == "__main__":
distiset = pipeline.run(
parameters={
load_data.name: {
"repo_id": "distilabel-internal-testing/instruction-dataset-mini",
"split": "test",
},
},
use_cache=False,
)