Skip to main content

Text Generation

Llama 2 uses autoregressive generation with nucleus (top-p) sampling to produce high-quality text. The generation process supports various parameters for controlling randomness, output format, and token probability computation.

Core Generation Method

The generate method is the foundation for all text generation:
@torch.inference_mode()
def generate(
    self,
    prompt_tokens: List[List[int]],
    max_gen_len: int,
    temperature: float = 0.6,
    top_p: float = 0.9,
    logprobs: bool = False,
    echo: bool = False,
) -> Tuple[List[List[int]], Optional[List[List[float]]]]:

Parameters

  • prompt_tokens: List of tokenized prompts (batch of token sequences)
  • max_gen_len: Maximum number of tokens to generate
  • temperature: Controls randomness (higher = more random, lower = more deterministic)
  • top_p: Nucleus sampling threshold (0.0-1.0)
  • logprobs: Whether to return log probabilities for each token
  • echo: Whether to include prompt tokens in the output

Temperature Sampling

Temperature controls the randomness of predictions by scaling the logits before applying softmax:
if temperature > 0:
    probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
    next_token = sample_top_p(probs, top_p)
else:
    next_token = torch.argmax(logits[:, -1], dim=-1)
How temperature works:
  • temperature = 0: Greedy decoding (always pick the highest probability token)
  • temperature < 1 (e.g., 0.6): Sharper distribution, more focused and deterministic
  • temperature = 1: Use raw probabilities
  • temperature > 1: Flatter distribution, more random and creative
Example effect on probabilities:
# Original probabilities: [0.5, 0.3, 0.2]

# temperature = 0.5 (sharper)
# Result: [0.62, 0.26, 0.12]  # Favors top choice more

# temperature = 1.0
# Result: [0.5, 0.3, 0.2]  # Unchanged

# temperature = 2.0 (flatter)
# Result: [0.42, 0.33, 0.25]  # More uniform

Top-p (Nucleus) Sampling

Nucleus sampling selects from the smallest set of tokens whose cumulative probability exceeds p:
def sample_top_p(probs, p):
    """
    Perform top-p (nucleus) sampling on a probability distribution.

    Args:
        probs (torch.Tensor): Probability distribution tensor.
        p (float): Probability threshold for top-p sampling.

    Returns:
        torch.Tensor: Sampled token indices.

    Note:
        Top-p sampling selects the smallest set of tokens whose cumulative probability mass
        exceeds the threshold p. The distribution is renormalized based on the selected tokens.
    """
    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    mask = probs_sum - probs_sort > p
    probs_sort[mask] = 0.0
    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
    next_token = torch.multinomial(probs_sort, num_samples=1)
    next_token = torch.gather(probs_idx, -1, next_token)
    return next_token

How Nucleus Sampling Works

  1. Sort tokens by probability in descending order
  2. Compute cumulative sum of probabilities
  3. Create mask for tokens where cumulative probability exceeds p
  4. Zero out probabilities of tokens outside the nucleus
  5. Renormalize the remaining probabilities
  6. Sample from the renormalized distribution
Example: Given token probabilities [0.4, 0.3, 0.2, 0.05, 0.03, 0.02] with top_p = 0.9:
  1. Probabilities are already sorted
  2. Cumulative: [0.4, 0.7, 0.9, 0.95, 0.98, 1.0]
  3. Nucleus includes first 3 tokens (0.4 + 0.3 + 0.2 = 0.9)
  4. Renormalize: [0.44, 0.33, 0.22, 0, 0, 0]
  5. Sample from these 3 tokens only

Top-p vs Top-k

  • Top-p (nucleus): Dynamic vocabulary size based on probability distribution
  • Top-k: Fixed vocabulary size (always use top k tokens)
Llama 2 uses top-p, which adapts better to varying probability distributions.

Log Probabilities

When logprobs=True, the model returns the negative log probability for each generated token:
if logprobs:
    token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
        input=logits.transpose(1, 2),
        target=tokens[:, prev_pos + 1 : cur_pos + 1],
        reduction="none",
        ignore_index=pad_id,
    )
Log probabilities are useful for:
  • Model confidence assessment: Lower negative log prob = higher confidence
  • Beam search: Ranking multiple generation paths
  • Fine-tuning: Computing training objectives
  • Debugging: Understanding model behavior

