跳到内容

Step

本节包含 distilabel 步骤的 API 参考,包括 _Step 基类和 Step 类。

有关如何使用现有步骤或创建自定义步骤的更多信息和示例,请参阅 教程 - 步骤

base

StepInput = Annotated[List[Dict[str, Any]], _STEP_INPUT_ANNOTATION] module-attribute

StepInput 只是 Annotated 别名,类型为 List[Dict[str, Any]],带有额外的元数据,允许 distilabel 对每个 Step 中定义的 process 步骤方法执行验证

_Step

基类: RuntimeParametersMixin, RequirementsMixin, SignatureMixin, BaseModel, _Serializable, ABC

可以包含在 Pipeline 中的步骤的基类。

Step 是一个定义某些处理逻辑的类。此处理逻辑的输入和输出是具有相同键的字典列表

```python
[
    {"column1": "value1", "column2": "value2", ...},
    {"column1": "value1", "column2": "value2", ...},
    {"column1": "value1", "column2": "value2", ...},
]
```

处理逻辑在 process 方法中定义,根据之前的步骤数量,可以接收多个字典列表,每个列表包含先前步骤的输出。为了使 distilabel 知道先前步骤的输出位置,每个 Stepprocess 函数必须具有一个参数或位置参数,并使用 StepInput 进行注释。

```python
class StepWithOnePreviousStep(Step):
    def process(self, inputs: StepInput) -> StepOutput:
        yield [...]

class StepWithSeveralPreviousStep(Step):
    # mind the * to indicate that the argument is a list of StepInput
    def process(self, *inputs: StepInput) -> StepOutput:
        yield [...]
```

为了执行静态验证并检查 Pipeline 中步骤的链接是否有效,Step 还必须定义 inputsoutputs 属性

  • inputs:字符串列表,其中包含步骤作为输入所需的列的名称。如果步骤是生成器步骤,则可以为空列表。
  • outputs:字符串列表,其中包含步骤将作为输出生成的列的名称。

可选地,Step 可以覆盖 load 方法,以在调用 process 方法之前执行任何初始化逻辑。例如,加载 LLM,建立与数据库的连接等。

最后,Step 类继承自 pydantic.BaseModel,因此可以轻松定义、验证、序列化属性,并将其包含在步骤的 __init__ 方法中。

