Skip to main content

Overview

Production ML systems require continuous monitoring to ensure models perform as expected, detect degradation, and identify when retraining is needed. This guide covers comprehensive monitoring strategies for AQI prediction systems.

Monitoring Architecture

┌─────────────────┐
│   Prediction    │
│    Service      │
└────────┬────────┘

         ├──────────▶ Metrics (Prometheus)
         ├──────────▶ Logs (ELK/Loki)
         ├──────────▶ Traces (Jaeger)
         └──────────▶ Data Store


              ┌───────────────────────┐
              │   Monitoring Agent    │
              │  - Drift Detection    │
              │  - Performance Checks │
              │  - Data Quality       │
              └───────────┬───────────┘


              ┌───────────────────────┐
              │  Alerting System      │
              │  - Slack/PagerDuty    │
              │  - Email              │
              └───────────────────────┘

Key Metrics

Track prediction accuracy and model performance:
from prometheus_client import Gauge, Histogram, Counter
import numpy as np
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score

# Define metrics
prediction_mae = Gauge('aqi_prediction_mae', 'Mean Absolute Error')
prediction_rmse = Gauge('aqi_prediction_rmse', 'Root Mean Squared Error')
prediction_r2 = Gauge('aqi_prediction_r2', 'R² Score')
prediction_error_distribution = Histogram(
    'aqi_prediction_error',
    'Distribution of prediction errors',
    buckets=[0, 5, 10, 20, 30, 50, 100, float('inf')]
)

def track_predictions(y_true, y_pred):
    """Track prediction metrics."""
    mae = mean_absolute_error(y_true, y_pred)
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    r2 = r2_score(y_true, y_pred)
    
    prediction_mae.set(mae)
    prediction_rmse.set(rmse)
    prediction_r2.set(r2)
    
    # Track error distribution
    for true, pred in zip(y_true, y_pred):
        error = abs(true - pred)
        prediction_error_distribution.observe(error)

Data Quality Monitoring

Input Validation

import numpy as np
import pandas as pd
from typing import Dict, List, Tuple
from dataclasses import dataclass

@dataclass
class ValidationResult:
    is_valid: bool
    errors: List[str]
    warnings: List[str]

class InputValidator:
    """Validate input features for AQI prediction."""
    
    def __init__(self):
        self.feature_ranges = {
            'pm25': (0, 500),
            'pm10': (0, 600),
            'no2': (0, 200),
            'so2': (0, 100),
            'co': (0, 50),
            'o3': (0, 300),
            'temperature': (-50, 60),
            'humidity': (0, 100),
            'wind_speed': (0, 50),
            'pressure': (900, 1100)
        }
        
        self.missing_threshold = 0.1  # 10% missing values
    
    def validate(self, data: pd.DataFrame) -> ValidationResult:
        """Validate input data."""
        errors = []
        warnings = []
        
        # Check missing values
        missing_pct = data.isnull().sum() / len(data)
        for col, pct in missing_pct.items():
            if pct > self.missing_threshold:
                errors.append(f"{col}: {pct:.1%} missing values")
        
        # Check value ranges
        for col, (min_val, max_val) in self.feature_ranges.items():
            if col not in data.columns:
                errors.append(f"Missing required column: {col}")
                continue
            
            out_of_range = (
                (data[col] < min_val) | (data[col] > max_val)
            ).sum()
            
            if out_of_range > 0:
                pct = out_of_range / len(data)
                if pct > 0.01:  # More than 1%
                    errors.append(
                        f"{col}: {out_of_range} values out of range "
                        f"[{min_val}, {max_val}]"
                    )
                else:
                    warnings.append(
                        f"{col}: {out_of_range} values out of range"
                    )
        
        # Check for suspicious patterns
        for col in data.select_dtypes(include=[np.number]).columns:
            # Check for constant values
            if data[col].nunique() == 1:
                warnings.append(f"{col}: constant value {data[col].iloc[0]}")
            
            # Check for suspicious zeros
            zero_pct = (data[col] == 0).sum() / len(data)
            if zero_pct > 0.5:
                warnings.append(f"{col}: {zero_pct:.1%} zero values")
        
        is_valid = len(errors) == 0
        return ValidationResult(is_valid, errors, warnings)
Set up alerts for critical data quality issues. Invalid inputs can lead to unreliable predictions.

Prediction Drift Detection

Statistical Drift Detection

Population Stability Index for distribution shift:
import numpy as np
import pandas as pd

