ArgillaLabeller¶
基于输入字段、示例记录和问题设置注释 Argilla 记录。
此任务旨在通过利用预训练的 LLM 来促进 Argilla 记录的注释。它使用系统提示来引导 LLM 理解输入字段、问题类型和问题设置。然后,该任务格式化输入数据并根据问题生成响应。响应会根据问题的值模型进行验证,并准备好最终的建议以进行注释。
属性¶
- _template: 用于格式化 LLM 输入的 Jinja2 模板。
输入 & 输出列¶
graph TD
subgraph Dataset
subgraph Columns
ICOL0[record]
ICOL1[fields]
ICOL2[question]
ICOL3[example_records]
ICOL4[guidelines]
end
subgraph New columns
OCOL0[suggestion]
end
end
subgraph ArgillaLabeller
StepInput[Input Columns: record, fields, question, example_records, guidelines]
StepOutput[Output Columns: suggestion]
end
ICOL0 --> StepInput
ICOL1 --> StepInput
ICOL2 --> StepInput
ICOL3 --> StepInput
ICOL4 --> StepInput
StepOutput --> OCOL0
StepInput --> StepOutput
输入¶
-
record (
argilla.Record
): 要注释的记录。 -
fields (
Optional[List[Dict[str, Any]]]
): 输入字段的字段设置列表。 -
question (
Optional[Dict[str, Any]]
): 要回答的问题的问题设置。 -
example_records (
Optional[List[Dict[str, Any]]]
): 少量示例记录,其中包含用于回答问题的响应。 -
guidelines (
Optional[str]
): 注释任务的指南。
输出¶
- suggestion (
Dict[str, Any]
): 用于注释的最终建议。
示例¶
使用相同的数据集和问题注释记录¶
import argilla as rg
from argilla import Suggestion
from distilabel.steps.tasks import ArgillaLabeller
from distilabel.models import InferenceEndpointsLLM
# Get information from Argilla dataset definition
dataset = rg.Dataset("my_dataset")
pending_records_filter = rg.Filter(("status", "==", "pending"))
completed_records_filter = rg.Filter(("status", "==", "completed"))
pending_records = list(
dataset.records(
query=rg.Query(filter=pending_records_filter),
limit=5,
)
)
example_records = list(
dataset.records(
query=rg.Query(filter=completed_records_filter),
limit=5,
)
)
field = dataset.settings.fields["text"]
question = dataset.settings.questions["label"]
# Initialize the labeller with the model and fields
labeller = ArgillaLabeller(
llm=InferenceEndpointsLLM(
model_id="mistralai/Mistral-7B-Instruct-v0.2",
),
fields=[field],
question=question,
example_records=example_records,
guidelines=dataset.guidelines
)
labeller.load()
# Process the pending records
result = next(
labeller.process(
[
{
"record": record
} for record in pending_records
]
)
)
# Add the suggestions to the records
for record, suggestion in zip(pending_records, result):
record.suggestions.add(Suggestion(**suggestion["suggestion"]))
# Log the updated records
dataset.records.log(pending_records)
使用交替的数据集和问题注释记录¶
import argilla as rg
from distilabel.steps.tasks import ArgillaLabeller
from distilabel.models import InferenceEndpointsLLM
# Get information from Argilla dataset definition
dataset = rg.Dataset("my_dataset")
field = dataset.settings.fields["text"]
question = dataset.settings.questions["label"]
question2 = dataset.settings.questions["label2"]
# Initialize the labeller with the model and fields
labeller = ArgillaLabeller(
llm=InferenceEndpointsLLM(
model_id="mistralai/Mistral-7B-Instruct-v0.2",
)
)
labeller.load()
# Process the record
record = next(dataset.records())
result = next(
labeller.process(
[
{
"record": record,
"fields": [field],
"question": question,
},
{
"record": record,
"fields": [field],
"question": question2,
}
]
)
)
# Add the suggestions to the record
for suggestion in result:
record.suggestions.add(rg.Suggestion(**suggestion["suggestion"]))
# Log the updated record
dataset.records.log([record])
覆盖默认提示和说明¶
import argilla as rg
from distilabel.steps.tasks import ArgillaLabeller
from distilabel.models import InferenceEndpointsLLM
# Overwrite default prompts and instructions
labeller = ArgillaLabeller(
llm=InferenceEndpointsLLM(
model_id="mistralai/Mistral-7B-Instruct-v0.2",
),
system_prompt="You are an expert annotator and labelling assistant that understands complex domains and natural language processing.",
question_to_label_instruction={
"label_selection": "Select the appropriate label from the list of provided labels.",
"multi_label_selection": "Select none, one or multiple labels from the list of provided labels.",
"text": "Provide a text response to the question.",
"rating": "Provide a rating for the question.",
},
)
labeller.load()