Skip to main content

Overview

The CustomTextCLIP class is a variant of the CLIP model that builds the text encoder as a separate module (self.text) rather than directly incorporating its components. This architecture provides:
  • Modularity: Easier to swap different text encoder architectures
  • Flexibility: Supports custom text encoders including HuggingFace models
  • Consistency: Parallel structure to the vision tower
  • Intermediate features: Better support for extracting text intermediate layer features
The main difference from standard CLIP is architectural organization, with identical training and inference capabilities.

Class Definition

from open_clip import CustomTextCLIP

Initialization Parameters

embed_dim
int
required
Dimensionality of the joint embedding space for image and text features.
vision_cfg
CLIPVisionCfg
required
Configuration object for the vision encoder.
text_cfg
CLIPTextCfg
required
Configuration object for the text encoder. Supports both custom transformers and HuggingFace models via hf_model_name parameter.
quick_gelu
bool
default:"False"
Use QuickGELU activation (as in original OpenAI models) instead of standard GELU.
init_logit_scale
float
default:"np.log(1 / 0.07)"
Initial value for the learned temperature parameter (logit scale).
init_logit_bias
Optional[float]
default:"None"
Optional learnable bias term added to logits. When None, no bias is used.
nonscalar_logit_scale
bool
default:"False"
If True, logit_scale has shape [1] instead of [].
cast_dtype
Optional[torch.dtype]
default:"None"
Precision for model computations (e.g., torch.float16, torch.bfloat16).
output_dict
bool
default:"False"
If True, forward() returns a dictionary with named outputs. If False, returns a tuple.

Attributes

  • visual: Vision encoder module (VisionTransformer, ModifiedResNet, or TimmModel)
  • text: Text encoder module (TextTransformer or HFTextEncoder)
  • logit_scale: Learned temperature parameter (exponential of stored value)
  • logit_bias: Optional learned bias (if init_logit_bias is not None)
  • context_length: Maximum text sequence length
  • vocab_size: Size of text vocabulary

Key Methods

encode_image

def encode_image(self, image: torch.Tensor, normalize: bool = False) -> torch.Tensor:
Encodes images into the joint embedding space. Parameters:
  • image: Image tensor of shape (batch_size, channels, height, width)
  • normalize: If True, L2-normalizes the output features
Returns: Image features of shape (batch_size, embed_dim)

encode_text

def encode_text(self, text: torch.Tensor, normalize: bool = False) -> torch.Tensor:
Encodes tokenized text into the joint embedding space. Parameters:
  • text: Tokenized text tensor of shape (batch_size, context_length)
  • normalize: If True, L2-normalizes the output features
Returns: Text features of shape (batch_size, embed_dim)

get_logits