源代码在 src/distilabel/steps/base.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
class _Step(
    RuntimeParametersMixin,
    RequirementsMixin,
    SignatureMixin,
    BaseModel,
    _Serializable,
    ABC,
):
    """Base class for the steps that can be included in a `Pipeline`.

    A `Step` is a class defining some processing logic. The input and outputs for this
    processing logic are lists of dictionaries with the same keys:

        ```python
        [
            {"column1": "value1", "column2": "value2", ...},
            {"column1": "value1", "column2": "value2", ...},
            {"column1": "value1", "column2": "value2", ...},
        ]
        ```

    The processing logic is defined in the `process` method, which depending on the
    number of previous steps, can receive more than one list of dictionaries, each with
    the output of the previous steps. In order to make `distilabel` know where the outputs
    from the previous steps are, the `process` function from each `Step` must have an argument
    or positional argument annotated with `StepInput`.

        ```python
        class StepWithOnePreviousStep(Step):
            def process(self, inputs: StepInput) -> StepOutput:
                yield [...]

        class StepWithSeveralPreviousStep(Step):
            # mind the * to indicate that the argument is a list of StepInput
            def process(self, *inputs: StepInput) -> StepOutput:
                yield [...]
        ```

    In order to perform static validations and to check that the chaining of the steps
    in the pipeline is valid, a `Step` must also define the `inputs` and `outputs`
    properties:

    - `inputs`: a list of strings with the names of the columns that the step needs as
        input. It can be an empty list if the step is a generator step.
    - `outputs`: a list of strings with the names of the columns that the step will
        produce as output.

    Optionally, a `Step` can override the `load` method to perform any initialization
    logic before the `process` method is called. For example, to load an LLM, stablish a
    connection to a database, etc.

    Finally, the `Step` class inherits from `pydantic.BaseModel`, so attributes can be easily
    defined, validated, serialized and included in the `__init__` method of the step.
    """

    model_config = ConfigDict(
        arbitrary_types_allowed=True,
        validate_default=True,
        validate_assignment=True,
        extra="forbid",
    )

    name: Optional[str] = Field(default=None, pattern=r"^[a-zA-Z0-9_-]+$")
    resources: StepResources = StepResources()
    pipeline: Any = Field(default=None, exclude=True, repr=False)
    input_mappings: Dict[str, str] = {}
    output_mappings: Dict[str, str] = {}
    use_cache: bool = True

    _pipeline_artifacts_path: Path = PrivateAttr(None)
    _built_from_decorator: bool = PrivateAttr(default=False)
    _logger: "Logger" = PrivateAttr(None)

    def model_post_init(self, __context: Any) -> None:
        from distilabel.pipeline.base import _GlobalPipelineManager

        super().model_post_init(__context)

        if self.pipeline is None:
            self.pipeline = _GlobalPipelineManager.get_pipeline()

        if self.pipeline is None:
            _logger = logging.getLogger(f"distilabel.step.{self.name}")
            _logger.warning(
                f"Step '{self.name}' hasn't received a pipeline, and it hasn't been"
                " created within a `Pipeline` context. Please, use"
                " `with Pipeline() as pipeline:` and create the step within the context."
            )

        if not self.name:
            # This must be done before the check for repeated names, but assuming
            # we are passing the pipeline from the _GlobalPipelineManager, should
            # be done after that.
            self.name = _infer_step_name(type(self).__name__, self.pipeline)

        if self.pipeline is not None:
            # If not set an error will be raised in `Pipeline.run` parent
            self.pipeline._add_step(self)

    def connect(
        self,
        *steps: "_Step",
        routing_batch_function: Optional["RoutingBatchFunction"] = None,
    ) -> None:
        """Connects the current step to another step in the pipeline, which means that
        the output of this step will be the input of the other step.

        Args:
            steps: The steps to connect to the current step.
            routing_batch_function: A function that receives a list of steps and returns
                a list of steps to which the output batch generated by this step should be
                routed. It should be used to define the routing logic of the pipeline. If
                not provided, the output batch will be routed to all the connected steps.
                Defaults to `None`.
        """
        assert self.pipeline is not None

        if routing_batch_function:
            self._set_routing_batch_function(routing_batch_function)

        for step in steps:
            self.pipeline._add_edge(from_step=self.name, to_step=step.name)  # type: ignore

    def _set_routing_batch_function(
        self, routing_batch_function: "RoutingBatchFunction"
    ) -> None:
        """Sets a routing batch function for the batches generated by this step, so they
        get routed to specific downstream steps.

        Args:
            routing_batch_function: The routing batch function that will be used to route
                the batches generated by this step.
        """
        self.pipeline._add_routing_batch_function(
            step_name=self.name,  # type: ignore
            routing_batch_function=routing_batch_function,
        )
        routing_batch_function._step = self

    @overload
    def __rshift__(self, other: "RoutingBatchFunction") -> "RoutingBatchFunction": ...

    @overload
    def __rshift__(
        self, other: List["DownstreamConnectableSteps"]
    ) -> List["DownstreamConnectableSteps"]: ...

    @overload
    def __rshift__(self, other: "DownstreamConnectable") -> "DownstreamConnectable": ...

    def __rshift__(
        self,
        other: Union[
            "DownstreamConnectable",
            "RoutingBatchFunction",
            List["DownstreamConnectableSteps"],
        ],
    ) -> Union[
        "DownstreamConnectable",
        "RoutingBatchFunction",
        List["DownstreamConnectableSteps"],
    ]:
        """Allows using the `>>` operator to connect steps in the pipeline.

        Args:
            other: The step to connect, a list of steps to connect to or a routing batch
                function to be set for the step.

        Returns:
            The connected step, the list of connected steps or the routing batch function.

        Example:
            ```python
            step1 >> step2
            # Would be equivalent to:
            step1.connect(step2)

            # It also allows to connect a list of steps
            step1 >> [step2, step3]
            ```
        """
        # Here to avoid circular imports
        from distilabel.pipeline.routing_batch_function import RoutingBatchFunction

        if isinstance(other, list):
            self.connect(*other)
            return other

        if isinstance(other, RoutingBatchFunction):
            self._set_routing_batch_function(other)
            return other

        self.connect(other)
        return other

    def __rrshift__(self, other: List["UpstreamConnectableSteps"]) -> Self:
        """Allows using the [step1, step2] >> step3 operator to connect a list of steps in the pipeline
        to a single step, as the list doesn't have the __rshift__ operator.

        Args:
            other: The step to connect to.

        Returns:
            The connected step

        Example:
            ```python
            [step2, step3] >> step1
            # Would be equivalent to:
            step2.connect(step1)
            step3.connect(step1)
            ```
        """
        for o in other:
            o.connect(self)
        return self

    def load(self) -> None:
        """Method to perform any initialization logic before the `process` method is
        called. For example, to load an LLM, stablish a connection to a database, etc.
        """
        self._logger = logging.getLogger(f"distilabel.step.{self.name}")

    def unload(self) -> None:
        """Method to perform any cleanup logic after the `process` method is called. For
        example, to close a connection to a database, etc.
        """
        self._logger.debug("Executing step unload logic.")

    @property
    def is_generator(self) -> bool:
        """Whether the step is a generator step or not.

        Returns:
            `True` if the step is a generator step, `False` otherwise.
        """
        return isinstance(self, GeneratorStep)

    @property
    def is_global(self) -> bool:
        """Whether the step is a global step or not.

        Returns:
            `True` if the step is a global step, `False` otherwise.
        """
        return isinstance(self, GlobalStep)

    @property
    def is_normal(self) -> bool:
        """Whether the step is a normal step or not.

        Returns:
            `True` if the step is a normal step, `False` otherwise.
        """
        return not self.is_generator and not self.is_global

    @property
    def inputs(self) -> "StepColumns":
        """List of strings with the names of the mandatory columns that the step needs as
        input or dictionary in which the keys are the input columns of the step and the
        values are booleans indicating whether the column is optional or not.

        Returns:
            List of strings with the names of the columns that the step needs as input.
        """
        return []

    @property
    def outputs(self) -> "StepColumns":
        """List of strings with the names of the columns that the step will produce as
        output or dictionary in which the keys are the output columns of the step and the
        values are booleans indicating whether the column is optional or not.

        Returns:
            List of strings with the names of the columns that the step will produce as
            output.
        """
        return []

    @cached_property
    def process_parameters(self) -> List[inspect.Parameter]:
        """Returns the parameters of the `process` method of the step.

        Returns:
            The parameters of the `process` method of the step.
        """
        return list(inspect.signature(self.process).parameters.values())  # type: ignore

    def has_multiple_inputs(self) -> bool:
        """Whether the `process` method of the step receives more than one input or not
        i.e. has a `*` argument annotated with `StepInput`.

        Returns:
            `True` if the `process` method of the step receives more than one input,
            `False` otherwise.
        """
        return any(
            param.kind == param.VAR_POSITIONAL for param in self.process_parameters
        )

    def get_process_step_input(self) -> Union[inspect.Parameter, None]:
        """Returns the parameter of the `process` method of the step annotated with
        `StepInput`.

        Returns:
            The parameter of the `process` method of the step annotated with `StepInput`,
            or `None` if there is no parameter annotated with `StepInput`.

        Raises:
            TypeError: If the step has more than one parameter annotated with `StepInput`.
        """
        step_input_parameter = None
        for parameter in self.process_parameters:
            if is_parameter_annotated_with(parameter, _STEP_INPUT_ANNOTATION):
                if step_input_parameter is not None:
                    raise DistilabelTypeError(
                        f"Step '{self.name}' should have only one parameter with type"
                        " hint `StepInput`.",
                        page="sections/how_to_guides/basic/step/#defining-custom-steps",
                    )
                step_input_parameter = parameter
        return step_input_parameter

    def verify_inputs_mappings(self) -> None:
        """Verifies that the `inputs_mappings` of the step are valid i.e. the input
        columns exist in the inputs of the step.

        Raises:
            ValueError: If the `inputs_mappings` of the step are not valid.
        """
        if not self.input_mappings:
            return

        for input in self.input_mappings:
            if input not in self.inputs:
                raise DistilabelUserError(
                    f"The input column '{input}' doesn't exist in the inputs of the"
                    f" step '{self.name}'. Inputs of the step are: {self.inputs}."
                    " Please, review the `inputs_mappings` argument of the step.",
                    page="sections/how_to_guides/basic/step/#arguments",
                )

    def verify_outputs_mappings(self) -> None:
        """Verifies that the `outputs_mappings` of the step are valid i.e. the output
        columns exist in the outputs of the step.

        Raises:
            ValueError: If the `outputs_mappings` of the step are not valid.
        """
        if not self.output_mappings:
            return

        for output in self.output_mappings:
            if output not in self.outputs:
                raise DistilabelUserError(
                    f"The output column '{output}' doesn't exist in the outputs of the"
                    f" step '{self.name}'. Outputs of the step are: {self.outputs}."
                    " Please, review the `outputs_mappings` argument of the step.",
                    page="sections/how_to_guides/basic/step/#arguments",
                )

    def get_inputs(self) -> Dict[str, bool]:
        """Gets the inputs of the step after the `input_mappings`. This method is meant
        to be used to run validations on the inputs of the step.

        Returns:
            The inputs of the step after the `input_mappings` and if they are required or
            not.
        """
        if isinstance(self.inputs, list):
            return {
                self.input_mappings.get(input, input): True for input in self.inputs
            }

        return {
            self.input_mappings.get(input, input): required
            for input, required in self.inputs.items()
        }

    def get_outputs(self) -> Dict[str, bool]:
        """Gets the outputs of the step after the `outputs_mappings`. This method is
        meant to be used to run validations on the outputs of the step.

        Returns:
            The outputs of the step after the `outputs_mappings` and if they are required
            or not.
        """
        if isinstance(self.outputs, list):
            return {
                self.output_mappings.get(output, output): True
                for output in self.outputs
            }

        return {
            self.output_mappings.get(output, output): required
            for output, required in self.outputs.items()
        }

    def set_pipeline_artifacts_path(self, path: Path) -> None:
        """Sets the `_pipeline_artifacts_path` attribute. This method is meant to be used
        by the `Pipeline` once the cache location is known.

        Args:
            path: the path where the artifacts generated by the pipeline steps should be
                saved.
        """
        self._pipeline_artifacts_path = path

    @property
    def artifacts_directory(self) -> Union[Path, None]:
        """Gets the path of the directory where the step should save its generated artifacts.

        Returns:
            The path of the directory where the step should save the generated artifacts,
                or `None` if `_pipeline_artifacts_path` is not set.
        """
        if self._pipeline_artifacts_path is None:
            return None
        return self._pipeline_artifacts_path / self.name  # type: ignore

    def save_artifact(
        self,
        name: str,
        write_function: Callable[[Path], None],
        metadata: Optional[Dict[str, Any]] = None,
    ) -> None:
        """Saves an artifact generated by the `Step`.

        Args:
            name: the name of the artifact.
            write_function: a function that will receive the path where the artifact should
                be saved.
            metadata: the artifact metadata. Defaults to `None`.
        """
        if self.artifacts_directory is None:
            self._logger.warning(
                f"Cannot save artifact with '{name}' as `_pipeline_artifacts_path` is not"
                " set. This is normal if the `Step` is being executed as a standalone component."
            )
            return

        artifact_directory_path = self.artifacts_directory / name
        artifact_directory_path.mkdir(parents=True, exist_ok=True)

        self._logger.info(f"🏺 Storing '{name}' generated artifact...")

        self._logger.debug(
            f"Calling `write_function` to write artifact in '{artifact_directory_path}'..."
        )
        write_function(artifact_directory_path)

        metadata_path = artifact_directory_path / "metadata.json"
        self._logger.debug(
            f"Calling `write_json` to write artifact metadata in '{metadata_path}'..."
        )
        write_json(filename=metadata_path, data=metadata or {})

    def impute_step_outputs(
        self, step_output: List[Dict[str, Any]]
    ) -> List[Dict[str, Any]]:
        """
        Imputes the output columns of the step that are not present in the step output.
        """
        result = []
        for row in step_output:
            data = row.copy()
            for output in self.get_outputs().keys():
                data[output] = None
            result.append(data)
        return result

    def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]:
        dump = super()._model_dump(obj, **kwargs)
        dump["runtime_parameters_info"] = self.get_runtime_parameters_info()
        return dump
