Overview
The ClipLoss class implements the symmetric cross-entropy loss used to train CLIP models. It computes contrastive loss between image and text features by treating each image-text pair as a positive match and all other pairs in the batch as negatives.
The loss is computed as the average of:
- Image-to-text classification loss
- Text-to-image classification loss
Key features:
- Distributed training support - Efficient all-gather operations across GPUs
- Local loss option - Compute loss only on local batch for memory efficiency
- Label caching - Cache ground truth labels for faster training
- Horovod support - Alternative distributed backend
Class Definition
from open_clip import ClipLoss
Initialization Parameters
If True, computes loss only between local image features and gathered text features (and vice versa). Reduces memory usage in distributed training but changes the gradient dynamics.
If True, gathers features with gradient flow enabled. Required for certain distributed training strategies but increases memory usage.
If True, caches ground truth labels to avoid recomputing them each forward pass. Saves computation at the cost of memory.
Current process rank in distributed training. Should match the rank from your distributed backend.
Total number of processes in distributed training. Set to 1 for single-GPU training.
If True, uses Horovod for distributed operations instead of torch.distributed.
Attributes
- local_loss: Whether local loss mode is enabled
- gather_with_grad: Whether gradient flows through gather operations
- cache_labels: Whether label caching is enabled
- rank: Current process rank
- world_size: Total number of processes
- use_horovod: Whether using Horovod backend
- prev_num_logits: Cached logits count for label caching
- labels: Dictionary of cached label tensors per device
Key Methods
forward
def forward(
self,
image_features: torch.Tensor,
text_features: torch.Tensor,
logit_scale: torch.Tensor,
logit_bias: Optional[torch.Tensor] = None,
output_dict: bool = False,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
Computes the contrastive loss.
Parameters:
image_features: Normalized image features of shape (batch_size, embed_dim)
text_features: Normalized text features of shape (batch_size, embed_dim)
logit_scale: Temperature parameter (typically model.logit_scale.exp())
logit_bias: Optional bias term to add to logits
output_dict: If True, returns dict with key “contrastive_loss”, else returns scalar
Returns: Contrastive loss value (average of image-to-text and text-to-image losses)
get_logits
def get_logits(
self,
image_features: torch.Tensor,
text_features: torch.Tensor,
logit_scale: torch.Tensor,
logit_bias: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
Computes similarity logits between image and text features. Handles distributed gathering if world_size > 1.
Returns: Tuple of (logits_per_image, logits_per_text)
get_ground_truth
def get_ground_truth(self, device: torch.device, num_logits: int) -> torch.Tensor:
Generates or retrieves cached ground truth labels (diagonal identity matrix).
Returns: Label tensor of shape (num_logits,) with values [0, 1, 2, …, num_logits-1]
Usage Example
import torch
from open_clip import create_model_and_transforms, ClipLoss
# Create model
model, _, preprocess = create_model_and_transforms('ViT-B-32', pretrained='openai')
# Create loss function
loss_fn = ClipLoss()
# Training loop
images = torch.randn(32, 3, 224, 224)
texts = torch.randint(0, 49408, (32, 77))
# Forward pass
image_features = model.encode_image(images, normalize=True)
text_features = model.encode_text(texts, normalize=True)
logit_scale = model.logit_scale.exp()
# Compute loss
loss = loss_fn(image_features, text_features, logit_scale)
loss.backward()
Distributed Training Example
import torch
import torch.distributed as dist
from open_clip import create_model, ClipLoss
# Initialize distributed training
dist.init_process_group(backend='nccl')
rank = dist.get_rank()
world_size = dist.get_world_size()
# Create model and loss
model = create_model('ViT-B-32').to(rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])
loss_fn = ClipLoss(
local_loss=False, # Global loss across all GPUs
gather_with_grad=True, # Enable gradient through gather
cache_labels=True, # Cache labels for efficiency
rank=rank,
world_size=world_size
)
# Training loop
for images, texts in dataloader:
images = images.to(rank)
texts = texts.to(rank)
# Forward
image_features = model.module.encode_image(images, normalize=True)
text_features = model.module.encode_text(texts, normalize=True)
logit_scale = model.module.logit_scale.exp()
# Loss computation (automatically gathers features from all GPUs)
loss = loss_fn(image_features, text_features, logit_scale)
# Backward
loss.backward()
optimizer.step()
optimizer.zero_grad()
Local Loss Example
# Use local loss for memory efficiency
loss_fn = ClipLoss(
local_loss=True, # Only compute loss on local-global pairs
gather_with_grad=False,
rank=rank,
world_size=world_size
)
# With local loss:
# - Each GPU computes image_local @ text_global.T
# - Reduces memory for very large batch sizes
# - Changes gradient dynamics (no all-to-all comparisons)
With Logit Bias
# Train with learnable bias term
model = create_model('ViT-B-32', init_logit_bias=0.0) # Initialize bias
loss_fn = ClipLoss()
image_features = model.encode_image(images, normalize=True)
text_features = model.encode_text(texts, normalize=True)
logit_scale = model.logit_scale.exp()
logit_bias = model.logit_bias # Learned parameter
loss = loss_fn(image_features, text_features, logit_scale, logit_bias)
Dictionary Output
# Get loss in dictionary format (useful for logging)
loss_dict = loss_fn(
image_features,
text_features,
logit_scale,
output_dict=True
)
print(loss_dict) # {'contrastive_loss': tensor(4.2)}
Memory vs. Accuracy Trade-offs:
| Configuration | Memory | Accuracy | Use Case |
|---|
local_loss=False | High | Best | Standard training, moderate batch sizes |
local_loss=True | Lower | Good | Very large batch sizes (>2048) |
gather_with_grad=True | Higher | Better gradients | Advanced distributed strategies |
gather_with_grad=False | Lower | Standard gradients | Default distributed training |
Best Practices:
- Enable
cache_labels=True for training (slight memory cost, faster)
- Use
local_loss=True only when memory is constrained
- Set
gather_with_grad=False for most use cases
- Ensure features are normalized before passing to loss
Given normalized image features I∈RN×D and text features T∈RN×D:
-
Compute logits: L=τ⋅(I⋅T⊤)+b
- τ =
logit_scale.exp()
- b =
logit_bias (optional)
-
Compute symmetric cross-entropy:
L=21[CE(L,labels)+CE(L⊤,labels)]
where labels = [0,1,2,...,N−1] (diagonal is positive)
-
In distributed setting, features are gathered from all GPUs before computing logits
- CLIP - Model that uses this loss
- CoCaLoss - Extended loss with caption generation
- SigLipLoss - Alternative sigmoid-based loss