Skip to main content

Overview

Batch inference allows you to process multiple prompts simultaneously, significantly improving throughput. With flash attention enabled, Qwen can achieve up to 40% speedup with batch inference compared to sequential processing.

Prerequisites

1

Install Flash Attention

For optimal batch inference performance, install flash-attention:
git clone https://github.com/Dao-AILab/flash-attention
cd flash-attention && pip install .
2

Install Dependencies

Ensure you have the required packages:
pip install transformers>=4.32.0
pip install torch>=2.0.0
Flash attention 2 is now supported and provides the best performance for batch inference.

Basic Batch Inference

Here’s a complete example of batch inference with Qwen:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import GenerationConfig
from qwen_generation_utils import make_context, decode_tokens, get_stop_words_ids

# Configure tokenizer for batch inference
# Assign distinct token_ids to pad_token and eos_token
tokenizer = AutoTokenizer.from_pretrained(
    'Qwen/Qwen-7B-Chat',
    pad_token='<|extra_0|>',
    eos_token='<|endoftext|>',
    padding_side='left',
    trust_remote_code=True
)

# Load model with pad_token_id
model = AutoModelForCausalLM.from_pretrained(
    'Qwen/Qwen-7B-Chat',
    pad_token_id=tokenizer.pad_token_id,
    device_map="auto",
    trust_remote_code=True
).eval()

# Set generation config with pad_token_id
model.generation_config = GenerationConfig.from_pretrained(
    'Qwen/Qwen-7B-Chat', 
    pad_token_id=tokenizer.pad_token_id
)

# Prepare batch of queries
all_raw_text = [
    "我想听你说爱我。", 
    "今天我想吃点啥,甜甜的,推荐下", 
    "我马上迟到了,怎么做才能不迟到"
]

# Create context for each query
batch_raw_text = []
for q in all_raw_text:
    raw_text, _ = make_context(
        tokenizer,
        q,
        system="You are a helpful assistant.",
        max_window_size=model.generation_config.max_window_size,
        chat_format=model.generation_config.chat_format,
    )
    batch_raw_text.append(raw_text)

# Tokenize with padding
batch_input_ids = tokenizer(batch_raw_text, padding='longest')
batch_input_ids = torch.LongTensor(batch_input_ids['input_ids']).to(model.device)

# Generate responses
batch_out_ids = model.generate(
    batch_input_ids,
    return_dict_in_generate=False,
    generation_config=model.generation_config
)

# Calculate padding lengths
padding_lens = [
    batch_input_ids[i].eq(tokenizer.pad_token_id).sum().item() 
    for i in range(batch_input_ids.size(0))
]

# Decode responses
batch_response = [
    decode_tokens(
        batch_out_ids[i][padding_lens[i]:],
        tokenizer,
        raw_text_len=len(batch_raw_text[i]),
        context_length=(batch_input_ids[i].size(0)-padding_lens[i]),
        chat_format="chatml",
        verbose=False,
        errors='replace'
    ) for i in range(len(all_raw_text))
]

print(batch_response)

Key Configuration Details

For batch inference, you must configure the tokenizer with distinct pad and eos tokens:
tokenizer = AutoTokenizer.from_pretrained(
    'Qwen/Qwen-7B-Chat',
    pad_token='<|extra_0|>',      # Special padding token
    eos_token='<|endoftext|>',    # End of sequence token
    padding_side='left',           # Left padding for causal LM
    trust_remote_code=True
)
Why left padding? Causal language models generate from left to right. Left padding ensures that the actual content is right-aligned, which is crucial for proper attention mask generation.
Pass the pad_token_id to the model and generation config:
model = AutoModelForCausalLM.from_pretrained(
    'Qwen/Qwen-7B-Chat',
    pad_token_id=tokenizer.pad_token_id,  # Important!
    device_map="auto",
    trust_remote_code=True
).eval()

model.generation_config = GenerationConfig.from_pretrained(
    'Qwen/Qwen-7B-Chat', 
    pad_token_id=tokenizer.pad_token_id  # Important!
)
Use make_context to format each query properly:
from qwen_generation_utils import make_context

batch_raw_text = []
for query in all_queries:
    raw_text, _ = make_context(
        tokenizer,
        query,
        system="You are a helpful assistant.",
        max_window_size=model.generation_config.max_window_size,
        chat_format=model.generation_config.chat_format,
    )
    batch_raw_text.append(raw_text)

Performance Optimization

Batch Size Selection

Choose batch size based on your GPU memory and sequence length:
# Example for different GPU memory sizes

# 16GB GPU (e.g., V100)
batch_size = 4
max_sequence_length = 2048

# 24GB GPU (e.g., RTX 3090)
batch_size = 8
max_sequence_length = 2048

# 40GB GPU (e.g., A100)
batch_size = 16
max_sequence_length = 2048

# 80GB GPU (e.g., A100 80GB)
batch_size = 32
max_sequence_length = 2048

Dynamic Batching

Process queries in batches dynamically:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from qwen_generation_utils import make_context, decode_tokens

