Understanding ZenML’s Model Control Plane for ML model management
The Model Control Plane is ZenML’s system for managing the entire lifecycle of machine learning models. It provides a centralized registry that links models with their artifacts, metadata, pipeline runs, and deployment information.
Model: A named entity (e.g., “customer_churn_predictor”)
Model Version: A specific iteration of that model (e.g., version 1, 2, 3, or “production”)
from zenml import Model# Model namemodel = Model(name="customer_churn_predictor")# Specific versionmodel_v1 = Model(name="customer_churn_predictor", version="1")# Named stagemodel_prod = Model(name="customer_churn_predictor", version="production")
from zenml import pipeline, Model@pipeline( model=Model( name="fraud_detection", version="v1.0", description="Credit card fraud detection model", tags=["classification", "production"], license="Apache 2.0", audience="Risk management team", use_cases="Detect fraudulent credit card transactions", limitations="May have reduced accuracy on international transactions", ))def training_pipeline(): """Pipeline runs are automatically linked to this model version.""" data = load_data() model = train_model(data) metrics = evaluate(model) return model
from zenml import pipeline, Model@pipeline( model=Model( name="price_predictor", # No version specified - auto-increments: 1, 2, 3, ... ))def auto_versioned_pipeline(): train_model()# First run: creates version 1# Second run: creates version 2# Third run: creates version 3
from zenml.enums import ModelStagesfrom zenml import Model# Reference the production versionmodel = Model( name="price_predictor", version=ModelStages.PRODUCTION)# Other stages: STAGING, ARCHIVED, LATEST
from zenml.client import Clientfrom zenml.enums import ModelStagesclient = Client()# Get production model versionmodel = client.get_model_version( model_name_or_id="churn_predictor", model_version_name_or_number_or_id=ModelStages.PRODUCTION)# Load specific artifact from this versiontrained_model = model.load_artifact("trained_model")# Or get all model artifactsartifacts = model.model_artifactsfor artifact in artifacts: print(f"Artifact: {artifact.name}")
from zenml import pipeline, step, Modelfrom zenml.enums import ModelStages@stepdef inference_step() -> dict: """Load production model for inference.""" context = get_step_context() # Load artifact from production model model = context.model.load_artifact("trained_model") # Run inference predictions = model.predict(test_data) return {"predictions": predictions}@pipeline( model=Model( name="churn_predictor", version=ModelStages.PRODUCTION # Reference production version ))def inference_pipeline(): predictions = inference_step()
from zenml.client import Clientfrom zenml.enums import ModelStagesclient = Client()# Get a specific model versionmodel_version = client.get_model_version( model_name_or_id="churn_predictor", model_version_name_or_number_or_id="3")# Promote to stagingmodel_version.set_stage(ModelStages.STAGING)# After validation, promote to productionmodel_version.set_stage(ModelStages.PRODUCTION)# Archive old versionold_version = client.get_model_version( model_name_or_id="churn_predictor", model_version_name_or_number_or_id="2")old_version.set_stage(ModelStages.ARCHIVED)
from zenml import pipeline, stepfrom zenml.client import Clientfrom zenml.enums import ModelStages@stepdef evaluate_and_promote(model: Any, metrics: dict) -> None: """Promote model if it meets criteria.""" if metrics["accuracy"] >= 0.95: context = get_step_context() model_version = context.model # Promote to production model_version.set_stage(ModelStages.PRODUCTION) print(f"Promoted model version {model_version.number} to production") else: print("Model did not meet promotion criteria")@pipeline( model=Model(name="churn_predictor"))def training_with_promotion(): data = load_data() model = train_model(data) metrics = evaluate_model(model) evaluate_and_promote(model, metrics)
When a model registry is in your stack, models are automatically registered:
from zenml import pipeline, Model@pipeline( model=Model( name="recommendation_model", save_models_to_registry=True # Default is True ))def pipeline_with_registry(): """Models automatically registered to model registry.""" model = train_model() return model
from zenml.client import Clientfrom zenml.enums import ModelStagesclient = Client()# List all modelsmodels = client.list_models()for model in models: print(f"Model: {model.name}")# Get specific modelmodel = client.get_model("churn_predictor")# List versions of a modelversions = client.list_model_versions( model_name_or_id="churn_predictor")# Filter by stageproduction_versions = client.list_model_versions( model_name_or_id="churn_predictor", stage=ModelStages.PRODUCTION)# Get latest versionlatest = client.get_model_version( model_name_or_id="churn_predictor", model_version_name_or_number_or_id=ModelStages.LATEST)
from zenml.client import Clientclient = Client()# Get model versionmodel_version = client.get_model_version( model_name_or_id="churn_predictor", model_version_name_or_number_or_id="production")# Get all artifacts linked to this model versionmodel_artifacts = model_version.model_artifactsdata_artifacts = model_version.data_artifactsprint("Model Artifacts:")for artifact in model_artifacts: print(f" - {artifact.name}: {artifact.version}")print("Data Artifacts:")for artifact in data_artifacts: print(f" - {artifact.name}: {artifact.version}")# Get pipeline runs that produced this versionpipeline_runs = model_version.pipeline_runsfor run in pipeline_runs: print(f"Run: {run.name} at {run.created}")
from zenml.client import Clientclient = Client()# Delete a specific versionclient.delete_model_version( model_name_or_id="churn_predictor", version_name_or_id="v0.1")# Delete entire model (all versions)client.delete_model("old_experimental_model")
Deleting a model version doesn’t delete the associated artifacts. Artifacts remain in the artifact store for lineage tracking.