Skip to main content

Overview

The features module handles the data-side processing of input features for AlphaFold 3. It provides dataclasses and functions for converting raw input data into model-ready tensors, including MSA processing, template features, token features, and atom layout management.

Core Types

BatchDict

Type alias for feature dictionaries passed to the model.
BatchDict: TypeAlias = dict[str, xnp_ndarray]
Where xnp_ndarray is a union type for NumPy or JAX arrays.

PaddingShapes

Defines padding dimensions for batched model inputs.
@dataclasses.dataclass(frozen=True)
class PaddingShapes:
    num_tokens: int
    msa_size: int
    num_chains: int
    num_templates: int
    num_atoms: int
num_tokens
int
required
Maximum number of tokens (residues + ligand atoms) in the sequence.
msa_size
int
required
Maximum number of MSA rows to include.
num_chains
int
required
Maximum number of chains in the complex.
num_templates
int
required
Maximum number of structural templates.
num_atoms
int
required
Maximum number of atoms per token.

MSA Features

MSA Dataclass

Contains multiple sequence alignment features.
@dataclasses.dataclass(frozen=True)
class MSA:
    rows: xnp_ndarray
    mask: xnp_ndarray
    deletion_matrix: xnp_ndarray
    profile: xnp_ndarray
    deletion_mean: xnp_ndarray
    num_alignments: xnp_ndarray
rows
xnp_ndarray
MSA sequences encoded as integers. Shape: (msa_size, num_tokens)
mask
xnp_ndarray
Binary mask for valid MSA positions. Shape: (msa_size, num_tokens)
deletion_matrix
xnp_ndarray
Number of deletions at each MSA position. Shape: (msa_size, num_tokens)
profile
xnp_ndarray
Occurrence of each residue type along the sequence, averaged over MSA rows. Shape: (num_tokens, num_residue_types)
deletion_mean
xnp_ndarray
Occurrence of deletions along the sequence, averaged over MSA rows. Shape: (num_tokens,)
num_alignments
xnp_ndarray
Total number of MSA alignments (scalar).

compute_features

Computes MSA features from folding input.
@classmethod
def compute_features(
    cls,
    *,
    all_tokens: atom_layout.AtomLayout,
    standard_token_idxs: np.ndarray,
    padding_shapes: PaddingShapes,
    fold_input: folding_input.Input,
    logging_name: str,
    max_paired_sequence_per_species: int,
    resolve_msa_overlaps: bool = True,
) -> Self:
    """Compute the MSA features."""
all_tokens
atom_layout.AtomLayout
required
Atom layout containing one representative atom per token.
standard_token_idxs
np.ndarray
required
Token indices for non-flattened standard residues.
padding_shapes
PaddingShapes
required
Padding dimensions for the output tensors.
fold_input
folding_input.Input
required
Input data containing MSAs for each chain.
logging_name
str
required
Name for logging (typically mmCIF ID).
max_paired_sequence_per_species
int
required
Maximum number of paired sequences per species.
resolve_msa_overlaps
bool
default:true
Whether to deduplicate overlapping sequences in paired MSA.

Methods

index_msa_rows

Subsample MSA rows by indices.
def index_msa_rows(self, indices: xnp_ndarray) -> Self:
    """Returns new MSA with selected rows."""

from_data_dict

Create MSA from batch dictionary.
@classmethod
def from_data_dict(cls, batch: BatchDict) -> Self:
    """Reconstruct MSA from feature dictionary."""

as_data_dict

Convert MSA to batch dictionary.
def as_data_dict(self) -> BatchDict:
    """Convert to dictionary of arrays for model input."""

Template Features

Templates Dataclass

Contains structural template features.
@dataclasses.dataclass(frozen=True)
class Templates:
    aatype: xnp_ndarray  # Shape: (num_templates, num_tokens)
    atom_positions: xnp_ndarray  # Shape: (num_templates, num_tokens, 24, 3)
    atom_mask: xnp_ndarray  # Shape: (num_templates, num_tokens, 24)
aatype
xnp_ndarray
Amino acid type encoded as integers. Shape: (num_templates, num_tokens)
atom_positions
xnp_ndarray
3D coordinates of template atoms. Shape: (num_templates, num_tokens, 24, 3)
atom_mask
xnp_ndarray
Binary mask for valid template atoms. Shape: (num_templates, num_tokens, 24)

compute_features

Computes template features from protein chain templates.
@classmethod
def compute_features(
    cls,
    all_tokens: atom_layout.AtomLayout,
    standard_token_idxs: np.ndarray,
    padding_shapes: PaddingShapes,
    fold_input: folding_input.Input,
    max_templates: int,
    logging_name: str,
) -> Self:
    """Compute template features."""
all_tokens
atom_layout.AtomLayout
required
Atom layout with representative atom per token.
standard_token_idxs
np.ndarray
required
Indices for standard (non-flattened) tokens.
padding_shapes
PaddingShapes
required
Padding dimensions.
fold_input
folding_input.Input
required
Input containing template structures.
max_templates
int
required
Maximum number of templates to use.
logging_name
str
required
Name for logging.

Token Features

TokenFeatures Dataclass

Per-token features including chain identifiers and token types.
@dataclasses.dataclass(frozen=True)
class TokenFeatures:
    residue_index: xnp_ndarray
    token_index: xnp_ndarray
    aatype: xnp_ndarray
    mask: xnp_ndarray
    seq_length: xnp_ndarray
    
    # Chain symmetry identifiers
    asym_id: xnp_ndarray
    entity_id: xnp_ndarray
    sym_id: xnp_ndarray
    
    # Token type features
    is_protein: xnp_ndarray
    is_rna: xnp_ndarray
    is_dna: xnp_ndarray
    is_ligand: xnp_ndarray
    is_nonstandard_polymer_chain: xnp_ndarray
    is_water: xnp_ndarray