def calculate_psi(expected: np.ndarray, actual: np.ndarray, bins=10) -> float:
    """Calculate Population Stability Index."""
    # Create bins based on expected distribution
    breakpoints = np.percentile(expected, np.linspace(0, 100, bins + 1))
    breakpoints = np.unique(breakpoints)
    
    # Calculate distribution for each dataset
    expected_percents = np.histogram(expected, breakpoints)[0] / len(expected)
    actual_percents = np.histogram(actual, breakpoints)[0] / len(actual)
    
    # Avoid division by zero
    expected_percents = np.where(expected_percents == 0, 0.0001, expected_percents)
    actual_percents = np.where(actual_percents == 0, 0.0001, actual_percents)
    
    # Calculate PSI
    psi = np.sum((actual_percents - expected_percents) * 
                 np.log(actual_percents / expected_percents))
    
    return psi

def interpret_psi(psi_value: float) -> str:
    """Interpret PSI value."""
    if psi_value < 0.1:
        return "No significant change"
    elif psi_value < 0.25:
        return "Small change detected"
    else:
        return "Major shift detected - retrain model"

# Example usage
reference_predictions = model.predict(X_reference)
current_predictions = model.predict(X_current)

psi = calculate_psi(reference_predictions, current_predictions)
print(f"PSI: {psi:.4f} - {interpret_psi(psi)}")
Use multiple drift detection methods for robust monitoring. Different methods detect different types of drift.

Performance Degradation Alerts

Alert Configuration

import smtplib
from email.mime.text import MIMEText
from typing import List, Dict
import requests
from dataclasses import dataclass

@dataclass
class Alert:
    severity: str  # 'warning', 'error', 'critical'
    title: str
    message: str
    metrics: Dict

class AlertManager:
    """Manage alerts for model monitoring."""
    
    def __init__(self, config: Dict):
        self.config = config
        self.alert_history = []
    
    def send_email_alert(self, alert: Alert):
        """Send email alert."""
        msg = MIMEText(f"{alert.message}\n\nMetrics: {alert.metrics}")
        msg['Subject'] = f"[{alert.severity.upper()}] {alert.title}"
        msg['From'] = self.config['email_from']
        msg['To'] = ', '.join(self.config['email_to'])
        
        with smtplib.SMTP(self.config['smtp_host']) as server:
            server.send_message(msg)
    
    def send_slack_alert(self, alert: Alert):
        """Send Slack alert."""
        color_map = {
            'warning': '#FFA500',
            'error': '#FF0000',
            'critical': '#8B0000'
        }
        
        payload = {
            'attachments': [{
                'color': color_map[alert.severity],
                'title': alert.title,
                'text': alert.message,
                'fields': [
                    {'title': k, 'value': str(v), 'short': True}
                    for k, v in alert.metrics.items()
                ]
            }]
        }
        
        requests.post(self.config['slack_webhook'], json=payload)
    
    def send_pagerduty_alert(self, alert: Alert):
        """Send PagerDuty alert."""
        if alert.severity != 'critical':
            return  # Only critical alerts to PagerDuty
        
        payload = {
            'routing_key': self.config['pagerduty_key'],
            'event_action': 'trigger',
            'payload': {
                'summary': alert.title,
                'severity': 'critical',
                'source': 'aqi-predictor',
                'custom_details': alert.metrics
            }
        }
        
        requests.post(
            'https://events.pagerduty.com/v2/enqueue',
            json=payload
        )
    
    def trigger_alert(self, alert: Alert):
        """Trigger alert through configured channels."""
        self.alert_history.append(alert)
        
        # Send through all channels based on severity
        if alert.severity in ['error', 'critical']:
            self.send_email_alert(alert)
            self.send_slack_alert(alert)
        
        if alert.severity == 'critical':
            self.send_pagerduty_alert(alert)

Logging and Tracing

Structured Logging

import logging
import json
from datetime import datetime
from typing import Dict, Any

class StructuredLogger:
    """Structured logging for ML predictions."""
    
    def __init__(self, name: str):
        self.logger = logging.getLogger(name)
        self.logger.setLevel(logging.INFO)
        
        # JSON formatter
        handler = logging.StreamHandler()
        handler.setFormatter(self.JSONFormatter())
        self.logger.addHandler(handler)
    
    class JSONFormatter(logging.Formatter):
        def format(self, record):
            log_data = {
                'timestamp': datetime.utcnow().isoformat(),
                'level': record.levelname,
                'message': record.getMessage(),
                'logger': record.name
            }
            
            # Add extra fields
            if hasattr(record, 'extra'):
                log_data.update(record.extra)
            
            return json.dumps(log_data)
    
    def log_prediction(self, request_id: str, features: Dict, 
                      prediction: float, metadata: Dict):
        """Log prediction with full context."""
        self.logger.info(
            'Prediction made',
            extra={
                'request_id': request_id,
                'features': features,
                'prediction': prediction,
                'model_version': metadata.get('model_version'),
                'processing_time_ms': metadata.get('processing_time'),
                'confidence': metadata.get('confidence')
            }
        )
    
    def log_error(self, request_id: str, error: Exception, context: Dict):
        """Log error with context."""
        self.logger.error(
            f'Prediction error: {str(error)}',
            extra={
                'request_id': request_id,
                'error_type': type(error).__name__,
                'context': context
            }
        )

