# Copyright 2026 The Kubeflow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Marshal dispatcher and backend base class.
Defines :class:`MarshalBackend`, the abstract base class for type-aware
serializers that Kale uses to pass data between pipeline steps, and
:class:`Dispatcher`, which routes objects to the correct backend at
save/load time.
"""
import logging
import os
import re
from typing import Any
from kale.common import utils
log = logging.getLogger(__name__)
__DATA_DIR = os.path.curdir
[docs]
def set_data_dir(path):
"""Set the data directory where marshalling happens."""
global __DATA_DIR
__DATA_DIR = path
# create dir if not exists
if not os.path.isdir(__DATA_DIR):
os.makedirs(__DATA_DIR, exist_ok=True)
[docs]
def get_data_dir():
"""Get the data directory where marshalling happens."""
global __DATA_DIR
return __DATA_DIR
[docs]
class MarshalBackend:
"""Base class for marshalling Python objects.
This class is supposed to be subclassed by specialized backends that
implement the `save` and `load` functions to marshal library-specific
objects.
A backend registers itself to specific objects/file types using the
following class attributes:
* `file_type`: The file extension of the files/folders the backend is able
to restore. NOTE: Currently this can be just *one* ext.
* `obj_type_regex`: A regex which is matched against the `type` of an
object.
Take a look at `backend.py` for some examples on how to create custom
marshal backends.
"""
name: str = "Default backend"
display_name: str = "generic" # This is supposed to tbe the library name
file_type: str = "dillpkl"
obj_type_regex: str = None
predictor_type: str = None # Used for creating serving predictors
# Set to False if you want your backend not to use the default backend
# in case of a missing library.
fallback_on_missing_lib = True
def __init__(
self,
name: str = None,
display_name: str = None,
obj_type_regex: str = None,
file_type: str = None,
):
self.name = name or self.name
self.display_name = display_name or self.display_name
self.obj_type_regex = obj_type_regex or self.obj_type_regex
self.file_type = file_type or self.file_type
[docs]
def wrapped_save(self, obj: Any, name: str):
"""Wrapper around the public `save` function.
This function provides common logging and exception handling for every
class that extends the base `MarshalBackend`. `Dispatcher` calls
directly this function instead of `save`.
Returns the path (<data_dir>/<basename>.<backend_extension>) to the
saved file.
"""
abs_path = os.path.join(get_data_dir(), name + "." + self.file_type)
log.debug(
"Saving %s object using %s: %s to %s", self.display_name, self.name, name, abs_path
)
try:
self.save(obj, abs_path)
except ImportError as e:
if not self.fallback_on_missing_lib:
raise e
log.warning(
"Failed to import %s (%s). Falling back to default backend.", self.display_name, e
)
self._default_save(obj, name) # always try the default save
return abs_path
[docs]
def save(self, obj: Any, path: str):
"""Save `obj` to file."""
self._default_save(obj, path)
@staticmethod
def _default_save(obj: Any, path: str):
import dill
with open(path, "wb") as f:
dill.dump(obj, f)
[docs]
def wrapped_load(self, name: str) -> Any:
"""Wrapper around the public `load` function.
This function provides common logging and exception handling for every
class that extends the base `MarshalBackend`. `Dispatcher` calls
directly this function instead of `load`.
"""
abs_path = os.path.join(get_data_dir(), name + "." + self.file_type)
log.debug("Loading %s file using %s: %s", self.display_name, self.name, name)
try:
return self.load(abs_path)
except ImportError as e:
if not self.fallback_on_missing_lib:
raise e
log.warning(
"Failed to import %s (%s). Falling back to default backend.", self.display_name, e
)
return self._default_load(abs_path) # always try the default load
[docs]
def load(self, file_path: str) -> Any:
"""Restore `file_path` to memory."""
return self._default_load(file_path)
@staticmethod
def _default_load(file_path: str) -> Any:
import dill
with open(file_path, "rb") as f:
return dill.load(f)
dispatcher = None
[docs]
def get_dispatcher():
"""Get the unique instance of dispatcher.
This is preferred since Dispatcher registered all MarshalBackends that
are decorated with the `register` function. We don't want the registration
process to happen all the time.
"""
global dispatcher
if not dispatcher:
dispatcher = Dispatcher()
return dispatcher
[docs]
class Dispatcher:
"""Dispatch backend classes based on obj types or file extensions.
This class holds a reference to all the marshalling backends that register
themselves with the `register` function. `Dispatcher` is the main mechanism
with which a specialized backend is chosen to either save or load an
and object to/from memory.
The public functions that users should be aware of:
* `save`: Dispatches to a specialized backend based on the input object
type, by filtering through the backends' `obj_type_regex`
attribute.
* `load`: Dispatches to a specialized backend based on the input file path
by filtering through the backends' `file_type` attribute.
Users and external code are not supposed to interact directly with the
singleton instance of this class. Rather, they should just call the
two publicly exposed functions `save` and `load` like so:
```
from kale.marshal import save, load
```
"""
END_USER_EXC_MSG = (
"\n\nThe error was:\n%s\n\nPlease help us improve Kale"
" by opening a new issue at:"
"\nhttps://github.com/kubeflow-kale/kale/issues."
)
def __init__(self):
self.backends: dict[str, MarshalBackend] = {}
[docs]
def register(self, cls: type[MarshalBackend]) -> type[MarshalBackend]:
"""Register a new marshalling backend.
Args:
cls: Marshal backend class
Returns: the class itself
"""
if cls.__name__ not in self.backends:
self.backends[cls.__name__] = cls()
return cls
[docs]
def get_backend(self, obj: Any):
"""Get the backend registered for the input object type."""
return self._dispatch_obj_type(obj)
[docs]
def get_backends(self) -> dict[str, MarshalBackend]:
"""Get all registered backends."""
# FIXME: How can we make this dict readonly? We don't want external
# code to mess with it.
return dict(self.backends)
[docs]
def get_backend_by_name(self, name: str):
"""Get a registered backend by its display name."""
return self.backends[name]
[docs]
def save(self, obj: Any, obj_name: str):
"""Save an object to file.
Args:
obj: Object to be marshalled
obj_name: Name of the object to be saved
"""
try:
return self._dispatch_obj_type(obj).wrapped_save(obj, obj_name)
except Exception as e:
error_msg = (
"During data passing, Kale could not marshal the"
f" following object:\n\n - path: '{obj_name}'\n - type: '{type(obj)}'"
)
log.error(error_msg + self.END_USER_EXC_MSG % e)
log.debug("Original Traceback", exc_info=e.__traceback__)
utils.graceful_exit(1)
[docs]
def load(self, basename: str):
"""Restore a file to memory.
Args:
basename: The name of the serialized object to be loaded
Returns: restored object
"""
try:
entry_name = self._unique_ls(basename)
return self._dispatch_file_type(entry_name).wrapped_load(basename)
except Exception as e:
error_msg = (
"During data passing, Kale could not load the"
f" following file:\n\n\n - name: '{basename}'"
)
log.error(error_msg + self.END_USER_EXC_MSG % e)
log.debug("Original Traceback", exc_info=e.__traceback__)
utils.graceful_exit(1)
[docs]
def get_path(self, basename: str):
"""Get a serialized kfp artifact abs path.
Args:
basename: The name of the artifact to retrieve its abs path
Returns: the marshalled artifact path
"""
return os.path.join(get_data_dir(), self._unique_ls(basename))
@staticmethod
def _unique_ls(basename: str):
# get the unique file/folder inside _DATA_DIR: there could be
# multiple files with the same name and different extension.
entries = [
ls
for ls in os.listdir(get_data_dir())
if (
(
os.path.isfile(os.path.join(get_data_dir(), ls))
or os.path.isdir(os.path.join(get_data_dir(), ls))
)
and os.path.splitext(ls)[0] == basename
)
]
log.debug("Found %d entries for basename '%s': %s", len(entries), basename, entries)
if not entries:
log.debug(
"Looking for unique file/folder with basename '%s' in %s", basename, get_data_dir()
)
raise ValueError(
f"No file or folder found with basename '{basename}' in {get_data_dir()}"
)
if len(entries) > 1:
raise ValueError(f"Found multiple files/folders with name {basename}: {entries}")
return entries[0]
def _dispatch_obj_type(self, obj: Any) -> MarshalBackend:
"""Dispatch to a backend based on the object's type matching regex.
Args:
obj: any Python object
"""
# object types are printed as <class 'obj type'>
_type = re.sub(r"'>$", "", re.sub(r"^<class '", "", str(type(obj))))
# type of base class
_type_base = re.sub(r"'>$", "", re.sub(r"^<class '", "", str(obj.__class__.__bases__[0])))
_backends = [
backend
for backend in self.backends.values()
if re.match(backend.obj_type_regex, _type)
or re.match(backend.obj_type_regex, _type_base)
]
if len(_backends) > 1:
raise RuntimeError(
"Too many matching marshalling backends for"
f" object type {_type} (base type {_type_base}): {_backends}"
)
if not _backends:
log.debug(
f"No backends found for type {_type} ({_type_base}). Falling back to"
" default backend."
)
return MarshalBackend()
else:
return _backends[0]
def _dispatch_file_type(self, filename: str) -> MarshalBackend:
"""Dispatch to a backend based on the matching file type.
Note: The "file" could be a folder. Some backends marshal an object
inside a folder. That is why we don't explicitly check if the path
points to a file and instead we just get the extension.
Args:
filename (str): filename whose extension must be matched.
"""
_backends = [
backend
for backend in self.backends.values()
if os.path.splitext(filename)[1].lstrip(".") == backend.file_type
]
if len(_backends) > 1:
raise RuntimeError(
"Too many matching marshalling backends for"
f" file {os.path.basename(filename)} : {_backends}"
)
if not _backends:
log.debug(f"No backends found for '{filename}'. Falling back to default backend.")
return MarshalBackend()
else:
return _backends[0]