Skip to main content
Materializers handle the serialization and deserialization of artifacts passed between pipeline steps. Custom materializers allow you to work with any data type, from simple objects to complex ML models.

Understanding Materializers

When data flows between pipeline steps:
  1. Save: The materializer serializes the artifact to the artifact store
  2. Load: The materializer deserializes the artifact for the next step
Materializers also handle:
  • Extracting metadata for tracking
  • Creating visualizations for the dashboard
  • Computing content hashes for caching
  • Loading specific items from collections

The BaseMaterializer Interface

All materializers inherit from BaseMaterializer and implement key methods:
from zenml.materializers import BaseMaterializer
from typing import Type, Any

class BaseMaterializer:
    """Base class for all materializers."""
    
    # Required class attributes
    ASSOCIATED_TYPES = ()  # Types this materializer handles
    ASSOCIATED_ARTIFACT_TYPE = ArtifactType.BASE  # Category
    
    def __init__(self, uri: str, artifact_store = None):
        """Initialize with storage location.
        
        Args:
            uri: URI where artifact will be stored
            artifact_store: The artifact store instance
        """
        self.uri = uri
        self._artifact_store = artifact_store
    
    def save(self, data: Any) -> None:
        """Save artifact data to self.uri."""
        raise NotImplementedError
    
    def load(self, data_type: Type[Any]) -> Any:
        """Load artifact data from self.uri."""
        raise NotImplementedError

Creating a Simple Materializer

Let’s create a materializer for a custom data class:
from dataclasses import dataclass
from typing import Type, Any
import json
import os

@dataclass
class ModelMetrics:
    """Custom class for model metrics."""
    accuracy: float
    precision: float
    recall: float
    f1_score: float

Step 1: Define the Materializer Class

from zenml.materializers import BaseMaterializer
from zenml.enums import ArtifactType
from typing import ClassVar, Tuple, Type, Any

class ModelMetricsMaterializer(BaseMaterializer):
    """Materializer for ModelMetrics objects."""
    
    # Register the types this materializer handles
    ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (ModelMetrics,)
    ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA
    
    def save(self, metrics: ModelMetrics) -> None:
        """Save metrics to JSON file.
        
        Args:
            metrics: The ModelMetrics object to save
        """
        # Convert to dictionary
        data = {
            "accuracy": metrics.accuracy,
            "precision": metrics.precision,
            "recall": metrics.recall,
            "f1_score": metrics.f1_score,
        }
        
        # Write to artifact store
        filepath = os.path.join(self.uri, "metrics.json")
        with self.artifact_store.open(filepath, "w") as f:
            json.dump(data, f, indent=2)
    
    def load(self, data_type: Type[Any]) -> ModelMetrics:
        """Load metrics from JSON file.
        
        Args:
            data_type: The expected type (ModelMetrics)
            
        Returns:
            The loaded ModelMetrics object
        """
        # Read from artifact store
        filepath = os.path.join(self.uri, "metrics.json")
        with self.artifact_store.open(filepath, "r") as f:
            data = json.load(f)
        
        # Reconstruct object
        return ModelMetrics(
            accuracy=data["accuracy"],
            precision=data["precision"],
            recall=data["recall"],
            f1_score=data["f1_score"],
        )

Step 2: Use in a Pipeline

from zenml import pipeline, step

@step
def train_model() -> ModelMetrics:
    """Train model and return metrics."""
    # Training logic here
    return ModelMetrics(
        accuracy=0.95,
        precision=0.93,
        recall=0.97,
        f1_score=0.95,
    )

@step
def evaluate_model(metrics: ModelMetrics) -> None:
    """Evaluate model metrics."""
    print(f"Model achieved {metrics.accuracy:.2%} accuracy")

@pipeline
def training_pipeline():
    metrics = train_model()
    evaluate_model(metrics)

# ZenML automatically uses ModelMetricsMaterializer
training_pipeline()

Advanced Features

Extracting Metadata

Metadata appears in the ZenML dashboard alongside your artifacts:
from zenml.metadata.metadata_types import MetadataType
from typing import Dict

class ModelMetricsMaterializer(BaseMaterializer):
    # ... save() and load() methods ...
    
    def extract_metadata(
        self, metrics: ModelMetrics
    ) -> Dict[str, MetadataType]:
        """Extract metadata for dashboard display.
        
        Args:
            metrics: The ModelMetrics object
            
        Returns:
            Dictionary of metadata to track
        """
        return {
            "accuracy": metrics.accuracy,
            "precision": metrics.precision,
            "recall": metrics.recall,
            "f1_score": metrics.f1_score,
            "best_metric": max(
                metrics.accuracy,
                metrics.precision,
                metrics.recall,
                metrics.f1_score,
            ),
        }

