Skip to main content

Overview

The CLIP class implements the core CLIP (Contrastive Language-Image Pre-Training) model architecture. It combines a vision encoder and text encoder to learn aligned multimodal representations through contrastive learning. The model outputs normalized image and text features that can be compared in a shared embedding space using cosine similarity.

Class Definition

from open_clip import CLIP

Initialization Parameters

embed_dim
int
required
Dimensionality of the joint embedding space for image and text features.
vision_cfg
CLIPVisionCfg
required
Configuration object for the vision encoder. Controls architecture (ViT, ResNet, or timm model), layer depth, width, and attention settings.
text_cfg
CLIPTextCfg
required
Configuration object for the text encoder. Specifies transformer architecture, vocabulary size, context length, and pooling strategy.
quick_gelu
bool
default:"False"
Use QuickGELU activation (as in original OpenAI models) instead of standard GELU. QuickGELU is less memory efficient but maintains compatibility with OpenAI checkpoints.
init_logit_scale
float
default:"np.log(1 / 0.07)"
Initial value for the learned temperature parameter (logit scale) that controls the sharpness of the similarity distribution.
init_logit_bias
Optional[float]
default:"None"
Optional learnable bias term added to logits. When None, no bias is used.
nonscalar_logit_scale
bool
default:"False"
If True, logit_scale has shape [1] instead of []. Some training frameworks require explicit dimensions.
cast_dtype
Optional[torch.dtype]
default:"None"
Precision for model computations (e.g., torch.float16, torch.bfloat16). Used for mixed precision training.
output_dict
bool
default:"False"
If True, forward() returns a dictionary with named outputs. If False, returns a tuple.

Attributes

  • visual: Vision encoder module (VisionTransformer, ModifiedResNet, or TimmModel)
  • transformer: Text transformer encoder
  • token_embedding: Text token embedding layer
  • positional_embedding: Learned positional embeddings for text
  • ln_final: Final layer normalization for text features
  • text_projection: Projection matrix from text features to joint embedding space
  • logit_scale: Learned temperature parameter (exponential of stored value)
  • logit_bias: Optional learned bias (if init_logit_bias is not None)
  • context_length: Maximum text sequence length
  • vocab_size: Size of text vocabulary

Key Methods

encode_image

def encode_image(self, image: torch.Tensor, normalize: bool = False) -> torch.Tensor:
Encodes images into the joint embedding space. Parameters:
  • image: Image tensor of shape (batch_size, channels, height, width)
  • normalize: If True, L2-normalizes the output features
Returns: Image features of shape (batch_size, embed_dim)

encode_text

def encode_text(self, text: torch.Tensor, normalize: bool = False) -> torch.Tensor:
Encodes tokenized text into the joint embedding space. Parameters:
  • text: Tokenized text tensor of shape (batch_size, context_length)
  • normalize: If True, L2-normalizes the output features
Returns: Text features of shape (batch_size, embed_dim)

get_logits

def get_logits(self, image: torch.Tensor, text: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
Computes similarity logits between image and text features. Parameters:
  • image: Image tensor
  • text: Tokenized text tensor
Returns: Tuple of (image_logits, text_logits) representing similarity scores scaled by temperature

forward

def forward(
    self,
    image: Optional[torch.Tensor] = None,
    text: Optional[torch.Tensor] = None
) -> Union[Dict, Tuple]:
Forward pass through the model. Parameters:
  • image: Optional image tensor
  • text: Optional tokenized text tensor
Returns:
  • If output_dict=True: Dictionary with keys image_features, text_features, logit_scale, and optionally logit_bias
  • If output_dict=False: Tuple of (image_features, text_features, logit_scale) or (image_features, text_features, logit_scale, logit_bias)

forward_intermediates

def forward_intermediates(
    self,
    image: Optional[torch.Tensor] = None,
    text: Optional[torch.Tensor] = None,
    image_indices: Optional[Union[int, List[int]]] = None,
    text_indices: Optional[Union[int, List[int]]] = None,
    stop_early: bool = False,
    normalize: bool = True,
    normalize_intermediates: bool = False,
    intermediates_only: bool = False,
    image_output_fmt: str = 'NCHW',
    image_output_extra_tokens: bool = False,
    text_output_fmt: str = 'NLC',
    text_output_extra_tokens: bool = False,
    output_logits: bool = False,
    output_logit_scale_bias: bool = False,
) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
Forward pass that returns intermediate layer features. Useful for feature extraction, analysis, and distillation. Returns: Dictionary with intermediate features, final features, and optionally logits

lock_image_tower

def lock_image_tower(self, unlocked_groups: int = 0, freeze_bn_stats: bool = False):
Freezes the image encoder for efficient fine-tuning (LiT-style training). Parameters:
  • unlocked_groups: Number of layer groups to keep trainable (from the end)
  • freeze_bn_stats: If True, freezes batch normalization statistics

lock_text_tower

def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
Freezes the text encoder for efficient fine-tuning. Parameters:
  • unlocked_layers: Number of transformer layers to keep trainable (from the end)
  • freeze_layer_norm: If True, freezes layer normalization parameters

set_grad_checkpointing

def set_grad_checkpointing(self, enable: bool = True):
Enables or disables gradient checkpointing to reduce memory usage during training at the cost of computation.

Usage Example

import torch
from open_clip import CLIP, CLIPVisionCfg, CLIPTextCfg

# Define model configuration
vision_cfg = CLIPVisionCfg(
    layers=12,
    width=768,
    patch_size=16,
    image_size=224
)

text_cfg = CLIPTextCfg(
    context_length=77,
    vocab_size=49408,
    width=512,
    heads=8,
    layers=12
)

# Initialize model
model = CLIP(
    embed_dim=512,
    vision_cfg=vision_cfg,
    text_cfg=text_cfg,
    output_dict=True
)

# Forward pass
images = torch.randn(4, 3, 224, 224)
texts = torch.randint(0, 49408, (4, 77))

output = model(images, texts)
image_features = output['image_features']  # (4, 512)
text_features = output['text_features']    # (4, 512)
logit_scale = output['logit_scale']        # scalar

# Compute similarity
similarity = (image_features @ text_features.T) * logit_scale

# Encode separately
image_emb = model.encode_image(images, normalize=True)
text_emb = model.encode_text(texts, normalize=True)

Fine-tuning Example

# Lock image tower for text-only fine-tuning (LiT)
model.lock_image_tower(unlocked_groups=0, freeze_bn_stats=True)

# Enable gradient checkpointing for memory efficiency
model.set_grad_checkpointing(enable=True)

# Only text parameters will be updated during training
for name, param in model.named_parameters():
    print(f"{name}: requires_grad={param.requires_grad}")
  • CustomTextCLIP - Variant with separately built text tower
  • CoCa - Contrastive Captioner model
  • ClipLoss - Contrastive loss function for CLIP

Build docs developers (and LLMs) love