Skip to main content
The Model Control Plane is ZenML’s system for managing the entire lifecycle of machine learning models. It provides a centralized registry that links models with their artifacts, metadata, pipeline runs, and deployment information.

What is a Model?

In ZenML, a “Model” is not just the trained weights or serialized object. It’s a namespace that groups together everything related to an ML model:
  • Model artifacts (trained models, preprocessors, tokenizers)
  • Data artifacts (training data, test data, feature sets)
  • Metadata (metrics, hyperparameters, training info)
  • Pipeline runs (training runs, evaluation runs)
  • Deployment information (where and how the model is deployed)
  • Model versions (different iterations and stages)
Think of a Model as a project that contains multiple versions, where each version represents a specific iteration trained at a specific time.

Model vs. Model Version

  • Model: A named entity (e.g., “customer_churn_predictor”)
  • Model Version: A specific iteration of that model (e.g., version 1, 2, 3, or “production”)
from zenml import Model

# Model name
model = Model(name="customer_churn_predictor")

# Specific version
model_v1 = Model(name="customer_churn_predictor", version="1")

# Named stage
model_prod = Model(name="customer_churn_predictor", version="production")

Creating Models

Models are typically created by attaching them to pipelines:

In Pipeline Decorator

from zenml import pipeline, Model

@pipeline(
    model=Model(
        name="fraud_detection",
        version="v1.0",
        description="Credit card fraud detection model",
        tags=["classification", "production"],
        license="Apache 2.0",
        audience="Risk management team",
        use_cases="Detect fraudulent credit card transactions",
        limitations="May have reduced accuracy on international transactions",
    )
)
def training_pipeline():
    """Pipeline runs are automatically linked to this model version."""
    data = load_data()
    model = train_model(data)
    metrics = evaluate(model)
    return model

Dynamic Model Versioning

Use placeholders for dynamic version names:
from zenml import pipeline, Model

@pipeline(
    model=Model(
        name="recommendation_engine",
        version="v{date}_{time}",  # e.g., "v20240309_143022"
    ),
    substitutions={
        "custom_tag": "production"
    }
)
def dynamic_versioning_pipeline():
    train_model()

Creating Models Programmatically

from zenml.client import Client
from zenml.models import ModelRequest

client = Client()

# Create a new model
model = client.create_model(
    ModelRequest(
        name="sentiment_analyzer",
        description="BERT-based sentiment analysis",
        tags=["nlp", "bert", "sentiment"],
    )
)

print(f"Created model: {model.id}")

Model Versions

Model versions represent specific iterations of a model.

Automatic Versioning

from zenml import pipeline, Model

@pipeline(
    model=Model(
        name="price_predictor",
        # No version specified - auto-increments: 1, 2, 3, ...
    )
)
def auto_versioned_pipeline():
    train_model()

# First run: creates version 1
# Second run: creates version 2
# Third run: creates version 3

Named Versions

@pipeline(
    model=Model(
        name="price_predictor",
        version="baseline",  # Custom version name
    )
)
def named_version_pipeline():
    train_model()

Stage-based Versions

from zenml.enums import ModelStages
from zenml import Model

# Reference the production version
model = Model(
    name="price_predictor",
    version=ModelStages.PRODUCTION
)

# Other stages: STAGING, ARCHIVED, LATEST

Accessing Model Context

Access model information within steps:
from zenml import step, get_step_context

@step
def training_step(data: pd.DataFrame) -> Any:
    """Step with access to model context."""
    context = get_step_context()
    
    # Access current model version
    model_version = context.model
    
    print(f"Model: {model_version.name}")
    print(f"Version: {model_version.version}")
    print(f"Version number: {model_version.number}")
    
    # Train model
    trained_model = train(data)
    
    return trained_model

Linking Artifacts to Models

Artifacts are automatically linked when using model context:
from zenml import pipeline, step, Model
from typing import Annotated
from zenml.artifacts import ArtifactConfig
from zenml.enums import ArtifactType
import pandas as pd

@step
def train_step(data: pd.DataFrame) -> Annotated[
    Any,
    ArtifactConfig(
        name="trained_model",
        artifact_type=ArtifactType.MODEL
    )
]:
    """Train and return model artifact."""
    model = train_model(data)
    return model

@pipeline(
    model=Model(name="churn_predictor", version="v1")
)
def training_pipeline():
    data = load_data()
    model = train_step(data)  # Automatically linked to model version

Loading Model Artifacts

Retrieve artifacts from specific model versions:
from zenml.client import Client
from zenml.enums import ModelStages

client = Client()

# Get production model version
model = client.get_model_version(
    model_name_or_id="churn_predictor",
    model_version_name_or_number_or_id=ModelStages.PRODUCTION
)

# Load specific artifact from this version
trained_model = model.load_artifact("trained_model")

# Or get all model artifacts
artifacts = model.model_artifacts
for artifact in artifacts:
    print(f"Artifact: {artifact.name}")

Loading in Pipelines

from zenml import pipeline, step, Model
from zenml.enums import ModelStages

