Skip to main content

Overview

The CoCaLoss class extends ClipLoss to support training CoCa models, which combine contrastive learning with autoregressive caption generation. It computes a weighted combination of:
  1. Contrastive loss - CLIP-style image-text matching (inherited from ClipLoss)
  2. Caption loss - Cross-entropy loss for next-token prediction in caption generation
This dual-objective training enables models to both match images with text descriptions and generate natural language captions.

Class Definition

from open_clip import CoCaLoss

Initialization Parameters

caption_loss_weight
float
required
Weight applied to the caption generation loss. Controls the relative importance of caption quality vs. contrastive matching.
clip_loss_weight
float
required
Weight applied to the contrastive loss. Set to 0 to train only the caption decoder.
pad_id
int
default:"0"
Padding token ID. Positions with this token are ignored when computing caption loss.
local_loss
bool
default:"False"
If True, computes contrastive loss only between local and gathered features. See ClipLoss documentation for details.
gather_with_grad
bool
default:"False"
If True, gathers features with gradient flow enabled for contrastive loss.
cache_labels
bool
default:"False"
If True, caches ground truth labels for contrastive loss.
rank
int
default:"0"
Current process rank in distributed training.
world_size
int
default:"1"
Total number of processes in distributed training.
use_horovod
bool
default:"False"
If True, uses Horovod for distributed operations instead of torch.distributed.

Attributes

Inherits all attributes from ClipLoss, plus:
  • clip_loss_weight: Weight for contrastive loss component
  • caption_loss_weight: Weight for caption generation loss component
  • caption_loss: CrossEntropyLoss module with ignore_index=pad_id

Key Methods

forward

def forward(
    self,
    image_features: torch.Tensor,
    text_features: torch.Tensor,
    logits: torch.Tensor,
    labels: torch.Tensor,
    logit_scale: torch.Tensor,
    output_dict: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Dict[str, torch.Tensor]]:
Computes the combined CoCa loss. Parameters:
  • image_features: Normalized contrastive image features of shape (batch_size, embed_dim)
  • text_features: Normalized contrastive text features of shape (batch_size, embed_dim)
  • logits: Caption generation logits of shape (batch_size, seq_len, vocab_size)
  • labels: Target token IDs for caption generation of shape (batch_size, seq_len)
  • logit_scale: Temperature parameter for contrastive loss (typically model.logit_scale.exp())
  • output_dict: If True, returns dict with named losses, else returns tuple
Returns:
  • If output_dict=False: Tuple of (clip_loss, caption_loss) - both weighted by their respective coefficients
  • If output_dict=True: Dictionary with keys "contrastive_loss" and "caption_loss"

Usage Example

import torch
from open_clip import create_model, CoCaLoss

# Create CoCa model
model = create_model('coca_ViT-B-32', pretrained=False)

# Create loss function
loss_fn = CoCaLoss(
    caption_loss_weight=2.0,  # Caption loss is weighted 2x
    clip_loss_weight=1.0,     # Standard contrastive loss weight
    pad_id=0
)

# Training loop
images = torch.randn(16, 3, 224, 224)
captions = torch.randint(0, 49408, (16, 77))

# Forward pass through model
output = model(images, captions)

# Extract required tensors
image_features = output['image_features']  # Contrastive features
text_features = output['text_features']    # Contrastive features
logits = output['logits']                  # Caption logits
labels = output['labels']                  # Shifted caption targets
logit_scale = output['logit_scale']

# Compute loss
clip_loss, caption_loss = loss_fn(
    image_features,
    text_features,
    logits,
    labels,
    logit_scale
)

total_loss = clip_loss + caption_loss
total_loss.backward()

Distributed Training Example

import torch.distributed as dist
from open_clip import create_model, CoCaLoss

# Initialize distributed
dist.init_process_group(backend='nccl')
rank = dist.get_rank()
world_size = dist.get_world_size()

# Create model and loss
model = create_model('coca_ViT-L-14').to(rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])

loss_fn = CoCaLoss(
    caption_loss_weight=2.0,
    clip_loss_weight=1.0,
    pad_id=0,
    rank=rank,
    world_size=world_size,
    cache_labels=True
)

for images, captions in dataloader:
    images = images.to(rank)
    captions = captions.to(rank)
    
    output = model(images, captions)
    
    clip_loss, caption_loss = loss_fn(
        output['image_features'],
        output['text_features'],
        output['logits'],
        output['labels'],
        output['logit_scale']
    )
    
    total_loss = clip_loss + caption_loss
    total_loss.backward()
    optimizer.step()
    optimizer.zero_grad()

Dictionary Output for Logging

