Skip to main content

Overview

SAM 3 performs open-vocabulary instance segmentation on images using text or visual prompts. It can detect and segment all instances of a concept in a single forward pass.
Image Segmentation Example

Segmentation Workflow

The image segmentation pipeline consists of several stages:
1

Image Encoding

Extract multi-scale visual features using the vision encoder
2

Prompt Encoding

Process text and/or geometric prompts into embeddings
3

Feature Fusion

Fuse visual and prompt features via transformer encoder
4

Object Detection

Predict bounding boxes and scores via transformer decoder
5

Mask Generation

Generate high-resolution segmentation masks

Basic Usage

from PIL import Image
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor

# Load model
model = build_sam3_image_model()
processor = Sam3Processor(model)

# Load and process image
image = Image.open("example.jpg")
inference_state = processor.set_image(image)

# Segment with text prompt
output = processor.set_text_prompt(
    state=inference_state,
    prompt="a dog"
)

# Extract results
masks = output["masks"]      # Shape: [N, H, W] - Binary masks
boxes = output["boxes"]      # Shape: [N, 4] - Bounding boxes (x, y, w, h)
scores = output["scores"]    # Shape: [N] - Confidence scores

Stage 1: Image Encoding

The vision encoder processes the input image to extract hierarchical features.
source/sam3/model/vl_combiner.py
def forward_image(self, samples: torch.Tensor):
    """Extract vision features from images."""
    # Forward through ViT backbone
    sam3_features, sam3_pos, sam2_features, sam2_pos = \
        self.vision_backbone.forward(samples)
    
    # Return feature pyramid and positional encodings
    return {
        "vision_features": sam3_features[-1],  # Highest level
        "vision_pos_enc": sam3_pos,
        "backbone_fpn": sam3_features,  # All pyramid levels
        "sam2_backbone_out": sam2_output,  # For tracker
    }

Multi-Scale Features

The backbone produces a Feature Pyramid Network (FPN) with multiple resolutions:

High Resolution

Fine-grained details for small objects

Medium Resolution

Balanced features for most objects

Low Resolution

Semantic context for large objects
The FPN structure enables detecting objects at different scales, from small details to large regions.

Stage 2: Prompt Encoding

Prompts are encoded and prepared for fusion with image features.
source/sam3/model/sam3_image.py
def _encode_prompt(
    self,
    backbone_out,
    find_input,
    geometric_prompt,
    encode_text=True,
):
    # Get text features if using text prompts
    txt_feats = backbone_out["language_features"][:, txt_ids]
    txt_masks = backbone_out["language_mask"][txt_ids]
    
    # Get image features for geometric encoding
    img_feats, img_pos_embeds, vis_feat_sizes = \
        self._get_img_feats(backbone_out, img_ids)
    
    # Encode geometric prompts (boxes, points, masks)
    geo_feats, geo_masks = self.geometry_encoder(
        geo_prompt=geometric_prompt,
        img_feats=img_feats,
        img_sizes=vis_feat_sizes,
        img_pos_embeds=img_pos_embeds,
    )
    
    # Concatenate text and geometric features
    if encode_text:
        prompt = torch.cat([txt_feats, geo_feats], dim=0)
        prompt_mask = torch.cat([txt_masks, geo_masks], dim=1)
    else:
        prompt = geo_feats
        prompt_mask = geo_masks
    
    return prompt, prompt_mask, backbone_out

Text Encoding

Text prompts are processed by the language backbone:
  1. Tokenization of input text
  2. BERT-style encoding with self-attention
  3. Contextualized text embeddings

Geometric Encoding

Geometric prompts (boxes, points) are encoded with:
  1. Direct projection to embedding space
  2. ROI pooling from image features
  3. Positional encoding
  4. Label embeddings (positive/negative)

Stage 3: Transformer Encoder

The encoder fuses image and prompt features through cross-attention.
source/sam3/model/encoder.py
def forward(
    self,
    src: List[Tensor],  # Multi-level image features
    prompt: Tensor,     # Text + geometric prompts
    src_pos: List[Tensor],  # Positional encodings
    prompt_key_padding_mask: Tensor,
    ...
):
    # Add pooled text to image features for fusion
    if self.add_pooled_text_to_img_feat:
        pooled_text = pool_text_feat(prompt, prompt_key_padding_mask)
        pooled_text = self.text_pooling_proj(pooled_text)[..., None, None]
        src = [x.add_(pooled_text) for x in src]
    
    # Flatten multi-level features
    src_flatten, lvl_pos_embed_flatten = \
        self._prepare_multilevel_features(src, src_pos)
    
    # Process through encoder layers
    output = src_flatten
    for layer in self.layers:
        output = layer(
            tgt=output,
            memory=prompt.transpose(0, 1),  # Cross-attend to prompts
            query_pos=lvl_pos_embed_flatten,
            memory_key_padding_mask=prompt_key_padding_mask,
        )
    
    return {"memory": output, ...}

