Overview
The CustomTextCLIP class is a variant of the CLIP model that builds the text encoder as a separate module (self.text) rather than directly incorporating its components. This architecture provides:
- Modularity: Easier to swap different text encoder architectures
- Flexibility: Supports custom text encoders including HuggingFace models
- Consistency: Parallel structure to the vision tower
- Intermediate features: Better support for extracting text intermediate layer features
The main difference from standard CLIP is architectural organization, with identical training and inference capabilities.
Class Definition
from open_clip import CustomTextCLIP
Initialization Parameters
Dimensionality of the joint embedding space for image and text features.
Configuration object for the vision encoder.
Configuration object for the text encoder. Supports both custom transformers and HuggingFace models via hf_model_name parameter.
Use QuickGELU activation (as in original OpenAI models) instead of standard GELU.
init_logit_scale
float
default:"np.log(1 / 0.07)"
Initial value for the learned temperature parameter (logit scale).
init_logit_bias
Optional[float]
default:"None"
Optional learnable bias term added to logits. When None, no bias is used.
If True, logit_scale has shape [1] instead of [].
cast_dtype
Optional[torch.dtype]
default:"None"
Precision for model computations (e.g., torch.float16, torch.bfloat16).
If True, forward() returns a dictionary with named outputs. If False, returns a tuple.
Attributes
- visual: Vision encoder module (VisionTransformer, ModifiedResNet, or TimmModel)
- text: Text encoder module (TextTransformer or HFTextEncoder)
- logit_scale: Learned temperature parameter (exponential of stored value)
- logit_bias: Optional learned bias (if init_logit_bias is not None)
- context_length: Maximum text sequence length
- vocab_size: Size of text vocabulary
Key Methods
encode_image
def encode_image(self, image: torch.Tensor, normalize: bool = False) -> torch.Tensor:
Encodes images into the joint embedding space.
Parameters:
image: Image tensor of shape (batch_size, channels, height, width)
normalize: If True, L2-normalizes the output features
Returns: Image features of shape (batch_size, embed_dim)
encode_text
def encode_text(self, text: torch.Tensor, normalize: bool = False) -> torch.Tensor:
Encodes tokenized text into the joint embedding space.
Parameters:
text: Tokenized text tensor of shape (batch_size, context_length)
normalize: If True, L2-normalizes the output features
Returns: Text features of shape (batch_size, embed_dim)
get_logits
def get_logits(self, image: torch.Tensor, text: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
Computes similarity logits between image and text features.
Parameters:
image: Image tensor
text: Tokenized text tensor
Returns: Tuple of (image_logits, text_logits) representing similarity scores scaled by temperature
forward
def forward(
self,
image: Optional[torch.Tensor] = None,
text: Optional[torch.Tensor] = None
) -> Union[Dict, Tuple]:
Forward pass through the model.
Parameters:
image: Optional image tensor
text: Optional tokenized text tensor
Returns:
- If
output_dict=True: Dictionary with keys image_features, text_features, logit_scale, and optionally logit_bias
- If
output_dict=False: Tuple of (image_features, text_features, logit_scale) or (image_features, text_features, logit_scale, logit_bias)
def forward_intermediates(
self,
image: Optional[torch.Tensor] = None,
text: Optional[torch.Tensor] = None,
image_indices: Optional[Union[int, List[int]]] = None,
text_indices: Optional[Union[int, List[int]]] = None,
stop_early: bool = False,
normalize: bool = True,
normalize_intermediates: bool = False,
intermediates_only: bool = False,
image_output_fmt: str = 'NCHW',
image_output_extra_tokens: bool = False,
text_output_fmt: str = 'NLC',
text_output_extra_tokens: bool = False,
output_logits: bool = False,
output_logit_scale_bias: bool = False,
) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
Forward pass that returns intermediate layer features from both vision and text encoders. This method provides better text intermediate support compared to standard CLIP.
Returns: Dictionary with intermediate features, final features, and optionally logits
lock_image_tower
def lock_image_tower(self, unlocked_groups: int = 0, freeze_bn_stats: bool = False):
Freezes the image encoder for efficient fine-tuning.
Parameters:
unlocked_groups: Number of layer groups to keep trainable (from the end)
freeze_bn_stats: If True, freezes batch normalization statistics
lock_text_tower
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
Freezes the text encoder for efficient fine-tuning. Delegates to the text module’s lock method.
Parameters:
unlocked_layers: Number of transformer layers to keep trainable (from the end)
freeze_layer_norm: If True, freezes layer normalization parameters
set_grad_checkpointing
def set_grad_checkpointing(self, enable: bool = True):
Enables or disables gradient checkpointing for both vision and text encoders to reduce memory usage.
Usage Example
import torch
from open_clip import CustomTextCLIP, CLIPVisionCfg, CLIPTextCfg
# Define model configuration
vision_cfg = CLIPVisionCfg(
layers=12,
width=768,
patch_size=16,
image_size=224
)
text_cfg = CLIPTextCfg(
context_length=77,
vocab_size=49408,
width=512,
heads=8,
layers=12
)
# Initialize model
model = CustomTextCLIP(
embed_dim=512,
vision_cfg=vision_cfg,
text_cfg=text_cfg,
output_dict=True
)
# Forward pass
images = torch.randn(4, 3, 224, 224)
texts = torch.randint(0, 49408, (4, 77))
output = model(images, texts)
image_features = output['image_features']
text_features = output['text_features']
logit_scale = output['logit_scale']
# Compute similarity
similarity = (image_features @ text_features.T) * logit_scale
Using HuggingFace Text Encoders
from open_clip import CustomTextCLIP, CLIPVisionCfg, CLIPTextCfg
# Use a HuggingFace model as text encoder
text_cfg = CLIPTextCfg(
hf_model_name="roberta-base",
hf_model_pretrained=True,
hf_proj_type="mlp",
hf_pooler_type="mean_pooler"
)
vision_cfg = CLIPVisionCfg(
layers=12,
width=768,
patch_size=16,
image_size=224
)
model = CustomTextCLIP(
embed_dim=512,
vision_cfg=vision_cfg,
text_cfg=text_cfg
)
# The text tower is now a HuggingFace model
print(type(model.text)) # HFTextEncoder
# Extract intermediate features from both encoders
images = torch.randn(4, 3, 224, 224)
texts = torch.randint(0, 49408, (4, 77))
output = model.forward_intermediates(
image=images,
text=texts,
image_indices=[3, 6, 9, 11], # Extract specific layer features
text_indices=[3, 6, 9, 11],
normalize=True,
output_logits=True
)
image_intermediates = output['image_intermediates'] # List of 4 tensors
text_intermediates = output['text_intermediates'] # List of 4 tensors
image_features = output['image_features'] # Final features
text_features = output['text_features'] # Final features
image_logits = output['image_logits'] # Similarity logits
Selective Fine-tuning
# Freeze vision, fine-tune text encoder
model.lock_image_tower(unlocked_groups=0, freeze_bn_stats=True)
# Only train last 2 text layers
model.lock_text_tower(unlocked_layers=2, freeze_layer_norm=False)
# Check which parameters are trainable
for name, param in model.named_parameters():
if param.requires_grad:
print(f"Trainable: {name}")
Differences from Standard CLIP
| Aspect | CLIP | CustomTextCLIP |
|---|
| Text architecture | Components directly in model | Separate self.text module |
| Text encoder access | Via individual attributes | Via model.text |
| HF model support | Limited | Native support |
| Text intermediates | Basic support | Full support via text.forward_intermediates() |
| encode_text | Manually composes layers | Delegates to self.text() |
| State dict | Flat structure | Prefixed with text. |
Migration from CLIP
To convert existing CLIP state dictionaries to CustomTextCLIP format:
from open_clip.model import convert_to_custom_text_state_dict
# Load CLIP checkpoint
state_dict = torch.load('clip_model.pt')
# Convert to CustomTextCLIP format
custom_state_dict = convert_to_custom_text_state_dict(state_dict)
# Load into CustomTextCLIP model
model = CustomTextCLIP(...)
model.load_state_dict(custom_state_dict)