Skip to main content

Overview

The interpret module provides comprehensive model interpretation capabilities using SHAP values, permutation importance, and other feature importance metrics. It enables both global feature importance analysis and local (instance-level) explanations.

TreeInterpretation Class

Main class for interpreting tree-based models (e.g., RandomForest, GradientBoosting).

Constructor

from trifid.models.interpret import TreeInterpretation

interpreter = TreeInterpretation(
    model=trained_model,
    df=training_dataframe,
    features_col=feature_columns,
    target_col="label",
    random_state=123,
    test_size=0.25,
    preprocessing=None
)

Parameters

model
object
required
Trained scikit-learn model instance (e.g., RandomForestClassifier)
df
pandas.DataFrame
required
Training dataset as pandas DataFrame
features_col
list
required
List of feature column names to use as independent variables
target_col
string
required
Name of target column to use as dependent variable
random_state
integer
default:"123"
Random seed for reproducibility
test_size
float
default:"0.25"
Proportion of dataset to use for testing (0.0 to 1.0)
preprocessing
object
default:"None"
Optional preprocessing step to add to model pipeline (e.g., StandardScaler)

Attributes

After initialization, the model is automatically fitted:
model
object
The fitted model instance
predictions
numpy.ndarray
Model predictions on test set
probabilities
numpy.ndarray
Predicted probabilities for positive class
train_features
pandas.DataFrame
Training features
test_features
pandas.DataFrame
Test features
train_target
pandas.Series
Training target values
test_target
pandas.Series
Test target values

Global Feature Importance Methods

The class provides multiple property methods for computing feature importance:

feature_importances

Standard scikit-learn feature importances (Gini/entropy-based).
df_importance = interpreter.feature_importances
Returns: DataFrame with columns:
  • feature: Feature name
  • feature_importances_sklearn: Importance score
Reference: scikit-learn documentation

cv_importances

Cross-validated feature importances using k-fold CV.
df_importance = interpreter.cv_importances
Returns: DataFrame with columns:
  • feature: Feature name
  • cv_feature_importances: Cross-validated importance score
Reference: random-forest-importances

dropcol_importances

Importance computed by dropping each feature and measuring performance decrease.
df_importance = interpreter.dropcol_importances
Returns: DataFrame with columns:
  • feature: Feature name
  • dropcol_importances: Drop-column importance score
Reference: Explained.ai RF Importance

oob_dropcol_importances

Out-of-bag drop-column importances (for RandomForest only).
df_importance = interpreter.oob_dropcol_importances
Returns: DataFrame with columns:
  • feature: Feature name
  • oob_dropcol_importances: OOB drop-column importance score

permutation_importances

Permutation importance using stratified k-fold cross-validation.
df_importance = interpreter.permutation_importances
Returns: DataFrame with columns:
  • feature: Feature name
  • permutation_importance: Permutation importance score
Reference: ELI5 documentation

feature_importance_permutation

Permutation importance using mlxtend library.
df_importance = interpreter.feature_importance_permutation
Returns: DataFrame with columns:
  • feature: Feature name
  • feature_importance_permutation: Importance score from 10 rounds
Reference: mlxtend documentation

shap

SHAP (SHapley Additive exPlanations) values for feature importance.
df_importance = interpreter.shap
Returns: DataFrame with columns:
  • feature: Feature name
  • shap: Mean absolute SHAP value
Reference: SHAP documentation

mutual_information

Mutual information between features and target.
df_importance = interpreter.mutual_information
Returns: DataFrame with columns:
  • feature: Feature name
  • mutual_information: MI score (0 = independent, higher = more dependent)
Reference: scikit-learn MI documentation

target_permutation

Feature importance via target permutation (MDI vs MDA).
df_importance = interpreter.target_permutation
Returns: DataFrame with columns:
  • feature: Feature name
  • ratio_mdi-mda: Ratio of Mean Decrease in Impurity to Mean Decrease in Accuracy
Reference: Kaggle notebook by Olivier Grellier

merge_feature_importances

