Skip to main content

Overview

The CoCa (Contrastive Captioner) class implements a multimodal model that combines contrastive learning (like CLIP) with generative captioning capabilities. It consists of:
  1. Image encoder - Encodes images into a latent space and produces token embeddings
  2. Text encoder - Encodes text into contrastive features for CLIP-style learning
  3. Multimodal text decoder - Generates captions conditioned on image embeddings
CoCa can be used for both zero-shot image-text matching and image captioning.

Class Definition

from open_clip import CoCa

Initialization Parameters

embed_dim
int
required
Dimensionality of the joint embedding space for contrastive image and text features.
multimodal_cfg
MultimodalCfg
required
Configuration for the multimodal text decoder. Controls the cross-attention layers that condition text generation on image features.
text_cfg
CLIPTextCfg
required
Configuration for the unimodal text encoder used for contrastive learning.
vision_cfg
CLIPVisionCfg
required
Configuration for the vision encoder.
quick_gelu
bool
default:"False"
Use QuickGELU activation instead of standard GELU.
init_logit_scale
float
default:"np.log(1 / 0.07)"
Initial value for the learned temperature parameter in contrastive learning.
init_logit_bias
Optional[float]
default:"None"
Optional learnable bias term added to contrastive logits.
nonscalar_logit_scale
bool
default:"False"
If True, logit_scale has shape [1] instead of [].
cast_dtype
Optional[torch.dtype]
default:"None"
Precision for model computations (e.g., torch.float16, torch.bfloat16).
pad_id
int
default:"0"
Token ID used for padding in the vocabulary.

Attributes

  • visual: Vision encoder module
  • text: Unimodal text encoder for contrastive learning
  • text_decoder: Multimodal transformer decoder for caption generation
  • logit_scale: Learned temperature parameter for contrastive learning
  • logit_bias: Optional learned bias for contrastive logits
  • pad_id: Padding token ID
  • context_length: Maximum sequence length for caption generation

Key Methods

encode_image

def encode_image(self, images: torch.Tensor, normalize: bool = True) -> torch.Tensor:
Encodes images into the contrastive embedding space. Parameters:
  • images: Image tensor of shape (batch_size, channels, height, width)
  • normalize: If True, L2-normalizes the output features
Returns: Image features of shape (batch_size, embed_dim)

encode_text

def encode_text(self, text: torch.Tensor, normalize: bool = True) -> torch.Tensor:
Encodes tokenized text into the contrastive embedding space. Parameters:
  • text: Tokenized text tensor of shape (batch_size, context_length)
  • normalize: If True, L2-normalizes the output features
Returns: Text features of shape (batch_size, embed_dim)

forward

def forward(
    self,
    image: torch.Tensor,
    text: Optional[torch.Tensor] = None,
    image_latent: Optional[torch.Tensor] = None,
    image_embs: Optional[torch.Tensor] = None,
    output_labels: bool = True,
) -> Dict[str, torch.Tensor]:
Forward pass through the model. Parameters:
  • image: Image tensor of shape (batch_size, channels, height, width)
  • text: Optional tokenized text for teacher-forcing caption generation
  • image_latent: Optional pre-computed contrastive image features
  • image_embs: Optional pre-computed image token embeddings
  • output_labels: If True, creates caption labels by shifting text input
Returns: Dictionary containing:
  • image_features: Contrastive image embeddings (batch_size, embed_dim)
  • text_features: Contrastive text embeddings (batch_size, embed_dim) (if text provided)
  • logits: Caption generation logits (batch_size, seq_len, vocab_size) (if text provided)
  • labels: Ground truth labels for caption loss (batch_size, seq_len-1) (if output_labels=True)
  • logit_scale: Exponential of learned temperature parameter
  • logit_bias: Learned bias (if initialized with init_logit_bias)
  • image_embs: Image token embeddings (if text not provided)

generate

