Overview
The CoCaLoss class extends ClipLoss to support training CoCa models, which combine contrastive learning with autoregressive caption generation. It computes a weighted combination of:
- Contrastive loss - CLIP-style image-text matching (inherited from ClipLoss)
- Caption loss - Cross-entropy loss for next-token prediction in caption generation
This dual-objective training enables models to both match images with text descriptions and generate natural language captions.
Class Definition
from open_clip import CoCaLoss
Initialization Parameters
Weight applied to the caption generation loss. Controls the relative importance of caption quality vs. contrastive matching.
Weight applied to the contrastive loss. Set to 0 to train only the caption decoder.
Padding token ID. Positions with this token are ignored when computing caption loss.
If True, computes contrastive loss only between local and gathered features. See ClipLoss documentation for details.
If True, gathers features with gradient flow enabled for contrastive loss.
If True, caches ground truth labels for contrastive loss.
Current process rank in distributed training.
Total number of processes in distributed training.
If True, uses Horovod for distributed operations instead of torch.distributed.
Attributes
Inherits all attributes from ClipLoss, plus:
- clip_loss_weight: Weight for contrastive loss component
- caption_loss_weight: Weight for caption generation loss component
- caption_loss: CrossEntropyLoss module with
ignore_index=pad_id
Key Methods
forward
def forward(
self,
image_features: torch.Tensor,
text_features: torch.Tensor,
logits: torch.Tensor,
labels: torch.Tensor,
logit_scale: torch.Tensor,
output_dict: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Dict[str, torch.Tensor]]:
Computes the combined CoCa loss.
Parameters:
image_features: Normalized contrastive image features of shape (batch_size, embed_dim)
text_features: Normalized contrastive text features of shape (batch_size, embed_dim)
logits: Caption generation logits of shape (batch_size, seq_len, vocab_size)
labels: Target token IDs for caption generation of shape (batch_size, seq_len)
logit_scale: Temperature parameter for contrastive loss (typically model.logit_scale.exp())
output_dict: If True, returns dict with named losses, else returns tuple
Returns:
- If
output_dict=False: Tuple of (clip_loss, caption_loss) - both weighted by their respective coefficients
- If
output_dict=True: Dictionary with keys "contrastive_loss" and "caption_loss"
Usage Example
import torch
from open_clip import create_model, CoCaLoss
# Create CoCa model
model = create_model('coca_ViT-B-32', pretrained=False)
# Create loss function
loss_fn = CoCaLoss(
caption_loss_weight=2.0, # Caption loss is weighted 2x
clip_loss_weight=1.0, # Standard contrastive loss weight
pad_id=0
)
# Training loop
images = torch.randn(16, 3, 224, 224)
captions = torch.randint(0, 49408, (16, 77))
# Forward pass through model
output = model(images, captions)
# Extract required tensors
image_features = output['image_features'] # Contrastive features
text_features = output['text_features'] # Contrastive features
logits = output['logits'] # Caption logits
labels = output['labels'] # Shifted caption targets
logit_scale = output['logit_scale']
# Compute loss
clip_loss, caption_loss = loss_fn(
image_features,
text_features,
logits,
labels,
logit_scale
)
total_loss = clip_loss + caption_loss
total_loss.backward()
Distributed Training Example
import torch.distributed as dist
from open_clip import create_model, CoCaLoss
# Initialize distributed
dist.init_process_group(backend='nccl')
rank = dist.get_rank()
world_size = dist.get_world_size()
# Create model and loss
model = create_model('coca_ViT-L-14').to(rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])
loss_fn = CoCaLoss(
caption_loss_weight=2.0,
clip_loss_weight=1.0,
pad_id=0,
rank=rank,
world_size=world_size,
cache_labels=True
)
for images, captions in dataloader:
images = images.to(rank)
captions = captions.to(rank)
output = model(images, captions)
clip_loss, caption_loss = loss_fn(
output['image_features'],
output['text_features'],
output['logits'],
output['labels'],
output['logit_scale']
)
total_loss = clip_loss + caption_loss
total_loss.backward()
optimizer.step()
optimizer.zero_grad()
Dictionary Output for Logging
loss_fn = CoCaLoss(
caption_loss_weight=2.0,
clip_loss_weight=1.0
)
output = model(images, captions)
loss_dict = loss_fn(
output['image_features'],
output['text_features'],
output['logits'],
output['labels'],
output['logit_scale'],
output_dict=True
)
print(loss_dict)
# {'contrastive_loss': tensor(4.2), 'caption_loss': tensor(3.1)}
# Easy logging to wandb/tensorboard
for key, value in loss_dict.items():
logger.log({key: value.item()})
Caption-Only Training
# Train only the caption decoder (freeze contrastive learning)
loss_fn = CoCaLoss(
caption_loss_weight=1.0,
clip_loss_weight=0.0, # Disable contrastive loss
pad_id=0
)
# Clip loss will be zero tensor (no backward pass)
clip_loss, caption_loss = loss_fn(...)
print(clip_loss) # tensor(0.)
Custom Loss Weighting Strategy
class AdaptiveCoCaLoss(CoCaLoss):
"""Dynamically adjust loss weights during training."""
def __init__(self, *args, initial_caption_weight=2.0, **kwargs):
super().__init__(
caption_loss_weight=initial_caption_weight,
*args,
**kwargs
)
self.step = 0
def forward(self, *args, **kwargs):
# Gradually increase caption loss weight
self.caption_loss_weight = 2.0 + (self.step / 10000)
self.step += 1
return super().forward(*args, **kwargs)
loss_fn = AdaptiveCoCaLoss(
clip_loss_weight=1.0,
pad_id=0
)
Monitoring Loss Components
import torch
from collections import defaultdict
loss_fn = CoCaLoss(
caption_loss_weight=2.0,
clip_loss_weight=1.0
)
metrics = defaultdict(list)
for epoch in range(num_epochs):
for images, captions in dataloader:
output = model(images, captions)
clip_loss, caption_loss = loss_fn(
output['image_features'],
output['text_features'],
output['logits'],
output['labels'],
output['logit_scale']
)
# Track both components
metrics['clip_loss'].append(clip_loss.item())
metrics['caption_loss'].append(caption_loss.item())
metrics['total_loss'].append((clip_loss + caption_loss).item())
total_loss = clip_loss + caption_loss
total_loss.backward()
optimizer.step()
optimizer.zero_grad()
# Epoch summary
print(f"Epoch {epoch}:")
print(f" Avg CLIP loss: {sum(metrics['clip_loss'])/len(metrics['clip_loss']):.4f}")
print(f" Avg Caption loss: {sum(metrics['caption_loss'])/len(metrics['caption_loss']):.4f}")
The total CoCa loss is:
Ltotal=w1⋅Lcontrastive+w2⋅Lcaption
Where:
-
Contrastive Loss (inherited from ClipLoss):
Lcontrastive=21[CE(τ⋅IT⊤,y)+CE(τ⋅TI⊤,y)]
-
Caption Loss (cross-entropy with teacher forcing):
Lcaption=−N1∑i=1N∑t=1T1[yit=pad]⋅logP(yit∣yi<t,xi)
Where:
- N = batch size
- T = sequence length
- yit = target token at position t
- xi = image features
- Padding tokens are ignored via
ignore_index
Hyperparameter Tuning
Recommended loss weight ratios:
| Dataset Type | Caption Weight | CLIP Weight | Notes |
|---|
| Large web data (LAION) | 1.0-2.0 | 1.0 | Balanced training |
| Caption-focused (COCO) | 2.0-3.0 | 1.0 | Prioritize generation quality |
| Retrieval-focused | 0.5-1.0 | 1.0 | Prioritize matching |
| Fine-tuning | 3.0-5.0 | 0.1-0.5 | Adapt caption style |
Guidelines:
- Start with
caption_loss_weight=2.0, clip_loss_weight=1.0
- If captions are low quality, increase
caption_loss_weight
- If retrieval performance is poor, increase
clip_loss_weight
- Monitor both loss components separately