Skip to main content
The system uses a centralized YAML configuration file (config.yaml) to control all aspects of training, deployment, and monitoring. This configuration-driven approach ensures reproducibility and eliminates hard-coded parameters.

Configuration File Structure

The complete config.yaml structure:
config.yaml
seed: 42

data:
  path: ml_datasource.csv
  target: purchased
  test_size: 0.2

features:
  epsilon: 1.0e-06
  engagement:
    minutes_watched_weight: 0.6
    days_on_platform_weight: 0.3
    courses_started_weight: 10.0

preprocessing:
  outlier_factor: 1.5
  numeric_imputer: median
  categorical_imputer: most_frequent

models:
  logistic_regression:
    max_iter: 2000
  knn:
    n_neighbors: 7
  svm:
    C: 1.0
    kernel: rbf
    gamma: scale
  decision_tree:
    max_depth: 8
    min_samples_leaf: 10
  random_forest:
    n_estimators: 400
    min_samples_leaf: 2

cv:
  n_splits: 5

business:
  target_precision: 0.9

artifacts:
  model_dir: artifacts
  model_file: best_model.joblib
  threshold_file: threshold.txt
  metrics_file: metrics.json
  drift_baseline_file: drift_baseline.json
  lineage_file: lineage.json

monitoring:
  prediction_log_file: artifacts/prediction_log.jsonl
  drift_min_samples: 50
  drift_zscore_threshold: 3.0
  drift_min_features: 2
  class_rate_shift_threshold: 0.1

benchmarking:
  repeated_runs: 10
  batch_size: 256
  parity_abs_tolerance: 0.04
  parity_mean_tolerance: 0.01

Configuration Sections

Seed

Controls all random number generation for reproducibility.
seed: 42
seed
integer
required
Global random seed for NumPy, Python random, and scikit-learn
Usage:
src/data.py
def set_global_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)

Data

Defines dataset location, target variable, and train/test split.
data:
  path: ml_datasource.csv
  target: purchased
  test_size: 0.2
data.path
string
required
Path to the CSV dataset file
data.target
string
required
Name of the target column for prediction
data.test_size
float
required
Fraction of data reserved for testing (0.0 to 1.0)
Usage:
src/data.py
def load_dataset(config: dict) -> pd.DataFrame:
    data_path = config["data"]["path"]
    df = pd.read_csv(data_path)
    return add_engineered_features(df, fcfg)

def split_data(df: pd.DataFrame, config: dict) -> Tuple:
    target = config["data"]["target"]
    X = df.drop(columns=[target])
    y = df[target]
    return train_test_split(
        X, y,
        test_size=float(config["data"]["test_size"]),
        random_state=int(config["seed"]),
        stratify=y,
    )

Features

Controls feature engineering parameters.
features:
  epsilon: 1.0e-06
  engagement:
    minutes_watched_weight: 0.6
    days_on_platform_weight: 0.3
    courses_started_weight: 10.0
features.epsilon
float
required
Small constant to prevent division by zero in feature calculations
features.engagement.minutes_watched_weight
float
required
Weight for minutes_watched in engagement score computation
features.engagement.days_on_platform_weight
float
required
Weight for days_on_platform in engagement score computation
features.engagement.courses_started_weight
float
required
Weight for courses_started in engagement score computation
Usage:
src/data.py
fcfg = FeatureConfig(
    epsilon=float(config["features"]["epsilon"]),
    minutes_watched_weight=float(config["features"]["engagement"]["minutes_watched_weight"]),
    days_on_platform_weight=float(config["features"]["engagement"]["days_on_platform_weight"]),
    courses_started_weight=float(config["features"]["engagement"]["courses_started_weight"]),
)
src/features.py
out["engagement_score"] = (
    out["minutes_watched"] * cfg.minutes_watched_weight
    + out["days_on_platform"] * cfg.days_on_platform_weight
    + out["courses_started"] * cfg.courses_started_weight
)

out["exam_success_rate"] = np.where(
    out["practice_exams_started"] > 0,
    out["practice_exams_passed"] / (out["practice_exams_started"] + cfg.epsilon),
    0.0,
)

Preprocessing

Defines preprocessing strategies (currently reserved for future use).
preprocessing:
  outlier_factor: 1.5
  numeric_imputer: median
  categorical_imputer: most_frequent
preprocessing.outlier_factor
float
IQR factor for outlier clipping (used in IQRClipper)
preprocessing.numeric_imputer
string
Strategy for imputing missing numeric values: median, mean, or constant
preprocessing.categorical_imputer
string
Strategy for imputing missing categorical values: most_frequent or constant

