Skip to main content

Overview

The model.py module defines the core AlphaFold 3 model architecture, including the trunk (Evoformer), diffusion head, confidence head, and distogram head. It handles the complete forward pass and generates structure predictions.

Model Class

Main model class that orchestrates the full prediction pipeline.
class Model(hk.Module):
    def __init__(self, config: Config, name: str = 'diffuser')
config
Model.Config
required
Model configuration including Evoformer, heads, and global settings.
name
str
default:"diffuser"
Module name for Haiku.

Configuration

Model.Config

class Config(base_config.BaseConfig):
    evoformer: evoformer_network.Evoformer.Config = base_config.autocreate()
    global_config: model_config.GlobalConfig = base_config.autocreate()
    heads: 'Model.HeadsConfig' = base_config.autocreate()
    num_recycles: int = 10
    return_embeddings: bool = False
    return_distogram: bool = False
evoformer
evoformer_network.Evoformer.Config
Configuration for the Evoformer trunk (embedding module).
global_config
model_config.GlobalConfig
Global configuration including precision and attention settings.
heads
Model.HeadsConfig
Configuration for diffusion, confidence, and distogram heads.
num_recycles
int
default:"10"
Number of recycling iterations through the trunk.
return_embeddings
bool
default:"False"
Whether to return single and pair embeddings in output.
return_distogram
bool
default:"False"
Whether to return distogram in output.

Model.HeadsConfig

class HeadsConfig(base_config.BaseConfig):
    diffusion: diffusion_head.DiffusionHead.Config = base_config.autocreate()
    confidence: confidence_head.ConfidenceHead.Config = base_config.autocreate()
    distogram: distogram_head.DistogramHead.Config = base_config.autocreate()

Forward Pass

def __call__(
    self, 
    batch: features.BatchDict, 
    key: jax.Array | None = None
) -> ModelResult
batch
features.BatchDict
required
Input batch containing featurized sequences, MSAs, and templates.
key
jax.Array | None
Random key for JAX operations. If None, uses hk.next_rng_key().
return
ModelResult
Dictionary containing diffusion samples, confidence metrics, and optionally embeddings/distogram.

ModelResult Structure

ModelResult: TypeAlias = Mapping[str, Any]
The model returns a dictionary with the following keys:
diffusion_samples
dict
Contains atom_positions array with predicted atom coordinates.
distogram
dict
Distance histogram predictions and contact probabilities.
predicted_lddt
np.ndarray
Per-atom predicted local distance difference test (pLDDT) scores.
full_pae
np.ndarray
Full predicted aligned error (PAE) matrix [num_samples, num_tokens, num_tokens].
full_pde
np.ndarray
Full predicted distance error (PDE) matrix.
tmscore_adjusted_pae_global
np.ndarray
TM-score adjusted PAE for global structure assessment.
tmscore_adjusted_pae_interface
np.ndarray
TM-score adjusted PAE for interface assessment.
single_embeddings
jnp.ndarray
Single embeddings if return_embeddings=True [num_tokens, 384].
pair_embeddings
jnp.ndarray
Pair embeddings if return_embeddings=True [num_tokens, num_tokens, 128].

Class Methods

get_inference_result

@classmethod
def get_inference_result(
    cls,
    batch: features.BatchDict,
    result: ModelResult,
    target_name: str = '',
) -> Iterable[InferenceResult]
Processes model outputs and computes inference-time metrics.
batch
features.BatchDict
required
Data batch used for model inference.
result
ModelResult
required
Output dict from the model’s forward pass.
target_name
str
default:""
Target name to be saved within structure.
yield
InferenceResult
Yields one InferenceResult per diffusion sample containing predicted structure and confidence metrics.

InferenceResult Class

Postprocessed model result containing the predicted structure and all confidence metrics.
@dataclasses.dataclass(frozen=True, kw_only=True)
class InferenceResult:
    predicted_structure: structure.Structure
    numerical_data: Mapping[str, float | int | np.ndarray] = dataclasses.field(
        default_factory=dict
    )
    metadata: Mapping[str, float | int | np.ndarray] = dataclasses.field(
        default_factory=dict
    )
    debug_outputs: Mapping[str, Any] = dataclasses.field(default_factory=dict)
    model_id: bytes = b''
