Skip to main content

Overview

The embedding modules provide token embeddings and optional learned positional embeddings for sequence models.

TokenEmbedding

TokenEmbedding is the main embedding module that converts token IDs to dense vectors. It optionally supports learned positional embeddings.

Class Definition

from lrnnx.architectures.embedding import TokenEmbedding

embedding = TokenEmbedding(
    vocab_size=50000,
    embedding_dim=512,
    padding_idx=0,
    use_position=False,
    dropout=0.1
)

Parameters

vocab_size
int
required
Size of the vocabulary (number of unique tokens).
embedding_dim
int
required
Dimension of the embedding vectors.
padding_idx
int
default:"None"
Index for padding tokens. Embeddings at this index will be zero vectors and won’t be updated during training.
max_position_embeddings
int
default:"None"
Maximum sequence length for positional embeddings. Required when use_position=True.
use_position
bool
default:"False"
Whether to include learned positional embeddings. When True, positional embeddings are added to token embeddings.
dropout
float
default:"0.1"
Dropout probability applied to the final embeddings.

Methods

forward

embeddings = embedding.forward(token_ids)
Convert token IDs to embeddings.
token_ids
torch.Tensor
required
Tensor of token IDs of shape (batch_size, seq_len).
embeddings
torch.Tensor
Embedded tokens of shape (batch_size, seq_len, embedding_dim).

Example Usage

Basic Token Embeddings

import torch
from lrnnx.architectures.embedding import TokenEmbedding

# Create token embedding (no positional embeddings)
embedding = TokenEmbedding(
    vocab_size=10000,
    embedding_dim=256,
    padding_idx=0,
    dropout=0.1
).cuda()

# Input token IDs
token_ids = torch.randint(0, 10000, (4, 50)).cuda()  # (batch=4, seq_len=50)

# Get embeddings
embeddings = embedding(token_ids)
print(embeddings.shape)  # (4, 50, 256)

Token + Positional Embeddings

import torch
from lrnnx.architectures.embedding import TokenEmbedding

# Create token embedding with learned positional embeddings
embedding = TokenEmbedding(
    vocab_size=50000,
    embedding_dim=512,
    padding_idx=0,
    max_position_embeddings=2048,
    use_position=True,  # Enable positional embeddings
    dropout=0.1
).cuda()

# Input token IDs
token_ids = torch.randint(0, 50000, (8, 100)).cuda()

# Get embeddings (token + position)
embeddings = embedding(token_ids)
print(embeddings.shape)  # (8, 100, 512)

Variable-Length Sequences with Padding

import torch
from lrnnx.architectures.embedding import TokenEmbedding

# Create embedding with padding
embedding = TokenEmbedding(
    vocab_size=10000,
    embedding_dim=256,
    padding_idx=0,  # Token ID 0 is padding
    dropout=0.1
).cuda()

# Create padded batch
token_ids = torch.tensor([
    [1, 2, 3, 4, 0, 0],      # Length 4, padded to 6
    [5, 6, 7, 8, 9, 10],     # Length 6
    [11, 12, 0, 0, 0, 0],    # Length 2, padded to 6
]).cuda()

# Embeddings for padding tokens will be zero
embeddings = embedding(token_ids)
print(embeddings.shape)  # (3, 6, 256)
print(embeddings[0, 4:].sum())  # Should be ~0 (padding positions)

PositionEmbedding

PositionEmbedding provides learned positional embeddings that can be used separately from token embeddings.

Class Definition

from lrnnx.architectures.embedding import PositionEmbedding

pos_embedding = PositionEmbedding(
    max_position_embeddings=2048,
    embedding_dim=512
)

Parameters

max_position_embeddings
int
required
Maximum sequence length supported (number of position indices).
embedding_dim
int
required
Dimension of the embedding vectors.

Methods

forward

pos_embeddings = pos_embedding.forward(positions)
Get positional embeddings for the given position indices.
positions
torch.Tensor
required
Tensor of position indices.
pos_embeddings
torch.Tensor
Positional embeddings corresponding to the input positions.

Example Usage

import torch
from lrnnx.architectures.embedding import PositionEmbedding

# Create positional embedding
pos_embedding = PositionEmbedding(
    max_position_embeddings=1024,
    embedding_dim=256
).cuda()

# Create position indices
seq_len = 50
positions = torch.arange(seq_len).unsqueeze(0).cuda()  # (1, 50)

# Get positional embeddings
pos_emb = pos_embedding(positions)
print(pos_emb.shape)  # (1, 50, 256)

# Manually add to token embeddings
token_emb = torch.randn(4, 50, 256).cuda()
combined = token_emb + pos_emb
print(combined.shape)  # (4, 50, 256)

Design Philosophy

Explicit Positional Embeddings

By default, TokenEmbedding returns only token embeddings without positional information. This design choice allows:
  1. Flexibility: Many sequence models (like LRNNs) don’t require positional embeddings
  2. Explicit Control: Users must explicitly enable positional embeddings with use_position=True
  3. Custom Positioning: Users can implement their own positional encoding schemes

When to Use Positional Embeddings

  • Enable (use_position=True) for:
    • Transformer-based models
    • Models that need explicit position information
    • Tasks where token order is critical
  • Disable (use_position=False, default) for:
    • Recurrent models (LRNNs, RNNs, LSTMs)
    • Models with inherent sequential processing
    • When using custom positional encodings (sinusoidal, rotary, etc.)

Padding Handling

When using padding_idx, the embedding layer:
  • Sets padding token embeddings to zero vectors
  • Prevents gradient updates for the padding token
  • Enables proper masking in downstream layers

Integration Examples

With Classifier

from lrnnx.architectures import Classifier

model = Classifier(
    input_dim=0,  # Ignored when vocab_size is set
    num_classes=10,
    d_model=256,
    vocab_size=10000,
    embedding_dim=256,
    padding_idx=0,
    # The classifier internally uses TokenEmbedding
    lrnn_params={"d_model": 256, "d_state": 64}
)

Custom Embedding Pipeline

import torch
import torch.nn as nn
from lrnnx.architectures.embedding import TokenEmbedding, PositionEmbedding
from lrnnx.models.lti.lru import LRU

class CustomModel(nn.Module):
    def __init__(self):
        super().__init__()
        # Token embeddings only
        self.token_emb = TokenEmbedding(
            vocab_size=10000,
            embedding_dim=256,
            padding_idx=0,
            use_position=False  # No positional embeddings
        )
        # Recurrent layer (doesn't need positional info)
        self.lru = LRU(d_model=256, d_state=64)
        
    def forward(self, token_ids):
        x = self.token_emb(token_ids)
        x = self.lru(x)
        return x

Build docs developers (and LLMs) love