loss_fn = CoCaLoss(
    caption_loss_weight=2.0,
    clip_loss_weight=1.0
)

output = model(images, captions)

loss_dict = loss_fn(
    output['image_features'],
    output['text_features'],
    output['logits'],
    output['labels'],
    output['logit_scale'],
    output_dict=True
)

print(loss_dict)
# {'contrastive_loss': tensor(4.2), 'caption_loss': tensor(3.1)}

# Easy logging to wandb/tensorboard
for key, value in loss_dict.items():
    logger.log({key: value.item()})

Caption-Only Training

# Train only the caption decoder (freeze contrastive learning)
loss_fn = CoCaLoss(
    caption_loss_weight=1.0,
    clip_loss_weight=0.0,  # Disable contrastive loss
    pad_id=0
)

# Clip loss will be zero tensor (no backward pass)
clip_loss, caption_loss = loss_fn(...)
print(clip_loss)  # tensor(0.)

Custom Loss Weighting Strategy

class AdaptiveCoCaLoss(CoCaLoss):
    """Dynamically adjust loss weights during training."""
    
    def __init__(self, *args, initial_caption_weight=2.0, **kwargs):
        super().__init__(
            caption_loss_weight=initial_caption_weight,
            *args,
            **kwargs
        )
        self.step = 0
    
    def forward(self, *args, **kwargs):
        # Gradually increase caption loss weight
        self.caption_loss_weight = 2.0 + (self.step / 10000)
        self.step += 1
        return super().forward(*args, **kwargs)

loss_fn = AdaptiveCoCaLoss(
    clip_loss_weight=1.0,
    pad_id=0
)

Monitoring Loss Components

import torch
from collections import defaultdict

loss_fn = CoCaLoss(
    caption_loss_weight=2.0,
    clip_loss_weight=1.0
)

metrics = defaultdict(list)

for epoch in range(num_epochs):
    for images, captions in dataloader:
        output = model(images, captions)
        
        clip_loss, caption_loss = loss_fn(
            output['image_features'],
            output['text_features'],
            output['logits'],
            output['labels'],
            output['logit_scale']
        )
        
        # Track both components
        metrics['clip_loss'].append(clip_loss.item())
        metrics['caption_loss'].append(caption_loss.item())
        metrics['total_loss'].append((clip_loss + caption_loss).item())
        
        total_loss = clip_loss + caption_loss
        total_loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
    # Epoch summary
    print(f"Epoch {epoch}:")
    print(f"  Avg CLIP loss: {sum(metrics['clip_loss'])/len(metrics['clip_loss']):.4f}")
    print(f"  Avg Caption loss: {sum(metrics['caption_loss'])/len(metrics['caption_loss']):.4f}")

Mathematical Formulation

The total CoCa loss is: Ltotal=w1Lcontrastive+w2Lcaption\mathcal{L}_{\text{total}} = w_1 \cdot \mathcal{L}_{\text{contrastive}} + w_2 \cdot \mathcal{L}_{\text{caption}} Where:
  1. Contrastive Loss (inherited from ClipLoss): Lcontrastive=12[CE(τIT,y)+CE(τTI,y)]\mathcal{L}_{\text{contrastive}} = \frac{1}{2}\left[\text{CE}(\tau \cdot I T^\top, y) + \text{CE}(\tau \cdot T I^\top, y)\right]
  2. Caption Loss (cross-entropy with teacher forcing): Lcaption=1Ni=1Nt=1T1[yitpad]logP(yityi<t,xi)\mathcal{L}_{\text{caption}} = -\frac{1}{N} \sum_{i=1}^{N} \sum_{t=1}^{T} \mathbb{1}[y_i^t \neq \text{pad}] \cdot \log P(y_i^t | y_i^{<t}, x_i) Where:
    • NN = batch size
    • TT = sequence length
    • yity_i^t = target token at position tt
    • xix_i = image features
    • Padding tokens are ignored via ignore_index

Hyperparameter Tuning

Recommended loss weight ratios:
Dataset TypeCaption WeightCLIP WeightNotes
Large web data (LAION)1.0-2.01.0Balanced training
Caption-focused (COCO)2.0-3.01.0Prioritize generation quality
Retrieval-focused0.5-1.01.0Prioritize matching
Fine-tuning3.0-5.00.1-0.5Adapt caption style
Guidelines:
  • Start with caption_loss_weight=2.0, clip_loss_weight=1.0
  • If captions are low quality, increase caption_loss_weight
  • If retrieval performance is poor, increase clip_loss_weight
  • Monitor both loss components separately

Build docs developers (and LLMs) love