Models

Defines hyperparameters for all candidate models.
models:
  logistic_regression:
    max_iter: 2000
  knn:
    n_neighbors: 7
  svm:
    C: 1.0
    kernel: rbf
    gamma: scale
  decision_tree:
    max_depth: 8
    min_samples_leaf: 10
  random_forest:
    n_estimators: 400
    min_samples_leaf: 2
models.logistic_regression.max_iter
integer
Maximum iterations for solver convergence
models.knn.n_neighbors
integer
Number of neighbors for KNN classifier
models.svm.C
float
Regularization parameter (inverse of regularization strength)
models.svm.kernel
string
Kernel type: linear, poly, rbf, sigmoid
models.svm.gamma
string
Kernel coefficient: scale, auto, or numeric value
models.decision_tree.max_depth
integer
Maximum tree depth (limits overfitting)
models.decision_tree.min_samples_leaf
integer
Minimum samples required at leaf nodes
models.random_forest.n_estimators
integer
Number of trees in the forest
models.random_forest.min_samples_leaf
integer
Minimum samples required at leaf nodes
Usage:
src/train.py
def build_models(config: dict) -> dict:
    seed = int(config["seed"])
    return {
        "Logistic Regression": LogisticRegression(
            max_iter=int(config["models"]["logistic_regression"]["max_iter"]),
            random_state=seed,
        ),
        "KNN": KNeighborsClassifier(
            n_neighbors=int(config["models"]["knn"]["n_neighbors"])
        ),
        "SVM": SVC(
            C=float(config["models"]["svm"]["C"]),
            kernel=config["models"]["svm"]["kernel"],
            gamma=config["models"]["svm"]["gamma"],
            probability=True,
            random_state=seed,
        ),
        "Decision Tree": DecisionTreeClassifier(
            max_depth=int(config["models"]["decision_tree"]["max_depth"]),
            min_samples_leaf=int(config["models"]["decision_tree"]["min_samples_leaf"]),
            random_state=seed,
        ),
        "Random Forest": RandomForestClassifier(
            n_estimators=int(config["models"]["random_forest"]["n_estimators"]),
            min_samples_leaf=int(config["models"]["random_forest"]["min_samples_leaf"]),
            random_state=seed,
            n_jobs=-1,
        ),
    }

Cross-Validation

Controls stratified k-fold cross-validation.
cv:
  n_splits: 5
cv.n_splits
integer
required
Number of cross-validation folds for model selection
Usage:
src/train.py
cv = StratifiedKFold(
    n_splits=int(config["cv"]["n_splits"]),
    shuffle=True,
    random_state=int(config["seed"]),
)

for name, model in models.items():
    pipe = Pipeline(steps=[("preprocessor", preprocessor), ("model", model)])
    scores = cross_validate(pipe, X_train, y_train, cv=cv, scoring=scoring)

Business Constraints

Defines business requirements for model calibration.
business:
  target_precision: 0.9
business.target_precision
float
required
Minimum precision required for the positive class (0.0 to 1.0)
Usage:
src/train.py
probs = best_pipeline.predict_proba(X_test)[:, 1]
precisions, recalls, thresholds = precision_recall_curve(y_test, probs)
target_precision = float(config["business"]["target_precision"])

candidates = [i for i, p in enumerate(precisions[:-1]) if p >= target_precision]
if candidates:
    idx = max(candidates, key=lambda i: recalls[i])
else:
    idx = int(np.argmax(precisions[:-1]))

threshold = float(thresholds[idx])

Artifacts

Defines output locations for trained artifacts.
artifacts:
  model_dir: artifacts
  model_file: best_model.joblib
  threshold_file: threshold.txt
  metrics_file: metrics.json
  drift_baseline_file: drift_baseline.json
  lineage_file: lineage.json
artifacts.model_dir
string
required
Directory for storing all training artifacts
artifacts.model_file
string
required
Filename for the serialized model pipeline
artifacts.threshold_file
string
required
Filename for the calibrated decision threshold
artifacts.metrics_file
string
required
Filename for test set performance metrics
artifacts.drift_baseline_file
string
required
Filename for training distribution statistics
artifacts.lineage_file
string
required
Filename for SHA256 lineage manifest
Usage:
src/train.py
out_dir = Path(config["artifacts"]["model_dir"])
out_dir.mkdir(parents=True, exist_ok=True)

