Architecture Overview
AlphaFold 3 employs a sophisticated neural network architecture that combines evolutionary reasoning, structural attention mechanisms, and diffusion-based coordinate generation. The architecture consists of four major modules operating in sequence.The AlphaFold 3 architecture is implemented using JAX and Haiku, enabling efficient GPU computation and automatic differentiation.
High-Level Architecture
Module 1: Input Embedding
Purpose
Transform raw features into learned representations suitable for the Evoformer trunk.Token Features
Input features are first embedded into continuous representations:Residue Type Embedding
Residue Type Embedding
Each token’s residue/ligand type is embedded:Handles:
- 20 standard amino acids + unknown
- 4 RNA bases (A, C, G, U)
- 4 DNA bases (A, C, G, T)
- Ligand type indicators
- Modified residues via CCD codes
Positional Encoding
Positional Encoding
Encodes token position and chain information:Features:
- Relative token distance (clipped to ±32)
- Same chain vs different chain
- Chain separation distance
- N-terminal/C-terminal indicators
Chemical Features
Chemical Features
For ligands and modified residues:Derived from:
- Chemical Component Dictionary (CCD)
- RDKit molecular descriptors
- Custom chemical annotations
Pair Features
Pairwise relationships are initialized from token features:[num_tokens, num_tokens, 128]:
- Outer product of single representations
- Relative positional encodings
- Distance and orientation features (from templates)
Module 2: Evoformer Trunk
Purpose
Process MSA and template information to build rich single and pair representations capturing evolutionary and structural patterns.Architecture Configuration
MSA Stack
Processes multiple sequence alignments through attention layers:MSA Row Attention
Attention across sequences for each positionLearns which sequences are informative for each position.
MSA Column Attention
Attention across positions for each sequenceCaptures positional dependencies and co-evolution.
The MSA stack typically runs for 4 blocks, extracting progressively refined evolutionary features.
Template Processing
Structural templates are integrated via specialized modules:Template Pair Features
Template Pair Features
Generate pairwise features from template structures:Features computed from template coordinates:
- Distances: Cα-Cα distances between residues
- Orientations: Relative backbone orientations
- Backbone angles: ψ, φ, ω dihedral angles
- Unit vectors: Normalized direction vectors
[num_templates, num_tokens, num_tokens, template_pair_dim]Template Point Attention
Template Point Attention
Aggregate template information via attention mechanism:The attention weights determine which templates are most relevant for each token pair.
Single and Pair Updates
The single representation is maintained alongside the pair representation:Triangle Multiplicative Update
Triangle Multiplicative Update
Propagates information along triangle edges:This captures transitive relationships: if i-k and k-j are related, then i-j are related.
Triangle Self-Attention
Triangle Self-Attention
Attention along triangle edges:Aggregates information from all triangles involving edge (i,j).
Module 3: Pairformer
Purpose
Deep refinement of single and pair representations through 48 transformer-like blocks without MSA.Architecture
Pairformer Iteration
Each of the 48 blocks performs:Single Representation Update
Self-attention over tokens with pair biasThe pair representation biases attention weights, incorporating pairwise information.
Computational Cost: The 48 Pairformer layers are the most computationally intensive part of the network. For large proteins, these layers dominate inference time.Memory optimization: Block rematerialization (
block_remat=True) recomputes activations during backprop to save memory.Per-Atom Conditioning
After Pairformer, per-atom conditioning refines representations:- Expands token representations to atom-level
- Runs transformer over atoms within each token
- Cross-attends between atoms and tokens
- Prepares atom-level features for diffusion module
Module 4: Diffusion Module
Purpose
Generate 3D atomic coordinates through a learned reverse diffusion process.Diffusion Framework
AlphaFold 3 uses a score-based diffusion model:Sampling Process
Iterative Denoising
Apply learned denoising stepsDefault: 200 denoising steps from σ=160 to σ=0.0004
Diffusion Transformer
The core denoising network:Adaptive Layer Normalization
Adaptive Layer Normalization
Conditioning mechanism using single representations:Allows the network to modulate processing based on sequence and evolutionary context.
Transformer Blocks
Transformer Blocks
Each block contains:Gated Linear Units (GLU):GLU improves gradient flow and model capacity.
Noise Level Conditioning
Noise Level Conditioning
Diffusion time embedding:Informs the network about current noise level, guiding denoising strength.
Random Augmentation
During training, random rigid-body transformations are applied:Diffusion Hyperparameters
- More steps → better quality, slower
- More samples → more diversity, slower
- Higher noise_scale → more stochastic, more diverse
Module 5: Confidence Head
Purpose
Predict confidence metrics for the structure without generating coordinates.Architecture
Predicted Metrics
pLDDT (predicted LDDT)
pLDDT (predicted LDDT)
Per-atom local distance difference test:Binned prediction:
- 50 bins from 0 to 100
- Softmax over bins
- Expected value = predicted pLDDT
PAE (Predicted Aligned Error)
PAE (Predicted Aligned Error)
Pairwise error estimation:Interpretation:
- PAE[i,j] = predicted error in position of token j when aligned on token i
- Lower values = higher confidence in relative positioning
- Used to identify domain boundaries and reliable interactions
Contact Probabilities
Contact Probabilities
Probability of tokens being in contact:Used for:
- Interface prediction
- Contact map visualization
- Interaction analysis
Derived Metrics
From base predictions, aggregate metrics are computed:Memory and Compute Optimizations
Precision
- 16-bit floating point format
- Reduces memory by 2× vs float32
- Maintains float32 dynamic range
- Minimal accuracy loss
Gradient Checkpointing
- Don’t store all intermediate activations
- Recompute activations during backward pass
- Trade compute for memory
- Essential for large proteins
Attention Optimization
- Attention computation
- Gated linear units
- Layer normalization
Implementation Details
Haiku Modules
AlphaFold 3 uses Haiku (JAX neural network library):- Clean separation of state and logic
- Easy parameter management
- Composable modules
- JAX transformations (jit, grad, vmap)
JAX Transformations
Model Parameters
Loading Checkpoints
- All neural network weights
- Layer normalization parameters
- Embedding matrices
- ~5GB total size (fp32)
Parameter Initialization
Training initialization (reference only, inference uses pretrained):Scalability Considerations
Token Scaling
Memory complexity by component:| Component | Complexity | Example (1000 tokens) |
|---|---|---|
| Single repr | O(N) | 1000 × 384 = 0.4M |
| Pair repr | O(N²) | 1000² × 128 = 128M |
| Pairformer attention | O(N²) | Quadratic |
| Diffusion per step | O(A) | Linear in atoms |
Practical limits (on 40GB A100):
- ~2000 tokens with full Pairformer
- ~5000 tokens with sparse attention approximations
- Ligands contribute fewer tokens than protein chains
Strategies for Large Systems
- Chain cropping: Process subsets of chains
- Sparse attention: Approximate long-range interactions
- Mixed precision: bfloat16 throughout
- Gradient checkpointing: Reduce activation memory
Next Steps
Overview
Return to AlphaFold 3 overview
Inference Pipeline
Learn how to run inference