Source code for kale.pipeline

# Copyright 2026 The Kubeflow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pipeline data model.

Defines :class:`Pipeline` — a :class:`networkx.DiGraph` of :class:`~kale.step.Step`
nodes — along with the configuration classes that describe pipeline-level
settings such as the pipeline name, KFP host, volumes, and Katib experiments.
"""

from collections.abc import Iterable
import copy
import logging
import os

import networkx as nx

from kale.common import graphutils, podutils, utils
from kale.config import Config, Field, validators
from kale.step import PipelineParam, Step

log = logging.getLogger(__name__)


VOLUME_ACCESS_MODE_MAP = {
    "rom": ["ReadOnlyMany"],
    "rwo": ["ReadWriteOnce"],
    "rwm": ["ReadWriteMany"],
}
DEFAULT_VOLUME_ACCESS_MODE = VOLUME_ACCESS_MODE_MAP["rwm"]
DEFAULT_BASE_IMAGE = "python:3.12"


[docs] class VolumeConfig(Config): """Used for validating the `volumes` field of NotebookConfig.""" name = Field(type=str, required=True, validators=[validators.K8sNameValidator]) mount_point = Field(type=str, required=True) snapshot = Field(type=bool, default=False) snapshot_name = Field(type=str) size = Field(type=int) # fixme: validation for this field? size_type = Field(type=str) # fixme: validation for this field? type = Field(type=str, required=True, validators=[validators.VolumeTypeValidator]) annotations = Field(type=list, default=[]) storage_class_name = Field(type=str, validators=[validators.K8sNameValidator]) volume_access_mode = Field( type=str, validators=[validators.IsLowerValidator, validators.VolumeAccessModeValidator] ) def _parse_annotations(self): # Convert annotations to a {k: v} dictionary try: # TODO: Make JupyterLab annotate with {k: v} instead of # {'key': k, 'value': v} self.annotations = { a["key"]: a["value"] for a in self.annotations if a["key"] != "" and a["value"] != "" } except KeyError as e: if str(e) in ["'key'", "'value'"]: raise ValueError( "Volume spec: volume annotations must be a list of {'key': k, 'value': v} dicts" ) else: raise e def _parse_access_mode(self): if self.volume_access_mode: self.volume_access_mode = VOLUME_ACCESS_MODE_MAP[self.volume_access_mode] def _postprocess(self): self._parse_annotations() self._parse_access_mode()
[docs] class KatibConfig(Config): """Used to validate the `katib_metadata` field of NotebookConfig.""" # fixme: improve validation of single fields parameters = Field(type=list, default=[]) objective = Field(type=dict, default={}) algorithm = Field(type=dict, default={}) # fixme: Change these names to be Pythonic (need to change how the # labextension passes them) maxTrialCount = Field(type=int, default=12) maxFailedTrialCount = Field(type=int, default=3) parallelTrialCount = Field(type=int, default=3)
[docs] class SecurityContextConfig(Config): """Configuration for Kubernetes security context settings. These settings control the security context applied to all pipeline steps. Can be configured via JupyterLab settings or ``KALE_*`` environment variables. """ enabled = Field(type=bool, default=True) run_as_user = Field(type=int, default=65534) run_as_group = Field(type=int, default=0) run_as_non_root = Field(type=bool, default=True) def __eq__(self, value): if not isinstance(value, SecurityContextConfig): return False return ( self.enabled == value.enabled and self.run_as_user == value.run_as_user and self.run_as_group == value.run_as_group and self.run_as_non_root == value.run_as_non_root )
[docs] class PipelineConfig(Config): """Main config class to validate the pipeline metadata.""" pipeline_name = Field(type=str, required=True, validators=[validators.PipelineNameValidator]) experiment_name = Field(type=str, required=True) pipeline_description = Field(type=str, default="") base_image = Field(type=str, default="") enable_caching = Field(type=bool, default=True) volumes = Field(type=list, items_config_type=VolumeConfig, default=[]) katib_run = Field(type=bool, default=False) katib_metadata = Field(type=KatibConfig) abs_working_dir = Field(type=str, default="") marshal_volume = Field(type=bool, default=True) marshal_path = Field(type=str, default="/tmp/marshal") steps_defaults = Field(type=dict, default={}) kfp_host = Field(type=str) storage_class_name = Field(type=str, validators=[validators.K8sNameValidator]) volume_access_mode = Field( type=str, validators=[validators.IsLowerValidator, validators.VolumeAccessModeValidator] ) timeout = Field(type=int, validators=[validators.PositiveIntegerValidator]) security_context = Field(type=SecurityContextConfig, default=None) output_path = Field(type=str, default="", validators=[validators.OutputPathValidator]) @property def source_path(self): """Get the path to the main entry point script.""" return utils.get_main_source_path() def _postprocess(self): # self._randomize_pipeline_name() self._set_base_image() self._set_volume_storage_class() self._set_volume_access_mode() self._sort_volumes() self._set_abs_working_dir() self._set_marshal_path() self._set_security_context() def _randomize_pipeline_name(self): self.pipeline_name = f"{self.pipeline_name}-{utils.random_string()}" def _set_base_image(self): if not self.base_image: self.base_image = utils.get_default_base_image_from_env() or DEFAULT_BASE_IMAGE def _set_volume_storage_class(self): if not self.storage_class_name: return for v in self.volumes: if not v.storage_class_name: v.storage_class_name = self.storage_class_name def _set_volume_access_mode(self): if not self.volume_access_mode: self.volume_access_mode = DEFAULT_VOLUME_ACCESS_MODE else: self.volume_access_mode = VOLUME_ACCESS_MODE_MAP[self.volume_access_mode] for v in self.volumes: if not v.volume_access_mode: v.volume_access_mode = self.volume_access_mode def _sort_volumes(self): # The Jupyter Web App assumes the first volume of the notebook is the # working directory, so we make sure to make it appear first in the # spec. self.volumes = sorted( self.volumes, reverse=True, key=lambda _v: podutils.is_workspace_dir(_v.mount_point) ) def _set_abs_working_dir(self): if not self.abs_working_dir: self.abs_working_dir = utils.abs_working_dir(self.source_path) def _set_marshal_path(self): # Check if the workspace directory is under a mounted volume. # If so, marshal data into a folder in that volume, # otherwise create a new volume and mount it at /tmp/marshal wd = os.path.realpath(self.abs_working_dir) # get the volumes for which the working directory is a sub-path of # the mount point vols = list(filter(lambda x: wd.startswith(x.mount_point), self.volumes)) # if we found any, then set marshal directory inside working directory if len(vols) > 0: basename = os.path.basename(self.source_path) marshal_dir = f".{basename}.kale.marshal.dir" self.marshal_volume = False self.marshal_path = os.path.join(wd, marshal_dir) def _set_security_context(self): """Initialize security context from env vars if not set from metadata. Precedence: JupyterLab metadata > env vars > defaults """ env_config = utils.get_security_context_from_env() if self.security_context is None: self.security_context = env_config
[docs] class Pipeline(nx.DiGraph): """A Pipeline that can be converted into a KFP pipeline. This class is used to define a pipeline, its steps and all its configurations. It extends nx.DiGraph to exploit some graph-related algorithms but provides helper functions to work with Step objects instead of standard networkx "nodes". This makes it simpler to access the steps of the pipeline and their attributes. """ def __init__(self, config: PipelineConfig, *args, **kwargs): super().__init__(*args, **kwargs) self.config = config self.pipeline_parameters: dict[str, PipelineParam] = {} self.processor = None self._pps_names = None
[docs] def run(self): """Runs the steps locally in topological sort.""" for step in self.steps: step.run(self.pipeline_parameters)
[docs] def add_step(self, step: Step): """Add a new Step to the pipeline.""" if not isinstance(step, Step): raise RuntimeError("Not of type Step.") if step.name in self.steps_names: raise RuntimeError(f"Step with name '{step.name}' already exists") self.add_node(step.name, step=step)
[docs] def add_dependency(self, parent: Step, child: Step): """Link two Steps in the pipeline.""" self.add_edge(parent.name, child.name)
[docs] def get_step(self, name: str) -> Step: """Get the Step with the provided name.""" return self.nodes()[name]["step"]
@property def steps(self) -> Iterable[Step]: """Get the Steps objects sorted topologically.""" return (self.nodes()[x]["step"] for x in self.steps_names) @property def steps_names(self): """Get all Steps' names, sorted topologically.""" return [step.name for step in self._topological_sort()] @property def all_steps_parameters(self): """Create a dict with step names and their parameters.""" return {step: sorted(self.get_step(step).parameters.keys()) for step in self.steps_names} @property def pipeline_dependencies_tasks(self): """Generate a dictionary of Pipeline dependencies.""" return {step_name: list(self.predecessors(step_name)) for step_name in self.steps_names} @property def pps_names(self): """Get the names of the pipeline parameters sorted.""" if self._pps_names is None: self._pps_names = sorted(self.pipeline_parameters.keys()) return self._pps_names @property def pps_types(self): """Get the types of the pipeline parameters, sorted by name.""" return [self.pipeline_parameters[n].param_type for n in self.pps_names] @property def pps_values(self): """Get the values of the pipeline parameters, sorted by name.""" return [self.pipeline_parameters[n].param_value for n in self.pps_names] def _topological_sort(self) -> Iterable[Step]: return self._steps_iterable(nx.topological_sort(self))
[docs] def get_ordered_ancestors(self, step_name: str) -> Iterable[Step]: """Return the ancestors of a step in an ordered manner. Wrapper of graphutils.get_ordered_ancestors. Returns: Iterable[Step]: A Steps iterable. """ return self._steps_iterable(graphutils.get_ordered_ancestors(self, step_name))
def _steps_iterable(self, step_names: Iterable[str]) -> Iterable[Step]: for name in step_names: yield self.get_step(name)
[docs] def get_leaf_steps(self): """Get the list of leaf steps of the pipeline. A step is considered a leaf when its in-degree is > 0 and its out-degree is 0. Returns (list): A list of leaf Steps. """ return [x for x in self.steps if self.out_degree(x.name) == 0]
[docs] def override_pipeline_parameters_from_kwargs(self, **kwargs): """Overwrite the current pipeline parameters with provided inputs.""" _pipeline_parameters = copy.deepcopy(self.pipeline_parameters) for k, v in kwargs.items(): if k not in self.pipeline_parameters: raise RuntimeError( f"Running pipeline '{self.config.pipeline_name}' with" " an input argument that is not in its" f" parameters: {k}" ) # replace default value with the provided one _type = _pipeline_parameters[k].param_type _pipeline_parameters[k] = PipelineParam(_type, v) self.pipeline_parameters = _pipeline_parameters
[docs] def show(self): """Print the pipeline nodes and dependencies in a table.""" from tabulate import tabulate data = [] for step in self.steps: data.append([step.name, list(self.predecessors(step.name))]) log.info("Pipeline layout:") log.info("\n" + tabulate(data, headers=["Step", "Depends On"]) + "\n")