Skip to main content

Overview

The parity_check.py script validates that ONNX models produce numerically equivalent predictions to the original scikit-learn model, preventing silent regressions during deployment.

Prerequisites

1

Train scikit-learn model

python -m src.train
2

Export to ONNX

python deployment/export_onnx.py
3

(Optional) Quantize model

If testing quantized model, run quantization first:
python deployment/quantize_onnx.py

Usage

python deployment/parity_check.py [OPTIONS]

Command-Line Arguments

--abs-tol
float
default:"0.04"
Maximum allowed absolute difference between predictions.When to adjust:
  • Tighten (0.01-0.02) for high-stakes predictions
  • Widen (0.06-0.10) after quantization or with float32 precision
--mean-tol
float
default:"0.01"
Maximum allowed mean absolute difference across all predictions.When to adjust:
  • Tighten (0.005) to catch systematic bias
  • Widen (0.02) if few outliers are acceptable
--batch-size
int
default:"256"
Number of test samples to validate. Larger batches increase coverage but take longer.Recommended:
  • Development: 256-512 samples
  • CI pipeline: 1000+ samples

How It Works

1. Load Models and Data

deployment/parity_check.py
config = load_config("config.yaml")
model_path = Path(config["artifacts"]["model_dir"]) / config["artifacts"]["model_file"]
onnx_path = Path("artifacts/model.onnx")

model = joblib.load(model_path)  # scikit-learn
sess = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"])  # ONNX

df = load_dataset(config)
_, X_test, _, _ = split_data(df, config)
X = X_test.head(args.batch_size)

2. Generate Predictions

Scikit-learn:
deployment/parity_check.py
sk = model.predict_proba(X)[:, 1]  # Probability of positive class
ONNX:
deployment/parity_check.py
onx_outputs = sess.run(None, _to_onnx_inputs(X))
ox = _extract_proba(onx_outputs)  # Extract P(y=1) from ONNX outputs
The _extract_proba function handles different ONNX output formats:
  • 2D array with shape (n_samples, 2) → Extract column 1
  • List of dicts [{0: p0, 1: p1}, ...] → Extract value for key 1

3. Compute Differences

deployment/parity_check.py
abs_diff = np.abs(sk - ox)
report = {
    "samples": int(len(sk)),
    "max_abs_diff": float(abs_diff.max()),
    "mean_abs_diff": float(abs_diff.mean()),
    "abs_tol": args.abs_tol,
    "mean_tol": args.mean_tol,
    "passed": bool(abs_diff.max() <= args.abs_tol and abs_diff.mean() <= args.mean_tol),
}

4. Write Report and Exit

deployment/parity_check.py
Path("artifacts/parity_report.json").write_text(json.dumps(report, indent=2))
print(json.dumps(report, indent=2))

if not report["passed"]:
    raise SystemExit(1)  # Fail with exit code 1

Output Artifacts

artifacts/parity_report.json

{
  "samples": 256,
  "max_abs_diff": 0.0234,
  "mean_abs_diff": 0.0067,
  "abs_tol": 0.04,
  "mean_tol": 0.01,
  "passed": true
}
samples
int
Number of test samples compared
max_abs_diff
float
Maximum absolute difference across all predictions: max(|sklearn - onnx|)Critical for: Catching worst-case outliers near decision boundary
mean_abs_diff
float
Mean absolute difference across all predictions: mean(|sklearn - onnx|)Critical for: Detecting systematic bias or drift
abs_tol
float
Configured maximum allowed absolute difference (from --abs-tol)
mean_tol
float
Configured maximum allowed mean difference (from --mean-tol)
passed
bool
true if both max_abs_diff <= abs_tol and mean_abs_diff <= mean_tolScript exits with code 0 if passed, 1 if failed

Example Usage

python deployment/parity_check.py --abs-tol 0.02 --mean-tol 0.005 --batch-size 1000

Expected Output

Passing Validation

$ python deployment/parity_check.py --abs-tol 0.04 --mean-tol 0.01
{
  "samples": 256,
  "max_abs_diff": 0.0234,
  "mean_abs_diff": 0.0067,
  "abs_tol": 0.04,
  "mean_tol": 0.01,
  "passed": true
}
$ echo $?
0

Failing Validation

$ python deployment/parity_check.py --abs-tol 0.02 --mean-tol 0.01
{
  "samples": 256,
  "max_abs_diff": 0.0521,
  "mean_abs_diff": 0.0134,
  "abs_tol": 0.02,
  "mean_tol": 0.01,
  "passed": false
}
$ echo $?
1

Validation Workflow

1

Run parity check

Execute with default or custom tolerances:
python deployment/parity_check.py --abs-tol 0.04 --mean-tol 0.01
2

Inspect report

Check artifacts/parity_report.json for metrics:
  • passed: true → Proceed to deployment
  • passed: false → Investigate root cause (see below)
3

Debug failures (if needed)

