Skip to main content

Overview

SAM 3 extends image segmentation to videos with temporal tracking and propagation. It can segment objects across frames, maintain temporal consistency, and handle occlusions.
Video Segmentation Example

Video Architecture

SAM 3 uses a decoupled detector-tracker design for video:

Detector

Detects objects on prompted frames using text or visual cues

Tracker

Propagates detections across frames with memory attention
Both components share the same vision encoder for efficiency.

Video Workflow

The video segmentation pipeline:
1

Initialize Session

Load video and create inference state
2

Add Prompts

Specify objects on one or more frames
3

Detect Objects

Run detector on prompted frames
4

Encode Memory

Convert detections to memory representations
5

Propagate

Track objects across all frames using memory attention

Basic Usage

from sam3.model_builder import build_sam3_video_predictor

# Initialize predictor
predictor = build_sam3_video_predictor()

# Start session with video path (MP4 or JPEG folder)
response = predictor.handle_request({
    "type": "start_session",
    "resource_path": "path/to/video.mp4"
})

session_id = response["session_id"]

# Add text prompt on frame 0
response = predictor.handle_request({
    "type": "add_prompt",
    "session_id": session_id,
    "frame_index": 0,
    "text": "person wearing red shirt"
})

obj_id = response["outputs"]["obj_ids"][0]

# Propagate to all frames
for frame_output in predictor.handle_stream_request({
    "type": "propagate_in_video",
    "session_id": session_id,
    "propagation_direction": "both"  # Forward and backward
}):
    frame_idx = frame_output["frame_index"]
    masks = frame_output["outputs"]["masks"]  # Segmentation masks
    print(f"Frame {frame_idx}: {len(masks)} objects tracked")

Session Management

SAM 3 uses a session-based API for video processing:
source/sam3/model/sam3_video_predictor.py
class Sam3VideoPredictor:
    # Global dictionary holding all inference states
    _ALL_INFERENCE_STATES = {}
    
    def start_session(self, resource_path, session_id=None):
        """Start a new inference session on a video."""
        # Initialize inference state from model
        inference_state = self.model.init_state(
            resource_path=resource_path,
            async_loading_frames=self.async_loading_frames,
            video_loader_type=self.video_loader_type,
        )
        
        if not session_id:
            session_id = str(uuid.uuid4())
        
        self._ALL_INFERENCE_STATES[session_id] = {
            "state": inference_state,
            "session_id": session_id,
            "start_time": time.time(),
        }
        
        return {"session_id": session_id}

Session Lifecycle

1

Start Session

Create a session with a unique ID for the video
2

Add Prompts

Add prompts on specific frames to define objects
3

Propagate

Track objects across frames in either direction
4

Refine (Optional)

Add correction prompts on any frame
5

Close Session

Clean up resources when done
Sessions maintain state including video frames, memory bank, and tracking results. Always close sessions to free GPU memory.

Adding Prompts

Prompts can be added on any frame to initialize or refine tracking:
# Text prompt on first frame
response = predictor.handle_request({
    "type": "add_prompt",
    "session_id": session_id,
    "frame_index": 0,
    "text": "dog",
    "obj_id": 1  # Optional: specify object ID
})

# Point refinement on frame 10
response = predictor.handle_request({
    "type": "add_prompt",
    "session_id": session_id,
    "frame_index": 10,
    "points": [[200, 300]],
    "point_labels": [1],  # 1 = positive click
    "obj_id": 1  # Refine existing object
})

# Box prompt on frame 5
response = predictor.handle_request({
    "type": "add_prompt",
    "session_id": session_id,
    "frame_index": 5,
    "bounding_boxes": [[100, 100, 200, 300]],  # [x, y, w, h]
    "bounding_box_labels": [1],
    "obj_id": 2  # New object
})

Prompt Types

Prompt TypeUse CaseExample
TextInitialize tracking with concepts”person in blue jacket”
PointsRefine mask boundariesClick on missed regions
BoxesSpecify instance locationsBounding box around object
MasksProvide exact segmentationUpload pre-segmented mask
Prompts can be added on conditioning frames (for initialization) or non-conditioning frames (for correction).

