Overview
The CLIP class implements the core CLIP (Contrastive Language-Image Pre-Training) model architecture. It combines a vision encoder and text encoder to learn aligned multimodal representations through contrastive learning.
The model outputs normalized image and text features that can be compared in a shared embedding space using cosine similarity.
Class Definition
from open_clip import CLIP
Initialization Parameters
Dimensionality of the joint embedding space for image and text features.
Configuration object for the vision encoder. Controls architecture (ViT, ResNet, or timm model), layer depth, width, and attention settings.
Configuration object for the text encoder. Specifies transformer architecture, vocabulary size, context length, and pooling strategy.
Use QuickGELU activation (as in original OpenAI models) instead of standard GELU. QuickGELU is less memory efficient but maintains compatibility with OpenAI checkpoints.
init_logit_scale
float
default:"np.log(1 / 0.07)"
Initial value for the learned temperature parameter (logit scale) that controls the sharpness of the similarity distribution.
init_logit_bias
Optional[float]
default:"None"
Optional learnable bias term added to logits. When None, no bias is used.
If True, logit_scale has shape [1] instead of []. Some training frameworks require explicit dimensions.
cast_dtype
Optional[torch.dtype]
default:"None"
Precision for model computations (e.g., torch.float16, torch.bfloat16). Used for mixed precision training.
If True, forward() returns a dictionary with named outputs. If False, returns a tuple.
Attributes
- visual: Vision encoder module (VisionTransformer, ModifiedResNet, or TimmModel)
- transformer: Text transformer encoder
- token_embedding: Text token embedding layer
- positional_embedding: Learned positional embeddings for text
- ln_final: Final layer normalization for text features
- text_projection: Projection matrix from text features to joint embedding space
- logit_scale: Learned temperature parameter (exponential of stored value)
- logit_bias: Optional learned bias (if init_logit_bias is not None)
- context_length: Maximum text sequence length
- vocab_size: Size of text vocabulary
Key Methods
encode_image
def encode_image(self, image: torch.Tensor, normalize: bool = False) -> torch.Tensor:
Encodes images into the joint embedding space.
Parameters:
image: Image tensor of shape (batch_size, channels, height, width)
normalize: If True, L2-normalizes the output features
Returns: Image features of shape (batch_size, embed_dim)
encode_text
def encode_text(self, text: torch.Tensor, normalize: bool = False) -> torch.Tensor:
Encodes tokenized text into the joint embedding space.
Parameters:
text: Tokenized text tensor of shape (batch_size, context_length)
normalize: If True, L2-normalizes the output features
Returns: Text features of shape (batch_size, embed_dim)
get_logits
def get_logits(self, image: torch.Tensor, text: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
Computes similarity logits between image and text features.
Parameters:
image: Image tensor
text: Tokenized text tensor
Returns: Tuple of (image_logits, text_logits) representing similarity scores scaled by temperature
forward
def forward(
self,
image: Optional[torch.Tensor] = None,
text: Optional[torch.Tensor] = None
) -> Union[Dict, Tuple]:
Forward pass through the model.
Parameters:
image: Optional image tensor
text: Optional tokenized text tensor
Returns:
- If
output_dict=True: Dictionary with keys image_features, text_features, logit_scale, and optionally logit_bias
- If
output_dict=False: Tuple of (image_features, text_features, logit_scale) or (image_features, text_features, logit_scale, logit_bias)
def forward_intermediates(
self,
image: Optional[torch.Tensor] = None,
text: Optional[torch.Tensor] = None,
image_indices: Optional[Union[int, List[int]]] = None,
text_indices: Optional[Union[int, List[int]]] = None,
stop_early: bool = False,
normalize: bool = True,
normalize_intermediates: bool = False,
intermediates_only: bool = False,
image_output_fmt: str = 'NCHW',
image_output_extra_tokens: bool = False,
text_output_fmt: str = 'NLC',
text_output_extra_tokens: bool = False,
output_logits: bool = False,
output_logit_scale_bias: bool = False,
) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
Forward pass that returns intermediate layer features. Useful for feature extraction, analysis, and distillation.
Returns: Dictionary with intermediate features, final features, and optionally logits
lock_image_tower
def lock_image_tower(self, unlocked_groups: int = 0, freeze_bn_stats: bool = False):
Freezes the image encoder for efficient fine-tuning (LiT-style training).
Parameters:
unlocked_groups: Number of layer groups to keep trainable (from the end)
freeze_bn_stats: If True, freezes batch normalization statistics
lock_text_tower
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
Freezes the text encoder for efficient fine-tuning.
Parameters:
unlocked_layers: Number of transformer layers to keep trainable (from the end)
freeze_layer_norm: If True, freezes layer normalization parameters
set_grad_checkpointing
def set_grad_checkpointing(self, enable: bool = True):
Enables or disables gradient checkpointing to reduce memory usage during training at the cost of computation.
Usage Example
import torch
from open_clip import CLIP, CLIPVisionCfg, CLIPTextCfg
# Define model configuration
vision_cfg = CLIPVisionCfg(
layers=12,
width=768,
patch_size=16,
image_size=224
)
text_cfg = CLIPTextCfg(
context_length=77,
vocab_size=49408,
width=512,
heads=8,
layers=12
)
# Initialize model
model = CLIP(
embed_dim=512,
vision_cfg=vision_cfg,
text_cfg=text_cfg,
output_dict=True
)
# Forward pass
images = torch.randn(4, 3, 224, 224)
texts = torch.randint(0, 49408, (4, 77))
output = model(images, texts)
image_features = output['image_features'] # (4, 512)
text_features = output['text_features'] # (4, 512)
logit_scale = output['logit_scale'] # scalar
# Compute similarity
similarity = (image_features @ text_features.T) * logit_scale
# Encode separately
image_emb = model.encode_image(images, normalize=True)
text_emb = model.encode_text(texts, normalize=True)
Fine-tuning Example
# Lock image tower for text-only fine-tuning (LiT)
model.lock_image_tower(unlocked_groups=0, freeze_bn_stats=True)
# Enable gradient checkpointing for memory efficiency
model.set_grad_checkpointing(enable=True)
# Only text parameters will be updated during training
for name, param in model.named_parameters():
print(f"{name}: requires_grad={param.requires_grad}")
- CustomTextCLIP - Variant with separately built text tower
- CoCa - Contrastive Captioner model
- ClipLoss - Contrastive loss function for CLIP