is_generator property

步骤是否为生成器步骤。

返回

类型 描述
bool

如果步骤是生成器步骤,则为 True,否则为 False

is_global property

步骤是否为全局步骤。

返回

类型 描述
bool

如果步骤是全局步骤,则为 True,否则为 False

is_normal property

步骤是否为普通步骤。

返回

类型 描述
bool

如果步骤是普通步骤,则为 True,否则为 False

inputs property

字符串列表,其中包含步骤作为输入所需的强制列的名称,或者字典,其中键是步骤的输入列,值是布尔值,指示列是否为可选。

返回

类型 描述
StepColumns

字符串列表,其中包含步骤作为输入所需的列的名称。

outputs property

字符串列表,其中包含步骤将作为输出生成的列的名称,或者字典,其中键是步骤的输出列,值是布尔值,指示列是否为可选。

返回

类型 描述
StepColumns

字符串列表,其中包含步骤将生成的列的名称

StepColumns

输出。

process_parameters cached property

返回步骤的 process 方法的参数。

返回

类型 描述
List[Parameter]

步骤的 process 方法的参数。

artifacts_directory property

获取步骤应保存其生成的工件的目录路径。

返回

类型 描述
Union[Path, None]

步骤应保存生成的工件的目录路径,如果未设置 _pipeline_artifacts_path,则为 None