@step
def inference_step() -> dict:
    """Load production model for inference."""
    context = get_step_context()
    
    # Load artifact from production model
    model = context.model.load_artifact("trained_model")
    
    # Run inference
    predictions = model.predict(test_data)
    return {"predictions": predictions}

@pipeline(
    model=Model(
        name="churn_predictor",
        version=ModelStages.PRODUCTION  # Reference production version
    )
)
def inference_pipeline():
    predictions = inference_step()

Model Metadata

Attach rich metadata to models:
from zenml import step, log_model_metadata

@step
def training_with_metadata(data: pd.DataFrame) -> Any:
    """Train model and log metadata."""
    model = train_model(data)
    
    # Log metadata to current model version
    log_model_metadata(
        metadata={
            "accuracy": 0.95,
            "precision": 0.93,
            "recall": 0.91,
            "f1_score": 0.92,
            "training_samples": len(data),
            "features": list(data.columns),
            "hyperparameters": {
                "learning_rate": 0.001,
                "batch_size": 32,
                "epochs": 10
            },
            "framework": "scikit-learn",
            "framework_version": "1.3.0"
        }
    )
    
    return model

Model Stages

Promote models through lifecycle stages:
from zenml.client import Client
from zenml.enums import ModelStages

client = Client()

# Get a specific model version
model_version = client.get_model_version(
    model_name_or_id="churn_predictor",
    model_version_name_or_number_or_id="3"
)

# Promote to staging
model_version.set_stage(ModelStages.STAGING)

# After validation, promote to production
model_version.set_stage(ModelStages.PRODUCTION)

# Archive old version
old_version = client.get_model_version(
    model_name_or_id="churn_predictor",
    model_version_name_or_number_or_id="2"
)
old_version.set_stage(ModelStages.ARCHIVED)

Stage-based Promotion Pipeline

from zenml import pipeline, step
from zenml.client import Client
from zenml.enums import ModelStages

@step
def evaluate_and_promote(model: Any, metrics: dict) -> None:
    """Promote model if it meets criteria."""
    if metrics["accuracy"] >= 0.95:
        context = get_step_context()
        model_version = context.model
        
        # Promote to production
        model_version.set_stage(ModelStages.PRODUCTION)
        print(f"Promoted model version {model_version.number} to production")
    else:
        print("Model did not meet promotion criteria")

@pipeline(
    model=Model(name="churn_predictor")
)
def training_with_promotion():
    data = load_data()
    model = train_model(data)
    metrics = evaluate_model(model)
    evaluate_and_promote(model, metrics)

Model Registry Integration

When a model registry is in your stack, models are automatically registered:
from zenml import pipeline, Model

@pipeline(
    model=Model(
        name="recommendation_model",
        save_models_to_registry=True  # Default is True
    )
)
def pipeline_with_registry():
    """Models automatically registered to model registry."""
    model = train_model()
    return model

Querying Models

Find and filter models:
from zenml.client import Client
from zenml.enums import ModelStages

client = Client()

# List all models
models = client.list_models()
for model in models:
    print(f"Model: {model.name}")

# Get specific model
model = client.get_model("churn_predictor")

# List versions of a model
versions = client.list_model_versions(
    model_name_or_id="churn_predictor"
)

# Filter by stage
production_versions = client.list_model_versions(
    model_name_or_id="churn_predictor",
    stage=ModelStages.PRODUCTION
)

# Get latest version
latest = client.get_model_version(
    model_name_or_id="churn_predictor",
    model_version_name_or_number_or_id=ModelStages.LATEST
)

Model Lineage

Track the complete lineage of model versions:
from zenml.client import Client

client = Client()

# Get model version
model_version = client.get_model_version(
    model_name_or_id="churn_predictor",
    model_version_name_or_number_or_id="production"
)

# Get all artifacts linked to this model version
model_artifacts = model_version.model_artifacts
data_artifacts = model_version.data_artifacts

print("Model Artifacts:")
for artifact in model_artifacts:
    print(f"  - {artifact.name}: {artifact.version}")

print("Data Artifacts:")
for artifact in data_artifacts:
    print(f"  - {artifact.name}: {artifact.version}")

# Get pipeline runs that produced this version
pipeline_runs = model_version.pipeline_runs
for run in pipeline_runs:
    print(f"Run: {run.name} at {run.created}")

Model Deletion

Delete models or specific versions:
from zenml.client import Client

client = Client()

# Delete a specific version
client.delete_model_version(
    model_name_or_id="churn_predictor",
    version_name_or_id="v0.1"
)

# Delete entire model (all versions)
client.delete_model("old_experimental_model")
Deleting a model version doesn’t delete the associated artifacts. Artifacts remain in the artifact store for lineage tracking.

Best Practices

Semantic Naming

Use descriptive model names that indicate the use case: fraud_detector, product_recommender.

Rich Metadata

Log comprehensive metadata including metrics, hyperparameters, data info, and training details.

Stage Management

Use stages to manage model lifecycle: staging for validation, production for deployed models.

Version Strategy

Use auto-versioning during experimentation, semantic versioning for releases.

Code Reference

  • Model class: src/zenml/model/model.py:47
  • Model stages enum: src/zenml/enums.py
  • Client model methods: src/zenml/client.py

Build docs developers (and LLMs) love