Combines all feature importance methods into a single DataFrame.
df_all_importance = interpreter.merge_feature_importances
Returns: DataFrame with all importance scores merged on feature name

Local Explanation Methods

Explain individual predictions using SHAP values.

local_explanation()

Generate local SHAP explanations for a specific transcript or gene.
df_explanation = interpreter.local_explanation(
    df_features=features_dataframe,
    sample="ENST00000456328.2",
    waterfall=False
)

Parameters

df_features
pandas.DataFrame
required
DataFrame containing features for all samples, including gene_name and transcript_id columns
sample
string
required
Either:
  • Ensembl transcript ID (e.g., “ENST00000456328.2”)
  • Gene name (e.g., “DDX11L1”)
waterfall
boolean
default:"False"
If True, displays a waterfall plot for the transcript

Returns

For transcript ID: DataFrame with columns:
  • Index: Feature names
  • shap: SHAP value for each feature
  • feature: Feature value
For gene name: DataFrame with:
  • Rows: Feature names
  • Columns: Transcript IDs within the gene
  • Values: SHAP values
  • Additional columns: std (standard deviation), sum (sum across transcripts)

waterfall_plot()

Generate a SHAP waterfall plot for a specific sample.
interpreter.waterfall_plot(
    model=trained_model,
    df_features=features_dataframe,
    sample="ENST00000456328.2"
)

Complete Example

import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from trifid.models.interpret import TreeInterpretation

# Load data
df_training = pd.read_csv("data/training_set.tsv", sep="\t")
feature_cols = [col for col in df_training.columns 
                if col not in ["label", "gene_id", "transcript_id"]]

# Create and fit model
model = RandomForestClassifier(
    n_estimators=400,
    max_features=7,
    min_samples_leaf=7,
    random_state=123
)

# Initialize interpreter
interpreter = TreeInterpretation(
    model=model,
    df=df_training,
    features_col=feature_cols,
    target_col="label",
    random_state=123,
    test_size=0.25
)

# Global feature importance
print("SHAP Importances:")
print(interpreter.shap.head(10))

print("\nMerged Importances:")
print(interpreter.merge_feature_importances.head(10))

# Local explanation for a specific transcript
transcript_explanation = interpreter.local_explanation(
    df_features=df_training,
    sample="ENST00000456328.2",
    waterfall=True
)
print("\nLocal Explanation:")
print(transcript_explanation)

# Local explanation for all transcripts in a gene
gene_explanation = interpreter.local_explanation(
    df_features=df_training,
    sample="DDX11L1",
    waterfall=False
)
print("\nGene-level Explanation:")
print(gene_explanation)

Feature Importance Comparison

import matplotlib.pyplot as plt

# Get multiple importance metrics
importances = interpreter.merge_feature_importances

# Plot comparison
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

importances.nlargest(10, 'shap').plot(
    x='feature', y='shap', kind='barh', ax=axes[0,0], title='SHAP'
)
importances.nlargest(10, 'permutation_importance').plot(
    x='feature', y='permutation_importance', kind='barh', 
    ax=axes[0,1], title='Permutation'
)
importances.nlargest(10, 'dropcol_importances').plot(
    x='feature', y='dropcol_importances', kind='barh', 
    ax=axes[1,0], title='Drop Column'
)
importances.nlargest(10, 'feature_importances_sklearn').plot(
    x='feature', y='feature_importances_sklearn', kind='barh', 
    ax=axes[1,1], title='Sklearn (Gini)'
)

plt.tight_layout()
plt.show()

Best Practices

  1. Use multiple importance methods: Different methods capture different aspects of feature importance
  2. SHAP for interpretation: SHAP values provide theoretically sound feature attributions
  3. Permutation for model-agnostic: Works with any model type
  4. Local explanations: Use for understanding individual predictions and debugging
  5. Cross-validation: CV-based methods provide more robust estimates
  • train: Model training
  • predict: Generate predictions
  • select: Model selection and evaluation

Build docs developers (and LLMs) love