# Example usage
logger = StructuredLogger('aqi_predictor')

logger.log_prediction(
    request_id='abc-123',
    features={'pm25': 35.5, 'temperature': 22.5},
    prediction=75.3,
    metadata={'model_version': 'v1.2', 'processing_time': 45.2}
)

Distributed Tracing

from opentelemetry import trace
from opentelemetry.exporter.jaeger.thrift import JaegerExporter
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
import time

# Setup tracing
trace.set_tracer_provider(TracerProvider())
jaeger_exporter = JaegerExporter(
    agent_host_name='localhost',
    agent_port=6831,
)
trace.get_tracer_provider().add_span_processor(
    BatchSpanProcessor(jaeger_exporter)
)

tracer = trace.get_tracer(__name__)

def predict_with_tracing(features):
    """Make prediction with distributed tracing."""
    with tracer.start_as_current_span('predict') as span:
        span.set_attribute('model.version', 'v1.2')
        span.set_attribute('features.count', len(features))
        
        # Preprocess
        with tracer.start_as_current_span('preprocess'):
            processed = preprocess(features)
        
        # Inference
        with tracer.start_as_current_span('inference'):
            prediction = model.predict(processed)
        
        # Postprocess
        with tracer.start_as_current_span('postprocess'):
            result = postprocess(prediction)
        
        span.set_attribute('prediction.value', result)
        
        return result
Structured logging and distributed tracing are essential for debugging production issues and understanding system behavior.

Dashboard Setup

Grafana Dashboard

{
  "dashboard": {
    "title": "AQI Predictor Monitoring",
    "panels": [
      {
        "title": "Prediction MAE",
        "targets": [{
          "expr": "aqi_prediction_mae"
        }],
        "alert": {
          "conditions": [{
            "evaluator": {"params": [15], "type": "gt"},
            "query": {"params": ["A", "5m", "now"]}
          }]
        }
      },
      {
        "title": "Request Rate",
        "targets": [{
          "expr": "rate(aqi_api_requests_total[5m])"
        }]
      },
      {
        "title": "P95 Latency",
        "targets": [{
          "expr": "histogram_quantile(0.95, aqi_api_request_duration_seconds)"
        }]
      },
      {
        "title": "Error Rate",
        "targets": [{
          "expr": "rate(prediction_errors_total[5m])"
        }]
      }
    ]
  }
}

Automated Retraining

class RetrainingTrigger:
    """Trigger model retraining based on monitoring signals."""
    
    def __init__(self, config: Dict):
        self.config = config
        self.drift_detector = OnlineDriftDetector()
        self.performance_window = []
    
    def should_retrain(self, metrics: Dict) -> Tuple[bool, str]:
        """Determine if retraining is needed."""
        reasons = []
        
        # Check performance degradation
        if metrics['mae'] > self.config['max_mae']:
            reasons.append(f"MAE {metrics['mae']:.2f} exceeds threshold")
        
        # Check drift
        if metrics.get('drift_detected'):
            reasons.append("Data drift detected")
        
        # Check staleness
        days_since_training = metrics.get('days_since_training', 0)
        if days_since_training > self.config['max_model_age_days']:
            reasons.append(f"Model is {days_since_training} days old")
        
        # Check error rate trend
        self.performance_window.append(metrics['mae'])
        if len(self.performance_window) > 7:
            self.performance_window.pop(0)
            trend = np.polyfit(range(7), self.performance_window, 1)[0]
            if trend > 0.5:  # Increasing error trend
                reasons.append("Performance degrading over time")
        
        should_retrain = len(reasons) > 0
        reason_str = "; ".join(reasons)
        
        return should_retrain, reason_str
    
    def trigger_retraining_pipeline(self, reason: str):
        """Trigger automated retraining pipeline."""
        print(f"Triggering retraining: {reason}")
        
        # Call training pipeline (e.g., Airflow, Kubeflow)
        # requests.post('http://airflow:8080/api/v1/dags/aqi_training/dagRuns')
        pass

Best Practices

  • Monitor performance, system, and business metrics
  • Set up alerts for critical degradation
  • Use multiple drift detection methods
  • Implement gradual rollout for model updates
  • Maintain shadow deployments for testing
  • Log all predictions with full feature context
  • Store ground truth labels when available
  • Track feature distributions over time
  • Monitor data quality continuously
  • Implement sampling for high-volume systems
  • Set appropriate thresholds to avoid false alarms
  • Use escalating severity levels
  • Implement alert aggregation
  • Review and adjust thresholds regularly
  • Document response procedures
  • Analyze monitoring data for insights
  • A/B test model improvements
  • Automate retraining when possible
  • Conduct regular model audits
  • Document lessons learned

Next Steps

Build docs developers (and LLMs) love