Skip to main content

Overview

The trifid.visualization.figures module provides plotting functions for visualizing TRIFID predictions, feature importance, and model performance. This is an optional module that requires visualization dependencies.
Install visualization dependencies with: pip install .[extra]

Installation

# Install with visualization support
pip install git+https://github.com/fpozoc/trifid.git
pip install .[extra]
Required packages: altair, matplotlib, seaborn, shap, yellowbrick

Prediction Visualization

explain_prediction()

Generate SHAP force plot explanation for a specific transcript prediction.
df
pd.DataFrame
required
DataFrame containing features and predictions
model
object
required
Trained TRIFID model object
features
list
required
List of feature names to use in explanation
transcript_id
str
required
Transcript ID to explain (e.g., “ENST00000356207”)
Returns: pd.DataFrame with SHAP values Example:
from trifid.visualization.figures import explain_prediction

# Load model and data
model = load_model('path/to/model.pkl')
df = pd.read_csv('predictions.tsv.gz', sep='\t')

# Explain prediction for specific transcript
explain_prediction(
    df=df,
    model=model,
    features=feature_list,
    transcript_id='ENST00000356207'
)
This generates a SHAP force plot showing:
  • Base SHAP value (model’s expected value)
  • Feature contributions pushing prediction higher (orange)
  • Feature contributions pushing prediction lower (blue)
  • Final predicted score

plot_trifid_appris()

Compare TRIFID scores against APPRIS annotations.
df
pd.DataFrame
required
DataFrame with trifid_score and appris columns
Returns: Altair chart object
from trifid.visualization.figures import plot_trifid_appris

df = pd.read_csv('predictions.tsv.gz', sep='\t')
chart = plot_trifid_appris(df)
chart.save('trifid_vs_appris.html')

plot_appris_histogram()

Create histogram of TRIFID scores grouped by APPRIS labels.
df
pd.DataFrame
required
DataFrame with predictions and APPRIS annotations
Returns: Altair chart object
from trifid.visualization.figures import plot_appris_histogram, cat_appris_order

# Prepare data with sorted APPRIS categories
df = cat_appris_order(df)

# Generate histogram
chart = plot_appris_histogram(df)
chart.save('appris_histogram.html')

plot_transcript_types_histogram()

Histogram of transcript type distribution.
df
pd.DataFrame
required
DataFrame with transcript type flags
Returns: Altair chart object
from trifid.visualization.figures import plot_transcript_types_histogram, cat_transcript_type

# Categorize transcript types
df = cat_transcript_type(df)

# Plot distribution
chart = plot_transcript_types_histogram(df)

Feature Analysis

plot_feature_importances()

Visualize feature importance across multiple methods.
source
pd.DataFrame
required
DataFrame with importance scores
xcol
str
required
Column name for x-axis (feature names)
ycol
str
required
Column name for y-axis (importance values)
facetcol
str
required
Column for faceting (importance method)
method
str
required
Specific method to highlight
ntop
int
default:"18"
Number of top features to display
Returns: Altair chart object
from trifid.visualization.figures import plot_feature_importances

# Create importance DataFrame
importance_df = pd.DataFrame({
    'feature': feature_names,
    'importance': shap_values,
    'method': 'SHAP'
})

chart = plot_feature_importances(
    source=importance_df,
    xcol='feature',
    ycol='importance',
    facetcol='method',
    method='SHAP',
    ntop=20
)

plot_pulse_comparison()

Compare score distributions between two sets (e.g., TRIFID vs baseline).
df
pd.DataFrame
required
DataFrame with scores and comparison labels
x_col
str
default:"trifid_score"
Column name for score values
x_axis_name
str
default:"TRIFID score"
Label for x-axis
Returns: Altair chart object
from trifid.visualization.figures import plot_pulse_comparison

chart = plot_pulse_comparison(
    df=comparison_df,
    x_col='trifid_score',
    x_axis_name='Functional Score'
)

Model Evaluation

plot_learning_curve()

Generate learning curve to assess model performance vs training size.
model
object
required
Sklearn-compatible model
X
array-like
required
Feature matrix
y
array-like
required
Target labels
cv
int
default:"5"
Cross-validation folds
scoring
str
default:"matthews_corrcoef"
Scoring metric
n_jobs
int
default:"-1"
Number of parallel jobs
ax
matplotlib.Axes
required
Matplotlib axes object
title
str
default:"None"
Plot title
from trifid.visualization.figures import plot_learning_curve
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(10, 6))

plot_learning_curve(
    model=rf_model,
    X=X_train,
    y=y_train,
    cv=5,
    scoring='matthews_corrcoef',
    n_jobs=-1,
    ax=ax,
    title='TRIFID Learning Curve'
)