Creating Visualizations

Visualizations appear in the dashboard for interactive exploration:
from zenml.enums import VisualizationType
import matplotlib.pyplot as plt
import os

class ModelMetricsMaterializer(BaseMaterializer):
    # ... other methods ...
    
    def save_visualizations(
        self, metrics: ModelMetrics
    ) -> Dict[str, VisualizationType]:
        """Create visualizations for the dashboard.
        
        Args:
            metrics: The ModelMetrics object
            
        Returns:
            Dictionary mapping visualization URIs to their types
        """
        visualizations = {}
        
        # Create bar chart
        fig, ax = plt.subplots(figsize=(10, 6))
        metric_names = ["Accuracy", "Precision", "Recall", "F1 Score"]
        metric_values = [
            metrics.accuracy,
            metrics.precision,
            metrics.recall,
            metrics.f1_score,
        ]
        
        ax.bar(metric_names, metric_values, color='steelblue')
        ax.set_ylabel('Score')
        ax.set_title('Model Metrics')
        ax.set_ylim([0, 1])
        
        # Save to artifact store
        chart_path = os.path.join(self.uri, "metrics_chart.png")
        chart_path = chart_path.replace("\\", "/")  # Normalize path
        
        with self.artifact_store.open(chart_path, "wb") as f:
            plt.savefig(f, format='png', bbox_inches='tight')
        plt.close()
        
        visualizations[chart_path] = VisualizationType.IMAGE
        
        # Create HTML report
        html_path = os.path.join(self.uri, "metrics_report.html")
        html_path = html_path.replace("\\", "/")
        
        html_content = f"""
        <html>
        <head><title>Model Metrics Report</title></head>
        <body>
            <h1>Model Performance Metrics</h1>
            <table border="1">
                <tr><th>Metric</th><th>Value</th></tr>
                <tr><td>Accuracy</td><td>{metrics.accuracy:.4f}</td></tr>
                <tr><td>Precision</td><td>{metrics.precision:.4f}</td></tr>
                <tr><td>Recall</td><td>{metrics.recall:.4f}</td></tr>
                <tr><td>F1 Score</td><td>{metrics.f1_score:.4f}</td></tr>
            </table>
        </body>
        </html>
        """
        
        with self.artifact_store.open(html_path, "w") as f:
            f.write(html_content)
        
        visualizations[html_path] = VisualizationType.HTML
        
        return visualizations

Computing Content Hashes

Content hashes enable caching - steps skip re-execution if inputs haven’t changed:
import hashlib
import json

class ModelMetricsMaterializer(BaseMaterializer):
    # ... other methods ...
    
    def compute_content_hash(self, metrics: ModelMetrics) -> str:
        """Compute hash for caching.
        
        Args:
            metrics: The ModelMetrics object
            
        Returns:
            Hash string for cache comparison
        """
        # Serialize to consistent representation
        data = json.dumps({
            "accuracy": metrics.accuracy,
            "precision": metrics.precision,
            "recall": metrics.recall,
            "f1_score": metrics.f1_score,
        }, sort_keys=True)
        
        # Compute hash
        return hashlib.sha256(data.encode()).hexdigest()

Real-World Example: NumPy Arrays

Let’s look at how ZenML’s NumpyMaterializer handles arrays:
import numpy as np
import os
from typing import Any, Type, Dict, Optional
from zenml.materializers import BaseMaterializer
from zenml.enums import ArtifactType, VisualizationType
from zenml.metadata.metadata_types import DType, MetadataType