def process_in_batches(queries, model, tokenizer, batch_size=8):
    """
    Process queries in batches of specified size.
    """
    all_responses = []
    
    for i in range(0, len(queries), batch_size):
        batch_queries = queries[i:i + batch_size]
        
        # Prepare batch
        batch_raw_text = []
        for q in batch_queries:
            raw_text, _ = make_context(
                tokenizer,
                q,
                system="You are a helpful assistant.",
                max_window_size=model.generation_config.max_window_size,
                chat_format=model.generation_config.chat_format,
            )
            batch_raw_text.append(raw_text)
        
        # Tokenize
        batch_input_ids = tokenizer(batch_raw_text, padding='longest')
        batch_input_ids = torch.LongTensor(batch_input_ids['input_ids']).to(model.device)
        
        # Generate
        batch_out_ids = model.generate(
            batch_input_ids,
            return_dict_in_generate=False,
            generation_config=model.generation_config
        )
        
        # Decode
        padding_lens = [
            batch_input_ids[j].eq(tokenizer.pad_token_id).sum().item() 
            for j in range(batch_input_ids.size(0))
        ]
        
        batch_response = [
            decode_tokens(
                batch_out_ids[j][padding_lens[j]:],
                tokenizer,
                raw_text_len=len(batch_raw_text[j]),
                context_length=(batch_input_ids[j].size(0)-padding_lens[j]),
                chat_format="chatml",
                verbose=False,
                errors='replace'
            ) for j in range(len(batch_queries))
        ]
        
        all_responses.extend(batch_response)
    
    return all_responses

# Usage
tokenizer = AutoTokenizer.from_pretrained(
    'Qwen/Qwen-7B-Chat',
    pad_token='<|extra_0|>',
    eos_token='<|endoftext|>',
    padding_side='left',
    trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
    'Qwen/Qwen-7B-Chat',
    pad_token_id=tokenizer.pad_token_id,
    device_map="auto",
    trust_remote_code=True
).eval()

queries = [
    "What is machine learning?",
    "Explain quantum computing",
    "How does photosynthesis work?",
    # ... many more queries
]

responses = process_in_batches(queries, model, tokenizer, batch_size=8)
for q, r in zip(queries, responses):
    print(f"Q: {q}\nA: {r}\n")

Comparing Single vs Batch

Let’s compare single and batch inference:
# Process queries one by one
responses = []
for query in all_queries:
    response, _ = model.chat(tokenizer, query, history=None)
    responses.append(response)

# Time: ~5.0 seconds for 10 queries

Advanced: Mixed-Length Batching

Handle queries of very different lengths efficiently:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

def smart_batching(queries, model, tokenizer, max_batch_tokens=8192):
    """
    Group queries into batches based on total token count.
    """
    # Tokenize all queries
    tokenized = [tokenizer.encode(q) for q in queries]
    lengths = [len(t) for t in tokenized]
    
    # Sort by length
    sorted_indices = sorted(range(len(queries)), key=lambda i: lengths[i])
    
    batches = []
    current_batch = []
    current_tokens = 0
    
    for idx in sorted_indices:
        query_tokens = lengths[idx]
        
        # Check if adding this query exceeds batch token limit
        if current_tokens + query_tokens > max_batch_tokens and current_batch:
            batches.append(current_batch)
            current_batch = []
            current_tokens = 0
        
        current_batch.append(idx)
        current_tokens += query_tokens
    
    if current_batch:
        batches.append(current_batch)
    
    # Process each batch
    responses = [None] * len(queries)
    for batch_indices in batches:
        batch_queries = [queries[i] for i in batch_indices]
        batch_responses = process_in_batches(
            batch_queries, model, tokenizer, batch_size=len(batch_queries)
        )
        for i, response in zip(batch_indices, batch_responses):
            responses[i] = response
    
    return responses

Memory Management

Monitor and optimize memory usage:
import torch
import gc

def batch_inference_with_memory_management(queries, model, tokenizer, batch_size=8):
    """
    Batch inference with explicit memory management.
    """
    all_responses = []
    
    for i in range(0, len(queries), batch_size):
        batch_queries = queries[i:i + batch_size]
        
        # Process batch
        batch_responses = process_in_batches(
            batch_queries, model, tokenizer, batch_size=len(batch_queries)
        )
        all_responses.extend(batch_responses)
        
        # Clear cache after each batch
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()
        
        # Optional: Print memory usage
        if torch.cuda.is_available():
            print(f"GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
    
    return all_responses

Performance Benchmarks

Typical speedups with batch inference (with flash attention):
Batch SizeSpeedup vs SequentialGPU Memory (7B)
1 (baseline)1.0x16GB
41.3x18GB
81.4x22GB
161.4x28GB
321.4x40GB
Maximum speedup of ~40% is typically achieved with batch sizes of 8-16. Larger batches may not provide additional speedup due to GPU saturation.

Troubleshooting

Reduce batch size or sequence length:
# Reduce batch size
batch_size = 4  # Instead of 8

# Or reduce max tokens
model.generation_config.max_new_tokens = 256  # Instead of 512
Ensure you’re using left padding:
tokenizer = AutoTokenizer.from_pretrained(
    'Qwen/Qwen-7B-Chat',
    padding_side='left',  # Critical for causal LM!
    trust_remote_code=True
)
Check that flash attention is installed:
import torch

# Check flash attention
try:
    import flash_attn
    print("Flash attention available")
except ImportError:
    print("Flash attention not installed")
    print("Install: pip install flash-attn")

Next Steps

Streaming Responses

Stream tokens as they’re generated

Multi-GPU Inference

Scale across multiple GPUs

vLLM Integration

Production-grade serving with vLLM

Quantization

Reduce memory with quantization

Build docs developers (and LLMs) love