跳到内容

ArgillaLabeller

基于输入字段、示例记录和问题设置注释 Argilla 记录。

此任务旨在通过利用预训练的 LLM 来促进 Argilla 记录的注释。它使用系统提示来引导 LLM 理解输入字段、问题类型和问题设置。然后,该任务格式化输入数据并根据问题生成响应。响应会根据问题的值模型进行验证,并准备好最终的建议以进行注释。

属性

  • _template: 用于格式化 LLM 输入的 Jinja2 模板。

输入 & 输出列

graph TD
    subgraph Dataset
        subgraph Columns
            ICOL0[record]
            ICOL1[fields]
            ICOL2[question]
            ICOL3[example_records]
            ICOL4[guidelines]
        end
        subgraph New columns
            OCOL0[suggestion]
        end
    end

    subgraph ArgillaLabeller
        StepInput[Input Columns: record, fields, question, example_records, guidelines]
        StepOutput[Output Columns: suggestion]
    end

    ICOL0 --> StepInput
    ICOL1 --> StepInput
    ICOL2 --> StepInput
    ICOL3 --> StepInput
    ICOL4 --> StepInput
    StepOutput --> OCOL0
    StepInput --> StepOutput

输入

  • record (argilla.Record): 要注释的记录。

  • fields (Optional[List[Dict[str, Any]]]): 输入字段的字段设置列表。

  • question (Optional[Dict[str, Any]]): 要回答的问题的问题设置。

  • example_records (Optional[List[Dict[str, Any]]]): 少量示例记录,其中包含用于回答问题的响应。

  • guidelines (Optional[str]): 注释任务的指南。

输出

  • suggestion (Dict[str, Any]): 用于注释的最终建议。

示例

使用相同的数据集和问题注释记录

import argilla as rg
from argilla import Suggestion
from distilabel.steps.tasks import ArgillaLabeller
from distilabel.models import InferenceEndpointsLLM

# Get information from Argilla dataset definition
dataset = rg.Dataset("my_dataset")
pending_records_filter = rg.Filter(("status", "==", "pending"))
completed_records_filter = rg.Filter(("status", "==", "completed"))
pending_records = list(
    dataset.records(
        query=rg.Query(filter=pending_records_filter),
        limit=5,
    )
)
example_records = list(
    dataset.records(
        query=rg.Query(filter=completed_records_filter),
        limit=5,
    )
)
field = dataset.settings.fields["text"]
question = dataset.settings.questions["label"]

# Initialize the labeller with the model and fields
labeller = ArgillaLabeller(
    llm=InferenceEndpointsLLM(
        model_id="mistralai/Mistral-7B-Instruct-v0.2",
    ),
    fields=[field],
    question=question,
    example_records=example_records,
    guidelines=dataset.guidelines
)
labeller.load()

# Process the pending records
result = next(
    labeller.process(
        [
            {
                "record": record
            } for record in pending_records
        ]
    )
)

# Add the suggestions to the records
for record, suggestion in zip(pending_records, result):
    record.suggestions.add(Suggestion(**suggestion["suggestion"]))

# Log the updated records
dataset.records.log(pending_records)

使用交替的数据集和问题注释记录

import argilla as rg
from distilabel.steps.tasks import ArgillaLabeller
from distilabel.models import InferenceEndpointsLLM

# Get information from Argilla dataset definition
dataset = rg.Dataset("my_dataset")
field = dataset.settings.fields["text"]
question = dataset.settings.questions["label"]
question2 = dataset.settings.questions["label2"]

# Initialize the labeller with the model and fields
labeller = ArgillaLabeller(
    llm=InferenceEndpointsLLM(
        model_id="mistralai/Mistral-7B-Instruct-v0.2",
    )
)
labeller.load()

# Process the record
record = next(dataset.records())
result = next(
    labeller.process(
        [
            {
                "record": record,
                "fields": [field],
                "question": question,
            },
            {
                "record": record,
                "fields": [field],
                "question": question2,
            }
        ]
    )
)

# Add the suggestions to the record
for suggestion in result:
    record.suggestions.add(rg.Suggestion(**suggestion["suggestion"]))

# Log the updated record
dataset.records.log([record])

覆盖默认提示和说明

import argilla as rg
from distilabel.steps.tasks import ArgillaLabeller
from distilabel.models import InferenceEndpointsLLM

# Overwrite default prompts and instructions
labeller = ArgillaLabeller(
    llm=InferenceEndpointsLLM(
        model_id="mistralai/Mistral-7B-Instruct-v0.2",
    ),
    system_prompt="You are an expert annotator and labelling assistant that understands complex domains and natural language processing.",
    question_to_label_instruction={
        "label_selection": "Select the appropriate label from the list of provided labels.",
        "multi_label_selection": "Select none, one or multiple labels from the list of provided labels.",
        "text": "Provide a text response to the question.",
        "rating": "Provide a rating for the question.",
    },
)
labeller.load()

参考