connect(*steps, routing_batch_function=None)

将当前步骤连接到 Pipeline 中的另一个步骤,这意味着此步骤的输出将成为另一个步骤的输入。

参数

名称 类型 描述 默认值
steps _Step

要连接到当前步骤的步骤。

()
routing_batch_function Optional[RoutingBatchFunction]

一个函数,接收步骤列表并返回应将此步骤生成的输出批次路由到的步骤列表。它应用于定义 Pipeline 的路由逻辑。如果未提供,输出批次将路由到所有连接的步骤。默认为 None

None
源代码在 src/distilabel/steps/base.py
def connect(
    self,
    *steps: "_Step",
    routing_batch_function: Optional["RoutingBatchFunction"] = None,
) -> None:
    """Connects the current step to another step in the pipeline, which means that
    the output of this step will be the input of the other step.

    Args:
        steps: The steps to connect to the current step.
        routing_batch_function: A function that receives a list of steps and returns
            a list of steps to which the output batch generated by this step should be
            routed. It should be used to define the routing logic of the pipeline. If
            not provided, the output batch will be routed to all the connected steps.
            Defaults to `None`.
    """
    assert self.pipeline is not None

    if routing_batch_function:
        self._set_routing_batch_function(routing_batch_function)

    for step in steps:
        self.pipeline._add_edge(from_step=self.name, to_step=step.name)  # type: ignore
