Skip to main content
Steps are the building blocks of ZenML pipelines. Each step is a discrete unit of work that takes inputs, performs a computation, and produces outputs. Well-designed steps make your pipelines modular, testable, and reusable.

Basic Step Structure

A step is created using the @step decorator:
from zenml import step

@step
def my_step(input_data: str) -> str:
    """A simple step that processes input."""
    result = input_data.upper()
    return result

Creating Your First Step

1
Step 1: Import the Step Decorator
2
from zenml import step
from typing import Annotated
3
Step 2: Define Step Function
4
Create a function with clear inputs and outputs:
5
@step
def load_data(data_path: str) -> Annotated[dict, "dataset"]:
    """Load data from a file.
    
    Args:
        data_path: Path to the data file
        
    Returns:
        Loaded dataset as a dictionary
    """
    # Your data loading logic
    data = {"values": [1, 2, 3, 4, 5]}
    return data
6
Step 3: Add Type Hints
7
Always use type hints for inputs and outputs:
8
@step
def process_data(
    dataset: dict,
    multiplier: int = 2
) -> Annotated[list, "processed_data"]:
    """Process the dataset.
    
    Args:
        dataset: Input dataset dictionary
        multiplier: Value to multiply each item by
        
    Returns:
        Processed list of values
    """
    values = dataset["values"]
    processed = [x * multiplier for x in values]
    return processed
9
Step 4: Use Steps in a Pipeline
10
Connect steps in your pipeline:
11
from zenml import pipeline

@pipeline
def data_pipeline(data_path: str):
    """Pipeline that loads and processes data."""
    dataset = load_data(data_path)
    processed = process_data(dataset, multiplier=3)
    return processed

Step Inputs and Outputs

Type Annotations

Use Annotated to give artifacts meaningful names:
from typing import Annotated

@step
def train_model(
    train_data: dict,
    learning_rate: float
) -> Annotated[object, "trained_model"]:
    """Train a model on the data.
    
    The Annotated type gives the output artifact the name 'trained_model'
    which makes it easier to identify in the ZenML dashboard.
    """
    # Training logic
    model = {"lr": learning_rate, "trained": True}
    return model

Multiple Outputs

Return multiple artifacts from a step:
from typing import Tuple, Annotated

@step
def split_data(
    dataset: dict,
    test_size: float = 0.2
) -> Tuple[
    Annotated[dict, "train_data"],
    Annotated[dict, "test_data"]
]:
    """Split dataset into train and test sets.
    
    Args:
        dataset: Full dataset to split
        test_size: Fraction of data to use for testing
        
    Returns:
        Tuple of (train_data, test_data)
    """
    data = dataset["values"]
    split_idx = int(len(data) * (1 - test_size))
    
    train_data = {"values": data[:split_idx]}
    test_data = {"values": data[split_idx:]}
    
    return train_data, test_data
Use multiple outputs in a pipeline:
@pipeline
def ml_pipeline():
    """Pipeline with multiple step outputs."""
    dataset = load_data("data.csv")
    train_data, test_data = split_data(dataset, test_size=0.3)
    
    model = train_model(train_data, learning_rate=0.001)
    metrics = evaluate_model(model, test_data)

Optional Outputs

Use Optional for conditional outputs:
from typing import Optional, Annotated

@step
def validate_and_process(
    data: dict,
    strict: bool = False
) -> Tuple[
    Annotated[dict, "processed_data"],
    Annotated[Optional[dict], "validation_errors"]
]:
    """Validate and process data, optionally returning errors.
    
    Args:
        data: Input data to validate
        strict: Whether to perform strict validation
        
    Returns:
        Tuple of (processed_data, validation_errors)
        validation_errors will be None if no errors found
    """
    errors = None
    
    if strict:
        # Perform validation
        if len(data["values"]) < 10:
            errors = {"error": "Insufficient data"}
    
    # Process data
    processed = {"values": data["values"], "validated": True}
    
    return processed, errors

Step Parameters

Make steps configurable with parameters:
@step
def preprocess_text(
    text: str,
    lowercase: bool = True,
    remove_punctuation: bool = True,
    max_length: int = 1000
) -> Annotated[str, "processed_text"]:
    """Preprocess text with configurable options.
    
    Args:
        text: Input text to process
        lowercase: Convert to lowercase
        remove_punctuation: Remove punctuation marks
        max_length: Maximum length of output
        
    Returns:
        Processed text
    """
    processed = text
    
    if lowercase:
        processed = processed.lower()
    
    if remove_punctuation:
        processed = ''.join(c for c in processed if c.isalnum() or c.isspace())
    
    processed = processed[:max_length]
    
    return processed

Step Configuration

Resource Settings

Specify compute resources for a step:
from zenml import step
from zenml.config import ResourceSettings

@step(settings={"resources": ResourceSettings(cpu_count=4, memory="8GB")})
def train_large_model(data: dict) -> object:
    """Train a model that requires significant resources."""
    # Training logic
    return model

Disabling Cache

