Overview
The TransformerBuilder class provides a complete Vision Transformer (ViT) implementation for image classification. It uses self-attention mechanisms to process image patches, offering an alternative to convolutional architectures.
Classes
Builder class for Vision Transformer models.
from models.pytorch.transformer import TransformerBuilder
Location: app/models/pytorch/transformer.py:15
Constructor
__init__(config: Dict[str, Any])
Model configuration dictionary containing:
num_classes: Number of output classes
transformer_config: Transformer-specific configuration
Example:
config = {
"num_classes": 10,
"model_type": "Transformer",
"transformer_config": {
"patch_size": 16,
"embed_dim": 768,
"depth": 12,
"num_heads": 12,
"mlp_ratio": 4.0,
"dropout": 0.1
}
}
builder = TransformerBuilder(config)
model = builder.build()
Methods
build() -> nn.Module
Build Vision Transformer model.
Built PyTorch model (VisionTransformer instance)
get_parameters_count() -> Tuple[int, int]
Get parameter counts.
Total number of parameters
Number of trainable parameters
validate_config() -> bool
Validate transformer configuration.
Main Vision Transformer model class.
from models.pytorch.transformer import VisionTransformer
Location: app/models/pytorch/transformer.py:223
Constructor
__init__(
image_size: int = 224,
patch_size: int = 16,
num_classes: int = 1000,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
dropout: float = 0.1,
in_channels: int = 3
)
Input image size (square images)
Size of image patches. Image is divided into (image_size/patch_size)² patches
Embedding dimension for transformer
Number of transformer blocks
Number of attention heads (must divide embed_dim evenly)
Ratio of MLP hidden dimension to embedding dimension
Number of input channels (3 for RGB)
Example:
import torch
from models.pytorch.transformer import VisionTransformer
# Standard ViT-Base configuration
model = VisionTransformer(
image_size=224,
patch_size=16,
num_classes=10,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
dropout=0.1
)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
Methods
forward(x: torch.Tensor) -> torch.Tensor
Forward pass through the transformer.
Input images of shape (B, C, H, W)
Class logits of shape (B, num_classes)
Example:
import torch
model = VisionTransformer(num_classes=10)
x = torch.randn(8, 3, 224, 224) # Batch of 8 images
logits = model(x)
print(logits.shape) # torch.Size([8, 10])
# Get predictions
probs = torch.softmax(logits, dim=1)
predictions = torch.argmax(probs, dim=1)
print(predictions) # tensor([3, 7, 2, 1, 9, 0, 4, 6])
get_attention_maps(x: torch.Tensor) -> list
Extract attention maps for visualization.
Input images of shape (B, C, H, W)
List of attention maps from each transformer block
Note: Current implementation returns empty list. Requires modification of TransformerBlock to return attention weights.
PatchEmbedding
Convert images to patch embeddings.
from models.pytorch.transformer import PatchEmbedding
Location: app/models/pytorch/transformer.py:72
Constructor
__init__(
image_size: int = 224,
patch_size: int = 16,
in_channels: int = 3,
embed_dim: int = 768
)
Output embedding dimension
Example:
import torch
from models.pytorch.transformer import PatchEmbedding
patch_embed = PatchEmbedding(
image_size=224,
patch_size=16,
embed_dim=768
)
x = torch.randn(4, 3, 224, 224)
patches = patch_embed(x)
print(patches.shape) # torch.Size([4, 196, 768])
# 196 = (224/16)^2 patches
MultiHeadAttention
Multi-head self-attention mechanism.
from models.pytorch.transformer import MultiHeadAttention
Location: app/models/pytorch/transformer.py:117
Constructor
__init__(
embed_dim: int,
num_heads: int,
dropout: float = 0.0
)
Embedding dimension (must be divisible by num_heads)
Number of attention heads
Dropout rate for attention weights
Example:
import torch
from models.pytorch.transformer import MultiHeadAttention
attn = MultiHeadAttention(
embed_dim=768,
num_heads=12,
dropout=0.1
)
x = torch.randn(8, 197, 768) # (batch, seq_len, embed_dim)
output = attn(x)
print(output.shape) # torch.Size([8, 197, 768])
Single transformer encoder block.
from models.pytorch.transformer import TransformerBlock
Location: app/models/pytorch/transformer.py:195
Constructor
__init__(
embed_dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
dropout: float = 0.0
)
Number of attention heads
MLP hidden dimension = embed_dim * mlp_ratio
Example:
import torch
from models.pytorch.transformer import TransformerBlock
block = TransformerBlock(
embed_dim=768,
num_heads=12,
mlp_ratio=4.0,
dropout=0.1
)
x = torch.randn(8, 197, 768)
output = block(x)
print(output.shape) # torch.Size([8, 197, 768])
Standard Configurations
ViT-Tiny
config = {
"num_classes": 10,
"transformer_config": {
"patch_size": 16,
"embed_dim": 192,
"depth": 12,
"num_heads": 3,
"mlp_ratio": 4.0,
"dropout": 0.1
}
}
# ~5.7M parameters
ViT-Small
config = {
"num_classes": 10,
"transformer_config": {
"patch_size": 16,
"embed_dim": 384,
"depth": 12,
"num_heads": 6,
"mlp_ratio": 4.0,
"dropout": 0.1
}
}
# ~22M parameters
ViT-Base (Default)
config = {
"num_classes": 10,
"transformer_config": {
"patch_size": 16,
"embed_dim": 768,
"depth": 12,
"num_heads": 12,
"mlp_ratio": 4.0,
"dropout": 0.1
}
}
# ~86M parameters
ViT-Large
config = {
"num_classes": 10,
"transformer_config": {
"patch_size": 16,
"embed_dim": 1024,
"depth": 24,
"num_heads": 16,
"mlp_ratio": 4.0,
"dropout": 0.1
}
}
# ~307M parameters
Complete Example
from models.pytorch.transformer import TransformerBuilder
import torch
import torch.nn as nn
import torch.optim as optim
# Configure ViT-Base
config = {
"num_classes": 10,
"model_type": "Transformer",
"architecture": "ViT-Base",
"transformer_config": {
"patch_size": 16,
"embed_dim": 768,
"depth": 12,
"num_heads": 12,
"mlp_ratio": 4.0,
"dropout": 0.1
}
}
# Build model
builder = TransformerBuilder(config)
model = builder.build()
# Get summary
summary = builder.get_model_summary()
print(f"Model: {summary['architecture']}")
print(f"Total parameters: {summary['total_parameters']:,}")
print(f"Trainable parameters: {summary['trainable_parameters']:,}")
# Move to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
# Setup optimizer with weight decay
optimizer = optim.AdamW(
model.parameters(),
lr=3e-4,
betas=(0.9, 0.999),
weight_decay=0.3
)
# Learning rate scheduler
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=100,
eta_min=1e-6
)
# Training loop
model.train()
criterion = nn.CrossEntropyLoss()
for epoch in range(100):
total_loss = 0
for batch_idx, (images, labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
total_loss += loss.item()
scheduler.step()
avg_loss = total_loss / len(train_loader)
print(f"Epoch {epoch+1}: Loss = {avg_loss:.4f}, LR = {scheduler.get_last_lr()[0]:.6f}")
# Inference
model.eval()
with torch.no_grad():
images = torch.randn(8, 3, 224, 224).to(device)
logits = model(images)
probs = torch.softmax(logits, dim=1)
preds = torch.argmax(probs, dim=1)
print(f"Predictions: {preds.cpu().numpy()}")
Training Tips
Data Augmentation
ViT benefits from strong augmentation:
from torchvision import transforms
transform = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.05, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.4, 0.4, 0.4),
transforms.RandomRotation(15),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
transforms.RandomErasing(p=0.25)
])
Optimizer Configuration
Use AdamW with specific settings:
optimizer = optim.AdamW(
model.parameters(),
lr=3e-4,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0.3
)
Learning Rate Warmup
def get_lr(step, warmup_steps, total_steps, base_lr, min_lr):
if step < warmup_steps:
return base_lr * (step / warmup_steps)
else:
progress = (step - warmup_steps) / (total_steps - warmup_steps)
return min_lr + (base_lr - min_lr) * 0.5 * (1 + math.cos(math.pi * progress))
for step in range(total_steps):
lr = get_lr(step, warmup_steps=10000, total_steps=100000, base_lr=3e-4, min_lr=1e-6)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
Layer-wise Learning Rate Decay
def get_layer_decay_params(model, lr, layer_decay=0.75):
param_groups = []
# Group by layer depth
for i, block in enumerate(model.blocks):
decay_factor = layer_decay ** (len(model.blocks) - i)
param_groups.append({
'params': block.parameters(),
'lr': lr * decay_factor
})
# Head uses full learning rate
param_groups.append({
'params': model.head.parameters(),
'lr': lr
})
return param_groups
optimizer = optim.AdamW(
get_layer_decay_params(model, lr=1e-3, layer_decay=0.75),
weight_decay=0.05
)
Architecture Details
Patch Embedding Process
- Input image:
(B, 3, 224, 224)
- Split into patches:
(B, 196, 3*16*16) where 196 = (224/16)²
- Linear projection:
(B, 196, 768)
- Add CLS token:
(B, 197, 768)
- Add position embeddings:
(B, 197, 768)
Attention Mechanism
- Query, Key, Value computed from input
- Attention:
softmax(QK^T / sqrt(d)) V
- Multi-head: Process H heads in parallel, concatenate results
- Output projection: Linear layer to embed_dim
- Layer Norm → Multi-Head Attention → Residual
- Layer Norm → MLP → Residual
Classification
- Extract CLS token output:
(B, 768)
- Layer Norm
- Linear projection:
(B, num_classes)
Memory Usage
ViT models use significant memory:
- ViT-Base: ~340MB (fp32), ~170MB (fp16)
- ViT-Large: ~1.2GB (fp32), ~600MB (fp16)
Computation
Attention complexity: O(N²) where N is sequence length
- Patch size 16 → 196 patches
- Patch size 8 → 784 patches (4x more compute)
Training Time
ViT requires more epochs than CNNs:
- CNNs: 50-100 epochs
- ViT: 300+ epochs
- Use pre-training when possible
See Also