__rshift__(other)
__rshift__(other: RoutingBatchFunction) -> RoutingBatchFunction
__rshift__(other: List[DownstreamConnectableSteps]) -> List[DownstreamConnectableSteps]
__rshift__(other: DownstreamConnectable) -> DownstreamConnectable

允许使用 >> 运算符连接 Pipeline 中的步骤。

参数

名称 类型 描述 默认值
other Union[DownstreamConnectable, RoutingBatchFunction, List[DownstreamConnectableSteps]]

要连接的步骤,要连接到的步骤列表或要为步骤设置的路由批次函数。

required

返回

类型 描述
Union[DownstreamConnectable, RoutingBatchFunction, List[DownstreamConnectableSteps]]

连接的步骤,连接的步骤列表或路由批次函数。

示例
step1 >> step2
# Would be equivalent to:
step1.connect(step2)

# It also allows to connect a list of steps
step1 >> [step2, step3]
源代码在 src/distilabel/steps/base.py
def __rshift__(
    self,
    other: Union[
        "DownstreamConnectable",
        "RoutingBatchFunction",
        List["DownstreamConnectableSteps"],
    ],
) -> Union[
    "DownstreamConnectable",
    "RoutingBatchFunction",
    List["DownstreamConnectableSteps"],
]:
    """Allows using the `>>` operator to connect steps in the pipeline.

    Args:
        other: The step to connect, a list of steps to connect to or a routing batch
            function to be set for the step.

    Returns:
        The connected step, the list of connected steps or the routing batch function.

    Example:
        ```python
        step1 >> step2
        # Would be equivalent to:
        step1.connect(step2)

        # It also allows to connect a list of steps
        step1 >> [step2, step3]
        ```
    """
    # Here to avoid circular imports
    from distilabel.pipeline.routing_batch_function import RoutingBatchFunction

    if isinstance(other, list):
        self.connect(*other)
        return other

    if isinstance(other, RoutingBatchFunction):
        self._set_routing_batch_function(other)
        return other

    self.connect(other)
    return other
__rrshift__(other)

允许使用 [step1, step2] >> step3 运算符将 Pipeline 中的步骤列表连接到单个步骤,因为列表没有 rshift 运算符。

参数

名称 类型 描述 默认值
other List[UpstreamConnectableSteps]

要连接到的步骤。

required

返回

类型 描述
Self

连接的步骤

示例
[step2, step3] >> step1
# Would be equivalent to:
step2.connect(step1)
step3.connect(step1)
源代码在 src/distilabel/steps/base.py
def __rrshift__(self, other: List["UpstreamConnectableSteps"]) -> Self:
    """Allows using the [step1, step2] >> step3 operator to connect a list of steps in the pipeline
    to a single step, as the list doesn't have the __rshift__ operator.

    Args:
        other: The step to connect to.

    Returns:
        The connected step

    Example:
        ```python
        [step2, step3] >> step1
        # Would be equivalent to:
        step2.connect(step1)
        step3.connect(step1)
        ```
    """
    for o in other:
        o.connect(self)
    return self
load()

在调用 process 方法之前执行任何初始化逻辑的方法。例如,加载 LLM,建立与数据库的连接等。

源代码在 src/distilabel/steps/base.py
def load(self) -> None:
    """Method to perform any initialization logic before the `process` method is
    called. For example, to load an LLM, stablish a connection to a database, etc.
    """
    self._logger = logging.getLogger(f"distilabel.step.{self.name}")
unload()

在调用 process 方法之后执行任何清理逻辑的方法。例如,关闭与数据库的连接等。

源代码在 src/distilabel/steps/base.py
def unload(self) -> None:
    """Method to perform any cleanup logic after the `process` method is called. For
    example, to close a connection to a database, etc.
    """
    self._logger.debug("Executing step unload logic.")
has_multiple_inputs()

步骤的 process 方法是否接收多个输入,即是否具有使用 StepInput 注释的 * 参数。

返回

类型 描述
bool

如果步骤的 process 方法接收多个输入,则为 True

bool

否则为 False

源代码在 src/distilabel/steps/base.py
def has_multiple_inputs(self) -> bool:
    """Whether the `process` method of the step receives more than one input or not
    i.e. has a `*` argument annotated with `StepInput`.

    Returns:
        `True` if the `process` method of the step receives more than one input,
        `False` otherwise.
    """
    return any(
        param.kind == param.VAR_POSITIONAL for param in self.process_parameters
    )
get_process_step_input()

返回使用 StepInput 注释的步骤的 process 方法的参数。

返回

类型 描述
Union[Parameter, None]