Disable caching for specific steps:
@step(enable_cache=False)
def fetch_latest_data() -> Annotated[dict, "fresh_data"]:
    """Always fetch fresh data, never use cached results."""
    # This step will always execute, even if inputs are the same
    return fetch_from_api()

Step Operators

Run steps on different infrastructure:
@step(step_operator="gpu_operator")
def train_model_on_gpu(data: dict) -> object:
    """Train model using GPU resources."""
    # Training logic that uses GPU
    return model

Best Practices for Steps

Keep Steps Focused

Each step should have a single, clear purpose:
@step
def load_data(path: str) -> dict:
    """Load data from file."""
    return load_from_file(path)

@step
def clean_data(data: dict) -> dict:
    """Clean the loaded data."""
    return remove_nulls(data)

@step
def transform_data(data: dict) -> dict:
    """Transform cleaned data."""
    return apply_transformations(data)

Use Meaningful Names

Choose descriptive names for steps and artifacts:
# Good names
@step
def calculate_feature_importance(
    model: object,
    test_data: dict
) -> Annotated[dict, "feature_importance_scores"]:
    """Calculate importance scores for each feature."""
    pass

# Avoid vague names like:
# def process(), def step1(), def do_stuff()

Add Comprehensive Docstrings

@step
def evaluate_model(
    model: object,
    test_data: dict,
    threshold: float = 0.8
) -> Annotated[dict, "evaluation_metrics"]:
    """Evaluate model performance on test data.
    
    Computes various metrics including accuracy, precision, recall,
    and F1 score. If accuracy falls below the threshold, a warning
    is logged.
    
    Args:
        model: Trained model to evaluate
        test_data: Test dataset dictionary with 'X' and 'y' keys
        threshold: Minimum acceptable accuracy (default: 0.8)
        
    Returns:
        Dictionary containing all evaluation metrics
        
    Raises:
        ValueError: If test_data is empty or malformed
    """
    # Implementation
    pass

Handle Errors Gracefully

@step
def process_user_input(
    input_data: str
) -> Annotated[dict, "validated_data"]:
    """Process and validate user input.
    
    Args:
        input_data: Raw input string from user
        
    Returns:
        Validated and processed data
        
    Raises:
        ValueError: If input_data is invalid
    """
    if not input_data or not input_data.strip():
        raise ValueError("Input data cannot be empty")
    
    try:
        # Process input
        processed = parse_input(input_data)
        return {"data": processed, "valid": True}
    except Exception as e:
        raise ValueError(f"Failed to process input: {str(e)}")

Common Step Patterns

Data Loading Step

import pandas as pd
from typing import Annotated

@step
def data_loader(
    data_path: str,
    file_format: str = "csv"
) -> Annotated[pd.DataFrame, "dataset"]:
    """Load data from various file formats.
    
    Args:
        data_path: Path to data file
        file_format: Format of the file (csv, parquet, json)
        
    Returns:
        Loaded data as DataFrame
    """
    if file_format == "csv":
        df = pd.read_csv(data_path)
    elif file_format == "parquet":
        df = pd.read_parquet(data_path)
    elif file_format == "json":
        df = pd.read_json(data_path)
    else:
        raise ValueError(f"Unsupported format: {file_format}")
    
    print(f"Loaded {len(df)} records from {data_path}")
    return df

Model Training Step

from sklearn.linear_model import LogisticRegression
import pandas as pd

@step
def model_trainer(
    train_data: pd.DataFrame,
    target_column: str,
    learning_rate: float = 0.01,
    max_iter: int = 100
) -> Annotated[LogisticRegression, "trained_model"]:
    """Train a logistic regression model.
    
    Args:
        train_data: Training dataset
        target_column: Name of target variable
        learning_rate: Learning rate for optimization
        max_iter: Maximum iterations
        
    Returns:
        Trained model
    """
    X = train_data.drop(columns=[target_column])
    y = train_data[target_column]
    
    model = LogisticRegression(
        learning_rate=learning_rate,
        max_iter=max_iter
    )
    model.fit(X, y)
    
    print(f"Model trained with {len(X)} samples")
    return model

Model Evaluation Step

import pandas as pd
from sklearn.metrics import accuracy_score, precision_score, recall_score

@step
def model_evaluator(
    model: object,
    test_data: pd.DataFrame,
    target_column: str
) -> Annotated[dict, "metrics"]:
    """Evaluate model performance.
    
    Args:
        model: Trained model to evaluate
        test_data: Test dataset
        target_column: Name of target variable
        
    Returns:
        Dictionary of evaluation metrics
    """
    X = test_data.drop(columns=[target_column])
    y_true = test_data[target_column]
    
    y_pred = model.predict(X)
    
    metrics = {
        "accuracy": accuracy_score(y_true, y_pred),
        "precision": precision_score(y_true, y_pred, average='weighted'),
        "recall": recall_score(y_true, y_pred, average='weighted'),
    }
    
    print(f"Model Accuracy: {metrics['accuracy']:.3f}")
    return metrics

Next Steps

Step Context

Access runtime information and metadata within steps

Artifact Management

Learn how ZenML tracks and manages step artifacts

Creating Pipelines

Connect steps together in pipelines

Build docs developers (and LLMs) love