Overview
Llama 2 models use model parallelism to distribute large models across multiple GPUs. Different model sizes require different model-parallel (MP) values, which must be set correctly when running inference.Model Parallel Values
The MP value determines how many GPUs the model is sharded across:| Model | MP Value | Description |
|---|---|---|
| 7B | 1 | Runs on a single GPU |
| 13B | 2 | Sharded across 2 GPUs |
| 70B | 8 | Sharded across 8 GPUs |
Setting Model Parallel
Command Line
Use the--nproc_per_node flag with torchrun to set the MP value:
Python API
TheLlama.build() method accepts an optional model_parallel_size parameter:
WORLD_SIZE environment variable:
How It Works
Checkpoint Files
Each model directory contains.pth checkpoint files corresponding to the MP value:
- 7B models: 1 checkpoint file (
consolidated.00.pth) - 13B models: 2 checkpoint files (
consolidated.00.pth,consolidated.01.pth) - 70B models: 8 checkpoint files (
consolidated.00.pththroughconsolidated.07.pth)
Initialization
Model parallel initialization happens in thebuild() method:
Memory Considerations
Sequence Length and Batch Size
All models support sequences up to 4096 tokens, but memory is pre-allocated based onmax_seq_len and max_batch_size:
- Higher values: More memory usage, can process longer sequences and larger batches
- Lower values: Less memory usage, suitable for limited GPU memory
Text Completion Example
Chat Completion Example
Distributed Setup
Llama 2 uses:- torch.distributed: For multi-GPU communication (NCCL backend)
- fairscale.nn.model_parallel: For model parallel operations
Troubleshooting
Checkpoint Mismatch Error
--nproc_per_node to match the number of checkpoint files in your model directory.
Out of Memory Error
Solution: Reducemax_seq_len and max_batch_size parameters to lower memory usage.
No Checkpoint Files Found
ckpt_dir path is correct and contains .pth checkpoint files.