predicted_structure
structure.Structure
Predicted protein structure with atom coordinates and B-factors.
numerical_data
Mapping[str, float | int | np.ndarray]
Large numerical arrays like full PAE, PDE, and contact probabilities.
metadata
Mapping[str, float | int | np.ndarray]
Confidence metrics and summary statistics (see Metadata Fields below).
debug_outputs
Mapping[str, Any]
Additional debugging information.
model_id
bytes
Model identifier from parameters.

Metadata Fields

The metadata dictionary contains the following confidence scores:
ranking_score
float
Primary ranking score combining pTM, ipTM, disorder, and clash penalties.
predicted_tm_score
float
Predicted TM-score (pTM) measuring overall structure quality.
interface_predicted_tm_score
float
Interface predicted TM-score (ipTM) for multi-chain complexes.
ptm_iptm_average
float
Weighted average: 0.8 * ipTM + 0.2 * pTM.
ranking_confidence
float
Ranking confidence (equals ipTM for multi-chain, pTM for single chain).
ranking_confidence_pae
float
Alternative ranking metric based on PAE.
predicted_distance_error
float
Average predicted distance error across structure.
fraction_disordered
float
Fraction of structure predicted to be disordered.
has_clash
bool
Whether structure has atomic clashes.
chain_pair_pde_mean
np.ndarray
Mean PDE between chain pairs [num_chains, num_chains].
chain_pair_pde_min
np.ndarray
Minimum PDE between chain pairs [num_chains, num_chains].
chain_pair_pae_min
np.ndarray
Minimum PAE between chain pairs [num_chains, num_chains].
chain_pair_iptm
np.ndarray
Interface pTM between chain pairs [num_chains, num_chains].
intra_chain_single_pde
float
Average PDE within chains (intra-chain contacts).
cross_chain_single_pde
float
Average PDE between chains (inter-chain contacts).
pae_ichain
np.ndarray
Per-chain PAE scores [num_chains].
pae_xchain
np.ndarray
Cross-chain PAE scores [num_chains].
iptm_ichain
np.ndarray
Per-chain ipTM scores [num_chains].
iptm_xchain
np.ndarray
Cross-chain ipTM scores [num_chains].
token_chain_ids
list[str]
Chain IDs for each token.
token_res_ids
np.ndarray
Residue IDs for each token.

Numerical Data Fields

The numerical_data dictionary contains large arrays:
full_pde
np.ndarray
Full predicted distance error matrix [num_tokens, num_tokens].
full_pae
np.ndarray
Full predicted aligned error matrix [num_tokens, num_tokens].
contact_probs
np.ndarray
Contact probability matrix [num_tokens, num_tokens].

Helper Functions

get_predicted_structure

def get_predicted_structure(
    result: ModelResult, 
    batch: feat_batch.Batch
) -> structure.Structure
Creates the predicted structure from model outputs.
result
ModelResult
required
Model output in model-specific layout.
batch
feat_batch.Batch
required
Model input batch for layout conversion.
return
structure.Structure
Predicted structure with atom coordinates and B-factors.

create_target_feat_embedding

def create_target_feat_embedding(
    batch: feat_batch.Batch,
    config: evoformer_network.Evoformer.Config,
    global_config: model_config.GlobalConfig,
) -> jnp.ndarray
Creates target feature embedding for the Evoformer.
batch
feat_batch.Batch
required
Input batch data.
config
evoformer_network.Evoformer.Config
required
Evoformer configuration.
global_config
model_config.GlobalConfig
required
Global model configuration.
return
jnp.ndarray
Target feature embeddings [num_tokens, feature_dim].

Usage Examples

Basic Model Inference

import jax
from alphafold3.model import model, features
from alphafold3.model import params
import haiku as hk

# Load model configuration
config = model.Model.Config()
config.num_recycles = 10
config.heads.diffusion.eval.num_samples = 5

