Overview
Batch inference allows you to process multiple prompts simultaneously, significantly improving throughput. With flash attention enabled, Qwen can achieve up to 40% speedup with batch inference compared to sequential processing.
Prerequisites
Install Flash Attention
For optimal batch inference performance, install flash-attention: git clone https://github.com/Dao-AILab/flash-attention
cd flash-attention && pip install .
Install Dependencies
Ensure you have the required packages: pip install transformer s > = 4.32.0
pip install torc h > = 2.0.0
Flash attention 2 is now supported and provides the best performance for batch inference.
Basic Batch Inference
Here’s a complete example of batch inference with Qwen:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import GenerationConfig
from qwen_generation_utils import make_context, decode_tokens, get_stop_words_ids
# Configure tokenizer for batch inference
# Assign distinct token_ids to pad_token and eos_token
tokenizer = AutoTokenizer.from_pretrained(
'Qwen/Qwen-7B-Chat' ,
pad_token = '<|extra_0|>' ,
eos_token = '<|endoftext|>' ,
padding_side = 'left' ,
trust_remote_code = True
)
# Load model with pad_token_id
model = AutoModelForCausalLM.from_pretrained(
'Qwen/Qwen-7B-Chat' ,
pad_token_id = tokenizer.pad_token_id,
device_map = "auto" ,
trust_remote_code = True
).eval()
# Set generation config with pad_token_id
model.generation_config = GenerationConfig.from_pretrained(
'Qwen/Qwen-7B-Chat' ,
pad_token_id = tokenizer.pad_token_id
)
# Prepare batch of queries
all_raw_text = [
"我想听你说爱我。" ,
"今天我想吃点啥,甜甜的,推荐下" ,
"我马上迟到了,怎么做才能不迟到"
]
# Create context for each query
batch_raw_text = []
for q in all_raw_text:
raw_text, _ = make_context(
tokenizer,
q,
system = "You are a helpful assistant." ,
max_window_size = model.generation_config.max_window_size,
chat_format = model.generation_config.chat_format,
)
batch_raw_text.append(raw_text)
# Tokenize with padding
batch_input_ids = tokenizer(batch_raw_text, padding = 'longest' )
batch_input_ids = torch.LongTensor(batch_input_ids[ 'input_ids' ]).to(model.device)
# Generate responses
batch_out_ids = model.generate(
batch_input_ids,
return_dict_in_generate = False ,
generation_config = model.generation_config
)
# Calculate padding lengths
padding_lens = [
batch_input_ids[i].eq(tokenizer.pad_token_id).sum().item()
for i in range (batch_input_ids.size( 0 ))
]
# Decode responses
batch_response = [
decode_tokens(
batch_out_ids[i][padding_lens[i]:],
tokenizer,
raw_text_len = len (batch_raw_text[i]),
context_length = (batch_input_ids[i].size( 0 ) - padding_lens[i]),
chat_format = "chatml" ,
verbose = False ,
errors = 'replace'
) for i in range ( len (all_raw_text))
]
print (batch_response)
Key Configuration Details
For batch inference, you must configure the tokenizer with distinct pad and eos tokens: tokenizer = AutoTokenizer.from_pretrained(
'Qwen/Qwen-7B-Chat' ,
pad_token = '<|extra_0|>' , # Special padding token
eos_token = '<|endoftext|>' , # End of sequence token
padding_side = 'left' , # Left padding for causal LM
trust_remote_code = True
)
Why left padding? Causal language models generate from left to right. Left padding ensures that the actual content is right-aligned, which is crucial for proper attention mask generation.
Pass the pad_token_id to the model and generation config: model = AutoModelForCausalLM.from_pretrained(
'Qwen/Qwen-7B-Chat' ,
pad_token_id = tokenizer.pad_token_id, # Important!
device_map = "auto" ,
trust_remote_code = True
).eval()
model.generation_config = GenerationConfig.from_pretrained(
'Qwen/Qwen-7B-Chat' ,
pad_token_id = tokenizer.pad_token_id # Important!
)
Use make_context to format each query properly: from qwen_generation_utils import make_context
batch_raw_text = []
for query in all_queries:
raw_text, _ = make_context(
tokenizer,
query,
system = "You are a helpful assistant." ,
max_window_size = model.generation_config.max_window_size,
chat_format = model.generation_config.chat_format,
)
batch_raw_text.append(raw_text)
Batch Size Selection
Choose batch size based on your GPU memory and sequence length:
# Example for different GPU memory sizes
# 16GB GPU (e.g., V100)
batch_size = 4
max_sequence_length = 2048
# 24GB GPU (e.g., RTX 3090)
batch_size = 8
max_sequence_length = 2048
# 40GB GPU (e.g., A100)
batch_size = 16
max_sequence_length = 2048
# 80GB GPU (e.g., A100 80GB)
batch_size = 32
max_sequence_length = 2048
Dynamic Batching
Process queries in batches dynamically:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from qwen_generation_utils import make_context, decode_tokens
def process_in_batches ( queries , model , tokenizer , batch_size = 8 ):
"""
Process queries in batches of specified size.
"""
all_responses = []
for i in range ( 0 , len (queries), batch_size):
batch_queries = queries[i:i + batch_size]
# Prepare batch
batch_raw_text = []
for q in batch_queries:
raw_text, _ = make_context(
tokenizer,
q,
system = "You are a helpful assistant." ,
max_window_size = model.generation_config.max_window_size,
chat_format = model.generation_config.chat_format,
)
batch_raw_text.append(raw_text)
# Tokenize
batch_input_ids = tokenizer(batch_raw_text, padding = 'longest' )
batch_input_ids = torch.LongTensor(batch_input_ids[ 'input_ids' ]).to(model.device)
# Generate
batch_out_ids = model.generate(
batch_input_ids,
return_dict_in_generate = False ,
generation_config = model.generation_config
)
# Decode
padding_lens = [
batch_input_ids[j].eq(tokenizer.pad_token_id).sum().item()
for j in range (batch_input_ids.size( 0 ))
]
batch_response = [
decode_tokens(
batch_out_ids[j][padding_lens[j]:],
tokenizer,
raw_text_len = len (batch_raw_text[j]),
context_length = (batch_input_ids[j].size( 0 ) - padding_lens[j]),
chat_format = "chatml" ,
verbose = False ,
errors = 'replace'
) for j in range ( len (batch_queries))
]
all_responses.extend(batch_response)
return all_responses
# Usage
tokenizer = AutoTokenizer.from_pretrained(
'Qwen/Qwen-7B-Chat' ,
pad_token = '<|extra_0|>' ,
eos_token = '<|endoftext|>' ,
padding_side = 'left' ,
trust_remote_code = True
)
model = AutoModelForCausalLM.from_pretrained(
'Qwen/Qwen-7B-Chat' ,
pad_token_id = tokenizer.pad_token_id,
device_map = "auto" ,
trust_remote_code = True
).eval()
queries = [
"What is machine learning?" ,
"Explain quantum computing" ,
"How does photosynthesis work?" ,
# ... many more queries
]
responses = process_in_batches(queries, model, tokenizer, batch_size = 8 )
for q, r in zip (queries, responses):
print ( f "Q: { q } \n A: { r } \n " )
Comparing Single vs Batch
Let’s compare single and batch inference:
Single Inference
Batch Inference
# Process queries one by one
responses = []
for query in all_queries:
response, _ = model.chat(tokenizer, query, history = None )
responses.append(response)
# Time: ~5.0 seconds for 10 queries
Advanced: Mixed-Length Batching
Handle queries of very different lengths efficiently:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def smart_batching ( queries , model , tokenizer , max_batch_tokens = 8192 ):
"""
Group queries into batches based on total token count.
"""
# Tokenize all queries
tokenized = [tokenizer.encode(q) for q in queries]
lengths = [ len (t) for t in tokenized]
# Sort by length
sorted_indices = sorted ( range ( len (queries)), key = lambda i : lengths[i])
batches = []
current_batch = []
current_tokens = 0
for idx in sorted_indices:
query_tokens = lengths[idx]
# Check if adding this query exceeds batch token limit
if current_tokens + query_tokens > max_batch_tokens and current_batch:
batches.append(current_batch)
current_batch = []
current_tokens = 0
current_batch.append(idx)
current_tokens += query_tokens
if current_batch:
batches.append(current_batch)
# Process each batch
responses = [ None ] * len (queries)
for batch_indices in batches:
batch_queries = [queries[i] for i in batch_indices]
batch_responses = process_in_batches(
batch_queries, model, tokenizer, batch_size = len (batch_queries)
)
for i, response in zip (batch_indices, batch_responses):
responses[i] = response
return responses
Memory Management
Monitor and optimize memory usage:
import torch
import gc
def batch_inference_with_memory_management ( queries , model , tokenizer , batch_size = 8 ):
"""
Batch inference with explicit memory management.
"""
all_responses = []
for i in range ( 0 , len (queries), batch_size):
batch_queries = queries[i:i + batch_size]
# Process batch
batch_responses = process_in_batches(
batch_queries, model, tokenizer, batch_size = len (batch_queries)
)
all_responses.extend(batch_responses)
# Clear cache after each batch
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Optional: Print memory usage
if torch.cuda.is_available():
print ( f "GPU Memory: { torch.cuda.memory_allocated() / 1e9 :.2f} GB" )
return all_responses
Typical speedups with batch inference (with flash attention):
Batch Size Speedup vs Sequential GPU Memory (7B) 1 (baseline) 1.0x 16GB 4 1.3x 18GB 8 1.4x 22GB 16 1.4x 28GB 32 1.4x 40GB
Maximum speedup of ~40% is typically achieved with batch sizes of 8-16. Larger batches may not provide additional speedup due to GPU saturation.
Troubleshooting
Reduce batch size or sequence length: # Reduce batch size
batch_size = 4 # Instead of 8
# Or reduce max tokens
model.generation_config.max_new_tokens = 256 # Instead of 512
Ensure you’re using left padding: tokenizer = AutoTokenizer.from_pretrained(
'Qwen/Qwen-7B-Chat' ,
padding_side = 'left' , # Critical for causal LM!
trust_remote_code = True
)
Check that flash attention is installed: import torch
# Check flash attention
try :
import flash_attn
print ( "Flash attention available" )
except ImportError :
print ( "Flash attention not installed" )
print ( "Install: pip install flash-attn" )
Next Steps
Streaming Responses Stream tokens as they’re generated
Multi-GPU Inference Scale across multiple GPUs
vLLM Integration Production-grade serving with vLLM
Quantization Reduce memory with quantization