Skip to main content
SAM 3 supports interactive refinement, allowing you to iteratively improve segmentation results by adding positive and negative prompts. This is especially useful when initial results aren’t perfect.

What is Interactive Refinement?

Interactive refinement lets you:
  • Add regions: Click to include missed areas (positive prompts)
  • Remove regions: Click to exclude wrongly included areas (negative prompts)
  • Adjust boundaries: Draw boxes to fine-tune segmentation edges
  • Iterate: Apply multiple rounds of refinement

Jupyter Widget Interface

The easiest way to refine segmentations interactively is using the Jupyter widget:
1

Enable matplotlib widget backend

%matplotlib widget
2

Import and setup

import torch
import sam3
from sam3 import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
import os

# Configure PyTorch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
torch.inference_mode().__enter__()

# Build model
sam3_root = os.path.join(os.path.dirname(sam3.__file__), "..")
bpe_path = f"{sam3_root}/assets/bpe_simple_vocab_16e6.txt.gz"
model = build_sam3_image_model(bpe_path=bpe_path)
processor = Sam3Processor(model)
3

Launch the widget

# Widget code would be imported from sam3
from sam3.interactive_widget import Sam3SegmentationWidget

widget = Sam3SegmentationWidget(processor)
widget.display()

Widget Features

Image Source

  • Upload local images
  • Load from URL
  • Drag and drop support

Text Prompts

  • Enter natural language queries
  • See results in real-time
  • Combine with box prompts

Box Prompts

  • Draw positive boxes (green)
  • Draw negative boxes (red)
  • Toggle between modes

Controls

  • Adjust confidence threshold
  • Resize display
  • Clear all prompts

Programmatic Refinement

For non-interactive workflows, refine segmentations programmatically:

Adding Positive Boxes

Include regions that were missed:
from PIL import Image
from sam3.model.box_ops import box_xywh_to_cxcywh
from sam3.visualization_utils import normalize_bbox
import torch

# Load image
image = Image.open("path/to/image.jpg")
width, height = image.size

# Initial text prompt
inference_state = processor.set_image(image)
inference_state = processor.set_text_prompt(state=inference_state, prompt="person")

# Add positive box to include a missed person
box_xywh = torch.tensor([480.0, 290.0, 110.0, 360.0]).view(-1, 4)
box_cxcywh = box_xywh_to_cxcywh(box_xywh)
norm_box = normalize_bbox(box_cxcywh, width, height).flatten().tolist()

inference_state = processor.add_geometric_prompt(
    state=inference_state,
    box=norm_box,
    label=True  # Positive prompt
)

Adding Negative Boxes

Exclude wrongly segmented regions:
# Remove false positive
negative_box_xywh = torch.tensor([370.0, 280.0, 115.0, 375.0]).view(-1, 4)
negative_box_cxcywh = box_xywh_to_cxcywh(negative_box_xywh)
norm_negative_box = normalize_bbox(
    negative_box_cxcywh, width, height
).flatten().tolist()

inference_state = processor.add_geometric_prompt(
    state=inference_state,
    box=norm_negative_box,
    label=False  # Negative prompt
)

Multiple Refinement Rounds

# Round 1: Initial text prompt
inference_state = processor.reset_all_prompts(inference_state)
inference_state = processor.set_text_prompt(
    state=inference_state, prompt="shoe"
)

# Round 2: Add positive box for missed shoe
box1 = normalize_bbox(
    box_xywh_to_cxcywh(torch.tensor([100, 200, 50, 80]).view(-1, 4)),
    width, height
).flatten().tolist()
inference_state = processor.add_geometric_prompt(
    state=inference_state, box=box1, label=True
)

# Round 3: Remove false positive
box2 = normalize_bbox(
    box_xywh_to_cxcywh(torch.tensor([500, 100, 60, 90]).view(-1, 4)),
    width, height
).flatten().tolist()
inference_state = processor.add_geometric_prompt(
    state=inference_state, box=box2, label=False
)

# Visualize final result
from sam3.visualization_utils import plot_results
plot_results(image, inference_state)

Video Refinement

Refine video segmentations by adding prompts on specific frames:
from sam3.model_builder import build_sam3_video_predictor

predictor = build_sam3_video_predictor()

# Start session
response = predictor.handle_request(
    request=dict(type="start_session", resource_path="video.mp4")
)
session_id = response["session_id"]

# Initial text prompt on frame 0
response = predictor.handle_request(
    request=dict(
        type="add_prompt",
        session_id=session_id,
        frame_index=0,
        text="person",
    )
)

# Refine with positive point on frame 0
points = torch.tensor([[0.5, 0.6]])  # Normalized coordinates
point_labels = torch.tensor([1])  # 1 = positive

response = predictor.handle_request(
    request=dict(
        type="add_prompt",
        session_id=session_id,
        frame_index=0,
        points=points,
        point_labels=point_labels,
        obj_id=1,  # Refine specific object
    )
)

# Propagate refined segmentation
for response in predictor.handle_stream_request(
    request=dict(type="propagate_in_video", session_id=session_id)
):
    frame_idx = response["frame_index"]
    outputs = response["outputs"]
    # Process outputs...

Refinement Strategies

Use tight boxes around edges:
# For rough edges, draw a tight box around the correct boundary
boundary_box = [x, y, w, h]  # Tight around desired edge
inference_state = processor.add_geometric_prompt(
    state=inference_state,
    box=normalize_box(boundary_box, width, height),
    label=True
)

Resetting Prompts

Start fresh while keeping the image loaded:
# Clear all prompts and results
inference_state = processor.reset_all_prompts(inference_state)

# Image remains loaded, ready for new prompts
inference_state = processor.set_text_prompt(
    state=inference_state, prompt="new object"
)
Resetting prompts clears all segmentation results. Save any important masks before resetting.

Confidence Threshold Adjustment

Sometimes refinement isn’t needed—just adjust the threshold:
# Lower threshold to include more candidates
inference_state = processor.set_confidence_threshold(0.3, inference_state)

# Higher threshold for stricter filtering
inference_state = processor.set_confidence_threshold(0.7, inference_state)

Workflow Example: Perfect Segmentation

1

Initial prompt

inference_state = processor.set_image(image)
inference_state = processor.set_text_prompt(
    state=inference_state, prompt="dog"
)
plot_results(image, inference_state)
2

Add missed region

# Dog's tail was missed
tail_box = normalize_box([650, 450, 80, 120], width, height)
inference_state = processor.add_geometric_prompt(
    state=inference_state, box=tail_box, label=True
)
plot_results(image, inference_state)
3

Remove false positive

# Toy dog was incorrectly included
toy_box = normalize_box([200, 300, 60, 90], width, height)
inference_state = processor.add_geometric_prompt(
    state=inference_state, box=toy_box, label=False
)
plot_results(image, inference_state)
4

Final result

# Perfect segmentation achieved!
plot_results(image, inference_state)

Tips for Effective Refinement

Start broad, then refine:
  1. Begin with a text prompt
  2. Review initial results
  3. Add positive boxes for missed regions
  4. Add negative boxes for false positives
Use minimal prompts: Each additional prompt increases computation time. Try adjusting the confidence threshold before adding more boxes.
Box placement matters:
  • Positive boxes: Cover the entire missed region
  • Negative boxes: Cover only the wrongly included area
  • Avoid overlapping positive and negative boxes

Next Steps

Image Inference

Master basic prompting before refinement

SAM 3 Agent

Use MLLMs for complex queries that reduce manual refinement

Video Inference

Apply refinement techniques to video segmentation

Build docs developers (and LLMs) love