DeitaFiltering¶
使用 DEITA 过滤策略过滤数据集行。
根据 DEITA 分数和 embeddings 之间的余弦距离过滤数据集。 这是论文 'What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning' 中过滤步骤的实现。
属性¶
-
data_budget: 过滤后数据集的期望大小。
-
diversity_threshold: 如果某行与其最近邻居的余弦距离大于此值,则将其包含在过滤后的数据集中。 默认为
0.9
。 -
normalize_embeddings: 是否在计算余弦距离之前对 embeddings 进行归一化。 默认为
True
。
运行时参数¶
-
data_budget: 过滤后数据集的期望大小。
-
diversity_threshold: 如果某行与其最近邻居的余弦距离大于此值,则将其包含在过滤后的数据集中。
输入 & 输出列¶
graph TD
subgraph Dataset
subgraph Columns
ICOL0[evol_instruction_score]
ICOL1[evol_response_score]
ICOL2[embedding]
end
subgraph New columns
OCOL0[deita_score]
OCOL1[deita_score_computed_with]
OCOL2[nearest_neighbor_distance]
end
end
subgraph DeitaFiltering
StepInput[Input Columns: evol_instruction_score, evol_response_score, embedding]
StepOutput[Output Columns: deita_score, deita_score_computed_with, nearest_neighbor_distance]
end
ICOL0 --> StepInput
ICOL1 --> StepInput
ICOL2 --> StepInput
StepOutput --> OCOL0
StepOutput --> OCOL1
StepOutput --> OCOL2
StepInput --> StepOutput
输入¶
-
evol_instruction_score (
float
): 由 `ComplexityScorer` step 生成的指令的分数。 -
evol_response_score (
float
): 由 `QualityScorer` step 生成的响应的分数。 -
embedding (
List[float]
): 使用 `GenerateEmbeddings` step 为指令-响应对的对话生成的 embedding。
输出¶
-
deita_score (
float
): 指令-响应对的 DEITA 分数。 -
deita_score_computed_with (
List[str]
): 用于计算 DEITA 分数的分数。 -
nearest_neighbor_distance (
float
): 指令-响应对的 embeddings 之间的余弦距离。
示例¶
根据 DEITA 分数和 embeddings 之间的余弦距离过滤数据集¶
from distilabel.steps import DeitaFiltering
deita_filtering = DeitaFiltering(data_budget=1)
deita_filtering.load()
result = next(
deita_filtering.process(
[
{
"evol_instruction_score": 0.5,
"evol_response_score": 0.5,
"embedding": [-8.12729941, -5.24642847, -6.34003029],
},
{
"evol_instruction_score": 0.6,
"evol_response_score": 0.6,
"embedding": [2.99329242, 0.7800932, 0.7799726],
},
{
"evol_instruction_score": 0.7,
"evol_response_score": 0.7,
"embedding": [10.29041806, 14.33088073, 13.00557506],
},
],
)
)
# >>> result
# [{'evol_instruction_score': 0.5, 'evol_response_score': 0.5, 'embedding': [-8.12729941, -5.24642847, -6.34003029], 'deita_score': 0.25, 'deita_score_computed_with': ['evol_instruction_score', 'evol_response_score'], 'nearest_neighbor_distance': 1.9042812683723933}]