使用 StepInput 注释的步骤的 process 方法的参数,

Union[Parameter, None]

如果没有使用 StepInput 注释的参数,则为 None

引发

类型 描述
TypeError

如果步骤有多个使用 StepInput 注释的参数。

源代码在 src/distilabel/steps/base.py
def get_process_step_input(self) -> Union[inspect.Parameter, None]:
    """Returns the parameter of the `process` method of the step annotated with
    `StepInput`.

    Returns:
        The parameter of the `process` method of the step annotated with `StepInput`,
        or `None` if there is no parameter annotated with `StepInput`.

    Raises:
        TypeError: If the step has more than one parameter annotated with `StepInput`.
    """
    step_input_parameter = None
    for parameter in self.process_parameters:
        if is_parameter_annotated_with(parameter, _STEP_INPUT_ANNOTATION):
            if step_input_parameter is not None:
                raise DistilabelTypeError(
                    f"Step '{self.name}' should have only one parameter with type"
                    " hint `StepInput`.",
                    page="sections/how_to_guides/basic/step/#defining-custom-steps",
                )
            step_input_parameter = parameter
    return step_input_parameter
verify_inputs_mappings()

验证步骤的 inputs_mappings 是否有效,即输入列是否存在于步骤的输入中。

引发

类型 描述
ValueError

如果步骤的 inputs_mappings 无效。

源代码在 src/distilabel/steps/base.py
def verify_inputs_mappings(self) -> None:
    """Verifies that the `inputs_mappings` of the step are valid i.e. the input
    columns exist in the inputs of the step.

    Raises:
        ValueError: If the `inputs_mappings` of the step are not valid.
    """
    if not self.input_mappings:
        return

    for input in self.input_mappings:
        if input not in self.inputs:
            raise DistilabelUserError(
                f"The input column '{input}' doesn't exist in the inputs of the"
                f" step '{self.name}'. Inputs of the step are: {self.inputs}."
                " Please, review the `inputs_mappings` argument of the step.",
                page="sections/how_to_guides/basic/step/#arguments",
            )
verify_outputs_mappings()

验证步骤的 outputs_mappings 是否有效,即输出列是否存在于步骤的输出中。

引发

类型 描述
ValueError

如果步骤的 outputs_mappings 无效。

源代码在 src/distilabel/steps/base.py
def verify_outputs_mappings(self) -> None:
    """Verifies that the `outputs_mappings` of the step are valid i.e. the output
    columns exist in the outputs of the step.

    Raises:
        ValueError: If the `outputs_mappings` of the step are not valid.
    """
    if not self.output_mappings:
        return

    for output in self.output_mappings:
        if output not in self.outputs:
            raise DistilabelUserError(
                f"The output column '{output}' doesn't exist in the outputs of the"
                f" step '{self.name}'. Outputs of the step are: {self.outputs}."
                " Please, review the `outputs_mappings` argument of the step.",
                page="sections/how_to_guides/basic/step/#arguments",
            )
get_inputs()

获取应用 input_mappings 后的步骤的输入。此方法旨在用于对步骤的输入运行验证。

返回

类型 描述
Dict[str, bool]

应用 input_mappings 后的步骤的输入以及它们是否是必需的

Dict[str, bool]

不是。

源代码在 src/distilabel/steps/base.py
def get_inputs(self) -> Dict[str, bool]:
    """Gets the inputs of the step after the `input_mappings`. This method is meant
    to be used to run validations on the inputs of the step.

    Returns:
        The inputs of the step after the `input_mappings` and if they are required or
        not.
    """
    if isinstance(self.inputs, list):
        return {
            self.input_mappings.get(input, input): True for input in self.inputs
        }

    return {
        self.input_mappings.get(input, input): required
        for input, required in self.inputs.items()
    }
get_outputs()

获取应用 outputs_mappings 后的步骤的输出。此方法旨在用于对步骤的输出运行验证。

返回

类型 描述
Dict[str, bool]

应用 outputs_mappings 后的步骤的输出以及它们是否是必需的

Dict[str, bool]

不是。

源代码在 src/distilabel/steps/base.py
def get_outputs(self) -> Dict[str, bool]:
    """Gets the outputs of the step after the `outputs_mappings`. This method is
    meant to be used to run validations on the outputs of the step.

    Returns:
        The outputs of the step after the `outputs_mappings` and if they are required
        or not.
    """
    if isinstance(self.outputs, list):
        return {
            self.output_mappings.get(output, output): True
            for output in self.outputs
        }

    return {
        self.output_mappings.get(output, output): required
        for output, required in self.outputs.items()
    }
set_pipeline_artifacts_path(path)

设置 _pipeline_artifacts_path 属性。此方法旨在由 Pipeline 在知道缓存位置后使用。

参数

名称 类型 描述 默认值
path Path

Pipeline 步骤生成的工件应保存的路径。