class NumpyMaterializer(BaseMaterializer):
    """Materializer for numpy arrays."""
    
    ASSOCIATED_TYPES = (np.ndarray,)
    ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA
    
    def save(self, arr: np.ndarray) -> None:
        """Save array using numpy's native format."""
        filepath = os.path.join(self.uri, "data.npy")
        with self.artifact_store.open(filepath, "wb") as f:
            np.save(f, arr)
    
    def load(self, data_type: Type[Any]) -> np.ndarray:
        """Load array from storage."""
        filepath = os.path.join(self.uri, "data.npy")
        with self.artifact_store.open(filepath, "rb") as f:
            return np.load(f, allow_pickle=True)
    
    def extract_metadata(self, arr: np.ndarray) -> Dict[str, MetadataType]:
        """Extract array statistics."""
        if not np.issubdtype(arr.dtype, np.number):
            return {"shape": tuple(arr.shape)}
        
        return {
            "shape": tuple(arr.shape),
            "dtype": DType(arr.dtype.type),
            "mean": float(np.mean(arr)),
            "std": float(np.std(arr)),
            "min": float(np.min(arr)),
            "max": float(np.max(arr)),
        }
    
    def save_visualizations(
        self, arr: np.ndarray
    ) -> Dict[str, VisualizationType]:
        """Create visualizations based on array shape."""
        if not np.issubdtype(arr.dtype, np.number):
            return {}
        
        # For 1D arrays, create histogram
        if len(arr.shape) == 1:
            import matplotlib.pyplot as plt
            
            plt.hist(arr, bins=50)
            plt.xlabel('Value')
            plt.ylabel('Frequency')
            plt.title('Data Distribution')
            
            hist_path = os.path.join(self.uri, "histogram.png")
            with self.artifact_store.open(hist_path, "wb") as f:
                plt.savefig(f)
            plt.close()
            
            return {hist_path: VisualizationType.IMAGE}
        
        # For 2D/3D arrays with 3 or 4 channels, save as image
        if len(arr.shape) == 3 and arr.shape[2] in [3, 4]:
            from matplotlib.image import imsave
            
            img_path = os.path.join(self.uri, "image.png")
            with self.artifact_store.open(img_path, "wb") as f:
                imsave(f, arr)
            
            return {img_path: VisualizationType.IMAGE}
        
        return {}

Handling Collections

For collections (lists, dataframes), implement item loading:
from typing import Optional

class PandasMaterializer(BaseMaterializer):
    # ... save/load methods ...
    
    def get_item_count(self, data: Any) -> Optional[int]:
        """Return number of items in the artifact.
        
        Args:
            data: The pandas DataFrame or Series
            
        Returns:
            Number of items (rows for Series, columns for DataFrame)
        """
        import pandas as pd
        
        if isinstance(data, pd.Series):
            return len(data)
        elif isinstance(data, pd.DataFrame):
            return int(data.shape[1])  # Number of columns
        return None
    
    def load_item(self, data_type: Type[Any], index: int) -> Any:
        """Load a specific item from the collection.
        
        Args:
            data_type: Expected type
            index: Item index to load
            
        Returns:
            The specific item
        """
        # Load full data
        data = self.load(data_type)
        
        # Return specific item
        return data[index]
This enables mapping operations over artifacts:
@step
def process_row(row: pd.Series) -> dict:
    """Process a single row."""
    return row.to_dict()

@pipeline
def batch_processing_pipeline():
    df = load_dataframe()
    # Map over each row
    results = process_row.map(df)

Materialization Best Practices

Use Efficient Formats

Prefer binary formats (parquet, npy) over text formats (csv, json) for large data

Handle Errors Gracefully

Add try-except blocks with helpful error messages for missing dependencies

Version Compatibility

Support loading artifacts created with older materializer versions

Normalize Paths

Always replace backslashes with forward slashes for cross-platform compatibility

Handling Missing Dependencies

class CustomMaterializer(BaseMaterializer):
    def __init__(self, uri: str, artifact_store=None):
        super().__init__(uri, artifact_store)
        
        try:
            import special_library
            self.special_library_available = True
        except ImportError:
            self.special_library_available = False
            from zenml.logger import get_logger
            logger = get_logger(__name__)
            logger.warning(
                "special_library not installed. Install it with: "
                "pip install special_library"
            )
    
    def save(self, data: Any) -> None:
        if self.special_library_available:
            # Use efficient format
            self._save_with_special_library(data)
        else:
            # Fall back to standard format
            self._save_with_json(data)

Temporary Directories

For materializers that need local file operations:
class ComplexMaterializer(BaseMaterializer):
    def save(self, data: Any) -> None:
        """Save using temporary directory for intermediate files."""
        # Get temporary directory that cleans up after step finishes
        with self.get_temporary_directory(
            delete_at_exit=False,
            delete_after_step_execution=True,
        ) as temp_dir:
            # Do complex processing in temp directory
            temp_file = os.path.join(temp_dir, "intermediate.tmp")
            # ... processing ...
            
            # Copy final result to artifact store
            with open(temp_file, "rb") as src:
                final_path = os.path.join(self.uri, "data.bin")
                with self.artifact_store.open(final_path, "wb") as dst:
                    dst.write(src.read())

Next Steps

Custom Orchestrators

Build orchestrators for any execution backend

Dynamic Pipelines

Create pipelines with runtime-determined execution graphs

Resource Configuration

Configure CPU, memory, and GPU for your steps

Containerization

Package your code with Docker for reproducible execution

Build docs developers (and LLMs) love