Skip to main content

Overview

The SAM3InteractiveImagePredictor class provides SAM 1-style interactive segmentation with point and box prompts. It uses SAM 3’s tracking module for efficient mask prediction.

Class Initialization

from sam3.model.sam1_task_predictor import SAM3InteractiveImagePredictor

predictor = SAM3InteractiveImagePredictor(
    sam_model,
    mask_threshold=0.0,
    max_hole_area=256.0,
    max_sprinkle_area=0.0
)

Parameters

sam_model
Sam3TrackerBase
required
SAM 3 tracker model instance.
mask_threshold
float
default:"0.0"
Threshold for converting mask logits to binary masks.
max_hole_area
float
default:"256.0"
Maximum area of holes to fill in masks. Set to 0 to disable hole filling.
max_sprinkle_area
float
default:"0.0"
Maximum area of small regions to remove. Set to 0 to disable.

Methods

set_image

Computes image embeddings for the provided image.
predictor.set_image(image)
image
np.ndarray | PIL.Image.Image
required
Input image in RGB format. NumPy arrays should be in HWC format, PIL images in WHC format. Pixel values in [0, 255].

set_image_batch

Computes image embeddings for a batch of images.
predictor.set_image_batch(image_list)
image_list
list[np.ndarray]
required
List of images in RGB format (HWC, uint8, [0, 255]).

predict

Predict masks for the given prompts.
masks, iou_predictions, low_res_masks = predictor.predict(
    point_coords=None,
    point_labels=None,
    box=None,
    mask_input=None,
    multimask_output=True,
    return_logits=False,
    normalize_coords=True
)

Parameters

point_coords
np.ndarray | None
default:"None"
Nx2 array of point prompts in (X, Y) pixel coordinates.
point_labels
np.ndarray | None
default:"None"
Length N array of labels: 1 for foreground point, 0 for background point.
box
np.ndarray | None
default:"None"
Box prompt in XYXY format: [x0, y0, x1, y1].
mask_input
np.ndarray | None
default:"None"
Low-resolution mask from previous iteration (1xHxW, H=W=256).
multimask_output
bool
default:"True"
If True, returns 3 masks. For ambiguous prompts, often produces better results.
return_logits
bool
default:"False"
If True, returns un-thresholded mask logits instead of binary masks.
normalize_coords
bool
default:"True"
If True, point coordinates are normalized to [0, 1] range.

Returns

masks
np.ndarray
Output masks in CxHxW format, where C is the number of masks (1 or 3).
iou_predictions
np.ndarray
Length C array of predicted IoU scores for each mask.
low_res_masks
np.ndarray
Low-resolution logits in CxHxW format (H=W=256) for iterative refinement.

predict_batch

Predict masks for a batch of images.
masks_list, ious_list, low_res_list = predictor.predict_batch(
    point_coords_batch=None,
    point_labels_batch=None,
    box_batch=None,
    mask_input_batch=None,
    multimask_output=True,
    return_logits=False,
    normalize_coords=True
)

Parameters

Same as predict(), but each parameter is a list (one entry per image).

Returns

Lists of masks, IoU predictions, and low-resolution masks (one per image).

get_image_embedding

Returns the cached image embeddings.
embedding = predictor.get_image_embedding()
embedding
torch.Tensor
Image embeddings with shape 1xCxHxW (typically C=256, H=W=64).

reset_predictor

Resets all cached state (image embeddings, etc.).
predictor.reset_predictor()

Example Usage

Single Point Prompt

import numpy as np
from PIL import Image
from sam3.model_builder import build_tracker
from sam3.model.sam1_task_predictor import SAM3InteractiveImagePredictor

# Build tracker and predictor
tracker = build_tracker(apply_temporal_disambiguation=False)
predictor = SAM3InteractiveImagePredictor(tracker)

# Load and set image
image = np.array(Image.open("image.jpg"))
predictor.set_image(image)

# Predict with single point
point_coords = np.array([[500, 375]])  # x, y
point_labels = np.array([1])  # foreground

masks, scores, _ = predictor.predict(
    point_coords=point_coords,
    point_labels=point_labels,
    multimask_output=True
)

print(f"Generated {len(masks)} masks")
print(f"Best mask IoU: {scores.max():.3f}")

Box Prompt

# Set image
predictor.set_image(image)

# Predict with box
box = np.array([100, 100, 500, 400])  # x0, y0, x1, y1

masks, scores, _ = predictor.predict(
    box=box,
    multimask_output=False  # Single mask for unambiguous box
)

Combining Points and Boxes

# Both box and points can be used together
box = np.array([100, 100, 500, 400])
points = np.array([[250, 250], [350, 300]])  # Additional points
labels = np.array([1, 1])  # Both foreground

masks, scores, _ = predictor.predict(
    point_coords=points,
    point_labels=labels,
    box=box,
    multimask_output=True
)

Iterative Refinement

# First prediction
masks, scores, low_res_masks = predictor.predict(
    point_coords=np.array([[250, 250]]),
    point_labels=np.array([1]),
    multimask_output=True
)

# Select best mask
best_idx = np.argmax(scores)
best_mask_input = low_res_masks[best_idx:best_idx+1]

# Refine with additional point
masks_refined, scores_refined, _ = predictor.predict(
    point_coords=np.array([[250, 250], [300, 300]]),
    point_labels=np.array([1, 1]),
    mask_input=best_mask_input,
    multimask_output=False
)

Batch Processing

# Load multiple images
images = [np.array(Image.open(f"image_{i}.jpg")) for i in range(3)]
predictor.set_image_batch(images)

# Define prompts for each image
point_coords_batch = [
    np.array([[250, 250]]),
    np.array([[300, 300]]),
    np.array([[400, 400]])
]
point_labels_batch = [
    np.array([1]),
    np.array([1]),
    np.array([1])
]

# Predict on all images
masks_list, scores_list, _ = predictor.predict_batch(
    point_coords_batch=point_coords_batch,
    point_labels_batch=point_labels_batch,
    multimask_output=True
)

# Process results
for i, (masks, scores) in enumerate(zip(masks_list, scores_list)):
    print(f"Image {i}: {len(masks)} masks, best IoU: {scores.max():.3f}")

Negative Points

# Use negative points to exclude regions
points = np.array([
    [250, 250],  # Include this region
    [450, 450]   # Exclude this region
])
labels = np.array([1, 0])  # 1=foreground, 0=background

masks, scores, _ = predictor.predict(
    point_coords=points,
    point_labels=labels,
    multimask_output=True
)

Notes

  • Call set_image() or set_image_batch() before prediction
  • Images must be in RGB format with pixel values in [0, 255]
  • Point coordinates are in (X, Y) format (width, height)
  • Box format is XYXY: [x_min, y_min, x_max, y_max]
  • For ambiguous prompts (single point), use multimask_output=True
  • For clear prompts (box, multiple points), multimask_output=False often works better
  • Low-resolution masks can be reused for iterative refinement

Build docs developers (and LLMs) love