residue_index
xnp_ndarray
Residue index from input structure. Shape: (num_tokens,)
token_index
xnp_ndarray
Sequential token index (1-indexed). Shape: (num_tokens,)
aatype
xnp_ndarray
Encoded residue/ligand type. Shape: (num_tokens,)
mask
xnp_ndarray
Binary mask for valid tokens. Shape: (num_tokens,)
seq_length
xnp_ndarray
Total sequence length (scalar).
asym_id
xnp_ndarray
Asymmetric unit ID for each chain. For A3B2 stoichiometry: 1, 2, 3, 4, 5. Shape: (num_tokens,)
entity_id
xnp_ndarray
Entity ID grouping identical sequences. For A3B2: 1, 1, 1, 2, 2. Shape: (num_tokens,)
sym_id
xnp_ndarray
Symmetry ID within entity. For A3B2: 1, 2, 3, 1, 2. Shape: (num_tokens,)
is_protein
xnp_ndarray
Boolean mask for protein tokens. Shape: (num_tokens,)
is_rna
xnp_ndarray
Boolean mask for RNA tokens. Shape: (num_tokens,)
is_dna
xnp_ndarray
Boolean mask for DNA tokens. Shape: (num_tokens,)
is_ligand
xnp_ndarray
Boolean mask for ligand tokens. Shape: (num_tokens,)
is_nonstandard_polymer_chain
xnp_ndarray
Boolean mask for non-standard polymer chains. Shape: (num_tokens,)
is_water
xnp_ndarray
Boolean mask for water molecules. Shape: (num_tokens,)

Tokenization

tokenizer

Maps flat atom layout to tokens for the Evoformer.
def tokenizer(
    flat_output_layout: atom_layout.AtomLayout,
    ccd: chemical_components.Ccd,
    max_atoms_per_token: int,
    flatten_non_standard_residues: bool,
    logging_name: str,
) -> tuple[atom_layout.AtomLayout, atom_layout.AtomLayout, np.ndarray]:
    """Maps flat atom layout to tokens for evoformer.
    
    Creates one token per polymer residue and one token per ligand atom.
    
    Returns:
        all_tokens: AtomLayout with 1 representative atom per token
        all_token_atoms_layout: AtomLayout with all atoms per token
        standard_token_idxs: Token indices if not flattening non-standard residues
    """
flat_output_layout
atom_layout.AtomLayout
required
Flat atom layout containing all atoms to predict.
ccd
chemical_components.Ccd
required
Chemical components dictionary.
max_atoms_per_token
int
required
Number of atom slots per token.
flatten_non_standard_residues
bool
required
Whether to use one token per atom for non-standard residues.
logging_name
str
required
Name for logging (typically mmCIF ID).
Tokenization Rules:
  • Standard protein residues: 1 token per residue (CA representative atom)
  • Standard nucleic residues: 1 token per residue (C1’ representative atom)
  • Non-standard polymer residues: 1 token per atom if flatten_non_standard_residues=True
  • Ligands: 1 token per atom

Additional Feature Classes

PredictedStructureInfo

Information for working with predicted structures.
@dataclasses.dataclass(frozen=True)
class PredictedStructureInfo:
    atom_mask: xnp_ndarray
    residue_center_index: xnp_ndarray

PolymerLigandBondInfo

Information about polymer-ligand bonds.
@dataclasses.dataclass(frozen=True)
class PolymerLigandBondInfo:
    tokens_to_polymer_ligand_bonds: atom_layout.GatherInfo
    token_atoms_to_bonds: atom_layout.GatherInfo

LigandLigandBondInfo

Information about ligand-ligand bonds.
@dataclasses.dataclass(frozen=True)
class LigandLigandBondInfo:
    tokens_to_ligand_ligand_bonds: atom_layout.GatherInfo

PseudoBetaInfo

Information for extracting pseudo-beta and equivalent atoms.
@dataclasses.dataclass(frozen=True)
class PseudoBetaInfo:
    token_atoms_to_pseudo_beta: atom_layout.GatherInfo
Pseudo-beta atom selection:
  • Protein: CB (or CA for glycine)
  • Nucleic acids (purines A/G/DA/DG): C4
  • Nucleic acids (pyrimidines C/T/U/DC/DT): C2
  • Ligands: First atom

Chains

Chain identification dataclass.
@dataclasses.dataclass(frozen=True)
class Chains:
    chain_id: np.ndarray
    asym_id: np.ndarray
    entity_id: np.ndarray
    sym_id: np.ndarray

Usage Example

from alphafold3.model import features
from alphafold3.model.atom_layout import atom_layout
from alphafold3.constants import chemical_components

# Define padding shapes
padding_shapes = features.PaddingShapes(
    num_tokens=512,
    msa_size=1024,
    num_chains=10,
    num_templates=4,
    num_atoms=128
)

# Tokenize input structure
all_tokens, all_token_atoms, standard_idxs = features.tokenizer(
    flat_output_layout=flat_layout,
    ccd=chemical_components.CCD,
    max_atoms_per_token=128,
    flatten_non_standard_residues=True,
    logging_name="7A4D"
)

# Compute MSA features
msa_features = features.MSA.compute_features(
    all_tokens=all_tokens,
    standard_token_idxs=standard_idxs,
    padding_shapes=padding_shapes,
    fold_input=fold_input,
    logging_name="7A4D",
    max_paired_sequence_per_species=50
)

# Compute token features
token_features = features.TokenFeatures.compute_features(
    all_tokens=all_tokens,
    padding_shapes=padding_shapes
)

# Create batch dictionary
batch = {
    **msa_features.as_data_dict(),
    **token_features.as_data_dict(),
}

Build docs developers (and LLMs) love