Overview
Themodel.py module defines the core AlphaFold 3 model architecture, including the trunk (Evoformer), diffusion head, confidence head, and distogram head. It handles the complete forward pass and generates structure predictions.
Model Class
Main model class that orchestrates the full prediction pipeline.Model configuration including Evoformer, heads, and global settings.
Module name for Haiku.
Configuration
Model.Config
Configuration for the Evoformer trunk (embedding module).
Global configuration including precision and attention settings.
Configuration for diffusion, confidence, and distogram heads.
Number of recycling iterations through the trunk.
Whether to return single and pair embeddings in output.
Whether to return distogram in output.
Model.HeadsConfig
Forward Pass
Input batch containing featurized sequences, MSAs, and templates.
Random key for JAX operations. If None, uses
hk.next_rng_key().Dictionary containing diffusion samples, confidence metrics, and optionally embeddings/distogram.
ModelResult Structure
Contains
atom_positions array with predicted atom coordinates.Distance histogram predictions and contact probabilities.
Per-atom predicted local distance difference test (pLDDT) scores.
Full predicted aligned error (PAE) matrix [num_samples, num_tokens, num_tokens].
Full predicted distance error (PDE) matrix.
TM-score adjusted PAE for global structure assessment.
TM-score adjusted PAE for interface assessment.
Single embeddings if
return_embeddings=True [num_tokens, 384].Pair embeddings if
return_embeddings=True [num_tokens, num_tokens, 128].Class Methods
get_inference_result
Data batch used for model inference.
Output dict from the model’s forward pass.
Target name to be saved within structure.
Yields one InferenceResult per diffusion sample containing predicted structure and confidence metrics.
InferenceResult Class
Postprocessed model result containing the predicted structure and all confidence metrics.Predicted protein structure with atom coordinates and B-factors.
Large numerical arrays like full PAE, PDE, and contact probabilities.
Confidence metrics and summary statistics (see Metadata Fields below).
Additional debugging information.
Model identifier from parameters.
Metadata Fields
Themetadata dictionary contains the following confidence scores:
Primary ranking score combining pTM, ipTM, disorder, and clash penalties.
Predicted TM-score (pTM) measuring overall structure quality.
Interface predicted TM-score (ipTM) for multi-chain complexes.
Weighted average: 0.8 * ipTM + 0.2 * pTM.
Ranking confidence (equals ipTM for multi-chain, pTM for single chain).
Alternative ranking metric based on PAE.
Average predicted distance error across structure.
Fraction of structure predicted to be disordered.
Whether structure has atomic clashes.
Mean PDE between chain pairs [num_chains, num_chains].
Minimum PDE between chain pairs [num_chains, num_chains].
Minimum PAE between chain pairs [num_chains, num_chains].
Interface pTM between chain pairs [num_chains, num_chains].
Average PDE within chains (intra-chain contacts).
Average PDE between chains (inter-chain contacts).
Per-chain PAE scores [num_chains].
Cross-chain PAE scores [num_chains].
Per-chain ipTM scores [num_chains].
Cross-chain ipTM scores [num_chains].
Chain IDs for each token.
Residue IDs for each token.
Numerical Data Fields
Thenumerical_data dictionary contains large arrays:
Full predicted distance error matrix [num_tokens, num_tokens].
Full predicted aligned error matrix [num_tokens, num_tokens].
Contact probability matrix [num_tokens, num_tokens].
Helper Functions
get_predicted_structure
Model output in model-specific layout.
Model input batch for layout conversion.
Predicted structure with atom coordinates and B-factors.
create_target_feat_embedding
Input batch data.
Evoformer configuration.
Global model configuration.
Target feature embeddings [num_tokens, feature_dim].
Usage Examples
Basic Model Inference
Processing Results
Accessing Confidence Metrics
Multi-Sample Analysis
With Embeddings
Architecture Overview
The Model consists of:- Evoformer (Trunk): Processes MSA and creates single/pair embeddings through multiple recycling iterations
- Diffusion Head: Generates atom coordinates through denoising diffusion process
- Confidence Head: Predicts pLDDT, PAE, and PDE confidence metrics
- Distogram Head: Predicts distance histograms and contact probabilities
Forward Pass Flow
Performance Considerations
- Recycling: More recycles (10-20) improve quality but increase compute time
- Diffusion Samples: More samples (5-10) provide better coverage but are slower
- Embeddings: Enabling embeddings significantly increases memory usage
- Flash Attention: Use
tritonorcudnnfor best performance on Ampere+ GPUs
See Also
- run_alphafold.py - Main prediction script
- Input Dataclass - Input format specification
- DataPipeline - MSA and template processing