# Create model
def forward_fn(batch):
    return model.Model(config)(batch)

model_fn = hk.transform(forward_fn)

# Load parameters
model_params = params.get_model_haiku_params(model_dir='/path/to/models')

# Run inference
rng_key = jax.random.PRNGKey(42)
result = model_fn.apply(model_params, rng_key, batch)

Processing Results

# Extract inference results
inference_results = list(
    model.Model.get_inference_result(
        batch=batch,
        result=result,
        target_name='my_protein'
    )
)

# Access first sample
first_result = inference_results[0]
print(f"Ranking score: {first_result.metadata['ranking_score']}")
print(f"pTM: {first_result.metadata['predicted_tm_score']}")
print(f"Has clash: {first_result.metadata['has_clash']}")

# Save structure
first_result.predicted_structure.save_to_mmcif('output.cif')

Accessing Confidence Metrics

# Get confidence matrices
pae = first_result.numerical_data['full_pae']
pde = first_result.numerical_data['full_pde']
contacts = first_result.numerical_data['contact_probs']

# Chain-level metrics
chain_pair_iptm = first_result.metadata['chain_pair_iptm']
print(f"Chain pair interface scores:\n{chain_pair_iptm}")

# Overall quality
ranking_score = first_result.metadata['ranking_score']
fraction_disordered = first_result.metadata['fraction_disordered']
print(f"Quality: {ranking_score:.3f}, Disorder: {fraction_disordered:.3f}")

Multi-Sample Analysis

# Compare all samples by ranking score
samples = sorted(
    inference_results,
    key=lambda x: x.metadata['ranking_score'],
    reverse=True
)

best_sample = samples[0]
print(f"Best sample ranking score: {best_sample.metadata['ranking_score']}")

# Ensemble metrics
ptm_scores = [s.metadata['predicted_tm_score'] for s in samples]
print(f"pTM mean: {np.mean(ptm_scores):.3f}, std: {np.std(ptm_scores):.3f}")

With Embeddings

# Configure model to return embeddings
config.return_embeddings = True
config.return_distogram = True

# Run inference
result = model_fn.apply(model_params, rng_key, batch)

# Access embeddings
single_emb = result['single_embeddings']  # [num_tokens, 384]
pair_emb = result['pair_embeddings']      # [num_tokens, num_tokens, 128]
distogram = result['distogram']['distogram']  # [num_tokens, num_tokens, 64]

print(f"Single embedding shape: {single_emb.shape}")
print(f"Pair embedding shape: {pair_emb.shape}")

Architecture Overview

The Model consists of:
  1. Evoformer (Trunk): Processes MSA and creates single/pair embeddings through multiple recycling iterations
  2. Diffusion Head: Generates atom coordinates through denoising diffusion process
  3. Confidence Head: Predicts pLDDT, PAE, and PDE confidence metrics
  4. Distogram Head: Predicts distance histograms and contact probabilities

Forward Pass Flow

# Simplified forward pass
def __call__(batch):
    # 1. Create target features
    target_feat = create_target_feat_embedding(batch, config, global_config)
    
    # 2. Run recycling through Evoformer trunk
    for i in range(num_recycles + 1):
        embeddings = evoformer(batch, prev_embeddings, target_feat)
    
    # 3. Sample diffusion (generates atom coordinates)
    samples = diffusion_head(batch, embeddings)
    
    # 4. Compute confidence metrics
    confidence_output = confidence_head(samples, embeddings)
    
    # 5. Compute distogram
    distogram = distogram_head(batch, embeddings)
    
    return {
        'diffusion_samples': samples,
        'distogram': distogram,
        **confidence_output
    }

Performance Considerations

  • Recycling: More recycles (10-20) improve quality but increase compute time
  • Diffusion Samples: More samples (5-10) provide better coverage but are slower
  • Embeddings: Enabling embeddings significantly increases memory usage
  • Flash Attention: Use triton or cudnn for best performance on Ampere+ GPUs

See Also

Build docs developers (and LLMs) love