Overview
SAM 3 performs open-vocabulary instance segmentation on images using text or visual prompts. It can detect and segment all instances of a concept in a single forward pass.
Segmentation Workflow
The image segmentation pipeline consists of several stages:
Image Encoding
Extract multi-scale visual features using the vision encoder
Prompt Encoding
Process text and/or geometric prompts into embeddings
Feature Fusion
Fuse visual and prompt features via transformer encoder
Object Detection
Predict bounding boxes and scores via transformer decoder
Mask Generation
Generate high-resolution segmentation masks
Basic Usage
from PIL import Image
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
# Load model
model = build_sam3_image_model()
processor = Sam3Processor(model)
# Load and process image
image = Image.open( "example.jpg" )
inference_state = processor.set_image(image)
# Segment with text prompt
output = processor.set_text_prompt(
state = inference_state,
prompt = "a dog"
)
# Extract results
masks = output[ "masks" ] # Shape: [N, H, W] - Binary masks
boxes = output[ "boxes" ] # Shape: [N, 4] - Bounding boxes (x, y, w, h)
scores = output[ "scores" ] # Shape: [N] - Confidence scores
Stage 1: Image Encoding
The vision encoder processes the input image to extract hierarchical features.
source/sam3/model/vl_combiner.py
def forward_image ( self , samples : torch.Tensor):
"""Extract vision features from images."""
# Forward through ViT backbone
sam3_features, sam3_pos, sam2_features, sam2_pos = \
self .vision_backbone.forward(samples)
# Return feature pyramid and positional encodings
return {
"vision_features" : sam3_features[ - 1 ], # Highest level
"vision_pos_enc" : sam3_pos,
"backbone_fpn" : sam3_features, # All pyramid levels
"sam2_backbone_out" : sam2_output, # For tracker
}
Multi-Scale Features
The backbone produces a Feature Pyramid Network (FPN) with multiple resolutions:
High Resolution Fine-grained details for small objects
Medium Resolution Balanced features for most objects
Low Resolution Semantic context for large objects
The FPN structure enables detecting objects at different scales, from small details to large regions.
Stage 2: Prompt Encoding
Prompts are encoded and prepared for fusion with image features.
source/sam3/model/sam3_image.py
def _encode_prompt (
self ,
backbone_out ,
find_input ,
geometric_prompt ,
encode_text = True ,
):
# Get text features if using text prompts
txt_feats = backbone_out[ "language_features" ][:, txt_ids]
txt_masks = backbone_out[ "language_mask" ][txt_ids]
# Get image features for geometric encoding
img_feats, img_pos_embeds, vis_feat_sizes = \
self ._get_img_feats(backbone_out, img_ids)
# Encode geometric prompts (boxes, points, masks)
geo_feats, geo_masks = self .geometry_encoder(
geo_prompt = geometric_prompt,
img_feats = img_feats,
img_sizes = vis_feat_sizes,
img_pos_embeds = img_pos_embeds,
)
# Concatenate text and geometric features
if encode_text:
prompt = torch.cat([txt_feats, geo_feats], dim = 0 )
prompt_mask = torch.cat([txt_masks, geo_masks], dim = 1 )
else :
prompt = geo_feats
prompt_mask = geo_masks
return prompt, prompt_mask, backbone_out
Text Encoding
Text prompts are processed by the language backbone:
Tokenization of input text
BERT-style encoding with self-attention
Contextualized text embeddings
Geometric Encoding
Geometric prompts (boxes, points) are encoded with:
Direct projection to embedding space
ROI pooling from image features
Positional encoding
Label embeddings (positive/negative)
The encoder fuses image and prompt features through cross-attention.
source/sam3/model/encoder.py
def forward (
self ,
src : List[Tensor], # Multi-level image features
prompt : Tensor, # Text + geometric prompts
src_pos : List[Tensor], # Positional encodings
prompt_key_padding_mask : Tensor,
...
):
# Add pooled text to image features for fusion
if self .add_pooled_text_to_img_feat:
pooled_text = pool_text_feat(prompt, prompt_key_padding_mask)
pooled_text = self .text_pooling_proj(pooled_text)[ ... , None , None ]
src = [x.add_(pooled_text) for x in src]
# Flatten multi-level features
src_flatten, lvl_pos_embed_flatten = \
self ._prepare_multilevel_features(src, src_pos)
# Process through encoder layers
output = src_flatten
for layer in self .layers:
output = layer(
tgt = output,
memory = prompt.transpose( 0 , 1 ), # Cross-attend to prompts
query_pos = lvl_pos_embed_flatten,
memory_key_padding_mask = prompt_key_padding_mask,
)
return { "memory" : output, ... }
Cross-Modal Fusion
The encoder performs early fusion by:
Adding mean-pooled text features to all image features
Cross-attending from image tokens to text/geometric tokens
Building a fused representation for detection
Early fusion allows the model to leverage both visual and linguistic context when detecting objects.
The decoder uses learned object queries to predict instances.
source/sam3/model/decoder.py
def forward (
self ,
tgt , # Object queries
memory , # Encoded image features
memory_text : Tensor, # Text features
text_attention_mask : Tensor,
...
):
bs = memory.shape[ 1 ]
# Initialize queries
query_embed = self .query_embed.weight
tgt = query_embed.unsqueeze( 1 ).repeat( 1 , bs, 1 )
# Initialize reference boxes
reference_boxes = self .reference_points.weight.unsqueeze( 1 )
reference_boxes = reference_boxes.repeat( 1 , bs, 1 ).sigmoid()
intermediate = []
intermediate_ref_boxes = [reference_boxes]
# Process through decoder layers
for layer_idx, layer in enumerate ( self .layers):
# Self-attention among queries
# Cross-attention to text
# Cross-attention to image
tgt, presence_out = layer(
tgt = tgt,
memory = memory,
memory_text = memory_text,
text_attention_mask = text_attention_mask,
...
)
# Iteratively refine boxes
delta_boxes = self .bbox_embed(tgt)
reference_boxes = (inverse_sigmoid(reference_boxes) + delta_boxes).sigmoid()
intermediate.append( self .norm(tgt))
intermediate_ref_boxes.append(reference_boxes)
return intermediate, intermediate_ref_boxes, ...
Decoder Components
Object queries are learned embeddings that represent potential objects. Each query:
Starts as a random embedding
Refines through self-attention with other queries
Cross-attends to image features to localize objects
Cross-attends to text features to match the prompt
Outputs detection predictions (box, score, mask)
SAM 3 uses 300 queries by default to handle multiple instances.
What is iterative box refinement?
Iterative refinement improves localization across decoder layers:
Layer 0: Initialize boxes from learned reference points
Each layer predicts box offsets (deltas)
Boxes are updated: new_box = sigmoid(inverse_sigmoid(old_box) + delta)
Final prediction uses the last layer’s refined boxes
This allows gradual improvement of localization.
What is the presence token?
The presence token is a learned embedding that:
Participates in self-attention with object queries
Predicts whether matching objects exist in the image
Helps discriminate similar prompts (“player in white” vs “player in red”)
Improves performance on negative prompts (no matching objects)
It’s a key innovation in SAM 3 for handling open-vocabulary prompts.
Box-Aware Positional Bias (boxRPB)
SAM 3 uses boxRPB to encode spatial relationships:
source/sam3/model/decoder.py
def _get_rpb_matrix ( self , reference_boxes , feat_size ):
"""Compute relative positional bias from boxes."""
H, W = feat_size
boxes_xyxy = box_cxcywh_to_xyxy(reference_boxes)
# Compute deltas between boxes and spatial positions
deltas_y = coords_h.view( 1 , - 1 , 1 ) - boxes_xyxy[ ... , 1 : 4 : 2 ]
deltas_x = coords_w.view( 1 , - 1 , 1 ) - boxes_xyxy[ ... , 0 : 3 : 2 ]
# Apply log-space encoding
deltas_x_log = torch.sign(deltas_x) * torch.log2(torch.abs(deltas_x) + 1.0 )
deltas_y_log = torch.sign(deltas_y) * torch.log2(torch.abs(deltas_y) + 1.0 )
# Project to attention bias
B_x = self .boxRPB_embed_x(deltas_x) # bs, num_queries, W, n_heads
B_y = self .boxRPB_embed_y(deltas_y) # bs, num_queries, H, n_heads
# Combine into 2D bias
B = B_y.unsqueeze( 3 ) + B_x.unsqueeze( 2 ) # bs, num_queries, H, W, n_heads
return B.flatten( 2 , 3 ).permute( 0 , 3 , 1 , 2 ) # bs, n_heads, num_queries, HW
BoxRPB helps queries attend to regions inside or near their predicted boxes.
Stage 5: Mask Generation
The segmentation head generates high-resolution masks from decoder outputs.
source/sam3/model/sam3_image.py
def _run_segmentation_heads (
self ,
out ,
backbone_out ,
img_ids ,
vis_feat_sizes ,
encoder_hidden_states ,
prompt ,
prompt_mask ,
hs , # Decoder output queries
):
if self .segmentation_head is not None :
# Run mask prediction head
seg_head_outputs = self .segmentation_head(
backbone_feats = backbone_out[ "backbone_fpn" ],
obj_queries = hs, # Use decoder queries
image_ids = img_ids,
encoder_hidden_states = encoder_hidden_states,
prompt = prompt,
prompt_mask = prompt_mask,
)
# Store mask predictions
for k, v in seg_head_outputs.items():
if k == "pred_masks" :
out[k] = v
The segmentation head:
Takes decoder queries as input
Uses multi-scale FPN features for high resolution
Applies upsampling to match image size
Outputs binary masks for each detected instance
The complete output includes:
output = {
"masks" : torch.Tensor, # [N, H, W] - Binary masks
"boxes" : torch.Tensor, # [N, 4] - Boxes in [x, y, w, h] format
"scores" : torch.Tensor, # [N] - Confidence scores (0-1)
"pred_logits" : torch.Tensor, # [N, 1] - Raw logits before sigmoid
"queries" : torch.Tensor, # [N, D] - Final query embeddings
}
Post-Processing
# Filter by confidence threshold
threshold = 0.3
mask = scores > threshold
filtered_masks = masks[mask]
filtered_boxes = boxes[mask]
filtered_scores = scores[mask]
# Apply NMS to remove duplicates
from sam3.perflib.nms import nms_masks
keep = nms_masks(
pred_probs = filtered_scores,
pred_masks = filtered_masks,
prob_threshold = 0.3 ,
iou_threshold = 0.7 ,
)
final_masks = filtered_masks[keep]
final_boxes = filtered_boxes[keep]
final_scores = filtered_scores[keep]
Batched Inference
Process multiple images efficiently:
images = [Image.open( f "image_ { i } .jpg" ) for i in range (batch_size)]
prompts = [ "a dog" , "a cat" , "a car" ] # One per image
# Batch processing
outputs = processor.batch_inference(
images = images,
prompts = prompts,
batch_size = 4 , # Process 4 at a time
)
# Each output corresponds to one image
for i, output in enumerate (outputs):
print ( f "Image { i } : Found { len (output[ 'masks' ]) } instances" )
Batching significantly improves throughput when processing many images with the same or different prompts.
SAM 3 supports torch.compile for faster inference: model = build_sam3_image_model( compile = True )
This compiles:
Vision backbone
Transformer encoder/decoder
Segmentation head
Expect 1.5-2× speedup after warmup.
Use automatic mixed precision (AMP) to reduce memory and increase speed: with torch.amp.autocast( device_type = "cuda" , dtype = torch.bfloat16):
output = processor.set_text_prompt(
state = inference_state,
prompt = "a dog"
)
BF16 provides better numerical stability than FP16.
Apply Non-Maximum Suppression (NMS) to remove duplicate detections: from sam3.perflib.nms import nms_masks
keep = nms_masks(
pred_probs = scores,
pred_masks = masks,
prob_threshold = 0.3 , # Confidence threshold
iou_threshold = 0.7 , # IoU threshold for duplicates
)
This removes overlapping predictions that likely refer to the same object.
Multi-Mask Output
For ambiguous prompts, SAM 3 can output multiple mask candidates:
output = processor.set_text_prompt(
state = inference_state,
prompt = "a dog" ,
multimask_output = True # Enable multiple masks per detection
)
# Each instance now has 3 mask candidates
multi_masks = output[ "multi_pred_masks" ] # [N, 3, H, W]
multi_scores = output[ "multi_pred_logits" ] # [N, 3]
# Best mask is automatically selected
best_masks = output[ "masks" ] # [N, H, W] - highest scoring mask per instance
Multi-mask output is useful when the prompt is ambiguous and could refer to different segmentation granularities.
Common Use Cases
Instance Segmentation Detect and segment all instances of a class: output = processor.set_text_prompt(
state = inference_state,
prompt = "person"
)
Specific Object Segment a specific object with attributes: output = processor.set_text_prompt(
state = inference_state,
prompt = "red car with open door"
)
Interactive Refinement Refine segmentation with clicks: output = processor.set_point_prompt(
state = inference_state,
points = [[ 100 , 200 ]],
point_labels = [ 1 ]
)
Box-Prompted Segment objects within bounding boxes: output = processor.set_box_prompt(
state = inference_state,
boxes = [[ 50 , 50 , 200 , 300 ]]
)
Next Steps
Video Segmentation Learn about video tracking and propagation
Prompting Guide Deep dive into different prompting strategies
Image Inference Guide See complete code examples
API Reference Explore the full API