Skip to main content

Overview

The run_alphafold.py script is the main entry point for running AlphaFold 3 structure predictions. It orchestrates the complete prediction pipeline including data processing, MSA generation, template search, and model inference.

Main Functions

make_model_config

Creates a model configuration with customizable parameters.
def make_model_config(
    *,
    flash_attention_implementation: tokamax.DotProductAttentionImplementation = 'triton',
    num_diffusion_samples: int = 5,
    num_recycles: int = 10,
    return_embeddings: bool = False,
    return_distogram: bool = False,
) -> model.Model.Config
flash_attention_implementation
str
default:"triton"
Flash attention implementation to use. Options: 'triton', 'cudnn', or 'xla'. Triton is fastest and requires Ampere GPUs or later.
num_diffusion_samples
int
default:"5"
Number of diffusion samples to generate per seed.
num_recycles
int
default:"10"
Number of recycling iterations during inference.
return_embeddings
bool
default:"False"
Whether to return the final trunk single and pair embeddings. Embeddings are large float16 arrays: num_tokens * 384 + num_tokens * num_tokens * 128.
return_distogram
bool
default:"False"
Whether to return the final distogram. Distogram is a large float16 array: num_tokens * num_tokens * 64.
return
model.Model.Config
Configured model instance ready for inference.

predict_structure

Runs the full inference pipeline to predict structures for each seed.
def predict_structure(
    fold_input: folding_input.Input,
    model_runner: ModelRunner,
    buckets: Sequence[int] | None = None,
    ref_max_modified_date: datetime.date | None = None,
    conformer_max_iterations: int | None = None,
    resolve_msa_overlaps: bool = True,
) -> Sequence[ResultsForSeed]
fold_input
folding_input.Input
required
The input containing chains, sequences, MSAs, and templates.
model_runner
ModelRunner
required
The model runner instance for executing predictions.
buckets
Sequence[int] | None
Token bucket sizes for compilation caching. If None, calculates appropriate bucket from token count.
ref_max_modified_date
datetime.date | None
Maximum date for using CCD model coordinates as fallback.
conformer_max_iterations
int | None
Maximum iterations for RDKit conformer search.
resolve_msa_overlaps
bool
default:"True"
Whether to deduplicate unpaired MSA against paired MSA.
return
Sequence[ResultsForSeed]
List of results for each seed, containing inference results and full fold input.

process_fold_input

Runs data pipeline and/or inference on a single fold input.
def process_fold_input(
    fold_input: folding_input.Input,
    data_pipeline_config: pipeline.DataPipelineConfig | None,
    *,
    model_runner: ModelRunner | None,
    output_dir: os.PathLike[str] | str,
    buckets: Sequence[int] | None = None,
    ref_max_modified_date: datetime.date | None = None,
    conformer_max_iterations: int | None = None,
    resolve_msa_overlaps: bool = True,
    force_output_dir: bool = False,
    compress_large_output_files: bool = False,
) -> folding_input.Input | Sequence[ResultsForSeed]
fold_input
folding_input.Input
required
Fold input to process.
data_pipeline_config
pipeline.DataPipelineConfig | None
required
Data pipeline config to use. If None, skip the data pipeline.
model_runner
ModelRunner | None
required
Model runner to use. If None, skip inference.
output_dir
str
required
Output directory to write results to.
force_output_dir
bool
default:"False"
If True, use existing output directory even if non-empty. If False, create timestamped directory.
compress_large_output_files
bool
default:"False"
If True, compress large output files (mmCIF and confidences JSON) using zstandard.

ModelRunner Class

Helper class to run structure prediction stages.

Constructor

def __init__(
    self,
    config: model.Model.Config,
    device: jax.Device,
    model_dir: pathlib.Path,
)
config
model.Model.Config
required
Model configuration.
device
jax.Device
required
JAX device to run inference on (e.g., GPU).
model_dir
pathlib.Path
required
Path to directory containing model parameters.

Methods

run_inference

