Overview
The TransferLearningBuilder class provides an interface for building transfer learning models using pre-trained backbones. It supports multiple fine-tuning strategies from feature extraction to full fine-tuning.
Classes
TransferLearningBuilder
Builder class for transfer learning models.
from models.pytorch.transfer import TransferLearningBuilder
Location: app/models/pytorch/transfer.py:14
Constructor
__init__(config: Dict[str, Any])
Model configuration dictionary containing:
num_classes: Number of output classes
transfer_config: Transfer learning specific configuration
Example:
config = {
"num_classes": 10,
"model_type": "Transfer",
"transfer_config": {
"base_model": "ResNet50",
"weights": "ImageNet",
"strategy": "Feature Extraction",
"global_pooling": True,
"add_dense": True,
"dense_units": 512,
"dropout": 0.5
}
}
builder = TransferLearningBuilder(config)
model = builder.build()
Methods
build() -> nn.Module
Build transfer learning model from configuration.
Built PyTorch model (TransferModel instance)
Raises:
ValueError: If configuration is invalid
get_parameters_count() -> Tuple[int, int]
Get parameter counts.
Total number of parameters
Number of trainable parameters
validate_config() -> bool
Validate transfer learning configuration.
Required fields:
transfer_config.base_model
transfer_config.strategy
TransferModel
PyTorch model with pre-trained backbone.
from models.pytorch.transfer import TransferModel
Location: app/models/pytorch/transfer.py:81
Constructor
__init__(
base_model_name: str,
num_classes: int,
weights: str = "ImageNet",
strategy: str = "Feature Extraction",
unfreeze_layers: int = 0,
global_pooling: bool = True,
add_dense: bool = False,
dense_units: int = 512,
dropout: float = 0.5
)
Name of pre-trained model. Options:
"VGG16"
"VGG19"
"ResNet50"
"ResNet101"
"InceptionV3"
"EfficientNetB0"
Pre-trained weights: "ImageNet" or "Random"
strategy
str
default:"Feature Extraction"
Fine-tuning strategy:
"Feature Extraction": Freeze all base model layers
"Partial Fine-tuning": Unfreeze last N layers
"Full Fine-tuning": Train all layers
Number of layers to unfreeze (for partial fine-tuning)
Whether to use global average pooling
Whether to add an extra dense layer before output
Number of units in extra dense layer (if add_dense=True)
Dropout rate before final layer
Example:
import torch
from models.pytorch.transfer import TransferModel
# Feature extraction with ResNet50
model = TransferModel(
base_model_name="ResNet50",
num_classes=10,
weights="ImageNet",
strategy="Feature Extraction",
global_pooling=True,
add_dense=True,
dense_units=256,
dropout=0.3
)
# Check trainable parameters
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.1f}%)")
Methods
forward(x: torch.Tensor) -> torch.Tensor
Forward pass through the model.
Input tensor of shape (batch, channels, height, width)
Output logits of shape (batch, num_classes)
Example:
import torch
model = TransferModel("ResNet50", num_classes=10)
x = torch.randn(8, 3, 224, 224)
logits = model(x)
print(logits.shape) # torch.Size([8, 10])
get_trainable_layers() -> list
Get list of trainable layer names.
List of parameter names where requires_grad=True
Example:
model = TransferModel(
"ResNet50",
num_classes=10,
strategy="Partial Fine-tuning",
unfreeze_layers=2
)
trainable = model.get_trainable_layers()
print(f"Trainable layers ({len(trainable)}):")
for name in trainable[:5]: # Show first 5
print(f" - {name}")
Fine-tuning Strategies
Freeze all base model layers, train only the new classifier head.
Use when:
- Small dataset
- Limited compute resources
- Dataset similar to ImageNet
Example:
config = {
"num_classes": 10,
"transfer_config": {
"base_model": "ResNet50",
"weights": "ImageNet",
"strategy": "Feature Extraction",
"global_pooling": True,
"dropout": 0.5
}
}
builder = TransferLearningBuilder(config)
model = builder.build()
# Only classifier is trainable
total, trainable = builder.get_parameters_count()
print(f"Trainable: {trainable:,} / {total:,}") # ~0.1% trainable
Partial Fine-tuning
Freeze most layers, unfreeze the last N layers for fine-tuning.
Use when:
- Medium-sized dataset
- Dataset somewhat different from ImageNet
- Want more adaptation than feature extraction
Example:
config = {
"num_classes": 10,
"transfer_config": {
"base_model": "ResNet50",
"weights": "ImageNet",
"strategy": "Partial Fine-tuning",
"unfreeze_layers": 3, # Unfreeze last 3 layers
"global_pooling": True,
"add_dense": True,
"dense_units": 512,
"dropout": 0.3
}
}
builder = TransferLearningBuilder(config)
model = builder.build()
Full Fine-tuning
Train all layers, including base model.
Use when:
- Large dataset
- Dataset very different from ImageNet
- Need maximum model adaptation
Example:
config = {
"num_classes": 10,
"transfer_config": {
"base_model": "ResNet50",
"weights": "ImageNet",
"strategy": "Full Fine-tuning",
"global_pooling": True,
"add_dense": True,
"dense_units": 256,
"dropout": 0.5
}
}
builder = TransferLearningBuilder(config)
model = builder.build()
# All parameters trainable
total, trainable = builder.get_parameters_count()
print(f"Trainable: {trainable:,} / {total:,}") # 100% trainable
Supported Base Models
VGG Models
16-layer VGG network. 138M parameters.
19-layer VGG network. 144M parameters.
ResNet Models
50-layer ResNet with residual connections. 25.6M parameters.
101-layer ResNet. 44.5M parameters.
Inception Models
Inception V3 architecture. 23.8M parameters.
EfficientNet Models
EfficientNet B0 (baseline). 5.3M parameters.
Complete Example
from models.pytorch.transfer import TransferLearningBuilder
import torch
import torch.nn as nn
import torch.optim as optim
# Configure transfer learning
config = {
"num_classes": 10,
"model_type": "Transfer",
"architecture": "ResNet50",
"transfer_config": {
"base_model": "ResNet50",
"weights": "ImageNet",
"strategy": "Partial Fine-tuning",
"unfreeze_layers": 2,
"global_pooling": True,
"add_dense": True,
"dense_units": 512,
"dropout": 0.5
}
}
# Build model
builder = TransferLearningBuilder(config)
model = builder.build()
# Get model summary
summary = builder.get_model_summary()
print(f"Model: {summary['architecture']}")
print(f"Total parameters: {summary['total_parameters']:,}")
print(f"Trainable parameters: {summary['trainable_parameters']:,}")
# Setup training with different learning rates
base_params = []
head_params = []
for name, param in model.named_parameters():
if param.requires_grad:
if 'classifier' in name:
head_params.append(param)
else:
base_params.append(param)
optimizer = optim.Adam([
{'params': base_params, 'lr': 1e-5}, # Lower LR for base
{'params': head_params, 'lr': 1e-3} # Higher LR for head
])
# Training loop
model.train()
criterion = nn.CrossEntropyLoss()
for epoch in range(10):
for batch_idx, (images, labels) in enumerate(train_loader):
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# Get trainable layers for inspection
trainable_layers = model.get_trainable_layers()
print(f"\nTrainable layers: {len(trainable_layers)}")
Best Practices
Learning Rate Selection
Use different learning rates for base model and classifier:
optimizer = optim.Adam([
{'params': model.base_model.parameters(), 'lr': 1e-5},
{'params': model.classifier.parameters(), 'lr': 1e-3}
])
Progressive Unfreezing
Start with feature extraction, gradually unfreeze more layers:
# Stage 1: Feature extraction (10 epochs)
for param in model.base_model.parameters():
param.requires_grad = False
train(model, epochs=10, lr=1e-3)
# Stage 2: Fine-tune top layers (10 epochs)
layers = list(model.base_model.children())
for layer in layers[-3:]:
for param in layer.parameters():
param.requires_grad = True
train(model, epochs=10, lr=1e-4)
# Stage 3: Fine-tune all (5 epochs)
for param in model.base_model.parameters():
param.requires_grad = True
train(model, epochs=5, lr=1e-5)
Data Augmentation
Use strong augmentation with transfer learning:
from torchvision import transforms
transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.4, 0.4, 0.4),
transforms.RandomRotation(15),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
See Also