使用 LLM 执行任务¶
使用 LLM 工作¶
LLM 子类被设计为在 Task 中使用,但它们也可以独立使用。
from distilabel.models import InferenceEndpointsLLM
llm = InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct"
)
llm.load()
llm.generate_outputs(
inputs=[
[{"role": "user", "content": "What's the capital of Spain?"}],
],
)
# [
# {
# "generations": [
# "The capital of Spain is Madrid."
# ],
# "statistics": {
# "input_tokens": [
# 43
# ],
# "output_tokens": [
# 8
# ]
# }
# }
# ]
注意
当独立使用 LLM 或作为 Task
的一部分使用 LLM 时,始终调用 LLM.load
或 Task.load
方法。如果使用 Pipeline
,这会在 Pipeline.run()
中自动完成。
1.5.0 版本新增功能
自 1.5.0
版本起,LLM 输出是一个字典列表(inputs
中的每个条目一个字典),每个字典都包含 generations
,其中报告了 LLM
返回的文本,以及一个 statistics
字段,用于存储与 LLM
生成相关的统计信息。最初,这将包括 input_tokens
和 output_tokens
(如果可用),这些信息将通过 API 获取(如果可用),或者如果使用的模型有 tokenizer,则使用该模型的 tokenizer 获取。此数据将在 pipeline 处理期间由相应的 Task
移动,并移动到 distilabel_metadata
,以便我们可以在需要时对此数据进行操作,例如计算每个数据集的 token 数量。
要访问之前的结果,只需访问结果字典中的 generations:result[0]["generations"]
。
离线批量生成¶
默认情况下,所有 LLM
都将以同步方式生成文本,即使用 generate_outputs
方法发送输入,该方法将被阻塞,直到生成输出。有些 LLM
(例如 OpenAILLM)实现了我们称之为离线批量生成的功能,该功能允许将输入发送到 LLM 即服务平台,该平台将异步生成输出,并为我们提供一个 job id,我们可以稍后使用该 id 来检查状态并在输出准备就绪时检索生成的输出。LLM 即服务平台提供此功能是为了节省成本,以换取等待输出生成的时间。
要在 distilabel
中使用此功能,我们唯一需要做的就是在创建 LLM
实例时将 use_offline_batch_generation
属性设置为 True
from distilabel.models import OpenAILLM
llm = OpenAILLM(
model="gpt-4o",
use_offline_batch_generation=True,
)
llm.load()
llm.jobs_ids # (1)
# None
llm.generate_outputs( # (2)
inputs=[
[{"role": "user", "content": "What's the capital of Spain?"}],
],
)
# DistilabelOfflineBatchGenerationNotFinishedException: Batch generation with jobs_ids=('batch_OGB4VjKpu2ay9nz3iiFJxt5H',) is not finished
llm.jobs_ids # (3)
# ('batch_OGB4VjKpu2ay9nz3iiFJxt5H',)
llm.generate_outputs( # (4)
inputs=[
[{"role": "user", "content": "What's the capital of Spain?"}],
],
)
# [{'generations': ['The capital of Spain is Madrid.'],
# 'statistics': {'input_tokens': [13], 'output_tokens': [7]}}]
- 最初,
jobs_ids
属性为None
。 - 首次调用
generate_outputs
会将输入发送到 LLM 即服务平台,并返回DistilabelOfflineBatchGenerationNotFinishedException
,因为输出尚未准备就绪。 - 首次调用
generate_outputs
后,jobs_ids
属性将包含为生成输出而创建的 job id。 - 第二次调用或后续调用
generate_outputs
将在输出准备就绪时返回输出,否则如果输出尚未准备就绪,则会引发DistilabelOfflineBatchGenerationNotFinishedException
。
offline_batch_generation_block_until_done
属性可用于阻止 generate_outputs
方法,直到输出准备就绪并轮询平台指定的秒数。
from distilabel.models import OpenAILLM
llm = OpenAILLM(
model="gpt-4o",
use_offline_batch_generation=True,
offline_batch_generation_block_until_done=5, # poll for results every 5 seconds
)
llm.load()
llm.generate_outputs(
inputs=[
[{"role": "user", "content": "What's the capital of Spain?"}],
],
)
# [{'generations': ['The capital of Spain is Madrid.'],
# 'statistics': {'input_tokens': [13], 'output_tokens': [7]}}]
在 Task 中¶
将 LLM 作为参数传递给 Task
,task 将处理其余部分。
from distilabel.models import OpenAILLM
from distilabel.steps.tasks import TextGeneration
llm = OpenAILLM(model="gpt-4o-mini")
task = TextGeneration(name="text_generation", llm=llm)
task.load()
next(task.process(inputs=[{"instruction": "What's the capital of Spain?"}]))
# [{'instruction': "What's the capital of Spain?",
# 'generation': 'The capital of Spain is Madrid.',
# 'distilabel_metadata': {'raw_output_text_generation': 'The capital of Spain is Madrid.',
# 'raw_input_text_generation': [{'role': 'user',
# 'content': "What's the capital of Spain?"}],
# 'statistics_text_generation': {'input_tokens': 13, 'output_tokens': 7}},
# 'model_name': 'gpt-4o-mini'}]
注意
正如使用 LLM 工作部分中提到的,LLM 的生成会自动移动到 distilabel_metadata
,以避免干扰常见的工作流程,因此添加 statistics
是为用户提供的额外组件,但定义的 pipeline 中无需进行任何更改。
运行时参数¶
LLM 可以具有运行时参数,例如 generation_kwargs
,通过使用 params
参数的 Pipeline.run()
方法提供。
注意
运行时参数在不同的 LLM 子类之间可能有所不同,这是由于 LLM 提供商提供的不同功能造成的。
from distilabel.pipeline import Pipeline
from distilabel.models import OpenAILLM
from distilabel.steps import LoadDataFromDicts
from distilabel.steps.tasks import TextGeneration
with Pipeline(name="text-generation-pipeline") as pipeline:
load_dataset = LoadDataFromDicts(
name="load_dataset",
data=[{"instruction": "Write a short story about a dragon that saves a princess from a tower."}],
)
text_generation = TextGeneration(
name="text_generation",
llm=OpenAILLM(model="gpt-4o-mini"),
)
load_dataset >> text_generation
if __name__ == "__main__":
pipeline.run(
parameters={
text_generation.name: {"llm": {"generation_kwargs": {"temperature": 0.3}}},
},
)
创建自定义 LLM¶
要创建自定义 LLM,请为同步 LLM 子类化 LLM
,或为异步 LLM 子类化 AsyncLLM
。实现以下方法
-
model_name
:包含模型名称的属性。 -
generate
:一个接受提示列表并返回生成文本的方法。 -
agenerate
:一个接受单个提示并返回生成文本的方法。此方法在AsyncLLM
类的generate
方法中使用。 -
(可选)
get_last_hidden_state
:是一个接受提示列表并返回隐藏状态列表的方法。此方法是可选的,将由某些 task 使用,例如GenerateEmbeddings
task。
from typing import Any
from pydantic import validate_call
from distilabel.models import LLM
from distilabel.typing import GenerateOutput, HiddenState
from distilabel.typing import ChatType
class CustomLLM(LLM):
@property
def model_name(self) -> str:
return "my-model"
@validate_call
def generate(self, inputs: List[ChatType], num_generations: int = 1, **kwargs: Any) -> List[GenerateOutput]:
for _ in range(num_generations):
...
def get_last_hidden_state(self, inputs: List[ChatType]) -> List[HiddenState]:
...
from typing import Any
from pydantic import validate_call
from distilabel.models import AsyncLLM
from distilabel.typing import GenerateOutput, HiddenState
from distilabel.typing import ChatType
class CustomAsyncLLM(AsyncLLM):
@property
def model_name(self) -> str:
return "my-model"
@validate_call
async def agenerate(self, input: ChatType, num_generations: int = 1, **kwargs: Any) -> GenerateOutput:
for _ in range(num_generations):
...
def get_last_hidden_state(self, inputs: List[ChatType]) -> List[HiddenState]:
...
generate
和 agenerate
关键字参数(但 input
和 num_generations
除外)被视为 RuntimeParameter
,因此可以通过 Pipeline.run
方法的 parameters
参数将值传递给它们。
注意
为了使 generate
和 agenerate
的参数被强制转换为预期类型,使用了 validate_call
装饰器,它将自动将参数强制转换为预期类型,并在类型不正确时引发错误。当从 CLI 为 generate
或 agenerate
的参数提供值时,这尤其有用,因为 CLI 始终会将参数作为字符串提供。
警告
在 distilabel
中创建的其他 LLM 将必须考虑如何生成 statistics
,以便正确地将其包含在 LLM 输出中。
可用的 LLM¶
我们的 LLM 库 显示了可在 distilabel
库中使用的可用 LLM 列表。