Overview
The trifid.visualization.figures module provides plotting functions for visualizing TRIFID predictions, feature importance, and model performance. This is an optional module that requires visualization dependencies.
Install visualization dependencies with: pip install .[extra]
Installation
# Install with visualization support
pip install git+https://github.com/fpozoc/trifid.git
pip install .[extra]
Required packages: altair, matplotlib, seaborn, shap, yellowbrick
Prediction Visualization
explain_prediction()
Generate SHAP force plot explanation for a specific transcript prediction.
DataFrame containing features and predictions
Trained TRIFID model object
List of feature names to use in explanation
Transcript ID to explain (e.g., “ENST00000356207”)
Returns: pd.DataFrame with SHAP values
Example:
from trifid.visualization.figures import explain_prediction
# Load model and data
model = load_model( 'path/to/model.pkl' )
df = pd.read_csv( 'predictions.tsv.gz' , sep = ' \t ' )
# Explain prediction for specific transcript
explain_prediction(
df = df,
model = model,
features = feature_list,
transcript_id = 'ENST00000356207'
)
This generates a SHAP force plot showing:
Base SHAP value (model’s expected value)
Feature contributions pushing prediction higher (orange)
Feature contributions pushing prediction lower (blue)
Final predicted score
plot_trifid_appris()
Compare TRIFID scores against APPRIS annotations.
DataFrame with trifid_score and appris columns
Returns: Altair chart object
from trifid.visualization.figures import plot_trifid_appris
df = pd.read_csv( 'predictions.tsv.gz' , sep = ' \t ' )
chart = plot_trifid_appris(df)
chart.save( 'trifid_vs_appris.html' )
plot_appris_histogram()
Create histogram of TRIFID scores grouped by APPRIS labels.
DataFrame with predictions and APPRIS annotations
Returns: Altair chart object
from trifid.visualization.figures import plot_appris_histogram, cat_appris_order
# Prepare data with sorted APPRIS categories
df = cat_appris_order(df)
# Generate histogram
chart = plot_appris_histogram(df)
chart.save( 'appris_histogram.html' )
plot_transcript_types_histogram()
Histogram of transcript type distribution.
DataFrame with transcript type flags
Returns: Altair chart object
from trifid.visualization.figures import plot_transcript_types_histogram, cat_transcript_type
# Categorize transcript types
df = cat_transcript_type(df)
# Plot distribution
chart = plot_transcript_types_histogram(df)
Feature Analysis
plot_feature_importances()
Visualize feature importance across multiple methods.
DataFrame with importance scores
Column name for x-axis (feature names)
Column name for y-axis (importance values)
Column for faceting (importance method)
Specific method to highlight
Number of top features to display
Returns: Altair chart object
from trifid.visualization.figures import plot_feature_importances
# Create importance DataFrame
importance_df = pd.DataFrame({
'feature' : feature_names,
'importance' : shap_values,
'method' : 'SHAP'
})
chart = plot_feature_importances(
source = importance_df,
xcol = 'feature' ,
ycol = 'importance' ,
facetcol = 'method' ,
method = 'SHAP' ,
ntop = 20
)
plot_pulse_comparison()
Compare score distributions between two sets (e.g., TRIFID vs baseline).
DataFrame with scores and comparison labels
x_col
str
default: "trifid_score"
Column name for score values
x_axis_name
str
default: "TRIFID score"
Label for x-axis
Returns: Altair chart object
from trifid.visualization.figures import plot_pulse_comparison
chart = plot_pulse_comparison(
df = comparison_df,
x_col = 'trifid_score' ,
x_axis_name = 'Functional Score'
)
Model Evaluation
plot_learning_curve()
Generate learning curve to assess model performance vs training size.
scoring
str
default: "matthews_corrcoef"
Scoring metric
from trifid.visualization.figures import plot_learning_curve
import matplotlib.pyplot as plt
fig, ax = plt.subplots( figsize = ( 10 , 6 ))
plot_learning_curve(
model = rf_model,
X = X_train,
y = y_train,
cv = 5 ,
scoring = 'matthews_corrcoef' ,
n_jobs =- 1 ,
ax = ax,
title = 'TRIFID Learning Curve'
)
plt.tight_layout()
plt.savefig( 'learning_curve.png' , dpi = 300 )
plot_validation_curve()
Plot validation curve for hyperparameter tuning.
Range of parameter values
scoring
str
default: "matthews_corrcoef"
Metric
from trifid.visualization.figures import plot_validation_curve
fig, ax = plt.subplots( figsize = ( 10 , 6 ))
plot_validation_curve(
model = rf_model,
X = X_train,
y = y_train,
param_name = 'n_estimators' ,
param_range = [ 50 , 100 , 200 , 400 , 600 ],
cv = 5 ,
ax = ax,
title = 'Validation Curve: n_estimators'
)
plot_prcurve()
Generate precision-recall curve with cross-validation.
from trifid.visualization.figures import plot_prcurve
fig, ax = plt.subplots( figsize = ( 8 , 8 ))
plot_prcurve(
model = model,
X = X_test,
y = y_test,
n_splits = '10' ,
seed = 42 ,
ax = ax,
title = 'Precision-Recall Curve'
)
Utility Functions
cat_appris_order()
Sort APPRIS annotations in canonical order.
DataFrame with appris column
Returns: DataFrame with appris_order column added
from trifid.visualization.figures import cat_appris_order
df = cat_appris_order(df)
# Now df has appris_order: 1-5 for PRINCIPAL, 6-7 for ALTERNATIVE, 8 for MINOR
cat_transcript_type()
Categorize transcript type flags into readable labels.
DataFrame with flags column
Returns: DataFrame with flags_mod column
from trifid.visualization.figures import cat_transcript_type
df = cat_transcript_type(df)
# Adds flags_mod: "Protein coding", "Nonsense mediated decay", etc.
config_altair()
Apply standard Altair chart configuration.
Returns: Configured Altair chart
from trifid.visualization.figures import config_altair
import altair as alt
chart = alt.Chart(df).mark_bar().encode( x = 'feature' , y = 'value' )
chart = config_altair(chart, height = 400 , width = 600 )
create_categories()
Create feature categories for comparative visualization.
Feature name to categorize
Returns: Categorized DataFrame
from trifid.visualization.figures import create_categories
categorized = create_categories(
df_features = features,
df_predictions = predictions,
feature = 'RNA2sj' ,
cats = [ 0 , 0.25 , 0.5 , 0.75 , 1.0 ]
)
Complete Example
import pandas as pd
import pickle
import matplotlib.pyplot as plt
from trifid.visualization.figures import (
explain_prediction,
plot_feature_importances,
plot_appris_histogram,
plot_learning_curve,
cat_appris_order
)
# Load data
predictions = pd.read_csv( 'trifid_predictions.tsv.gz' , sep = ' \t ' )
with open ( 'trifid_model.pkl' , 'rb' ) as f:
model = pickle.load(f)
# 1. Explain a specific prediction
explain_prediction(
df = predictions,
model = model,
features = feature_list,
transcript_id = 'ENST00000356207'
)
# 2. Compare TRIFID vs APPRIS
predictions = cat_appris_order(predictions)
appris_chart = plot_appris_histogram(predictions)
appris_chart.save( 'appris_comparison.html' )
# 3. Feature importance
importance_chart = plot_feature_importances(
source = importance_df,
xcol = 'feature' ,
ycol = 'shap_value' ,
facetcol = 'method' ,
method = 'SHAP' ,
ntop = 15
)
importance_chart.save( 'feature_importance.html' )
# 4. Learning curve
fig, ax = plt.subplots( figsize = ( 10 , 6 ))
plot_learning_curve(model, X_train, y_train, ax = ax)
plt.savefig( 'learning_curve.png' , dpi = 300 )
Source Reference
Source: ~/workspace/source/trifid/visualization/figures.py
Related notebook: 02.figures.ipynb
See Also
Model Interpretation SHAP-based model interpretation functions
Interpreting Results Guide Step-by-step guide to result interpretation