def run_inference(
    self, 
    featurised_example: features.BatchDict, 
    rng_key: jnp.ndarray
) -> model.ModelResult
Computes a forward pass of the model on a featurised example.

extract_inference_results

def extract_inference_results(
    self,
    batch: features.BatchDict,
    result: model.ModelResult,
    target_name: str,
) -> list[model.InferenceResult]
Extracts inference results from model outputs.

extract_embeddings

def extract_embeddings(
    self, 
    result: model.ModelResult, 
    num_tokens: int
) -> dict[str, np.ndarray] | None
Extracts single and pair embeddings from model outputs.

extract_distogram

def extract_distogram(
    self, 
    result: model.ModelResult, 
    num_tokens: int
) -> np.ndarray | None
Extracts distogram from model outputs.

ResultsForSeed

Dataclass storing inference results for a single seed.
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
class ResultsForSeed:
    seed: int
    inference_results: Sequence[model.InferenceResult]
    full_fold_input: folding_input.Input
    embeddings: dict[str, np.ndarray] | None = None
    distogram: np.ndarray | None = None
seed
int
The random seed used to generate the samples.
inference_results
Sequence[model.InferenceResult]
The inference results, one per diffusion sample.
full_fold_input
folding_input.Input
The fold input including MSA and templates from data pipeline.
embeddings
dict[str, np.ndarray] | None
The final trunk single and pair embeddings, if requested.
distogram
np.ndarray | None
The token distance histogram, if requested.

Command Line Flags

Input/Output

  • --json_path: Path to input JSON file
  • --input_dir: Path to directory containing input JSON files
  • --output_dir: Path to output directory (required)
  • --model_dir: Path to model directory (default: ~/models)

Pipeline Control

  • --run_data_pipeline: Whether to run data pipeline (default: True)
  • --run_inference: Whether to run inference (default: True)

Database Paths

  • --db_dir: Database directory path (can specify multiple)
  • --small_bfd_database_path: Small BFD database path
  • --mgnify_database_path: Mgnify database path
  • --uniref90_database_path: UniRef90 database path
  • --uniprot_cluster_annot_database_path: UniProt database path
  • --ntrna_database_path: NT-RNA database path
  • --rfam_database_path: Rfam database path
  • --rna_central_database_path: RNAcentral database path
  • --pdb_database_path: PDB mmCIF files directory
  • --seqres_database_path: PDB sequence database path

Performance Tuning

  • --num_recycles: Number of recycles (default: 10)
  • --num_diffusion_samples: Number of diffusion samples (default: 5)
  • --num_seeds: Number of seeds to generate
  • --gpu_device: GPU device index (default: 0)
  • --flash_attention_implementation: Flash attention type: triton, cudnn, or xla (default: triton)
  • --buckets: Token bucket sizes for compilation caching

Output Control

  • --save_embeddings: Save final embeddings (default: False)
  • --save_distogram: Save distogram (default: False)
  • --compress_large_output_files: Compress output files (default: False)
  • --force_output_dir: Use existing output directory (default: False)

Usage Example

from alphafold3.common import folding_input
from alphafold3.data import pipeline
import pathlib

# Create model configuration
model_config = make_model_config(
    num_diffusion_samples=5,
    num_recycles=10,
    flash_attention_implementation='triton'
)

# Initialize model runner
model_runner = ModelRunner(
    config=model_config,
    device=jax.devices('gpu')[0],
    model_dir=pathlib.Path('~/models')
)

# Load fold input
fold_input = folding_input.load_fold_inputs_from_path(
    pathlib.Path('input.json')
)

# Run prediction
results = predict_structure(
    fold_input=fold_input,
    model_runner=model_runner,
    buckets=[256, 512, 1024],
)

Command Line Usage

python run_alphafold.py \
  --json_path=/path/to/input.json \
  --output_dir=/path/to/output \
  --model_dir=/path/to/models \
  --db_dir=/path/to/databases \
  --num_diffusion_samples=5 \
  --num_recycles=10

See Also

Build docs developers (and LLMs) love