Cross-Modal Fusion

The encoder performs early fusion by:
  1. Adding mean-pooled text features to all image features
  2. Cross-attending from image tokens to text/geometric tokens
  3. Building a fused representation for detection
Early fusion allows the model to leverage both visual and linguistic context when detecting objects.

Stage 4: Transformer Decoder

The decoder uses learned object queries to predict instances.
source/sam3/model/decoder.py
def forward(
    self,
    tgt,  # Object queries
    memory,  # Encoded image features
    memory_text: Tensor,  # Text features
    text_attention_mask: Tensor,
    ...
):
    bs = memory.shape[1]
    
    # Initialize queries
    query_embed = self.query_embed.weight
    tgt = query_embed.unsqueeze(1).repeat(1, bs, 1)
    
    # Initialize reference boxes
    reference_boxes = self.reference_points.weight.unsqueeze(1)
    reference_boxes = reference_boxes.repeat(1, bs, 1).sigmoid()
    
    intermediate = []
    intermediate_ref_boxes = [reference_boxes]
    
    # Process through decoder layers
    for layer_idx, layer in enumerate(self.layers):
        # Self-attention among queries
        # Cross-attention to text
        # Cross-attention to image
        tgt, presence_out = layer(
            tgt=tgt,
            memory=memory,
            memory_text=memory_text,
            text_attention_mask=text_attention_mask,
            ...
        )
        
        # Iteratively refine boxes
        delta_boxes = self.bbox_embed(tgt)
        reference_boxes = (inverse_sigmoid(reference_boxes) + delta_boxes).sigmoid()
        
        intermediate.append(self.norm(tgt))
        intermediate_ref_boxes.append(reference_boxes)
    
    return intermediate, intermediate_ref_boxes, ...

Decoder Components

Object queries are learned embeddings that represent potential objects. Each query:
  • Starts as a random embedding
  • Refines through self-attention with other queries
  • Cross-attends to image features to localize objects
  • Cross-attends to text features to match the prompt
  • Outputs detection predictions (box, score, mask)
SAM 3 uses 300 queries by default to handle multiple instances.
Iterative refinement improves localization across decoder layers:
  1. Layer 0: Initialize boxes from learned reference points
  2. Each layer predicts box offsets (deltas)
  3. Boxes are updated: new_box = sigmoid(inverse_sigmoid(old_box) + delta)
  4. Final prediction uses the last layer’s refined boxes
This allows gradual improvement of localization.
The presence token is a learned embedding that:
  • Participates in self-attention with object queries
  • Predicts whether matching objects exist in the image
  • Helps discriminate similar prompts (“player in white” vs “player in red”)
  • Improves performance on negative prompts (no matching objects)
It’s a key innovation in SAM 3 for handling open-vocabulary prompts.

Box-Aware Positional Bias (boxRPB)

SAM 3 uses boxRPB to encode spatial relationships:
source/sam3/model/decoder.py
def _get_rpb_matrix(self, reference_boxes, feat_size):
    """Compute relative positional bias from boxes."""
    H, W = feat_size
    boxes_xyxy = box_cxcywh_to_xyxy(reference_boxes)
    
    # Compute deltas between boxes and spatial positions
    deltas_y = coords_h.view(1, -1, 1) - boxes_xyxy[..., 1:4:2]
    deltas_x = coords_w.view(1, -1, 1) - boxes_xyxy[..., 0:3:2]
    
    # Apply log-space encoding
    deltas_x_log = torch.sign(deltas_x) * torch.log2(torch.abs(deltas_x) + 1.0)
    deltas_y_log = torch.sign(deltas_y) * torch.log2(torch.abs(deltas_y) + 1.0)
    
    # Project to attention bias
    B_x = self.boxRPB_embed_x(deltas_x)  # bs, num_queries, W, n_heads
    B_y = self.boxRPB_embed_y(deltas_y)  # bs, num_queries, H, n_heads
    
    # Combine into 2D bias
    B = B_y.unsqueeze(3) + B_x.unsqueeze(2)  # bs, num_queries, H, W, n_heads
    return B.flatten(2, 3).permute(0, 3, 1, 2)  # bs, n_heads, num_queries, HW
BoxRPB helps queries attend to regions inside or near their predicted boxes.

Stage 5: Mask Generation