Detection on Prompted Frames

When you add a prompt, the detector runs on that frame:
source/sam3/model/sam3_tracker_base.py
def track_step(
    self,
    frame_idx,
    is_init_cond_frame,  # True if this is a prompted frame
    current_vision_feats,
    point_inputs,
    mask_inputs,
    output_dict,
    num_frames,
    run_mem_encoder=True,
):
    # Get high-res features for SAM decoder
    if len(current_vision_feats) > 1:
        high_res_features = [
            x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
            for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
        ]
    
    if mask_inputs is not None:
        # Use mask directly as output
        sam_outputs = self._use_mask_as_output(
            pix_feat, high_res_features, mask_inputs
        )
    else:
        # Condition on memory from previous frames
        pix_feat_with_mem = self._prepare_memory_conditioned_features(
            frame_idx=frame_idx,
            is_init_cond_frame=is_init_cond_frame,
            current_vision_feats=current_vision_feats[-1:],
            output_dict=output_dict,
            num_frames=num_frames,
        )
        
        # Run SAM-style mask decoder
        sam_outputs = self._forward_sam_heads(
            backbone_features=pix_feat_with_mem,
            point_inputs=point_inputs,
            mask_inputs=mask_inputs,
            high_res_features=high_res_features,
            multimask_output=self._use_multimask(is_init_cond_frame, point_inputs),
        )
    
    # Extract masks and object pointers
    low_res_masks, high_res_masks, obj_ptr = sam_outputs[3:6]
    
    # Store in output
    current_out["pred_masks"] = low_res_masks
    current_out["pred_masks_high_res"] = high_res_masks
    current_out["obj_ptr"] = obj_ptr  # Used for memory
    
    return current_out

Memory Encoding

Predictions are encoded into memory for tracking:
source/sam3/model/sam3_tracker_base.py
def _encode_new_memory(
    self,
    image,
    current_vision_feats,
    feat_sizes,
    pred_masks_high_res,
    object_score_logits,
    is_mask_from_pts,
):
    """Encode current frame prediction into memory."""
    B = current_vision_feats[-1].size(1)
    C = self.hidden_dim
    H, W = feat_sizes[-1]
    
    # Get pixel features
    pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
    
    # Apply sigmoid to mask logits
    mask_for_mem = torch.sigmoid(pred_masks_high_res)
    mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
    mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
    
    # Encode mask + image features
    maskmem_out = self.maskmem_backbone(image, pix_feat, mask_for_mem)
    
    maskmem_features = maskmem_out["vision_features"]
    maskmem_pos_enc = maskmem_out["vision_pos_enc"]
    
    # Add no-object embedding for occluded frames
    is_obj_appearing = (object_score_logits > 0).float()
    maskmem_features += (
        1 - is_obj_appearing[..., None, None]
    ) * self.no_obj_embed_spatial[..., None, None]
    
    return maskmem_features, maskmem_pos_enc

Memory Bank

The memory bank stores:
  1. Spatial Memory: Mask-conditioned image features
  2. Object Pointers: Compact representations from decoder output tokens
  3. Temporal Encoding: Time offsets from current frame

Spatial Memory

Dense features at 1/4 image resolution encoding mask and visual context

Object Pointers

Compact vectors extracted from SAM decoder output tokens

Propagation and Tracking

