class RuntimeParametersMixin(BaseModel):
"""Mixin for classes that have `RuntimeParameter`s attributes.
Attributes:
_runtime_parameters: A dictionary containing the values of the runtime parameters
of the class. This attribute is meant to be used internally and should not be
accessed directly.
"""
_runtime_parameters: Dict[str, Any] = PrivateAttr(default_factory=dict)
@property
def runtime_parameters_names(self) -> "RuntimeParametersNames":
"""Returns a dictionary containing the name of the runtime parameters of the class
as keys and whether the parameter is required or not as values.
Returns:
A dictionary containing the name of the runtime parameters of the class as keys
and whether the parameter is required or not as values.
"""
runtime_parameters = {}
for name, field_info in self.model_fields.items(): # type: ignore
# `field: RuntimeParameter[Any]` or `field: Optional[RuntimeParameter[Any]]`
is_runtime_param, is_optional = _is_runtime_parameter(field_info)
if is_runtime_param:
runtime_parameters[name] = is_optional
continue
attr = getattr(self, name)
# `field: RuntimeParametersMixin`
if isinstance(attr, RuntimeParametersMixin):
runtime_parameters[name] = attr.runtime_parameters_names
# `field: List[RuntimeParametersMixin]`
if (
isinstance(attr, list)
and attr
and isinstance(attr[0], RuntimeParametersMixin)
):
runtime_parameters[name] = {
str(i): item.runtime_parameters_names for i, item in enumerate(attr)
}
return runtime_parameters
def get_runtime_parameters_info(self) -> List["RuntimeParameterInfo"]:
"""Gets the information of the runtime parameters of the class such as the name and
the description. This function is meant to include the information of the runtime
parameters in the serialized data of the class.
Returns:
A list containing the information for each runtime parameter of the class.
"""
runtime_parameters_info = []
for name, field_info in self.model_fields.items(): # type: ignore
if name not in self.runtime_parameters_names:
continue
attr = getattr(self, name)
# Get runtime parameters info for `RuntimeParametersMixin` field
if isinstance(attr, RuntimeParametersMixin):
runtime_parameters_info.append(
{
"name": name,
"runtime_parameters_info": attr.get_runtime_parameters_info(),
}
)
continue
# Get runtime parameters info for `List[RuntimeParametersMixin]` field
if isinstance(attr, list) and isinstance(attr[0], RuntimeParametersMixin):
runtime_parameters_info.append(
{
"name": name,
"runtime_parameters_info": {
str(i): item.get_runtime_parameters_info()
for i, item in enumerate(attr)
},
}
)
continue
info = {"name": name, "optional": self.runtime_parameters_names[name]}
if field_info.description is not None:
info["description"] = field_info.description
runtime_parameters_info.append(info)
return runtime_parameters_info
def set_runtime_parameters(self, runtime_parameters: Dict[str, Any]) -> None:
"""Sets the runtime parameters of the class using the provided values. If the attr
to be set is a `RuntimeParametersMixin`, it will call `set_runtime_parameters` on
the attr.
Args:
runtime_parameters: A dictionary containing the values of the runtime parameters
to set.
"""
runtime_parameters_names = list(self.runtime_parameters_names.keys())
for name, value in runtime_parameters.items():
if name not in self.runtime_parameters_names:
# Check done just to ensure the unit tests for the mixin run
if getattr(self, "pipeline", None):
closest = difflib.get_close_matches(
name, runtime_parameters_names, cutoff=0.5
)
msg = (
f"⚠️ Runtime parameter '{name}' unknown in step '{self.name}'." # type: ignore
)
if closest:
msg += f" Did you mean any of: {closest}"
else:
msg += f" Available runtime parameters for the step: {runtime_parameters_names}."
self.pipeline._logger.warning(msg) # type: ignore
continue
attr = getattr(self, name)
# Set runtime parameters for `RuntimeParametersMixin` field
if isinstance(attr, RuntimeParametersMixin):
attr.set_runtime_parameters(value)
self._runtime_parameters[name] = value
continue
# Set runtime parameters for `List[RuntimeParametersMixin]` field
if isinstance(attr, list) and isinstance(attr[0], RuntimeParametersMixin):
for i, item in enumerate(attr):
item_value = value.get(str(i), {})
item.set_runtime_parameters(item_value)
self._runtime_parameters[name] = value
continue
# Handle settings values for `_SecretField`
field_info = self.model_fields[name]
inner_type = extract_annotation_inner_type(field_info.annotation)
if is_type_pydantic_secret_field(inner_type):
value = inner_type(value)
# Set the value of the runtime parameter
setattr(self, name, value)
self._runtime_parameters[name] = value