The segmentation head generates high-resolution masks from decoder outputs.
source/sam3/model/sam3_image.py
def _run_segmentation_heads(
    self,
    out,
    backbone_out,
    img_ids,
    vis_feat_sizes,
    encoder_hidden_states,
    prompt,
    prompt_mask,
    hs,  # Decoder output queries
):
    if self.segmentation_head is not None:
        # Run mask prediction head
        seg_head_outputs = self.segmentation_head(
            backbone_feats=backbone_out["backbone_fpn"],
            obj_queries=hs,  # Use decoder queries
            image_ids=img_ids,
            encoder_hidden_states=encoder_hidden_states,
            prompt=prompt,
            prompt_mask=prompt_mask,
        )
        
        # Store mask predictions
        for k, v in seg_head_outputs.items():
            if k == "pred_masks":
                out[k] = v
The segmentation head:
  1. Takes decoder queries as input
  2. Uses multi-scale FPN features for high resolution
  3. Applies upsampling to match image size
  4. Outputs binary masks for each detected instance

Output Format

The complete output includes:
output = {
    "masks": torch.Tensor,        # [N, H, W] - Binary masks
    "boxes": torch.Tensor,        # [N, 4] - Boxes in [x, y, w, h] format
    "scores": torch.Tensor,       # [N] - Confidence scores (0-1)
    "pred_logits": torch.Tensor,  # [N, 1] - Raw logits before sigmoid
    "queries": torch.Tensor,      # [N, D] - Final query embeddings
}

Post-Processing

# Filter by confidence threshold
threshold = 0.3
mask = scores > threshold
filtered_masks = masks[mask]
filtered_boxes = boxes[mask]
filtered_scores = scores[mask]

# Apply NMS to remove duplicates
from sam3.perflib.nms import nms_masks

keep = nms_masks(
    pred_probs=filtered_scores,
    pred_masks=filtered_masks,
    prob_threshold=0.3,
    iou_threshold=0.7,
)

final_masks = filtered_masks[keep]
final_boxes = filtered_boxes[keep]
final_scores = filtered_scores[keep]

Batched Inference

Process multiple images efficiently:
images = [Image.open(f"image_{i}.jpg") for i in range(batch_size)]
prompts = ["a dog", "a cat", "a car"]  # One per image

# Batch processing
outputs = processor.batch_inference(
    images=images,
    prompts=prompts,
    batch_size=4,  # Process 4 at a time
)

# Each output corresponds to one image
for i, output in enumerate(outputs):
    print(f"Image {i}: Found {len(output['masks'])} instances")
Batching significantly improves throughput when processing many images with the same or different prompts.

Performance Optimization

SAM 3 supports torch.compile for faster inference:
model = build_sam3_image_model(compile=True)
This compiles:
  • Vision backbone
  • Transformer encoder/decoder
  • Segmentation head
Expect 1.5-2× speedup after warmup.
Use automatic mixed precision (AMP) to reduce memory and increase speed:
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
    output = processor.set_text_prompt(
        state=inference_state,
        prompt="a dog"
    )
BF16 provides better numerical stability than FP16.
Apply Non-Maximum Suppression (NMS) to remove duplicate detections:
from sam3.perflib.nms import nms_masks

keep = nms_masks(
    pred_probs=scores,
    pred_masks=masks,
    prob_threshold=0.3,  # Confidence threshold
    iou_threshold=0.7,   # IoU threshold for duplicates
)
This removes overlapping predictions that likely refer to the same object.

Multi-Mask Output

For ambiguous prompts, SAM 3 can output multiple mask candidates:
output = processor.set_text_prompt(
    state=inference_state,
    prompt="a dog",
    multimask_output=True  # Enable multiple masks per detection
)

# Each instance now has 3 mask candidates
multi_masks = output["multi_pred_masks"]  # [N, 3, H, W]
multi_scores = output["multi_pred_logits"]  # [N, 3]

# Best mask is automatically selected
best_masks = output["masks"]  # [N, H, W] - highest scoring mask per instance
Multi-mask output is useful when the prompt is ambiguous and could refer to different segmentation granularities.

Common Use Cases

Instance Segmentation

Detect and segment all instances of a class:
output = processor.set_text_prompt(
    state=inference_state,
    prompt="person"
)

Specific Object

Segment a specific object with attributes:
output = processor.set_text_prompt(
    state=inference_state,
    prompt="red car with open door"
)

Interactive Refinement

Refine segmentation with clicks:
output = processor.set_point_prompt(
    state=inference_state,
    points=[[100, 200]],
    point_labels=[1]
)

Box-Prompted

Segment objects within bounding boxes:
output = processor.set_box_prompt(
    state=inference_state,
    boxes=[[50, 50, 200, 300]]
)

Next Steps

Video Segmentation

Learn about video tracking and propagation

Prompting Guide

Deep dive into different prompting strategies

Image Inference Guide

See complete code examples

API Reference

Explore the full API

Build docs developers (and LLMs) love