Source code for kale.marshal.backends

# Copyright 2026 The Kubeflow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Built-in marshal backends.

Concrete :class:`~kale.marshal.backend.MarshalBackend` implementations for the
Python types Kale supports out of the box: numpy arrays, pandas DataFrames,
scikit-learn estimators, PyTorch / Keras / TensorFlow models, XGBoost
boosters and DMatrices, and plain Python functions.
"""

import logging

from kale.marshal.backend import MarshalBackend, get_dispatcher

log = logging.getLogger(__name__)


register_backend = get_dispatcher().register


[docs] @register_backend class FunctionBackend(MarshalBackend): """Marshal Python functions.""" name = "Function backend" display_name = "function" file_type = "pyfn" obj_type_regex = r"function"
[docs] @register_backend class SKLearnBackend(MarshalBackend): """Marshal SKLearn objects.""" name = "SKLearn backend" display_name = "scikit-learn" file_type = "joblib" obj_type_regex = r"sklearn\..*" predictor_type = "sklearn" # `joblib` is a separate library from sklearn that must be installed # independently. Don't fallback to dill since this will break when # serving models. fallback_on_missing_lib = False
[docs] def save(self, obj, path): """Save a SKLearn object.""" import joblib joblib.dump(obj, path)
[docs] def load(self, file_path): """Restore a SKLearn object.""" import joblib return joblib.load(file_path)
[docs] @register_backend class NumpyBackend(MarshalBackend): """Marshal Numpy objects functions.""" name = "Numpy backend" display_name = "numpy" file_type = "npy" obj_type_regex = r"numpy\..*"
[docs] def save(self, obj, path): """Save a Numpy object.""" import numpy as np np.save(path, obj)
[docs] def load(self, file_path): """Restore a Numpy object.""" import numpy as np return np.load(file_path)
[docs] @register_backend class PandasBackend(MarshalBackend): """Marshal Pandas objects.""" name = "Pandas backend" display_name = "pandas" file_type = "pdpkl" obj_type_regex = r"pandas\..*(DataFrame|Series)"
[docs] def save(self, obj, path): """Save a Pandas object.""" import pandas as pd # noqa: F401 obj.to_pickle(path)
[docs] def load(self, file_path): """Restore a Pandas object.""" import pandas as pd return pd.read_pickle(file_path)
[docs] @register_backend class XGBoostModelBackend(MarshalBackend): """Marshal XGBoost Model object.""" name = "XGBoost Model backend" display_name = "xgboost" file_type = "json" obj_type_regex = r"xgboost\.core\.Booster" predictor_type = "xgboost"
[docs] def save(self, obj, path): """Save an XGBoost Model object.""" obj.save_model(path)
[docs] def load(self, file_path): """Restore an XGBoost Model object.""" import xgboost as xgb obj_xgb = xgb.Booster() obj_xgb.load_model(file_path) return obj_xgb
[docs] @register_backend class XGBoostDMatrixBackend(MarshalBackend): """Marshal XGBoost DMatrix object.""" name = "XGBoost DMatrix backend" display_name = "xgboost-dmatrix" file_type = "dmatrix" obj_type_regex = r"xgboost\.core\.DMatrix"
[docs] def save(self, obj, path): """Save an XGBoost DMatrix object.""" obj.save_binary(path)
[docs] def load(self, file_path): """Restore an XGBoost DMatrix object.""" import xgboost as xgb return xgb.DMatrix(file_path)
[docs] @register_backend class PyTorchBackend(MarshalBackend): """Marshal PyTorch objects.""" name = "PyTorch backend" display_name = "pytorch" file_type = "pt" obj_type_regex = r"torch\.nn\.modules\.module\.Module"
[docs] def save(self, obj, path): """Save a PyTorch object.""" import torch model_script = torch.jit.script(obj) model_script.save(path)
[docs] def load(self, file_path): """Restore a PyTorch object.""" import torch obj_torch = torch.jit.load(file_path) # `jit.load` returns a `ScirptModule` object. # To turn it into a PyTorch `Module` again # we pass it inside a `Sequential` container. # The `Sequential` container is a wrapper # that feeds the data to the modules it contains in order. # Here, we create a `Sequential` container # with only one item, so it works like a wrapper # function around our model. obj_torch = torch.nn.Sequential(obj_torch) obj_torch.eval() return obj_torch
[docs] @register_backend class KerasBackend(MarshalBackend): """Marshal Keras objects.""" name = "Keras backend" display_name = "keras" file_type = "keras" obj_type_regex = r"keras\..*"
[docs] def save(self, obj, path): """Save a Keras object.""" import keras # noqa: F401 obj.save(path)
[docs] def load(self, file_path): """Restore a Keras object.""" from keras.models import load_model return load_model(file_path)
[docs] @register_backend class TensorflowKerasBackend(MarshalBackend): """Marshal Tensorflow Keras objects.""" name = "Tensorflow backend" display_name = "tensorflow" file_type = "tfkeras" obj_type_regex = r"tensorflow\.python\.keras.*" predictor_type = "tensorflow"
[docs] def save(self, obj, path): """Save a Tensorflow Keras object.""" import tensorflow.keras # noqa: F401 # XXX: Adding `/1` since tensorflow serve expects the model's models # to be saved under a versioned folder obj.save(path + "/1")
[docs] def load(self, file_path): """Restore a Tensorflow Keras object.""" from tensorflow.keras.models import load_model try: obj = load_model(file_path, compile=False) except OSError: # XXX: try to load a model that was saved within a versioned # folder (for tensorflow serve) obj = load_model(file_path + "/1", compile=False) return obj