Skip to main content

Transformer Architecture

Llama 2 uses an optimized transformer architecture with several key innovations that improve performance and efficiency. The architecture is available in three model sizes: 7B, 13B, and 70B parameters.

Model Sizes

ModelParametersHeadsKV HeadsContext LengthGQA
Llama 2 7B7B32324096No
Llama 2 13B13B40404096No
Llama 2 70B70B6484096Yes
All models were trained on 2 trillion tokens with a global batch size of 4M tokens. The 70B model uses Grouped-Query Attention (GQA) for improved inference scalability.

Architecture Components

ModelArgs Configuration

The model architecture is defined through ModelArgs, which specifies the hyperparameters:
@dataclass
class ModelArgs:
    dim: int = 4096
    n_layers: int = 32
    n_heads: int = 32
    n_kv_heads: Optional[int] = None
    vocab_size: int = -1  # defined later by tokenizer
    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
    ffn_dim_multiplier: Optional[float] = None
    norm_eps: float = 1e-5

    max_batch_size: int = 32
    max_seq_len: int = 2048
Key parameters:
  • dim: Model dimension (4096 for base models)
  • n_layers: Number of transformer layers (32 for 7B/13B, more for 70B)
  • n_heads: Number of attention heads
  • n_kv_heads: Number of key-value heads (for GQA)
  • max_seq_len: Maximum sequence length during training (context window is 4096)

RMSNorm

Llama 2 uses Root Mean Square Layer Normalization (RMSNorm) instead of standard LayerNorm for improved efficiency:
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight
RMSNorm normalizes using only the root mean square statistic, eliminating the mean centering operation from standard layer normalization. This reduces computation while maintaining effective normalization.

Rotary Position Embeddings (RoPE)

Instead of absolute position embeddings, Llama 2 uses Rotary Position Embeddings (RoPE) which encode position information through rotation in complex space:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis
The rotary embeddings are applied to queries and keys:
def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)
RoPE provides better length extrapolation and relative position modeling compared to absolute position embeddings.

Grouped-Query Attention (GQA)

The 70B model uses Grouped-Query Attention, where multiple query heads share the same key-value heads:
class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        model_parallel_size = fs_init.get_model_parallel_world_size()
        self.n_local_heads = args.n_heads // model_parallel_size
        self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = args.dim // args.n_heads
For GQA, key-value heads are repeated to match the number of query heads:
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    bs, slen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, :, None, :]
        .expand(bs, slen, n_kv_heads, n_rep, head_dim)
        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
    )
GQA reduces memory bandwidth requirements during inference by reducing the KV cache size while maintaining model quality.

SwiGLU Feedforward

The feedforward layer uses SwiGLU activation instead of ReLU:
class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        multiple_of: int,
        ffn_dim_multiplier: Optional[float],
    ):
        super().__init__()
        hidden_dim = int(2 * hidden_dim / 3)
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w1 = ColumnParallelLinear(
            dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
        )
        self.w2 = RowParallelLinear(
            hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
        )
        self.w3 = ColumnParallelLinear(
            dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
        )

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))
The SwiGLU activation combines SiLU (Swish) activation with a gating mechanism, improving model capacity.

Transformer Block

Each transformer block applies pre-normalization with residual connections:
class TransformerBlock(nn.Module):
    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        h = x + self.attention(
            self.attention_norm(x), start_pos, freqs_cis, mask
        )
        out = h + self.feed_forward(self.ffn_norm(h))
        return out
The pre-normalization architecture (normalizing before attention/FFN rather than after) has been shown to improve training stability.

KV Caching

Llama 2 implements efficient key-value caching for fast autoregressive generation:
self.cache_k = torch.zeros(
    (
        args.max_batch_size,
        args.max_seq_len,
        self.n_local_kv_heads,
        self.head_dim,
    )
).cuda()
self.cache_v = torch.zeros(
    (
        args.max_batch_size,
        args.max_seq_len,
        self.n_local_kv_heads,
        self.head_dim,
    )
).cuda()
During generation, previously computed keys and values are cached and reused:
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
This avoids recomputing attention for all previous tokens at each generation step.

Build docs developers (and LLMs) love