model_path = out_dir / config["artifacts"]["model_file"]
threshold_path = out_dir / config["artifacts"]["threshold_file"]
metrics_path = out_dir / config["artifacts"]["metrics_file"]
lineage_path = out_dir / config["artifacts"]["lineage_file"]

joblib.dump(best_pipeline, model_path)
threshold_path.write_text(str(threshold), encoding="utf-8")
metrics_path.write_text(json.dumps(metrics, indent=2), encoding="utf-8")
lineage_path.write_text(json.dumps(lineage, indent=2), encoding="utf-8")

Monitoring

Defines drift detection parameters for production monitoring.
monitoring:
  prediction_log_file: artifacts/prediction_log.jsonl
  drift_min_samples: 50
  drift_zscore_threshold: 3.0
  drift_min_features: 2
  class_rate_shift_threshold: 0.1
monitoring.prediction_log_file
string
required
Path to JSONL file for logging all predictions
monitoring.drift_min_samples
integer
required
Minimum samples required before computing drift statistics
monitoring.drift_zscore_threshold
float
required
Z-score threshold for detecting feature distribution drift
monitoring.drift_min_features
integer
required
Minimum number of drifted features to trigger retraining
monitoring.class_rate_shift_threshold
float
required
Maximum allowed shift in predicted positive rate before triggering retraining
Usage:
src/api.py
def _compute_drift_status() -> DriftStatusResponse:
    monitoring_cfg = _CONFIG.get("monitoring", {}) if _CONFIG else {}
    min_samples = int(monitoring_cfg.get("drift_min_samples", 100))
    z_threshold = float(monitoring_cfg.get("drift_zscore_threshold", 3.0))
    min_drifted_features = int(monitoring_cfg.get("drift_min_features", 2))
    class_rate_shift = float(monitoring_cfg.get("class_rate_shift_threshold", 0.10))
    
    # Compute drift metrics
    for feature, s in feature_sums.items():
        current_mean = float(s) / samples
        base_mean = float(baseline_stats[feature]["mean"])
        base_std = max(float(baseline_stats[feature]["std"]), 1e-6)
        abs_z = abs((current_mean - base_mean) / base_std)
        if abs_z >= z_threshold:
            drifted_features.append(feature)
    
    # Check retraining conditions
    if len(drifted_features) >= min_drifted_features:
        should_retrain = True

Benchmarking

Defines parameters for performance benchmarking and parity testing.
benchmarking:
  repeated_runs: 10
  batch_size: 256
  parity_abs_tolerance: 0.04
  parity_mean_tolerance: 0.01
benchmarking.repeated_runs
integer
required
Number of repeated runs for statistical latency measurements
benchmarking.batch_size
integer
required
Batch size for throughput benchmarking
benchmarking.parity_abs_tolerance
float
required
Maximum absolute difference allowed between PyTorch and ONNX predictions
benchmarking.parity_mean_tolerance
float
required
Maximum mean absolute difference allowed across all predictions

Loading Configuration

Configuration is loaded at runtime:
src/data.py
def load_config(path: str | Path = "config.yaml") -> dict:
    with open(path, "r", encoding="utf-8") as f:
        return yaml.safe_load(f)
Usage pattern:
config = load_config()
seed = int(config["seed"])
test_size = float(config["data"]["test_size"])
max_iter = int(config["models"]["logistic_regression"]["max_iter"])

Configuration Validation

The system performs runtime type coercion but does not validate configuration schema upfront. Invalid configurations will cause runtime errors during training.
Best practices:
  • Use type hints when accessing config values
  • Validate critical parameters at module initialization
  • Document expected ranges in configuration comments

Environment-Specific Overrides

For testing or CI environments, use environment variables:
src/runtime_config.py
def test_mode_enabled() -> bool:
    return os.getenv("TEST_MODE", "false").lower() == "true"

def test_int(key: str, default: int) -> int:
    return int(os.getenv(key, str(default)))
Example:
TEST_MODE=true TEST_MAX_ROWS=100 python -m src.train

Configuration Versioning

1

Commit config.yaml

Always commit configuration changes alongside code
2

Tag stable configurations

Use git tags to mark production-ready configurations
3

Document changes

Explain parameter changes in commit messages
4

Track lineage

The lineage.json file records the config SHA256 for each training run
Additional configuration files in config/:
  • config/datasets.yaml - Dataset metadata and schemas
  • config/experiments.yaml - Experiment tracking configurations
  • config/reproducibility.yaml - Reproducibility test cases
  • config/integrations.yaml - External service integrations
  • config/artifacts.yaml - Extended artifact management

Build docs developers (and LLMs) love