Source code for kale.marshal.decorator

# Copyright 2026 The Kubeflow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import marshal as marshal_utils
import sys
from typing import Any, NamedTuple

log = logging.getLogger(__name__)


[docs] class PipelineParam(NamedTuple): """A pipeline parameter.""" param_type: str param_value: Any
[docs] def marshal( ins: list, outs: list, parameters: dict[str, PipelineParam | Any] = None, marshal_dir: str = None, introspect: bool = False, ): """Decorator that ensures proper marshalling happens when the fn is run.""" _params = { k: (v if isinstance(v, PipelineParam) else PipelineParam(type(v), v)) for k, v in parameters.items() } def _marshal(func): return Marshaller(func, ins, outs, _params, marshal_dir, introspect) return _marshal
[docs] class Marshaller: """Wrap a function to perform marshalling around its execution. This class acts as a wrapper around a function that runs in a pipeline step and needs input arguments to be loaded from a marshal directory and its outputs saved likewise. """ def __init__( self, func, ins: list, outs: list, parameters: dict[str, PipelineParam] = None, marshal_dir=None, introspect=False, ): self._introspect = introspect if introspect: self._func = _persistent_locals(func) else: self._func = func self._ins = ins self._outs = outs self._parameters = parameters or {} marshal_utils.set_data_dir(marshal_dir)
[docs] def __call__(self): """Run the function by passing loaded vars and saving the results.""" loads = self._load() log.newline(lines=2) results = self._func(*loads) log.newline(lines=2) self._save(results)
def _load(self): loads = [] # load in the same order as in self._ins. for var_name in self._ins: if var_name not in self._parameters: loads.append(marshal_utils.load(var_name)) else: loads.append(self._parameters[var_name].param_value) return loads def _save(self, values): if self._introspect: # get vars from function locals for var_name in self._outs: if var_name not in self._func.locals: raise RuntimeError(f"Variable {var_name} not found in function's locals") marshal_utils.save(self._func.locals[var_name], var_name) else: # get vars from return value if len(self._outs) == 0: return if isinstance(values, tuple): if len(values) != len(self._outs): raise RuntimeError( "There is a mismatch between the tuple" " returned by the functions and its" " expected outs. If the functions is" " returning a tuple, make sure the " " return value it is properly" " unpacked." ) for name, value in dict(zip(self._outs, values, strict=False)).items(): marshal_utils.save(value, name) else: # any other object? if len(self._outs) > 1: raise RuntimeError( "The function returned a single object," " but there are multiple expected outs:" f" {str(self._outs)}" ) marshal_utils.save(values, self._outs[0])
class _persistent_locals: """Function decorator to expose local variables after execution. Modify the function such that, at the exit of the function (regular exit or exceptions), the local dictionary is copied to a read-only function property 'locals'. This decorator wraps the function in a callable object, and modifies its bytecode by adding an external try...finally statement equivalent to the following: ``` def f(self, *args, **kwargs): try: ... old code ... finally: self._locals = locals().copy() del self._locals['self'] ``` Refer to the docstring of instances for help about the wrapped function. """ def __init__(self, func): self._locals = {} self._func = func def __call__(self, *args, **kwargs): def tracer(frame, event, arg): if event == "return": self._locals = frame.f_locals.copy() # keep old profile old_profile = sys.getprofile() # tracer is activated on next call, return or exception sys.setprofile(tracer) try: # trace the function call res = self._func(*args, **kwargs) finally: # disable tracer and replace with old one sys.setprofile(old_profile) return res @property def locals(self): return self._locals