def generate(
    self,
    image: torch.Tensor,
    text: Optional[torch.Tensor] = None,
    seq_len: int = 30,
    max_seq_len: int = 77,
    temperature: float = 1.0,
    generation_type: str = "beam_search",
    top_p: float = 0.1,
    top_k: int = 1,
    pad_token_id: Optional[int] = None,
    eos_token_id: Optional[int] = None,
    sot_token_id: Optional[int] = None,
    num_beams: int = 6,
    num_beam_groups: int = 3,
    min_seq_len: int = 5,
    stopping_criteria: Optional[List] = None,
    repetition_penalty: float = 1.0,
    fixed_output_length: bool = False
) -> torch.Tensor:
Generates captions for images using beam search or sampling strategies. Parameters:
  • image: Image tensor to caption
  • text: Optional text prefix to continue from
  • seq_len: Target sequence length for generation
  • max_seq_len: Maximum context length (default 77)
  • temperature: Sampling temperature (higher = more random)
  • generation_type: One of “beam_search”, “top_p”, or “top_k”
  • top_p: Nucleus sampling parameter (keep tokens in top-p probability mass)
  • top_k: Top-k sampling parameter (keep top k tokens)
  • pad_token_id: Padding token (default 0)
  • eos_token_id: End-of-sequence token (default 49407)
  • sot_token_id: Start-of-text token (default 49406)
  • num_beams: Number of beams for beam search
  • num_beam_groups: Number of beam groups for diverse beam search
  • min_seq_len: Minimum generated sequence length
  • stopping_criteria: Optional list of stopping criteria
  • repetition_penalty: Penalty for repeated tokens (1.0 = no penalty)
  • fixed_output_length: If True, pad output to seq_len
Returns: Generated token IDs of shape (batch_size, seq_len) Note: Requires transformers library: pip install transformers

set_grad_checkpointing

def set_grad_checkpointing(self, enable: bool = True):
Enables gradient checkpointing for all three model components (visual, text, text_decoder) to reduce memory usage.

Usage Example

import torch
from open_clip import CoCa, CLIPVisionCfg, CLIPTextCfg, MultimodalCfg

# Define configurations
vision_cfg = CLIPVisionCfg(
    layers=12,
    width=768,
    patch_size=16,
    image_size=224,
    output_tokens=True  # Required for caption generation
)

text_cfg = CLIPTextCfg(
    context_length=77,
    vocab_size=49408,
    width=512,
    heads=8,
    layers=12
)

multimodal_cfg = MultimodalCfg(
    context_length=77,
    vocab_size=49408,
    width=512,
    heads=8,
    layers=6  # Decoder layers
)

# Initialize model
model = CoCa(
    embed_dim=512,
    vision_cfg=vision_cfg,
    text_cfg=text_cfg,
    multimodal_cfg=multimodal_cfg
)

# Training: compute both contrastive and captioning loss
images = torch.randn(4, 3, 224, 224)
captions = torch.randint(0, 49408, (4, 77))

output = model(images, captions)
image_features = output['image_features']
text_features = output['text_features']
logits = output['logits']
labels = output['labels']

# Caption generation at inference
model.eval()
with torch.no_grad():
    generated_ids = model.generate(
        images,
        seq_len=30,
        generation_type="beam_search",
        num_beams=6
    )

Contrastive Learning Example

# Use CoCa for zero-shot image-text matching
model.eval()
images = torch.randn(4, 3, 224, 224)
texts = torch.randint(0, 49408, (4, 77))

with torch.no_grad():
    image_emb = model.encode_image(images, normalize=True)
    text_emb = model.encode_text(texts, normalize=True)
    
    # Compute similarities
    logit_scale = model.logit_scale.exp()
    similarity = logit_scale * (image_emb @ text_emb.T)
    probs = similarity.softmax(dim=-1)

Caption Generation with Custom Parameters

# Generate diverse captions with nucleus sampling
generated = model.generate(
    images,
    seq_len=50,
    generation_type="top_p",
    top_p=0.9,
    temperature=0.8,
    repetition_penalty=1.2,
    min_seq_len=10
)

# Beam search for more focused captions
generated = model.generate(
    images,
    seq_len=30,
    generation_type="beam_search",
    num_beams=10,
    num_beam_groups=5,  # Diverse beam search
    min_seq_len=8
)

Build docs developers (and LLMs) love