plt.tight_layout()
plt.savefig('learning_curve.png', dpi=300)

plot_validation_curve()

Plot validation curve for hyperparameter tuning.
model
object
required
Sklearn model
X
array-like
required
Features
y
array-like
required
Labels
param_name
str
required
Parameter name to vary
param_range
array-like
required
Range of parameter values
cv
int
default:"5"
Cross-validation splits
scoring
str
default:"matthews_corrcoef"
Metric
n_jobs
int
default:"-1"
Parallel jobs
ax
matplotlib.Axes
required
Axes object
title
str
default:"None"
Title
from trifid.visualization.figures import plot_validation_curve

fig, ax = plt.subplots(figsize=(10, 6))

plot_validation_curve(
    model=rf_model,
    X=X_train,
    y=y_train,
    param_name='n_estimators',
    param_range=[50, 100, 200, 400, 600],
    cv=5,
    ax=ax,
    title='Validation Curve: n_estimators'
)

plot_prcurve()

Generate precision-recall curve with cross-validation.
model
object
required
Trained model
X
array-like
required
Feature matrix
y
array-like
required
True labels
n_splits
str
required
Number of CV splits
seed
int
required
Random seed
ax
matplotlib.Axes
required
Axes for plotting
title
str
default:"None"
Plot title
from trifid.visualization.figures import plot_prcurve

fig, ax = plt.subplots(figsize=(8, 8))

plot_prcurve(
    model=model,
    X=X_test,
    y=y_test,
    n_splits='10',
    seed=42,
    ax=ax,
    title='Precision-Recall Curve'
)

Utility Functions

cat_appris_order()

Sort APPRIS annotations in canonical order.
df
pd.DataFrame
required
DataFrame with appris column
Returns: DataFrame with appris_order column added
from trifid.visualization.figures import cat_appris_order

df = cat_appris_order(df)
# Now df has appris_order: 1-5 for PRINCIPAL, 6-7 for ALTERNATIVE, 8 for MINOR

cat_transcript_type()

Categorize transcript type flags into readable labels.
df
pd.DataFrame
required
DataFrame with flags column
Returns: DataFrame with flags_mod column
from trifid.visualization.figures import cat_transcript_type

df = cat_transcript_type(df)
# Adds flags_mod: "Protein coding", "Nonsense mediated decay", etc.

config_altair()

Apply standard Altair chart configuration.
chart
altair.Chart
required
Altair chart object
height
int
required
Chart height in pixels
width
int
required
Chart width in pixels
Returns: Configured Altair chart
from trifid.visualization.figures import config_altair
import altair as alt

chart = alt.Chart(df).mark_bar().encode(x='feature', y='value')
chart = config_altair(chart, height=400, width=600)

create_categories()

Create feature categories for comparative visualization.
df_features
pd.DataFrame
required
Features DataFrame
df_predictions
pd.DataFrame
required
Predictions DataFrame
feature
str
required
Feature name to categorize
cats
list
required
Category boundaries
Returns: Categorized DataFrame
from trifid.visualization.figures import create_categories

categorized = create_categories(
    df_features=features,
    df_predictions=predictions,
    feature='RNA2sj',
    cats=[0, 0.25, 0.5, 0.75, 1.0]
)

Complete Example

import pandas as pd
import pickle
import matplotlib.pyplot as plt
from trifid.visualization.figures import (
    explain_prediction,
    plot_feature_importances,
    plot_appris_histogram,
    plot_learning_curve,
    cat_appris_order
)

# Load data
predictions = pd.read_csv('trifid_predictions.tsv.gz', sep='\t')
with open('trifid_model.pkl', 'rb') as f:
    model = pickle.load(f)

# 1. Explain a specific prediction
explain_prediction(
    df=predictions,
    model=model,
    features=feature_list,
    transcript_id='ENST00000356207'
)

# 2. Compare TRIFID vs APPRIS
predictions = cat_appris_order(predictions)
appris_chart = plot_appris_histogram(predictions)
appris_chart.save('appris_comparison.html')

# 3. Feature importance
importance_chart = plot_feature_importances(
    source=importance_df,
    xcol='feature',
    ycol='shap_value',
    facetcol='method',
    method='SHAP',
    ntop=15
)
importance_chart.save('feature_importance.html')

# 4. Learning curve
fig, ax = plt.subplots(figsize=(10, 6))
plot_learning_curve(model, X_train, y_train, ax=ax)
plt.savefig('learning_curve.png', dpi=300)

Source Reference

Source: ~/workspace/source/trifid/visualization/figures.py Related notebook: 02.figures.ipynb

See Also

Model Interpretation

SHAP-based model interpretation functions

Interpreting Results Guide

Step-by-step guide to result interpretation

Build docs developers (and LLMs) love