Echo Parameter

The echo parameter controls whether prompt tokens are included in the output:
for i, toks in enumerate(tokens.tolist()):
    start = 0 if echo else len(prompt_tokens[i])
    toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
  • echo=True: Output includes both prompt and generated tokens
  • echo=False: Output includes only newly generated tokens

Generation Loop

The autoregressive generation loop processes one token at a time:
prev_pos = 0
eos_reached = torch.tensor([False] * bsz, device="cuda")
input_text_mask = tokens != pad_id

for cur_pos in range(min_prompt_len, total_len):
    logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
    
    if temperature > 0:
        probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
        next_token = sample_top_p(probs, top_p)
    else:
        next_token = torch.argmax(logits[:, -1], dim=-1)

    next_token = next_token.reshape(-1)
    # only replace token if prompt has already been generated
    next_token = torch.where(
        input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
    )
    tokens[:, cur_pos] = next_token
    
    eos_reached |= (~input_text_mask[:, cur_pos]) & (
        next_token == self.tokenizer.eos_id
    )
    prev_pos = cur_pos
    if all(eos_reached):
        break

Key steps:

  1. Forward pass: Compute logits for the current position
  2. Sampling: Apply temperature and top-p to select next token
  3. Masking: Preserve original prompt tokens using input_text_mask
  4. EOS detection: Stop generation when all sequences reach EOS
  5. KV cache update: prev_pos tracks cache position for efficiency

Text Completion Interface

The text_completion method provides a high-level interface:
def text_completion(
    self,
    prompts: List[str],
    temperature: float = 0.6,
    top_p: float = 0.9,
    max_gen_len: Optional[int] = None,
    logprobs: bool = False,
    echo: bool = False,
) -> List[CompletionPrediction]:
    if max_gen_len is None:
        max_gen_len = self.model.params.max_seq_len - 1
    
    prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
    generation_tokens, generation_logprobs = self.generate(
        prompt_tokens=prompt_tokens,
        max_gen_len=max_gen_len,
        temperature=temperature,
        top_p=top_p,
        logprobs=logprobs,
        echo=echo,
    )
    
    if logprobs:
        return [
            {
                "generation": self.tokenizer.decode(t),
                "tokens": [self.tokenizer.decode(x) for x in t],
                "logprobs": logprobs_i,
            }
            for t, logprobs_i in zip(generation_tokens, generation_logprobs)
        ]
    return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens]

Example Usage

Basic generation

results = llama.text_completion(
    prompts=["The capital of France is"],
    temperature=0.6,
    top_p=0.9,
    max_gen_len=64,
)
print(results[0]["generation"])
# Output: "The capital of France is Paris."

Deterministic generation (greedy)

results = llama.text_completion(
    prompts=["2 + 2 ="],
    temperature=0.0,  # Greedy decoding
    max_gen_len=10,
)

Creative generation

results = llama.text_completion(
    prompts=["Write a creative story:"],
    temperature=0.9,  # Higher temperature for creativity
    top_p=0.95,       # Larger nucleus
    max_gen_len=256,
)

With log probabilities

results = llama.text_completion(
    prompts=["The weather today is"],
    temperature=0.6,
    logprobs=True,
    max_gen_len=32,
)

for token, logprob in zip(results[0]["tokens"], results[0]["logprobs"]):
    confidence = math.exp(-logprob)  # Convert to probability
    print(f"{token}: {confidence:.3f}")

Echo prompt with generation

results = llama.text_completion(
    prompts=["Complete this: Hello"],
    echo=True,  # Include prompt in output
    max_gen_len=16,
)
print(results[0]["generation"])
# Output: "Complete this: Hello, how can I help you today?"

Best Practices

For factual/precise tasks:

  • Use low temperature (0.0 - 0.3)
  • Use lower top_p (0.7 - 0.85)

For creative/diverse tasks:

  • Use higher temperature (0.7 - 1.0)
  • Use higher top_p (0.9 - 0.95)

For production systems:

  • Enable logprobs to monitor model confidence
  • Set max_gen_len to prevent excessive generation
  • Use temperature=0 for deterministic outputs when needed

Build docs developers (and LLMs) love