Overview
SAM 3 extends image segmentation to videos with temporal tracking and propagation . It can segment objects across frames, maintain temporal consistency, and handle occlusions.
Video Architecture
SAM 3 uses a decoupled detector-tracker design for video:
Detector Detects objects on prompted frames using text or visual cues
Tracker Propagates detections across frames with memory attention
Both components share the same vision encoder for efficiency.
Video Workflow
The video segmentation pipeline:
Initialize Session
Load video and create inference state
Add Prompts
Specify objects on one or more frames
Detect Objects
Run detector on prompted frames
Encode Memory
Convert detections to memory representations
Propagate
Track objects across all frames using memory attention
Basic Usage
from sam3.model_builder import build_sam3_video_predictor
# Initialize predictor
predictor = build_sam3_video_predictor()
# Start session with video path (MP4 or JPEG folder)
response = predictor.handle_request({
"type" : "start_session" ,
"resource_path" : "path/to/video.mp4"
})
session_id = response[ "session_id" ]
# Add text prompt on frame 0
response = predictor.handle_request({
"type" : "add_prompt" ,
"session_id" : session_id,
"frame_index" : 0 ,
"text" : "person wearing red shirt"
})
obj_id = response[ "outputs" ][ "obj_ids" ][ 0 ]
# Propagate to all frames
for frame_output in predictor.handle_stream_request({
"type" : "propagate_in_video" ,
"session_id" : session_id,
"propagation_direction" : "both" # Forward and backward
}):
frame_idx = frame_output[ "frame_index" ]
masks = frame_output[ "outputs" ][ "masks" ] # Segmentation masks
print ( f "Frame { frame_idx } : { len (masks) } objects tracked" )
Session Management
SAM 3 uses a session-based API for video processing:
source/sam3/model/sam3_video_predictor.py
class Sam3VideoPredictor :
# Global dictionary holding all inference states
_ALL_INFERENCE_STATES = {}
def start_session ( self , resource_path , session_id = None ):
"""Start a new inference session on a video."""
# Initialize inference state from model
inference_state = self .model.init_state(
resource_path = resource_path,
async_loading_frames = self .async_loading_frames,
video_loader_type = self .video_loader_type,
)
if not session_id:
session_id = str (uuid.uuid4())
self . _ALL_INFERENCE_STATES [session_id] = {
"state" : inference_state,
"session_id" : session_id,
"start_time" : time.time(),
}
return { "session_id" : session_id}
Session Lifecycle
Start Session
Create a session with a unique ID for the video
Add Prompts
Add prompts on specific frames to define objects
Propagate
Track objects across frames in either direction
Refine (Optional)
Add correction prompts on any frame
Close Session
Clean up resources when done
Sessions maintain state including video frames, memory bank, and tracking results. Always close sessions to free GPU memory.
Adding Prompts
Prompts can be added on any frame to initialize or refine tracking:
# Text prompt on first frame
response = predictor.handle_request({
"type" : "add_prompt" ,
"session_id" : session_id,
"frame_index" : 0 ,
"text" : "dog" ,
"obj_id" : 1 # Optional: specify object ID
})
# Point refinement on frame 10
response = predictor.handle_request({
"type" : "add_prompt" ,
"session_id" : session_id,
"frame_index" : 10 ,
"points" : [[ 200 , 300 ]],
"point_labels" : [ 1 ], # 1 = positive click
"obj_id" : 1 # Refine existing object
})
# Box prompt on frame 5
response = predictor.handle_request({
"type" : "add_prompt" ,
"session_id" : session_id,
"frame_index" : 5 ,
"bounding_boxes" : [[ 100 , 100 , 200 , 300 ]], # [x, y, w, h]
"bounding_box_labels" : [ 1 ],
"obj_id" : 2 # New object
})
Prompt Types
Prompt Type Use Case Example Text Initialize tracking with concepts ”person in blue jacket” Points Refine mask boundaries Click on missed regions Boxes Specify instance locations Bounding box around object Masks Provide exact segmentation Upload pre-segmented mask
Prompts can be added on conditioning frames (for initialization) or non-conditioning frames (for correction).
Detection on Prompted Frames
When you add a prompt, the detector runs on that frame:
source/sam3/model/sam3_tracker_base.py
def track_step (
self ,
frame_idx ,
is_init_cond_frame , # True if this is a prompted frame
current_vision_feats ,
point_inputs ,
mask_inputs ,
output_dict ,
num_frames ,
run_mem_encoder = True ,
):
# Get high-res features for SAM decoder
if len (current_vision_feats) > 1 :
high_res_features = [
x.permute( 1 , 2 , 0 ).view(x.size( 1 ), x.size( 2 ), * s)
for x, s in zip (current_vision_feats[: - 1 ], feat_sizes[: - 1 ])
]
if mask_inputs is not None :
# Use mask directly as output
sam_outputs = self ._use_mask_as_output(
pix_feat, high_res_features, mask_inputs
)
else :
# Condition on memory from previous frames
pix_feat_with_mem = self ._prepare_memory_conditioned_features(
frame_idx = frame_idx,
is_init_cond_frame = is_init_cond_frame,
current_vision_feats = current_vision_feats[ - 1 :],
output_dict = output_dict,
num_frames = num_frames,
)
# Run SAM-style mask decoder
sam_outputs = self ._forward_sam_heads(
backbone_features = pix_feat_with_mem,
point_inputs = point_inputs,
mask_inputs = mask_inputs,
high_res_features = high_res_features,
multimask_output = self ._use_multimask(is_init_cond_frame, point_inputs),
)
# Extract masks and object pointers
low_res_masks, high_res_masks, obj_ptr = sam_outputs[ 3 : 6 ]
# Store in output
current_out[ "pred_masks" ] = low_res_masks
current_out[ "pred_masks_high_res" ] = high_res_masks
current_out[ "obj_ptr" ] = obj_ptr # Used for memory
return current_out
Memory Encoding
Predictions are encoded into memory for tracking:
source/sam3/model/sam3_tracker_base.py
def _encode_new_memory (
self ,
image ,
current_vision_feats ,
feat_sizes ,
pred_masks_high_res ,
object_score_logits ,
is_mask_from_pts ,
):
"""Encode current frame prediction into memory."""
B = current_vision_feats[ - 1 ].size( 1 )
C = self .hidden_dim
H, W = feat_sizes[ - 1 ]
# Get pixel features
pix_feat = current_vision_feats[ - 1 ].permute( 1 , 2 , 0 ).view(B, C, H, W)
# Apply sigmoid to mask logits
mask_for_mem = torch.sigmoid(pred_masks_high_res)
mask_for_mem = mask_for_mem * self .sigmoid_scale_for_mem_enc
mask_for_mem = mask_for_mem + self .sigmoid_bias_for_mem_enc
# Encode mask + image features
maskmem_out = self .maskmem_backbone(image, pix_feat, mask_for_mem)
maskmem_features = maskmem_out[ "vision_features" ]
maskmem_pos_enc = maskmem_out[ "vision_pos_enc" ]
# Add no-object embedding for occluded frames
is_obj_appearing = (object_score_logits > 0 ).float()
maskmem_features += (
1 - is_obj_appearing[ ... , None , None ]
) * self .no_obj_embed_spatial[ ... , None , None ]
return maskmem_features, maskmem_pos_enc
Memory Bank
The memory bank stores:
Spatial Memory : Mask-conditioned image features
Object Pointers : Compact representations from decoder output tokens
Temporal Encoding : Time offsets from current frame
Spatial Memory Dense features at 1/4 image resolution encoding mask and visual context
Object Pointers Compact vectors extracted from SAM decoder output tokens
Propagation and Tracking
The tracker propagates masks across frames using memory attention :
source/sam3/model/sam3_tracker_base.py
def _prepare_memory_conditioned_features (
self ,
frame_idx ,
is_init_cond_frame ,
current_vision_feats ,
output_dict ,
num_frames ,
track_in_reverse = False ,
):
"""Condition current frame on memories from previous frames."""
B = current_vision_feats[ - 1 ].size( 1 )
C = self .hidden_dim
H, W = feat_sizes[ - 1 ]
device = current_vision_feats[ - 1 ].device
if is_init_cond_frame:
# First frame: no memory, add no_mem_embed
pix_feat_with_mem = current_vision_feats[ - 1 ] + self .no_mem_embed
return pix_feat_with_mem.permute( 1 , 2 , 0 ).view(B, C, H, W)
# Retrieve memories from conditioning frames
to_cat_prompt, to_cat_prompt_pos_embed = [], []
# Add conditioning frame outputs
for t, out in output_dict[ "cond_frame_outputs" ].items():
t_pos = (frame_idx - t) * tpos_sign_mul
# Spatial memory
maskmem_features = out[ "maskmem_features" ].cuda()
to_cat_prompt.append(maskmem_features.flatten( 2 ).permute( 2 , 0 , 1 ))
# Positional encoding
maskmem_pos_enc = out[ "maskmem_pos_enc" ][ - 1 ].cuda()
maskmem_pos_enc = maskmem_pos_enc + self .maskmem_tpos_enc[ self .num_maskmem - t_pos - 1 ]
to_cat_prompt_pos_embed.append(maskmem_pos_enc.flatten( 2 ).permute( 2 , 0 , 1 ))
# Add recent non-conditioning frames (last num_maskmem-1 frames)
for t_pos in range ( 1 , self .num_maskmem):
prev_frame_idx = frame_idx - t_pos if not track_in_reverse else frame_idx + t_pos
out = output_dict[ "non_cond_frame_outputs" ].get(prev_frame_idx)
if out is not None :
# Add to memory
...
# Collect object pointers
for t, out in ptr_cond_outputs.items():
obj_ptrs.append(out[ "obj_ptr" ])
obj_pos.append( self ._get_tpos_enc([frame_idx - t], device))
# Concatenate all prompts
prompt = torch.cat(to_cat_prompt, dim = 0 )
prompt_pos_embed = torch.cat(to_cat_prompt_pos_embed, dim = 0 )
# Cross-attention: current frame attends to memories
encoder_out = self .transformer.encoder(
src = current_vision_feats,
prompt = prompt,
prompt_pos = prompt_pos_embed,
num_obj_ptr_tokens = num_obj_ptr_tokens,
)
# Return memory-conditioned features
pix_feat_with_mem = encoder_out[ "memory" ].permute( 1 , 2 , 0 ).view(B, C, H, W)
return pix_feat_with_mem
Memory Attention Mechanism
Gather Memories
Collect spatial memories and object pointers from:
Conditioning frames (prompted frames)
Last N-1 non-conditioning frames
Add Temporal Encoding
Encode how far each memory is from current frame
Cross-Attention
Current frame features attend to memories:
Spatial memory via spatial cross-attention
Object pointers as additional tokens
Fused Features
Output memory-conditioned features for mask prediction
By default, SAM 3 maintains num_maskmem=7 memory frames :
Conditioning frames (prompted)
Last 6 non-conditioning frames
This balances context and efficiency.
Temporal Stride
For long videos, use temporal stride to reduce memory:
# Skip frames in memory (use every r-th frame)
memory_temporal_stride_for_eval = 5 # Use every 5th frame
model = build_sam3_video_predictor(
checkpoint_path = checkpoint,
memory_temporal_stride_for_eval = 5
)
Propagation Directions
Track forward, backward, or both:
# Forward only (from prompt frame to end)
for frame_output in predictor.handle_stream_request({
"type" : "propagate_in_video" ,
"session_id" : session_id,
"start_frame_index" : 0 ,
"propagation_direction" : "forward"
}):
...
# Backward only (from prompt frame to start)
for frame_output in predictor.handle_stream_request({
"type" : "propagate_in_video" ,
"session_id" : session_id,
"start_frame_index" : 50 ,
"propagation_direction" : "backward"
}):
...
# Both directions
for frame_output in predictor.handle_stream_request({
"type" : "propagate_in_video" ,
"session_id" : session_id,
"start_frame_index" : 25 ,
"propagation_direction" : "both" # Forward then backward
}):
...
When to use different directions?
Forward : When you prompt on early frames
Start from frame 0 or beginning
Track objects entering the scene
Backward : When you prompt on later frames
Start from frame where object is clearly visible
Track backwards to find when object appeared
Both : For middle-frame prompting
Prompt on a clear frame in the middle
Track in both directions for full coverage
Multi-GPU Video Processing
For faster inference, use multi-GPU processing:
from sam3.model_builder import build_sam3_video_predictor_multigpu
# Use GPUs 0, 1, 2, 3
predictor = build_sam3_video_predictor_multigpu(
gpus_to_use = [ 0 , 1 , 2 , 3 ]
)
# Each GPU processes different frames in parallel
for frame_output in predictor.handle_stream_request({
"type" : "propagate_in_video" ,
"session_id" : session_id,
"propagation_direction" : "both"
}):
...
How Multi-GPU Works
source/sam3/model/sam3_image.py
class Sam3ImageOnVideoMultiGPU ( Sam3Image ):
def forward_video_grounding_multigpu (
self ,
frame_idx ,
num_frames ,
multigpu_buffer , # Cache for detector outputs
...
):
"""
Process frames in chunks across GPUs:
- GPU 0 processes frame 0, 4, 8, ...
- GPU 1 processes frame 1, 5, 9, ...
- GPU 2 processes frame 2, 6, 10, ...
- GPU 3 processes frame 3, 7, 11, ...
Results are gathered via NCCL all-gather.
"""
# Compute detection on local GPU's frame
frame_idx_local = frame_idx_begin + self .rank
out_local = self .forward_grounding(
find_input = find_inputs[frame_idx_local],
geometric_prompt = geometric_prompt,
)
# Gather results from all GPUs
out_gathered = {k: self ._gather_tensor(v) for k, v in out_local.items()}
# Cache in multigpu_buffer for later access
for rank in range ( self .world_size):
frame_idx_to_save = frame_idx_begin + rank
multigpu_buffer[frame_idx_to_save] = out_gathered[rank]
return multigpu_buffer
Multi-GPU processing provides near-linear speedup. With 4 GPUs, expect ~3.5× faster propagation.
Memory Management
How much GPU memory is needed?
Memory usage depends on:
Video resolution : Higher resolution = more memory
Video length : Longer videos need more frame storage
Number of objects : More objects = larger memory bank
Memory frames : More memory frames = higher quality but more memory
Typical usage:
720p video, 100 frames, 1 object: ~4GB
1080p video, 500 frames, 5 objects: ~16GB
How to reduce memory usage?
Several strategies:
Temporal stride : Skip frames in memory
memory_temporal_stride_for_eval = 5
Offload to CPU : Store outputs on CPU
offload_output_to_cpu_for_eval = True
Per-frame backbone : Process backbone per frame
forward_backbone_per_frame_for_eval = True
Trim past frames : Don’t keep old non-cond frames
trim_past_non_cond_mem_for_eval = True
How to handle very long videos?
For videos with 1000+ frames:
Process in chunks with separate sessions
Use temporal stride (e.g., r=5)
Enable CPU offloading
Process backbone per frame
Consider splitting video into segments
predictor = build_sam3_video_predictor(
forward_backbone_per_frame_for_eval = True ,
memory_temporal_stride_for_eval = 5 ,
offload_output_to_cpu_for_eval = True ,
)
Object Management
Track multiple objects independently:
# Add first object
response = predictor.handle_request({
"type" : "add_prompt" ,
"session_id" : session_id,
"frame_index" : 0 ,
"text" : "person in red" ,
"obj_id" : 1
})
# Add second object
response = predictor.handle_request({
"type" : "add_prompt" ,
"session_id" : session_id,
"frame_index" : 0 ,
"text" : "person in blue" ,
"obj_id" : 2
})
# Remove an object
response = predictor.handle_request({
"type" : "remove_object" ,
"session_id" : session_id,
"obj_id" : 1
})
Each object has its own memory bank and is tracked independently. Object IDs must be unique within a session.
Best Practices
Prompt on Clear Frames Add prompts on frames where objects are clearly visible and unoccluded
Use Multiple Prompts For difficult cases, add prompts on multiple frames to improve tracking
Refine When Needed Add correction prompts on frames where tracking fails
Close Sessions Always close sessions to free GPU memory when done
Next Steps
Architecture Understand the detector-tracker design
Prompting Learn about different prompt types
Video Inference Guide See complete code examples
API Reference Explore the full video API