required
源代码在 src/distilabel/steps/base.py
def set_pipeline_artifacts_path(self, path: Path) -> None:
    """Sets the `_pipeline_artifacts_path` attribute. This method is meant to be used
    by the `Pipeline` once the cache location is known.

    Args:
        path: the path where the artifacts generated by the pipeline steps should be
            saved.
    """
    self._pipeline_artifacts_path = path
save_artifact(name, write_function, metadata=None)

保存 Step 生成的工件。

参数

名称 类型 描述 默认值
name str

工件的名称。

required
write_function Callable[[Path], None]

一个函数,它将接收应保存工件的路径。

required
metadata Optional[Dict[str, Any]]

工件元数据。默认为 None

None
源代码在 src/distilabel/steps/base.py
def save_artifact(
    self,
    name: str,
    write_function: Callable[[Path], None],
    metadata: Optional[Dict[str, Any]] = None,
) -> None:
    """Saves an artifact generated by the `Step`.

    Args:
        name: the name of the artifact.
        write_function: a function that will receive the path where the artifact should
            be saved.
        metadata: the artifact metadata. Defaults to `None`.
    """
    if self.artifacts_directory is None:
        self._logger.warning(
            f"Cannot save artifact with '{name}' as `_pipeline_artifacts_path` is not"
            " set. This is normal if the `Step` is being executed as a standalone component."
        )
        return

    artifact_directory_path = self.artifacts_directory / name
    artifact_directory_path.mkdir(parents=True, exist_ok=True)

    self._logger.info(f"🏺 Storing '{name}' generated artifact...")

    self._logger.debug(
        f"Calling `write_function` to write artifact in '{artifact_directory_path}'..."
    )
    write_function(artifact_directory_path)

    metadata_path = artifact_directory_path / "metadata.json"
    self._logger.debug(
        f"Calling `write_json` to write artifact metadata in '{metadata_path}'..."
    )
    write_json(filename=metadata_path, data=metadata or {})
impute_step_outputs(step_output)

推算步骤输出中不存在的步骤输出列。

