Skip to main content

Overview

The ClipLoss class implements the symmetric cross-entropy loss used to train CLIP models. It computes contrastive loss between image and text features by treating each image-text pair as a positive match and all other pairs in the batch as negatives. The loss is computed as the average of:
  • Image-to-text classification loss
  • Text-to-image classification loss
Key features:
  • Distributed training support - Efficient all-gather operations across GPUs
  • Local loss option - Compute loss only on local batch for memory efficiency
  • Label caching - Cache ground truth labels for faster training
  • Horovod support - Alternative distributed backend

Class Definition

from open_clip import ClipLoss

Initialization Parameters

local_loss
bool
default:"False"
If True, computes loss only between local image features and gathered text features (and vice versa). Reduces memory usage in distributed training but changes the gradient dynamics.
gather_with_grad
bool
default:"False"
If True, gathers features with gradient flow enabled. Required for certain distributed training strategies but increases memory usage.
cache_labels
bool
default:"False"
If True, caches ground truth labels to avoid recomputing them each forward pass. Saves computation at the cost of memory.
rank
int
default:"0"
Current process rank in distributed training. Should match the rank from your distributed backend.
world_size
int
default:"1"
Total number of processes in distributed training. Set to 1 for single-GPU training.
use_horovod
bool
default:"False"
If True, uses Horovod for distributed operations instead of torch.distributed.

Attributes

  • local_loss: Whether local loss mode is enabled
  • gather_with_grad: Whether gradient flows through gather operations
  • cache_labels: Whether label caching is enabled
  • rank: Current process rank
  • world_size: Total number of processes
  • use_horovod: Whether using Horovod backend
  • prev_num_logits: Cached logits count for label caching
  • labels: Dictionary of cached label tensors per device

Key Methods

forward

def forward(
    self,
    image_features: torch.Tensor,
    text_features: torch.Tensor,
    logit_scale: torch.Tensor,
    logit_bias: Optional[torch.Tensor] = None,
    output_dict: bool = False,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
Computes the contrastive loss. Parameters:
  • image_features: Normalized image features of shape (batch_size, embed_dim)
  • text_features: Normalized text features of shape (batch_size, embed_dim)
  • logit_scale: Temperature parameter (typically model.logit_scale.exp())
  • logit_bias: Optional bias term to add to logits
  • output_dict: If True, returns dict with key “contrastive_loss”, else returns scalar
Returns: Contrastive loss value (average of image-to-text and text-to-image losses)

get_logits

def get_logits(
    self,
    image_features: torch.Tensor,
    text_features: torch.Tensor,
    logit_scale: torch.Tensor,
    logit_bias: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
Computes similarity logits between image and text features. Handles distributed gathering if world_size > 1. Returns: Tuple of (logits_per_image, logits_per_text)

get_ground_truth

def get_ground_truth(self, device: torch.device, num_logits: int) -> torch.Tensor:
Generates or retrieves cached ground truth labels (diagonal identity matrix). Returns: Label tensor of shape (num_logits,) with values [0, 1, 2, …, num_logits-1]

Usage Example

import torch
from open_clip import create_model_and_transforms, ClipLoss

# Create model
model, _, preprocess = create_model_and_transforms('ViT-B-32', pretrained='openai')

# Create loss function
loss_fn = ClipLoss()

# Training loop
images = torch.randn(32, 3, 224, 224)
texts = torch.randint(0, 49408, (32, 77))

# Forward pass
image_features = model.encode_image(images, normalize=True)
text_features = model.encode_text(texts, normalize=True)
logit_scale = model.logit_scale.exp()

# Compute loss
loss = loss_fn(image_features, text_features, logit_scale)
loss.backward()

Distributed Training Example

import torch
import torch.distributed as dist
from open_clip import create_model, ClipLoss

# Initialize distributed training
dist.init_process_group(backend='nccl')
rank = dist.get_rank()
world_size = dist.get_world_size()

# Create model and loss
model = create_model('ViT-B-32').to(rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])

loss_fn = ClipLoss(
    local_loss=False,  # Global loss across all GPUs
    gather_with_grad=True,  # Enable gradient through gather
    cache_labels=True,  # Cache labels for efficiency
    rank=rank,
    world_size=world_size
)

# Training loop
for images, texts in dataloader:
    images = images.to(rank)
    texts = texts.to(rank)
    
    # Forward
    image_features = model.module.encode_image(images, normalize=True)
    text_features = model.module.encode_text(texts, normalize=True)
    logit_scale = model.module.logit_scale.exp()
    
    # Loss computation (automatically gathers features from all GPUs)
    loss = loss_fn(image_features, text_features, logit_scale)
    
    # Backward
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

Local Loss Example

# Use local loss for memory efficiency
loss_fn = ClipLoss(
    local_loss=True,  # Only compute loss on local-global pairs
    gather_with_grad=False,
    rank=rank,
    world_size=world_size
)

# With local loss:
# - Each GPU computes image_local @ text_global.T
# - Reduces memory for very large batch sizes
# - Changes gradient dynamics (no all-to-all comparisons)

With Logit Bias

# Train with learnable bias term
model = create_model('ViT-B-32', init_logit_bias=0.0)  # Initialize bias
loss_fn = ClipLoss()

image_features = model.encode_image(images, normalize=True)
text_features = model.encode_text(texts, normalize=True)
logit_scale = model.logit_scale.exp()
logit_bias = model.logit_bias  # Learned parameter

loss = loss_fn(image_features, text_features, logit_scale, logit_bias)

Dictionary Output

# Get loss in dictionary format (useful for logging)
loss_dict = loss_fn(
    image_features,
    text_features,
    logit_scale,
    output_dict=True
)

print(loss_dict)  # {'contrastive_loss': tensor(4.2)}

Performance Considerations

Memory vs. Accuracy Trade-offs:
ConfigurationMemoryAccuracyUse Case
local_loss=FalseHighBestStandard training, moderate batch sizes
local_loss=TrueLowerGoodVery large batch sizes (>2048)
gather_with_grad=TrueHigherBetter gradientsAdvanced distributed strategies
gather_with_grad=FalseLowerStandard gradientsDefault distributed training
Best Practices:
  • Enable cache_labels=True for training (slight memory cost, faster)
  • Use local_loss=True only when memory is constrained
  • Set gather_with_grad=False for most use cases
  • Ensure features are normalized before passing to loss

Mathematical Formulation

Given normalized image features IRN×DI \in \mathbb{R}^{N \times D} and text features TRN×DT \in \mathbb{R}^{N \times D}:
  1. Compute logits: L=τ(IT)+bL = \tau \cdot (I \cdot T^\top) + b
    • τ\tau = logit_scale.exp()
    • bb = logit_bias (optional)
  2. Compute symmetric cross-entropy: L=12[CE(L,labels)+CE(L,labels)]\mathcal{L} = \frac{1}{2} \left[ \text{CE}(L, \text{labels}) + \text{CE}(L^\top, \text{labels}) \right] where labels = [0,1,2,...,N1][0, 1, 2, ..., N-1] (diagonal is positive)
  3. In distributed setting, features are gathered from all GPUs before computing logits
  • CLIP - Model that uses this loss
  • CoCaLoss - Extended loss with caption generation
  • SigLipLoss - Alternative sigmoid-based loss

Build docs developers (and LLMs) love