4

Re-validate after fixes

After addressing issues, re-run parity check to confirm

Tolerance Guidelines

--abs-tol 0.01 --mean-tol 0.005
Use when:
  • Model decisions have significant business/regulatory impact
  • Near-zero false positive/negative tolerance
  • Predictions are used for ranking or calibration
Trade-off: May fail after quantization or with FP32 precision

Debugging Parity Failures

1. Identify Outlier Samples

Modify parity_check.py to log samples with large differences:
abs_diff = np.abs(sk - ox)
outliers = np.where(abs_diff > args.abs_tol)[0]
for idx in outliers[:5]:  # Show top 5
    print(f"Sample {idx}: sklearn={sk[idx]:.4f}, onnx={ox[idx]:.4f}, diff={abs_diff[idx]:.4f}")
    print(X.iloc[idx])

2. Check Feature Preprocessing

Symptom: Large differences on samples with categorical features.Cause: OneHotEncoder or OrdinalEncoder handles unknown categories differently.Fix:
  • Ensure handle_unknown='infrequent_if_exist' or 'ignore' in sklearn
  • Verify ONNX conversion preserves encoding logic
Symptom: Differences on samples with NaN values.Cause: SimpleImputer strategy not preserved in ONNX.Fix:
  • Check imputer strategy (mean, median, constant)
  • Verify ONNX graph includes imputation nodes
Symptom: Systematic bias (high mean_abs_diff, moderate max_abs_diff).Cause: StandardScaler fit on different data or not saved correctly.Fix:
  • Ensure scaler is part of sklearn pipeline
  • Re-export ONNX after re-training

3. Inspect ONNX Graph

Visualize ONNX graph to identify missing or incorrect nodes:
pip install netron
netron artifacts/model.onnx
Check for:
  • Missing preprocessing nodes (imputer, scaler, encoder)
  • Incorrect input types (FloatTensorType vs StringTensorType)
  • Operator version mismatches

4. Compare Float Precision

Test if FP64 → FP32 conversion causes drift:
import numpy as np

# Force FP32 in sklearn
X_fp32 = X.astype(np.float32)
sk_fp32 = model.predict_proba(X_fp32)[:, 1]
print(f"FP64 vs FP32 diff: {np.abs(sk - sk_fp32).max():.6f}")
If difference is large (>0.01), model is sensitive to precision.

5. Validate Quantized Model

If parity fails after quantization:
# Compare non-quantized ONNX first
python deployment/parity_check.py --abs-tol 0.04 --mean-tol 0.01

# If passes, quantization is the issue
# Widen tolerances or skip quantization
python deployment/parity_check.py --abs-tol 0.06 --mean-tol 0.02

Integration with CI/CD

Use parity check as a deployment gate:
.github/workflows/deploy.yml
- name: Validate ONNX parity
  run: |
    python deployment/export_onnx.py
    python deployment/parity_check.py --abs-tol 0.04 --mean-tol 0.01 --batch-size 1000
  
- name: Validate quantized parity (if quantization enabled)
  run: |
    python deployment/quantize_onnx.py
    python deployment/parity_check.py --abs-tol 0.06 --mean-tol 0.02 --batch-size 1000
Parity check exits with code 1 on failure, causing CI pipeline to halt. This prevents deploying models with silent regressions.

Common Failure Scenarios

Diagnosis: Few outlier predictions with large errors.Action:
  1. Inspect outlier samples (see Identify Outlier Samples)
  2. Check if outliers have unusual feature values (extreme, rare categories)
  3. Widen --abs-tol if outliers are acceptable, or fix preprocessing
Diagnosis: Systematic bias across all predictions.Action:
  1. Check for feature scaling drift (StandardScaler fit on wrong data)
  2. Verify preprocessing pipeline is identical between sklearn and ONNX
  3. Re-export ONNX after re-training
Diagnosis: Major preprocessing mismatch or incorrect ONNX conversion.Action:
  1. Verify sklearn pipeline structure matches ONNX graph (use netron)
  2. Check for custom transformers not supported by skl2onnx
  3. Re-export ONNX with correct initial_types
  4. If after quantization, check if model is sensitive to INT8 precision
Diagnosis: Rare edge cases or data distribution shift.Action:
  1. Increase --batch-size to 1000+ samples
  2. Stratify test set to ensure coverage of rare categories
  3. Inspect samples where failure occurs

Best Practices

Run parity check before and after quantization with different tolerances
Use larger batch sizes (1000+) in CI pipelines for comprehensive coverage
Store parity_report.json as deployment artifact for audit trail
Re-tune tolerances when feature engineering changes materially
Combine parity validation with A/B testing in production for full confidence
Passing parity check does NOT guarantee production readiness. Always monitor model performance in production and compare against baseline.

Next Steps

CPU Inference

Benchmark inference performance after validation

Deployment Overview

Return to deployment workflow overview

Build docs developers (and LLMs) love