源代码在 src/distilabel/steps/base.py
def impute_step_outputs(
    self, step_output: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
    """
    Imputes the output columns of the step that are not present in the step output.
    """
    result = []
    for row in step_output:
        data = row.copy()
        for output in self.get_outputs().keys():
            data[output] = None
        result.append(data)
    return result

Step

基类: _Step, ABC

可以包含在 Pipeline 中的步骤的基类。

属性

名称 类型 描述
input_batch_size RuntimeParameter[PositiveInt]

步骤处理的批次将包含的行数。默认为 50

运行时参数
  • input_batch_size:步骤处理的批次将包含的行数。默认为 50
源代码在 src/distilabel/steps/base.py
class Step(_Step, ABC):
    """Base class for the steps that can be included in a `Pipeline`.

    Attributes:
        input_batch_size: The number of rows that will contain the batches processed by
            the step. Defaults to `50`.

    Runtime parameters:
        - `input_batch_size`: The number of rows that will contain the batches processed
            by the step. Defaults to `50`.
    """

    input_batch_size: RuntimeParameter[PositiveInt] = Field(
        default=DEFAULT_INPUT_BATCH_SIZE,
        description="The number of rows that will contain the batches processed by the"
        " step.",
    )

    @abstractmethod
    def process(self, *inputs: StepInput) -> "StepOutput":
        """Method that defines the processing logic of the step. It should yield the
        output rows.

        Args:
            *inputs: An argument used to receive the outputs of the previous steps. The
                number of arguments depends on the number of previous steps. It doesn't
                need to be an `*args` argument, it can be a regular argument annotated
                with `StepInput` if the step has only one previous step.
        """
        pass

    def process_applying_mappings(self, *args: List[Dict[str, Any]]) -> "StepOutput":
        """Runs the `process` method of the step applying the `input_mappings` to the input
        rows and the `outputs_mappings` to the output rows. This is the function that
        should be used to run the processing logic of the step.

        Yields:
            The output rows.
        """

        inputs, overriden_inputs = (
            self._apply_input_mappings(args)
            if self.input_mappings
            else (args, [{} for _ in range(len(args[0]))])
        )

        # If the `Step` was built using the `@step` decorator, then we need to pass
        # the runtime parameters as kwargs, so they can be used within the processing
        # function
        generator = (
            self.process(*inputs)
            if not self._built_from_decorator
            else self.process(*inputs, **self._runtime_parameters)
        )

        for output_rows in generator:
            restored = []
            for i, row in enumerate(output_rows):
                # Correct the index here because we don't know the num_generations from the llm
                # ahead of time. For example, if we have `len(overriden_inputs)==5` and `len(row)==10`,
                # from `num_generations==2` and `group_generations=False` in the LLM:
                # The loop will use indices 0, 1, 2, 3, 4, 0, 1, 2, 3, 4
                ntimes_i = i % len(overriden_inputs)
                restored.append(
                    self._apply_mappings_and_restore_overriden(
                        row, overriden_inputs[ntimes_i]
                    )
                )
            yield restored

    def _apply_input_mappings(
        self, inputs: Tuple[List[Dict[str, Any]], ...]
    ) -> Tuple[Tuple[List[Dict[str, Any]], ...], List[Dict[str, Any]]]:
        """Applies the `input_mappings` to the input rows.

        Args:
            inputs: The input rows.

        Returns:
            The input rows with the `input_mappings` applied and the overriden values
                that were replaced by the `input_mappings`.
        """
        reverted_input_mappings = {v: k for k, v in self.input_mappings.items()}

        renamed_inputs = []
        overriden_inputs = []
        for i, row_inputs in enumerate(inputs):
            renamed_row_inputs = []
            for row in row_inputs:
                overriden_keys = {}
                renamed_row = {}
                for k, v in row.items():
                    renamed_key = reverted_input_mappings.get(k, k)

                    if renamed_key not in renamed_row or k != renamed_key:
                        renamed_row[renamed_key] = v

                        if k != renamed_key and renamed_key in row and len(inputs) == 1:
                            overriden_keys[renamed_key] = row[renamed_key]

                if i == 0:
                    overriden_inputs.append(overriden_keys)
                renamed_row_inputs.append(renamed_row)
            renamed_inputs.append(renamed_row_inputs)
        return tuple(renamed_inputs), overriden_inputs

    def _apply_mappings_and_restore_overriden(
        self, row: Dict[str, Any], overriden: Dict[str, Any]
    ) -> Dict[str, Any]:
        """Reverts the `input_mappings` applied to the input rows and applies the `output_mappings`
        to the output rows. In addition, it restores the overriden values that were replaced
        by the `input_mappings`.

        Args:
            row: The output row.
            overriden: The overriden values that were replaced by the `input_mappings`.

        Returns:
            The output row with the `output_mappings` applied and the overriden values
            restored.
        """
        result = {}
        for k, v in row.items():
            mapped_key = (
                self.output_mappings.get(k, None)
                or self.input_mappings.get(k, None)
                or k
            )
            result[mapped_key] = v

        # Restore overriden values
        for k, v in overriden.items():
            if k not in result:
                result[k] = v

        return result
process(*inputs) abstractmethod

定义步骤处理逻辑的方法。它应生成输出行。

参数

名称 类型 描述 默认值
*inputs StepInput

用于接收先前步骤输出的参数。参数的数量取决于先前步骤的数量。它不需要是 *args 参数,如果步骤只有一个先前步骤,则可以是使用 StepInput 注释的常规参数。

()
源代码在 src/distilabel/steps/base.py
@abstractmethod
def process(self, *inputs: StepInput) -> "StepOutput":
    """Method that defines the processing logic of the step. It should yield the
    output rows.

    Args:
        *inputs: An argument used to receive the outputs of the previous steps. The
            number of arguments depends on the number of previous steps. It doesn't
            need to be an `*args` argument, it can be a regular argument annotated
            with `StepInput` if the step has only one previous step.
    """
    pass
process_applying_mappings(*args)

运行步骤的 process 方法,将 input_mappings 应用于输入行,并将 outputs_mappings 应用于输出行。此函数应用于运行步骤的处理逻辑。

Yields

类型 描述
StepOutput

输出行。

源代码在 src/distilabel/steps/base.py
def process_applying_mappings(self, *args: List[Dict[str, Any]]) -> "StepOutput":
    """Runs the `process` method of the step applying the `input_mappings` to the input
    rows and the `outputs_mappings` to the output rows. This is the function that
    should be used to run the processing logic of the step.

    Yields:
        The output rows.
    """

    inputs, overriden_inputs = (
        self._apply_input_mappings(args)
        if self.input_mappings
        else (args, [{} for _ in range(len(args[0]))])
    )

    # If the `Step` was built using the `@step` decorator, then we need to pass
    # the runtime parameters as kwargs, so they can be used within the processing
    # function
    generator = (
        self.process(*inputs)
        if not self._built_from_decorator
        else self.process(*inputs, **self._runtime_parameters)
    )

    for output_rows in generator:
        restored = []
        for i, row in enumerate(output_rows):
            # Correct the index here because we don't know the num_generations from the llm
            # ahead of time. For example, if we have `len(overriden_inputs)==5` and `len(row)==10`,
            # from `num_generations==2` and `group_generations=False` in the LLM:
            # The loop will use indices 0, 1, 2, 3, 4, 0, 1, 2, 3, 4
            ntimes_i = i % len(overriden_inputs)
            restored.append(
                self._apply_mappings_and_restore_overriden(
                    row, overriden_inputs[ntimes_i]
                )
            )
        yield restored