Tensor parallelism and pipeline parallelism enable vLLM to run models that don’t fit on a single GPU by distributing the model across multiple GPUs and nodes.
Parallelism strategies
vLLM supports three distributed inference strategies:
Tensor parallelism (single-node multi-GPU)
Splits individual weight tensors across multiple GPUs. All GPUs work on the same batch of requests simultaneously.
When to use:
- Model fits on a single node but not a single GPU
- You have multiple GPUs with fast interconnect (NVLink)
- Low latency is critical
Configuration:
from vllm import LLM
llm = LLM(
model="meta-llama/Llama-3.1-70B-Instruct",
tensor_parallel_size=4, # Use 4 GPUs
)
Pipeline parallelism (multi-node)
Splits the model into stages by layers, with each stage running on a different GPU or node. Requests flow through stages sequentially.
When to use:
- Model doesn’t fit on a single node
- You have multiple nodes
- GPU count doesn’t evenly divide for tensor parallelism
- GPUs lack NVLink (e.g., L40S)
Configuration:
llm = LLM(
model="meta-llama/Llama-3.1-70B-Instruct",
tensor_parallel_size=4, # 4 GPUs per node
pipeline_parallel_size=2, # 2 nodes
)
Expert parallelism (MoE models)
For Mixture-of-Experts models, distribute expert layers separately for better load balancing.
When to use:
- Running MoE models (Mixtral, DeepSeek, etc.)
- Want to optimize expert-level parallelism
See Data Parallel Deployment for details.
Quick start
Single-node (tensor parallelism)
Offline inference:
from vllm import LLM, SamplingParams
# Model requires ~140GB (doesn't fit on 1x A100 80GB)
llm = LLM(
model="meta-llama/Llama-3.1-70B-Instruct",
tensor_parallel_size=2, # Use 2 GPUs
)
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=256,
)
outputs = llm.generate(["Explain tensor parallelism:"], sampling_params)
print(outputs[0].outputs[0].text)
API server:
vllm serve meta-llama/Llama-3.1-70B-Instruct \
--tensor-parallel-size 4
Multi-node (tensor + pipeline parallelism)
Using multiprocessing (simple setup):
On head node:
vllm serve meta-llama/Llama-3.1-405B-Instruct \
--tensor-parallel-size 8 \
--pipeline-parallel-size 2 \
--nnodes 2 \
--node-rank 0 \
--master-addr <HEAD_NODE_IP>
On worker node:
vllm serve meta-llama/Llama-3.1-405B-Instruct \
--tensor-parallel-size 8 \
--pipeline-parallel-size 2 \
--nnodes 2 \
--node-rank 1 \
--master-addr <HEAD_NODE_IP> \
--headless
Using Ray (production deployments):
vllm serve meta-llama/Llama-3.1-405B-Instruct \
--tensor-parallel-size 8 \
--pipeline-parallel-size 4 \
--distributed-executor-backend ray
See Multi-node deployment for complete setup instructions.
Choosing parallelism strategy
Use this decision tree:
Does model fit on single GPU?
├─ Yes → No parallelism needed
└─ No → Does model fit on single node (multiple GPUs)?
├─ Yes → Use tensor_parallel_size = number of GPUs
└─ No → Does GPU count evenly divide model?
├─ Yes → Use tensor_parallel_size = GPUs per node
│ AND pipeline_parallel_size = number of nodes
└─ No → Use pipeline parallelism with uneven splits
Set tensor_parallel_size=1
Set pipeline_parallel_size = number of GPUs
Example configurations
Single node with 4x A100 80GB:
# For Llama-3.1-70B (~140GB)
llm = LLM(model="meta-llama/Llama-3.1-70B-Instruct", tensor_parallel_size=2)
# For Llama-3.1-405B (~810GB) - won't fit, need multiple nodes
Two nodes with 8x A100 80GB each:
# For Llama-3.1-405B (~810GB)
llm = LLM(
model="meta-llama/Llama-3.1-405B-Instruct",
tensor_parallel_size=8, # Use all 8 GPUs per node
pipeline_parallel_size=2, # Across 2 nodes
)
Single node with 8x L40S (no NVLink):
# Use pipeline parallelism instead of tensor for better performance
llm = LLM(
model="meta-llama/Llama-3.1-70B-Instruct",
tensor_parallel_size=1,
pipeline_parallel_size=8, # Each GPU handles different layers
)
Memory and capacity planning
After configuring parallelism, check GPU memory usage:
vllm serve model --tensor-parallel-size 4
Look for log messages:
INFO: GPU KV cache size: 643,232 tokens
INFO: Maximum concurrency for 40,960 tokens per request: 15.70x
- GPU KV cache size: Total tokens that fit in GPU KV cache
- Maximum concurrency: How many requests can run simultaneously
If these numbers are too low:
- Add more GPUs to increase
tensor_parallel_size
- Add more nodes to increase
pipeline_parallel_size
- Reduce
gpu_memory_utilization if you have other memory needs
llm = LLM(
model="large-model",
tensor_parallel_size=8,
gpu_memory_utilization=0.85, # Leave more room for KV cache
)
Multi-node deployment
vLLM supports two backends for multi-node deployment:
Ray (recommended for production)
Ray provides better fault tolerance, resource management, and scaling.
1. Set up Ray cluster using containers:
On head node:
bash examples/online_serving/run_cluster.sh \
vllm/vllm-openai \
<HEAD_NODE_IP> \
--head \
/path/to/huggingface/cache \
-e VLLM_HOST_IP=<HEAD_NODE_IP>
On each worker node:
bash examples/online_serving/run_cluster.sh \
vllm/vllm-openai \
<HEAD_NODE_IP> \
--worker \
/path/to/huggingface/cache \
-e VLLM_HOST_IP=<WORKER_NODE_IP>
2. Verify cluster:
# Inside any container
ray status
ray list nodes
3. Run vLLM:
vllm serve meta-llama/Llama-3.1-405B-Instruct \
--tensor-parallel-size 8 \
--pipeline-parallel-size 4 \
--distributed-executor-backend ray
vLLM automatically discovers all GPUs in the Ray cluster. Set tensor_parallel_size to GPUs per node and pipeline_parallel_size to number of nodes.
Multiprocessing (simpler setup)
For quick testing without Ray overhead.
On head node (rank 0):
vllm serve meta-llama/Llama-3.1-405B-Instruct \
--tensor-parallel-size 8 \
--pipeline-parallel-size 2 \
--nnodes 2 \
--node-rank 0 \
--master-addr <HEAD_NODE_IP>
On worker node (rank 1):
vllm serve meta-llama/Llama-3.1-405B-Instruct \
--tensor-parallel-size 8 \
--pipeline-parallel-size 2 \
--nnodes 2 \
--node-rank 1 \
--master-addr <HEAD_NODE_IP> \
--headless
Network optimization
Tensor parallelism requires fast inter-GPU communication. Optimize for:
InfiniBand (recommended)
For multi-node tensor parallelism, use InfiniBand for best performance.
# Add to run_cluster.sh
--privileged -e NCCL_IB_HCA=mlx5
Contact your system administrator for the correct NCCL_IB_HCA value.
GPUDirect RDMA
Enable GPU-to-GPU communication over InfiniBand without CPU involvement.
Docker:
docker run --gpus all \
--ipc=host \
--shm-size=16G \
-v /dev/shm:/dev/shm \
vllm/vllm-openai
Kubernetes:
apiVersion: v1
kind: Pod
spec:
containers:
- name: vllm
image: vllm/vllm-openai
securityContext:
capabilities:
add: ["IPC_LOCK"]
volumeMounts:
- mountPath: /dev/shm
name: dshm
resources:
limits:
nvidia.com/gpu: 8
volumes:
- name: dshm
emptyDir:
medium: Memory
Verify GPUDirect RDMA:
NCCL_DEBUG=TRACE vllm serve model --tensor-parallel-size 8
Look for:
- ✅
[send] via NET/IB/GDRDMA - InfiniBand with GPUDirect (good)
- ❌
[send] via NET/Socket - TCP fallback (inefficient for tensor parallelism)
Configuration parameters
Tensor parallelism
Number of GPUs to split tensor weights across.
- Typically set to number of GPUs per node
- Requires fast interconnect (NVLink or InfiniBand)
Pipeline parallelism
Number of pipeline stages (typically number of nodes).
- Model is split into stages by layers
- Each stage processes different requests sequentially
Distributed backend
distributed_executor_backend
Backend for distributed execution.
"ray": Use Ray (multi-node recommended)
"mp": Use multiprocessing (single-node default)
"auto": Automatically choose based on environment
Multi-node configuration
Total number of nodes (for multiprocessing backend).
Rank of current node, starting from 0 (for multiprocessing backend).
IP address of the head node (for multiprocessing backend).
Complete examples
Example 1: Single-node tensor parallelism
from vllm import LLM, SamplingParams
# 4x A100 80GB
llm = LLM(
model="meta-llama/Llama-3.1-70B-Instruct",
tensor_parallel_size=4,
gpu_memory_utilization=0.9,
)
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=512,
)
prompts = [
"Explain how tensor parallelism works in vLLM:",
"What are the benefits of distributed inference?",
]
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
print(f"Prompt: {output.prompt}")
print(f"Output: {output.outputs[0].text}")
print("-" * 80)
Example 2: Multi-node pipeline parallelism
# 2 nodes, 8 GPUs each, for Llama-3.1-405B
# On head node (rank 0):
vllm serve meta-llama/Llama-3.1-405B-Instruct \
--tensor-parallel-size 8 \
--pipeline-parallel-size 2 \
--nnodes 2 \
--node-rank 0 \
--master-addr 192.168.1.100
# On worker node (rank 1):
vllm serve meta-llama/Llama-3.1-405B-Instruct \
--tensor-parallel-size 8 \
--pipeline-parallel-size 2 \
--nnodes 2 \
--node-rank 1 \
--master-addr 192.168.1.100 \
--headless
# Make requests to head node:
curl http://192.168.1.100:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "meta-llama/Llama-3.1-405B-Instruct",
"prompt": "Explain distributed inference:",
"max_tokens": 256
}'
Example 3: Uneven GPU split with pipeline parallelism
# Single node, 6 GPUs (not evenly divisible)
llm = LLM(
model="meta-llama/Llama-3.1-70B-Instruct",
tensor_parallel_size=1, # No tensor parallelism
pipeline_parallel_size=6, # Split across 6 pipeline stages
)
Troubleshooting
Out of memory errors
Symptom: CUDA out of memory errors
Solutions:
- Increase
tensor_parallel_size or pipeline_parallel_size
- Reduce
gpu_memory_utilization (default 0.9)
- Reduce
max_model_len
- Enable quantization (FP8, INT8, etc.)
Symptom: No speedup or slower than single GPU
Possible causes:
- No NVLink between GPUs
- Network bottleneck in multi-node setup
- Small batch sizes (communication overhead dominates)
Solutions:
- Use pipeline parallelism instead if no NVLink
- Check InfiniBand/GPUDirect RDMA configuration
- Increase batch size or concurrent requests
Communication errors
Symptom: NCCL errors, timeout errors
Solutions:
- Verify all nodes can reach each other via IP
- Check firewall settings
- Ensure consistent vLLM versions across nodes
- Set
NCCL_DEBUG=INFO for detailed logs
Different outputs across runs
Symptom: Same prompt produces different outputs
Cause: Batch size variations affecting numerical stability
Solution: Set explicit seed in SamplingParams
sampling_params = SamplingParams(temperature=0.8, seed=42)
See FAQ for more details.
Best practices
Network security
Distributed vLLM traffic is unencrypted and can be exploited for remote code execution. Always use private network segments and restrict access.
Set VLLM_HOST_IP to private network addresses:
-e VLLM_HOST_IP=10.0.1.10 # Private network
Pre-download models
Download models on all nodes before starting vLLM to avoid concurrent download issues.
# On all nodes:
python -c "from huggingface_hub import snapshot_download; \
snapshot_download('meta-llama/Llama-3.1-70B-Instruct')"
Or use a shared filesystem (NFS, Lustre) accessible by all nodes.
Container consistency
Use identical container images across all nodes:
docker pull vllm/vllm-openai:latest
# Verify same image digest on all nodes
Resource allocation
For Ray clusters on Kubernetes, use KubeRay for automated resource management.