def get_logits(self, image: torch.Tensor, text: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
Computes similarity logits between image and text features. Parameters:
  • image: Image tensor
  • text: Tokenized text tensor
Returns: Tuple of (image_logits, text_logits) representing similarity scores scaled by temperature

forward

def forward(
    self,
    image: Optional[torch.Tensor] = None,
    text: Optional[torch.Tensor] = None
) -> Union[Dict, Tuple]:
Forward pass through the model. Parameters:
  • image: Optional image tensor
  • text: Optional tokenized text tensor
Returns:
  • If output_dict=True: Dictionary with keys image_features, text_features, logit_scale, and optionally logit_bias
  • If output_dict=False: Tuple of (image_features, text_features, logit_scale) or (image_features, text_features, logit_scale, logit_bias)

forward_intermediates

def forward_intermediates(
    self,
    image: Optional[torch.Tensor] = None,
    text: Optional[torch.Tensor] = None,
    image_indices: Optional[Union[int, List[int]]] = None,
    text_indices: Optional[Union[int, List[int]]] = None,
    stop_early: bool = False,
    normalize: bool = True,
    normalize_intermediates: bool = False,
    intermediates_only: bool = False,
    image_output_fmt: str = 'NCHW',
    image_output_extra_tokens: bool = False,
    text_output_fmt: str = 'NLC',
    text_output_extra_tokens: bool = False,
    output_logits: bool = False,
    output_logit_scale_bias: bool = False,
) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
Forward pass that returns intermediate layer features from both vision and text encoders. This method provides better text intermediate support compared to standard CLIP. Returns: Dictionary with intermediate features, final features, and optionally logits

lock_image_tower

def lock_image_tower(self, unlocked_groups: int = 0, freeze_bn_stats: bool = False):
Freezes the image encoder for efficient fine-tuning. Parameters:
  • unlocked_groups: Number of layer groups to keep trainable (from the end)
  • freeze_bn_stats: If True, freezes batch normalization statistics

lock_text_tower

def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
Freezes the text encoder for efficient fine-tuning. Delegates to the text module’s lock method. Parameters:
  • unlocked_layers: Number of transformer layers to keep trainable (from the end)
  • freeze_layer_norm: If True, freezes layer normalization parameters

set_grad_checkpointing

def set_grad_checkpointing(self, enable: bool = True):
Enables or disables gradient checkpointing for both vision and text encoders to reduce memory usage.

Usage Example

import torch
from open_clip import CustomTextCLIP, CLIPVisionCfg, CLIPTextCfg

# Define model configuration
vision_cfg = CLIPVisionCfg(
    layers=12,
    width=768,
    patch_size=16,
    image_size=224
)

text_cfg = CLIPTextCfg(
    context_length=77,
    vocab_size=49408,
    width=512,
    heads=8,
    layers=12
)

# Initialize model
model = CustomTextCLIP(
    embed_dim=512,
    vision_cfg=vision_cfg,
    text_cfg=text_cfg,
    output_dict=True
)

# Forward pass
images = torch.randn(4, 3, 224, 224)
texts = torch.randint(0, 49408, (4, 77))

output = model(images, texts)
image_features = output['image_features']
text_features = output['text_features']
logit_scale = output['logit_scale']

# Compute similarity
similarity = (image_features @ text_features.T) * logit_scale

Using HuggingFace Text Encoders

from open_clip import CustomTextCLIP, CLIPVisionCfg, CLIPTextCfg

# Use a HuggingFace model as text encoder
text_cfg = CLIPTextCfg(
    hf_model_name="roberta-base",
    hf_model_pretrained=True,
    hf_proj_type="mlp",
    hf_pooler_type="mean_pooler"
)

vision_cfg = CLIPVisionCfg(
    layers=12,
    width=768,
    patch_size=16,
    image_size=224
)

model = CustomTextCLIP(
    embed_dim=512,
    vision_cfg=vision_cfg,
    text_cfg=text_cfg
)

# The text tower is now a HuggingFace model
print(type(model.text))  # HFTextEncoder

Extracting Intermediate Features

# Extract intermediate features from both encoders
images = torch.randn(4, 3, 224, 224)
texts = torch.randint(0, 49408, (4, 77))

output = model.forward_intermediates(
    image=images,
    text=texts,
    image_indices=[3, 6, 9, 11],  # Extract specific layer features
    text_indices=[3, 6, 9, 11],
    normalize=True,
    output_logits=True
)

image_intermediates = output['image_intermediates']  # List of 4 tensors
text_intermediates = output['text_intermediates']    # List of 4 tensors
image_features = output['image_features']            # Final features
text_features = output['text_features']              # Final features
image_logits = output['image_logits']                # Similarity logits

Selective Fine-tuning

# Freeze vision, fine-tune text encoder
model.lock_image_tower(unlocked_groups=0, freeze_bn_stats=True)

# Only train last 2 text layers
model.lock_text_tower(unlocked_layers=2, freeze_layer_norm=False)

# Check which parameters are trainable
for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"Trainable: {name}")

Differences from Standard CLIP

AspectCLIPCustomTextCLIP
Text architectureComponents directly in modelSeparate self.text module
Text encoder accessVia individual attributesVia model.text
HF model supportLimitedNative support
Text intermediatesBasic supportFull support via text.forward_intermediates()
encode_textManually composes layersDelegates to self.text()
State dictFlat structurePrefixed with text.

Migration from CLIP

To convert existing CLIP state dictionaries to CustomTextCLIP format:
from open_clip.model import convert_to_custom_text_state_dict

# Load CLIP checkpoint
state_dict = torch.load('clip_model.pt')

# Convert to CustomTextCLIP format
custom_state_dict = convert_to_custom_text_state_dict(state_dict)

# Load into CustomTextCLIP model
model = CustomTextCLIP(...)
model.load_state_dict(custom_state_dict)

Build docs developers (and LLMs) love