This guide covers the complete inference pipeline for OpenCLIP models, from loading to computing image-text similarities.
Quick Start
import torch
from PIL import Image
import open_clip
model, _, preprocess = open_clip.create_model_and_transforms(
'ViT-B-32',
pretrained='laion2b_s34b_b79k'
)
model.eval() # Set to eval mode
tokenizer = open_clip.get_tokenizer('ViT-B-32')
image = preprocess(Image.open("CLIP.png")).unsqueeze(0)
text = tokenizer(["a diagram", "a dog", "a cat"])
with torch.no_grad():
image_features = model.encode_image(image)
text_features = model.encode_text(text)
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
print("Label probs:", text_probs) # [[0.9927, 0.0022, 0.0051]]
Complete Inference Pipeline
Load Model and Preprocessing
Load the model with appropriate precision and device settings:import open_clip
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
model, _, preprocess = open_clip.create_model_and_transforms(
'ViT-L-14',
pretrained='datacomp_xl_s13b_b90k',
device=device,
precision='fp16' # Use fp16 for faster inference on GPU
)
model.eval()
tokenizer = open_clip.get_tokenizer('ViT-L-14')
Prepare Inputs
Preprocess images and tokenize text:from PIL import Image
# Load and preprocess image
image = Image.open('example.jpg')
image_input = preprocess(image).unsqueeze(0).to(device)
# Tokenize text
text_inputs = tokenizer([
"a photo of a cat",
"a photo of a dog",
"a photo of a bird"
]).to(device)
Encode Image and Text
Extract features using the model:with torch.no_grad(), torch.cuda.amp.autocast():
image_features = model.encode_image(image_input)
text_features = model.encode_text(text_inputs)
Normalize and Compute Similarity
Normalize features and compute cosine similarity:# Normalize features
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
# Compute similarity (cosine similarity scaled by 100)
similarity = (100.0 * image_features @ text_features.T)
probs = similarity.softmax(dim=-1)
# Get predictions
print("Probabilities:", probs)
print("Top prediction:", probs.argmax())
Encoding Methods
Image Encoding
# Single image
image_features = model.encode_image(image_tensor) # shape: [1, embed_dim]
# Batch of images
images = torch.stack([preprocess(img) for img in image_list])
image_features = model.encode_image(images) # shape: [batch_size, embed_dim]
The encode_image() method:
- Accepts tensors of shape
[batch, 3, height, width]
- Returns normalized embeddings of shape
[batch, embed_dim]
- embed_dim varies by model (512 for ViT-B, 768 for ViT-L, etc.)
Text Encoding
# Single or multiple texts
text_tokens = tokenizer(["a cat", "a dog", "a bird"])
text_features = model.encode_text(text_tokens) # shape: [3, embed_dim]
# With custom context length
text_tokens = tokenizer(["a very long description..."], context_length=77)
text_features = model.encode_text(text_tokens)
The encode_text() method:
- Accepts tokenized text tensors of shape
[batch, context_length]
- Returns embeddings of shape
[batch, embed_dim]
- Texts longer than context_length are truncated
Batch Processing
Processing Multiple Images Efficiently
import torch
from torch.utils.data import DataLoader, Dataset
from PIL import Image
class ImageDataset(Dataset):
def __init__(self, image_paths, preprocess):
self.image_paths = image_paths
self.preprocess = preprocess
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = Image.open(self.image_paths[idx])
return self.preprocess(image)
# Create dataset and dataloader
image_paths = ['img1.jpg', 'img2.jpg', 'img3.jpg', ...]
dataset = ImageDataset(image_paths, preprocess)
loader = DataLoader(dataset, batch_size=32, num_workers=4)
# Extract features in batches
all_features = []
with torch.no_grad(), torch.cuda.amp.autocast():
for batch in loader:
batch = batch.to(device)
features = model.encode_image(batch)
features /= features.norm(dim=-1, keepdim=True)
all_features.append(features.cpu())
all_features = torch.cat(all_features)
print("All features shape:", all_features.shape)
Processing Large Text Collections
texts = ["text 1", "text 2", ..., "text 10000"]
batch_size = 256
all_text_features = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
text_tokens = tokenizer(batch_texts).to(device)
with torch.no_grad(), torch.cuda.amp.autocast():
features = model.encode_text(text_tokens)
features /= features.norm(dim=-1, keepdim=True)
all_text_features.append(features.cpu())
all_text_features = torch.cat(all_text_features)
Computing Similarities
Image-to-Text Similarity
# Compute similarity matrix
similarity_matrix = image_features @ text_features.T
# Shape: [num_images, num_texts]
# Get top-k matches for each image
top_k = 5
values, indices = similarity_matrix.topk(top_k, dim=-1)
for img_idx in range(len(image_features)):
print(f"Image {img_idx} top matches:")
for k in range(top_k):
text_idx = indices[img_idx, k]
score = values[img_idx, k]
print(f" {texts[text_idx]}: {score:.3f}")
Image-to-Image Similarity
# Find similar images
query_features = model.encode_image(query_image)
query_features /= query_features.norm(dim=-1, keepdim=True)
# Compare against database
similarity = query_features @ database_features.T
top_indices = similarity.topk(10, dim=-1).indices
print("Most similar images:", top_indices)
Zero-Shot Classification
from PIL import Image
# Define class labels
class_labels = ["cat", "dog", "bird", "fish", "horse"]
# Create text prompts with templates
text_prompts = [f"a photo of a {label}" for label in class_labels]
text_tokens = tokenizer(text_prompts).to(device)
# Encode text
with torch.no_grad(), torch.cuda.amp.autocast():
text_features = model.encode_text(text_tokens)
text_features /= text_features.norm(dim=-1, keepdim=True)
# Classify image
image = preprocess(Image.open('test.jpg')).unsqueeze(0).to(device)
with torch.no_grad(), torch.cuda.amp.autocast():
image_features = model.encode_image(image)
image_features /= image_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
predicted_idx = similarity.argmax().item()
print(f"Predicted class: {class_labels[predicted_idx]}")
print(f"Confidence: {similarity[0, predicted_idx]:.2%}")
Optimizations
Mixed Precision Inference
# Use automatic mixed precision for faster inference
with torch.no_grad(), torch.cuda.amp.autocast():
image_features = model.encode_image(images)
text_features = model.encode_text(texts)
Automatic mixed precision (AMP) can provide 2-3x speedup on modern GPUs with minimal accuracy loss.
Disabling Gradient Computation
Always use torch.no_grad() during inference:
with torch.no_grad():
features = model.encode_image(images)
This:
- Reduces memory usage by ~50%
- Speeds up computation
- Prevents accidental gradient computation
Model Compilation (PyTorch 2.0+)
# Compile model for faster inference (PyTorch 2.0+)
model = torch.compile(model)
# First run will be slower (compilation)
with torch.no_grad():
_ = model.encode_image(dummy_input)
# Subsequent runs will be faster
with torch.no_grad():
features = model.encode_image(images)
Full Example: Image Search
import torch
import open_clip
from PIL import Image
import numpy as np
# Load model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, _, preprocess = open_clip.create_model_and_transforms(
'ViT-B-32',
pretrained='laion2b_s34b_b79k',
device=device
)
model.eval()
tokenizer = open_clip.get_tokenizer('ViT-B-32')
# Index images (build database)
image_paths = ['img1.jpg', 'img2.jpg', 'img3.jpg']
images = torch.stack([preprocess(Image.open(p)) for p in image_paths]).to(device)
with torch.no_grad(), torch.cuda.amp.autocast():
image_features = model.encode_image(images)
image_features /= image_features.norm(dim=-1, keepdim=True)
# Search by text query
query = "a cute cat playing"
text_tokens = tokenizer([query]).to(device)
with torch.no_grad(), torch.cuda.amp.autocast():
text_features = model.encode_text(text_tokens)
text_features /= text_features.norm(dim=-1, keepdim=True)
# Find matches
similarity = (text_features @ image_features.T).squeeze(0)
top_idx = similarity.argmax().item()
print(f"Best match: {image_paths[top_idx]}")
print(f"Similarity score: {similarity[top_idx]:.3f}")
Remember to call model.eval() before inference. Some models use BatchNorm or stochastic depth which behave differently in training vs eval mode.