Skip to main content

Overview

The sampling module provides a standalone script for generating text from trained GPT models. It supports loading models from checkpoints or using pretrained GPT-2 variants, with full control over generation parameters.

Basic usage

python sample.py --init_from=resume --start="Once upon a time"

Configuration parameters

Model initialization

init_from
str
default:"'resume'"
Model source:
  • 'resume': Load from checkpoint in out_dir
  • 'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl': Use pretrained GPT-2 models
out_dir
str
default:"'out'"
Directory containing checkpoint file (only used when init_from='resume')

Generation settings

start
str
default:"'\\n'"
Prompt text to condition generation on. Can be:
  • Direct text: "Once upon a time"
  • Special token: "<|endoftext|>"
  • File reference: "FILE:prompt.txt" (reads prompt from file)
num_samples
int
default:"10"
Number of independent samples to generate
max_new_tokens
int
default:"500"
Maximum number of tokens to generate for each sample

Sampling parameters

temperature
float
default:"0.8"
Controls randomness in sampling:
  • 1.0: No change to the probability distribution
  • < 1.0: Less random, more confident predictions (sharper distribution)
  • > 1.0: More random, more diverse predictions (flatter distribution)
  • Approaching 0.0: Deterministic (argmax)
top_k
int
default:"200"
Restricts sampling to the top-k most likely tokens. Tokens outside the top-k have their probability set to zero. Set to None to disable.

System settings

seed
int
default:"1337"
Random seed for reproducible generation
device
str
default:"'cuda'"
Device to run inference on: 'cpu', 'cuda', 'cuda:0', 'cuda:1', 'mps', etc.
dtype
str
default:"'bfloat16' or 'float16'"
Data type for inference: 'float32', 'bfloat16', or 'float16'. Auto-selects bfloat16 if GPU supports it.
compile
bool
default:"False"
Use PyTorch 2.0 compilation for faster inference (requires PyTorch >= 2.0)

Text encoding/decoding

The script automatically handles text encoding:

From checkpoint

When loading from a trained checkpoint, it looks for meta.pkl in the dataset directory:
if init_from == 'resume' and 'dataset' in checkpoint['config']:
    meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl')
    if os.path.exists(meta_path):
        # Load custom encoder/decoder
        stoi, itos = meta['stoi'], meta['itos']
        encode = lambda s: [stoi[c] for c in s]
        decode = lambda l: ''.join([itos[i] for i in l])

From GPT-2 pretrained

When using GPT-2 variants, it uses tiktoken for GPT-2 BPE encoding:
import tiktoken
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
decode = lambda l: enc.decode(l)
The special token <|endoftext|> is recognized by the GPT-2 tokenizer and can be used to signal document boundaries.

Generation flow

The complete generation process:

1. Model setup

# Set random seed
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

# Load model
if init_from == 'resume':
    checkpoint = torch.load(ckpt_path)
    model = GPT(GPTConfig(**checkpoint['model_args']))
    model.load_state_dict(checkpoint['model'])
elif init_from.startswith('gpt2'):
    model = GPT.from_pretrained(init_from, dict(dropout=0.0))

model.eval()
model.to(device)

2. Prompt encoding

# Handle file-based prompts
if start.startswith('FILE:'):
    with open(start[5:], 'r', encoding='utf-8') as f:
        start = f.read()

# Encode prompt
start_ids = encode(start)
x = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]

3. Generation loop

with torch.no_grad():
    with ctx:  # Automatic mixed precision context
        for k in range(num_samples):
            y = model.generate(x, max_new_tokens, 
                             temperature=temperature, 
                             top_k=top_k)
            print(decode(y[0].tolist()))
            print('---------------')

Examples

Creative story generation

python sample.py \
  --init_from=gpt2-medium \
  --start="In a distant galaxy" \
  --num_samples=3 \
  --max_new_tokens=300 \
  --temperature=0.9 \
  --top_k=40

Code completion

python sample.py \
  --init_from=resume \
  --start="def fibonacci(n):" \
  --num_samples=5 \
  --max_new_tokens=150 \
  --temperature=0.7 \
  --top_k=50

Deterministic generation

python sample.py \
  --init_from=gpt2 \
  --start="The capital of France is" \
  --temperature=0.1 \
  --top_k=1 \
  --num_samples=1

Batch prompts from file

Create prompts.txt:
Once upon a time in a land far away
Run:
python sample.py \
  --start="FILE:prompts.txt" \
  --num_samples=10 \
  --max_new_tokens=500

Performance considerations

Enable compilation for repeated sampling with --compile=True. The first run will be slower due to compilation overhead, but subsequent generations will be faster.
Use bfloat16 for better quality on supported GPUs. It provides improved numerical stability over float16 during generation.
Batch generation: The model’s generate() method supports batch inference. You can pass multiple prompts simultaneously by stacking them along the batch dimension.

Troubleshooting

Out of memory errors: Reduce max_new_tokens or use a smaller model variant. Generation requires keeping the full context in memory.
Repetitive text: Try increasing temperature (e.g., to 1.0) or decreasing top_k (e.g., to 50) to add more diversity.
Incoherent text: Decrease temperature (e.g., to 0.7) or increase top_k to focus on more likely tokens.

Programmatic usage

You can also use the generation functionality directly in your code:
import torch
from model import GPT

# Load model
model = GPT.from_pretrained('gpt2')
model.eval()
model.to('cuda')

# Encode prompt
import tiktoken
enc = tiktoken.get_encoding('gpt2')
prompt_ids = enc.encode("Once upon a time")
x = torch.tensor(prompt_ids, dtype=torch.long, device='cuda')[None, ...]

# Generate
with torch.no_grad():
    y = model.generate(x, max_new_tokens=100, temperature=0.8, top_k=200)
    
# Decode
generated_text = enc.decode(y[0].tolist())
print(generated_text)
For more control over generation, you can implement your own sampling loop using the model’s forward() method instead of generate().

Build docs developers (and LLMs) love