The tracker propagates masks across frames using memory attention:
source/sam3/model/sam3_tracker_base.py
def _prepare_memory_conditioned_features(
    self,
    frame_idx,
    is_init_cond_frame,
    current_vision_feats,
    output_dict,
    num_frames,
    track_in_reverse=False,
):
    """Condition current frame on memories from previous frames."""
    B = current_vision_feats[-1].size(1)
    C = self.hidden_dim
    H, W = feat_sizes[-1]
    device = current_vision_feats[-1].device
    
    if is_init_cond_frame:
        # First frame: no memory, add no_mem_embed
        pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed
        return pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
    
    # Retrieve memories from conditioning frames
    to_cat_prompt, to_cat_prompt_pos_embed = [], []
    
    # Add conditioning frame outputs
    for t, out in output_dict["cond_frame_outputs"].items():
        t_pos = (frame_idx - t) * tpos_sign_mul
        
        # Spatial memory
        maskmem_features = out["maskmem_features"].cuda()
        to_cat_prompt.append(maskmem_features.flatten(2).permute(2, 0, 1))
        
        # Positional encoding
        maskmem_pos_enc = out["maskmem_pos_enc"][-1].cuda()
        maskmem_pos_enc = maskmem_pos_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
        to_cat_prompt_pos_embed.append(maskmem_pos_enc.flatten(2).permute(2, 0, 1))
    
    # Add recent non-conditioning frames (last num_maskmem-1 frames)
    for t_pos in range(1, self.num_maskmem):
        prev_frame_idx = frame_idx - t_pos if not track_in_reverse else frame_idx + t_pos
        out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx)
        if out is not None:
            # Add to memory
            ...
    
    # Collect object pointers
    for t, out in ptr_cond_outputs.items():
        obj_ptrs.append(out["obj_ptr"])
        obj_pos.append(self._get_tpos_enc([frame_idx - t], device))
    
    # Concatenate all prompts
    prompt = torch.cat(to_cat_prompt, dim=0)
    prompt_pos_embed = torch.cat(to_cat_prompt_pos_embed, dim=0)
    
    # Cross-attention: current frame attends to memories
    encoder_out = self.transformer.encoder(
        src=current_vision_feats,
        prompt=prompt,
        prompt_pos=prompt_pos_embed,
        num_obj_ptr_tokens=num_obj_ptr_tokens,
    )
    
    # Return memory-conditioned features
    pix_feat_with_mem = encoder_out["memory"].permute(1, 2, 0).view(B, C, H, W)
    return pix_feat_with_mem

Memory Attention Mechanism

1

Gather Memories

Collect spatial memories and object pointers from:
  • Conditioning frames (prompted frames)
  • Last N-1 non-conditioning frames
2

Add Temporal Encoding

Encode how far each memory is from current frame
3

Cross-Attention

Current frame features attend to memories:
  • Spatial memory via spatial cross-attention
  • Object pointers as additional tokens
4

Fused Features

Output memory-conditioned features for mask prediction
By default, SAM 3 maintains num_maskmem=7 memory frames:
  • Conditioning frames (prompted)
  • Last 6 non-conditioning frames
This balances context and efficiency.

Temporal Stride

For long videos, use temporal stride to reduce memory:
# Skip frames in memory (use every r-th frame)
memory_temporal_stride_for_eval = 5  # Use every 5th frame

model = build_sam3_video_predictor(
    checkpoint_path=checkpoint,
    memory_temporal_stride_for_eval=5
)

Propagation Directions

Track forward, backward, or both:
# Forward only (from prompt frame to end)
for frame_output in predictor.handle_stream_request({
    "type": "propagate_in_video",
    "session_id": session_id,
    "start_frame_index": 0,
    "propagation_direction": "forward"
}):
    ...

# Backward only (from prompt frame to start)
for frame_output in predictor.handle_stream_request({
    "type": "propagate_in_video",
    "session_id": session_id,
    "start_frame_index": 50,
    "propagation_direction": "backward"
}):
    ...

# Both directions
for frame_output in predictor.handle_stream_request({
    "type": "propagate_in_video",
    "session_id": session_id,
    "start_frame_index": 25,
    "propagation_direction": "both"  # Forward then backward
}):
    ...
Forward: When you prompt on early frames
  • Start from frame 0 or beginning
  • Track objects entering the scene
Backward: When you prompt on later frames
  • Start from frame where object is clearly visible
  • Track backwards to find when object appeared
Both: For middle-frame prompting
  • Prompt on a clear frame in the middle
  • Track in both directions for full coverage

Multi-GPU Video Processing

For faster inference, use multi-GPU processing:
from sam3.model_builder import build_sam3_video_predictor_multigpu

