跳到内容

RuntimeParametersMixin

RuntimeParametersMixin

基类: BaseModel

对于具有 RuntimeParameter 属性的类的 Mixin。

属性

名称 类型 描述
_runtime_parameters Dict[str, Any]

一个字典,包含类的运行时参数的值。此属性旨在内部使用,不应直接访问。

源代码位于 src/distilabel/mixins/runtime_parameters.py
class RuntimeParametersMixin(BaseModel):
    """Mixin for classes that have `RuntimeParameter`s attributes.

    Attributes:
        _runtime_parameters: A dictionary containing the values of the runtime parameters
            of the class. This attribute is meant to be used internally and should not be
            accessed directly.
    """

    _runtime_parameters: Dict[str, Any] = PrivateAttr(default_factory=dict)

    @property
    def runtime_parameters_names(self) -> "RuntimeParametersNames":
        """Returns a dictionary containing the name of the runtime parameters of the class
        as keys and whether the parameter is required or not as values.

        Returns:
            A dictionary containing the name of the runtime parameters of the class as keys
            and whether the parameter is required or not as values.
        """

        runtime_parameters = {}

        for name, field_info in self.model_fields.items():  # type: ignore
            # `field: RuntimeParameter[Any]` or `field: Optional[RuntimeParameter[Any]]`
            is_runtime_param, is_optional = _is_runtime_parameter(field_info)
            if is_runtime_param:
                runtime_parameters[name] = is_optional
                continue

            attr = getattr(self, name)

            # `field: RuntimeParametersMixin`
            if isinstance(attr, RuntimeParametersMixin):
                runtime_parameters[name] = attr.runtime_parameters_names

            # `field: List[RuntimeParametersMixin]`
            if (
                isinstance(attr, list)
                and attr
                and isinstance(attr[0], RuntimeParametersMixin)
            ):
                runtime_parameters[name] = {
                    str(i): item.runtime_parameters_names for i, item in enumerate(attr)
                }

        return runtime_parameters

    def get_runtime_parameters_info(self) -> List["RuntimeParameterInfo"]:
        """Gets the information of the runtime parameters of the class such as the name and
        the description. This function is meant to include the information of the runtime
        parameters in the serialized data of the class.

        Returns:
            A list containing the information for each runtime parameter of the class.
        """
        runtime_parameters_info = []
        for name, field_info in self.model_fields.items():  # type: ignore
            if name not in self.runtime_parameters_names:
                continue

            attr = getattr(self, name)

            # Get runtime parameters info for `RuntimeParametersMixin` field
            if isinstance(attr, RuntimeParametersMixin):
                runtime_parameters_info.append(
                    {
                        "name": name,
                        "runtime_parameters_info": attr.get_runtime_parameters_info(),
                    }
                )
                continue

            # Get runtime parameters info for `List[RuntimeParametersMixin]` field
            if isinstance(attr, list) and isinstance(attr[0], RuntimeParametersMixin):
                runtime_parameters_info.append(
                    {
                        "name": name,
                        "runtime_parameters_info": {
                            str(i): item.get_runtime_parameters_info()
                            for i, item in enumerate(attr)
                        },
                    }
                )
                continue

            info = {"name": name, "optional": self.runtime_parameters_names[name]}
            if field_info.description is not None:
                info["description"] = field_info.description
            runtime_parameters_info.append(info)
        return runtime_parameters_info

    def set_runtime_parameters(self, runtime_parameters: Dict[str, Any]) -> None:
        """Sets the runtime parameters of the class using the provided values. If the attr
        to be set is a `RuntimeParametersMixin`, it will call `set_runtime_parameters` on
        the attr.

        Args:
            runtime_parameters: A dictionary containing the values of the runtime parameters
                to set.
        """
        runtime_parameters_names = list(self.runtime_parameters_names.keys())
        for name, value in runtime_parameters.items():
            if name not in self.runtime_parameters_names:
                # Check done just to ensure the unit tests for the mixin run
                if getattr(self, "pipeline", None):
                    closest = difflib.get_close_matches(
                        name, runtime_parameters_names, cutoff=0.5
                    )
                    msg = (
                        f"⚠️  Runtime parameter '{name}' unknown in step '{self.name}'."  # type: ignore
                    )
                    if closest:
                        msg += f" Did you mean any of: {closest}"
                    else:
                        msg += f" Available runtime parameters for the step: {runtime_parameters_names}."
                    self.pipeline._logger.warning(msg)  # type: ignore
                continue

            attr = getattr(self, name)

            # Set runtime parameters for `RuntimeParametersMixin` field
            if isinstance(attr, RuntimeParametersMixin):
                attr.set_runtime_parameters(value)
                self._runtime_parameters[name] = value
                continue

            # Set runtime parameters for `List[RuntimeParametersMixin]` field
            if isinstance(attr, list) and isinstance(attr[0], RuntimeParametersMixin):
                for i, item in enumerate(attr):
                    item_value = value.get(str(i), {})
                    item.set_runtime_parameters(item_value)
                self._runtime_parameters[name] = value
                continue

            # Handle settings values for `_SecretField`
            field_info = self.model_fields[name]
            inner_type = extract_annotation_inner_type(field_info.annotation)
            if is_type_pydantic_secret_field(inner_type):
                value = inner_type(value)

            # Set the value of the runtime parameter
            setattr(self, name, value)
            self._runtime_parameters[name] = value

