Skip to main content

Overview

The inference module provides the primary interface for running AlphaFold 3 predictions and handling model outputs. It includes the main Model class for forward passes and the InferenceResult dataclass for storing predictions.

InferenceResult

Dataclass storing postprocessed model predictions and associated metadata.
@dataclasses.dataclass(frozen=True, kw_only=True)
class InferenceResult:
    predicted_structure: structure.Structure
    numerical_data: Mapping[str, float | int | np.ndarray] = dataclasses.field(default_factory=dict)
    metadata: Mapping[str, float | int | np.ndarray] = dataclasses.field(default_factory=dict)
    debug_outputs: Mapping[str, Any] = dataclasses.field(default_factory=dict)
    model_id: bytes = b''

Attributes

predicted_structure
structure.Structure
required
The predicted protein structure containing atomic coordinates and metadata.
numerical_data
Mapping[str, float | int | np.ndarray]
Useful numerical data (scalars or arrays) to be saved at inference time. Commonly includes:
  • full_pde: Full predicted distance error matrix
  • full_pae: Full predicted aligned error matrix
  • contact_probs: Contact probability matrix
metadata
Mapping[str, float | int | np.ndarray]
Smaller numerical data (usually scalar) to be saved as inference metadata. Includes confidence metrics:
  • predicted_tm_score: Predicted TM-score (pTM)
  • interface_predicted_tm_score: Interface pTM (ipTM)
  • ranking_score: Overall ranking confidence
  • fraction_disordered: Fraction of disordered residues
  • has_clash: Boolean indicating structural clashes
  • chain_pair_pae_min: Minimum chain pair PAE
  • chain_pair_pde_mean: Mean chain pair PDE
debug_outputs
Mapping[str, Any]
Additional dictionary for debugging, e.g., raw outputs of a model forward pass.
model_id
bytes
Model identifier used to generate this prediction.

Model Class

Full AlphaFold 3 model implementation using Haiku modules.
class Model(hk.Module):
    def __init__(self, config: Config, name: str = 'diffuser'):
        super().__init__(name=name)
        self.config = config
        self.global_config = config.global_config
        self.diffusion_module = diffusion_head.DiffusionHead(
            self.config.heads.diffusion, self.global_config
        )

Configuration

config.evoformer
evoformer_network.Evoformer.Config
Configuration for the Evoformer trunk network.
config.global_config
model_config.GlobalConfig
Global model configuration including dtype settings.
config.heads
Model.HeadsConfig
Configuration for model heads (diffusion, confidence, distogram).
config.num_recycles
int
default:10
Number of recycling iterations through the trunk network.
config.return_embeddings
bool
default:false
Whether to return single and pair embeddings in output.
config.return_distogram
bool
default:false
Whether to compute and return distogram predictions.

Forward Pass

def __call__(
    self, 
    batch: features.BatchDict, 
    key: jax.Array | None = None
) -> ModelResult:
    """Run forward pass through the model.
    
    Args:
        batch: Input feature dictionary
        key: Random key for sampling (uses hk.next_rng_key() if None)
        
    Returns:
        Dictionary containing:
        - diffusion_samples: Sampled atomic positions
        - distogram: Distance distribution predictions
        - Confidence head outputs (pLDDT, PAE, PDE)
    """
batch
features.BatchDict
required
Dictionary of input features including MSA, templates, and token features.
key
jax.Array | None
JAX random key for stochastic sampling. If None, uses hk.next_rng_key().

Returns

diffusion_samples
dict
Sampled structure predictions from the diffusion head.
  • atom_positions: Predicted atomic coordinates
distogram
dict
Distance distribution predictions between residues.
predicted_lddt
np.ndarray
Predicted local distance difference test (pLDDT) scores per atom.
full_pae
np.ndarray
Full predicted aligned error matrix between all token pairs.
full_pde
np.ndarray
Full predicted distance error matrix.

Core Functions

get_predicted_structure

Converts model output to a Structure object with predicted coordinates.
def get_predicted_structure(
    result: ModelResult, 
    batch: feat_batch.Batch
) -> structure.Structure:
    """Creates the predicted structure from model output.
    
    Args:
        result: Model output in model-specific layout
        batch: Model input batch
        
    Returns:
        Predicted structure with atomic coordinates
    """
result
ModelResult
required
Dictionary containing model outputs including diffusion_samples with atom_positions.
batch
feat_batch.Batch
required
Input batch containing layout conversion information.

get_inference_result

Class method to compute full inference results including confidence metrics.
@classmethod
def get_inference_result(
    cls,
    batch: features.BatchDict,
    result: ModelResult,
    target_name: str = '',
) -> Iterable[InferenceResult]:
    """Get predicted structure, scalars, and arrays for inference.
    
    Computes inference-time quantities not calculated during forward pass,
    including additional confidence scores.
    
    Args:
        batch: Data batch used for model inference
        result: Output dict from model's forward pass
        target_name: Target name to be saved within structure
        
    Yields:
        InferenceResult: Contains predicted structure, confidence metrics,
                        numerical data, and metadata
    """
batch
features.BatchDict
required
Input feature dictionary including token features and atom layouts.
result
ModelResult
required
Raw model output from forward pass.
target_name
str
Optional name for the prediction target.

create_target_feat_embedding

Creates target feature embeddings for the Evoformer module.
def create_target_feat_embedding(
    batch: feat_batch.Batch,
    config: evoformer_network.Evoformer.Config,
    global_config: model_config.GlobalConfig,
) -> jnp.ndarray:
    """Create target feature embedding.
    
    Args:
        batch: Input batch with token and atom features
        config: Evoformer configuration
        global_config: Global model configuration
        
    Returns:
        Target feature tensor with atom cross-attention encoding
    """

Usage Example

import jax
from alphafold3.model import model
from alphafold3.model import features

# Initialize model
config = model.Model.Config(
    num_recycles=3,
    return_embeddings=True
)

# Create Haiku transform
def forward(batch):
    model_instance = model.Model(config)
    return model_instance(batch)

model_fn = hk.transform(forward)

# Initialize parameters
rng = jax.random.PRNGKey(42)
params = model_fn.init(rng, example_batch)

# Run inference
model_output = model_fn.apply(params, rng, batch)

# Get inference results with confidence metrics
for inference_result in model.Model.get_inference_result(
    batch=batch,
    result=model_output
):
    print(f"Ranking score: {inference_result.metadata['ranking_score']}")
    print(f"pTM: {inference_result.metadata['ptm']}")
    print(f"ipTM: {inference_result.metadata['iptm']}")
    
    # Save structure
    inference_result.predicted_structure.to_pdb("prediction.pdb")

Build docs developers (and LLMs) love