# Use GPUs 0, 1, 2, 3
predictor = build_sam3_video_predictor_multigpu(
    gpus_to_use=[0, 1, 2, 3]
)

# Each GPU processes different frames in parallel
for frame_output in predictor.handle_stream_request({
    "type": "propagate_in_video",
    "session_id": session_id,
    "propagation_direction": "both"
}):
    ...

How Multi-GPU Works

source/sam3/model/sam3_image.py
class Sam3ImageOnVideoMultiGPU(Sam3Image):
    def forward_video_grounding_multigpu(
        self,
        frame_idx,
        num_frames,
        multigpu_buffer,  # Cache for detector outputs
        ...
    ):
        """
        Process frames in chunks across GPUs:
        - GPU 0 processes frame 0, 4, 8, ...
        - GPU 1 processes frame 1, 5, 9, ...
        - GPU 2 processes frame 2, 6, 10, ...
        - GPU 3 processes frame 3, 7, 11, ...
        
        Results are gathered via NCCL all-gather.
        """
        # Compute detection on local GPU's frame
        frame_idx_local = frame_idx_begin + self.rank
        out_local = self.forward_grounding(
            find_input=find_inputs[frame_idx_local],
            geometric_prompt=geometric_prompt,
        )
        
        # Gather results from all GPUs
        out_gathered = {k: self._gather_tensor(v) for k, v in out_local.items()}
        
        # Cache in multigpu_buffer for later access
        for rank in range(self.world_size):
            frame_idx_to_save = frame_idx_begin + rank
            multigpu_buffer[frame_idx_to_save] = out_gathered[rank]
        
        return multigpu_buffer
Multi-GPU processing provides near-linear speedup. With 4 GPUs, expect ~3.5× faster propagation.

Memory Management

Memory usage depends on:
  • Video resolution: Higher resolution = more memory
  • Video length: Longer videos need more frame storage
  • Number of objects: More objects = larger memory bank
  • Memory frames: More memory frames = higher quality but more memory
Typical usage:
  • 720p video, 100 frames, 1 object: ~4GB
  • 1080p video, 500 frames, 5 objects: ~16GB
Several strategies:
  1. Temporal stride: Skip frames in memory
memory_temporal_stride_for_eval=5
  1. Offload to CPU: Store outputs on CPU
offload_output_to_cpu_for_eval=True
  1. Per-frame backbone: Process backbone per frame
forward_backbone_per_frame_for_eval=True
  1. Trim past frames: Don’t keep old non-cond frames
trim_past_non_cond_mem_for_eval=True
For videos with 1000+ frames:
  1. Process in chunks with separate sessions
  2. Use temporal stride (e.g., r=5)
  3. Enable CPU offloading
  4. Process backbone per frame
  5. Consider splitting video into segments
predictor = build_sam3_video_predictor(
    forward_backbone_per_frame_for_eval=True,
    memory_temporal_stride_for_eval=5,
    offload_output_to_cpu_for_eval=True,
)

Object Management

Track multiple objects independently:
# Add first object
response = predictor.handle_request({
    "type": "add_prompt",
    "session_id": session_id,
    "frame_index": 0,
    "text": "person in red",
    "obj_id": 1
})

# Add second object
response = predictor.handle_request({
    "type": "add_prompt",
    "session_id": session_id,
    "frame_index": 0,
    "text": "person in blue",
    "obj_id": 2
})

# Remove an object
response = predictor.handle_request({
    "type": "remove_object",
    "session_id": session_id,
    "obj_id": 1
})
Each object has its own memory bank and is tracked independently. Object IDs must be unique within a session.

Best Practices

Prompt on Clear Frames

Add prompts on frames where objects are clearly visible and unoccluded

Use Multiple Prompts

For difficult cases, add prompts on multiple frames to improve tracking

Refine When Needed

Add correction prompts on frames where tracking fails

Close Sessions

Always close sessions to free GPU memory when done

Next Steps

Architecture

Understand the detector-tracker design

Prompting

Learn about different prompt types

Video Inference Guide

See complete code examples

API Reference

Explore the full video API

Build docs developers (and LLMs) love