runtime_parameters_names property

返回一个字典,其中包含类的运行时参数的名称作为键,以及参数是否为必需的值。

返回

类型 描述
RuntimeParametersNames

一个字典,其中包含类的运行时参数的名称作为键

RuntimeParametersNames

以及参数是否为必需的值。

get_runtime_parameters_info()

获取类的运行时参数的信息,例如名称和描述。此函数旨在将运行时参数的信息包含在类的序列化数据中。

返回

类型 描述
List[RuntimeParameterInfo]

一个列表,其中包含类的每个运行时参数的信息。

源代码位于 src/distilabel/mixins/runtime_parameters.py
def get_runtime_parameters_info(self) -> List["RuntimeParameterInfo"]:
    """Gets the information of the runtime parameters of the class such as the name and
    the description. This function is meant to include the information of the runtime
    parameters in the serialized data of the class.

    Returns:
        A list containing the information for each runtime parameter of the class.
    """
    runtime_parameters_info = []
    for name, field_info in self.model_fields.items():  # type: ignore
        if name not in self.runtime_parameters_names:
            continue

        attr = getattr(self, name)

        # Get runtime parameters info for `RuntimeParametersMixin` field
        if isinstance(attr, RuntimeParametersMixin):
            runtime_parameters_info.append(
                {
                    "name": name,
                    "runtime_parameters_info": attr.get_runtime_parameters_info(),
                }
            )
            continue

        # Get runtime parameters info for `List[RuntimeParametersMixin]` field
        if isinstance(attr, list) and isinstance(attr[0], RuntimeParametersMixin):
            runtime_parameters_info.append(
                {
                    "name": name,
                    "runtime_parameters_info": {
                        str(i): item.get_runtime_parameters_info()
                        for i, item in enumerate(attr)
                    },
                }
            )
            continue

        info = {"name": name, "optional": self.runtime_parameters_names[name]}
        if field_info.description is not None:
            info["description"] = field_info.description
        runtime_parameters_info.append(info)
    return runtime_parameters_info

set_runtime_parameters(runtime_parameters)

使用提供的值设置类的运行时参数。如果要设置的属性是 RuntimeParametersMixin,它将对该属性调用 set_runtime_parameters

参数

名称 类型 描述 默认值
runtime_parameters Dict[str, Any]

一个字典,其中包含要设置的运行时参数的值。

必需
源代码位于 src/distilabel/mixins/runtime_parameters.py
def set_runtime_parameters(self, runtime_parameters: Dict[str, Any]) -> None:
    """Sets the runtime parameters of the class using the provided values. If the attr
    to be set is a `RuntimeParametersMixin`, it will call `set_runtime_parameters` on
    the attr.

    Args:
        runtime_parameters: A dictionary containing the values of the runtime parameters
            to set.
    """
    runtime_parameters_names = list(self.runtime_parameters_names.keys())
    for name, value in runtime_parameters.items():
        if name not in self.runtime_parameters_names:
            # Check done just to ensure the unit tests for the mixin run
            if getattr(self, "pipeline", None):
                closest = difflib.get_close_matches(
                    name, runtime_parameters_names, cutoff=0.5
                )
                msg = (
                    f"⚠️  Runtime parameter '{name}' unknown in step '{self.name}'."  # type: ignore
                )
                if closest:
                    msg += f" Did you mean any of: {closest}"
                else:
                    msg += f" Available runtime parameters for the step: {runtime_parameters_names}."
                self.pipeline._logger.warning(msg)  # type: ignore
            continue

        attr = getattr(self, name)

        # Set runtime parameters for `RuntimeParametersMixin` field
        if isinstance(attr, RuntimeParametersMixin):
            attr.set_runtime_parameters(value)
            self._runtime_parameters[name] = value
            continue

        # Set runtime parameters for `List[RuntimeParametersMixin]` field
        if isinstance(attr, list) and isinstance(attr[0], RuntimeParametersMixin):
            for i, item in enumerate(attr):
                item_value = value.get(str(i), {})
                item.set_runtime_parameters(item_value)
            self._runtime_parameters[name] = value
            continue

        # Handle settings values for `_SecretField`
        field_info = self.model_fields[name]
        inner_type = extract_annotation_inner_type(field_info.annotation)
        if is_type_pydantic_secret_field(inner_type):
            value = inner_type(value)

        # Set the value of the runtime parameter
        setattr(self, name, value)
        self._runtime_parameters[name] = value