Skip to main content

Overview

Llama 2 models use model parallelism to distribute large models across multiple GPUs. Different model sizes require different model-parallel (MP) values, which must be set correctly when running inference.

Model Parallel Values

The MP value determines how many GPUs the model is sharded across:
ModelMP ValueDescription
7B1Runs on a single GPU
13B2Sharded across 2 GPUs
70B8Sharded across 8 GPUs

Setting Model Parallel

Command Line

Use the --nproc_per_node flag with torchrun to set the MP value:
# 7B model (MP=1)
torchrun --nproc_per_node 1 example_chat_completion.py \
    --ckpt_dir llama-2-7b-chat/ \
    --tokenizer_path tokenizer.model \
    --max_seq_len 512 --max_batch_size 6
# 13B model (MP=2)
torchrun --nproc_per_node 2 example_chat_completion.py \
    --ckpt_dir llama-2-13b-chat/ \
    --tokenizer_path tokenizer.model \
    --max_seq_len 512 --max_batch_size 6
# 70B model (MP=8)
torchrun --nproc_per_node 8 example_chat_completion.py \
    --ckpt_dir llama-2-70b-chat/ \
    --tokenizer_path tokenizer.model \
    --max_seq_len 512 --max_batch_size 6

Python API

The Llama.build() method accepts an optional model_parallel_size parameter:
from llama import Llama

generator = Llama.build(
    ckpt_dir="llama-2-13b-chat/",
    tokenizer_path="tokenizer.model",
    max_seq_len=512,
    max_batch_size=8,
    model_parallel_size=2,  # Set MP value explicitly
)
If not provided, the MP value is determined from the WORLD_SIZE environment variable:
# Automatically uses WORLD_SIZE from environment
generator = Llama.build(
    ckpt_dir="llama-2-13b-chat/",
    tokenizer_path="tokenizer.model",
    max_seq_len=512,
    max_batch_size=8,
)

How It Works

Checkpoint Files

Each model directory contains .pth checkpoint files corresponding to the MP value:
  • 7B models: 1 checkpoint file (consolidated.00.pth)
  • 13B models: 2 checkpoint files (consolidated.00.pth, consolidated.01.pth)
  • 70B models: 8 checkpoint files (consolidated.00.pth through consolidated.07.pth)
The number of checkpoint files must match the MP value you specify.

Initialization

Model parallel initialization happens in the build() method:
if not model_parallel_is_initialized():
    if model_parallel_size is None:
        model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
    initialize_model_parallel(model_parallel_size)
Each process loads its corresponding checkpoint based on its rank:
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert model_parallel_size == len(checkpoints), \
    f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
ckpt_path = checkpoints[get_model_parallel_rank()]

Memory Considerations

Sequence Length and Batch Size

All models support sequences up to 4096 tokens, but memory is pre-allocated based on max_seq_len and max_batch_size:
generator = Llama.build(
    ckpt_dir="llama-2-7b-chat/",
    tokenizer_path="tokenizer.model",
    max_seq_len=512,      # Reduce for lower memory usage
    max_batch_size=4,     # Reduce for lower memory usage
)
Set these parameters according to your hardware capabilities:
  • Higher values: More memory usage, can process longer sequences and larger batches
  • Lower values: Less memory usage, suitable for limited GPU memory

Text Completion Example

torchrun --nproc_per_node 1 example_text_completion.py \
    --ckpt_dir llama-2-7b/ \
    --tokenizer_path tokenizer.model \
    --max_seq_len 128 --max_batch_size 4

Chat Completion Example

torchrun --nproc_per_node 1 example_chat_completion.py \
    --ckpt_dir llama-2-7b-chat/ \
    --tokenizer_path tokenizer.model \
    --max_seq_len 512 --max_batch_size 6

Distributed Setup

Llama 2 uses:
  • torch.distributed: For multi-GPU communication (NCCL backend)
  • fairscale.nn.model_parallel: For model parallel operations
The distributed process group is automatically initialized:
if not torch.distributed.is_initialized():
    torch.distributed.init_process_group("nccl")

Troubleshooting

Checkpoint Mismatch Error

AssertionError: Loading a checkpoint for MP=2 but world size is 1
Solution: Set --nproc_per_node to match the number of checkpoint files in your model directory.

Out of Memory Error

Solution: Reduce max_seq_len and max_batch_size parameters to lower memory usage.

No Checkpoint Files Found

AssertionError: no checkpoint files found in {ckpt_dir}
Solution: Verify that the ckpt_dir path is correct and contains .pth checkpoint files.

Build docs developers (and LLMs) love