Overview
The SAM3InteractiveImagePredictor class provides SAM 1-style interactive segmentation with point and box prompts. It uses SAM 3’s tracking module for efficient mask prediction.
Class Initialization
from sam3.model.sam1_task_predictor import SAM3InteractiveImagePredictor
predictor = SAM3InteractiveImagePredictor(
sam_model,
mask_threshold=0.0,
max_hole_area=256.0,
max_sprinkle_area=0.0
)
Parameters
SAM 3 tracker model instance.
Threshold for converting mask logits to binary masks.
Maximum area of holes to fill in masks. Set to 0 to disable hole filling.
Maximum area of small regions to remove. Set to 0 to disable.
Methods
set_image
Computes image embeddings for the provided image.
predictor.set_image(image)
image
np.ndarray | PIL.Image.Image
required
Input image in RGB format. NumPy arrays should be in HWC format, PIL images in WHC format. Pixel values in [0, 255].
set_image_batch
Computes image embeddings for a batch of images.
predictor.set_image_batch(image_list)
List of images in RGB format (HWC, uint8, [0, 255]).
predict
Predict masks for the given prompts.
masks, iou_predictions, low_res_masks = predictor.predict(
point_coords=None,
point_labels=None,
box=None,
mask_input=None,
multimask_output=True,
return_logits=False,
normalize_coords=True
)
Parameters
point_coords
np.ndarray | None
default:"None"
Nx2 array of point prompts in (X, Y) pixel coordinates.
point_labels
np.ndarray | None
default:"None"
Length N array of labels: 1 for foreground point, 0 for background point.
box
np.ndarray | None
default:"None"
Box prompt in XYXY format: [x0, y0, x1, y1].
mask_input
np.ndarray | None
default:"None"
Low-resolution mask from previous iteration (1xHxW, H=W=256).
If True, returns 3 masks. For ambiguous prompts, often produces better results.
If True, returns un-thresholded mask logits instead of binary masks.
If True, point coordinates are normalized to [0, 1] range.
Returns
Output masks in CxHxW format, where C is the number of masks (1 or 3).
Length C array of predicted IoU scores for each mask.
Low-resolution logits in CxHxW format (H=W=256) for iterative refinement.
predict_batch
Predict masks for a batch of images.
masks_list, ious_list, low_res_list = predictor.predict_batch(
point_coords_batch=None,
point_labels_batch=None,
box_batch=None,
mask_input_batch=None,
multimask_output=True,
return_logits=False,
normalize_coords=True
)
Parameters
Same as predict(), but each parameter is a list (one entry per image).
Returns
Lists of masks, IoU predictions, and low-resolution masks (one per image).
get_image_embedding
Returns the cached image embeddings.
embedding = predictor.get_image_embedding()
Image embeddings with shape 1xCxHxW (typically C=256, H=W=64).
reset_predictor
Resets all cached state (image embeddings, etc.).
predictor.reset_predictor()
Example Usage
Single Point Prompt
import numpy as np
from PIL import Image
from sam3.model_builder import build_tracker
from sam3.model.sam1_task_predictor import SAM3InteractiveImagePredictor
# Build tracker and predictor
tracker = build_tracker(apply_temporal_disambiguation=False)
predictor = SAM3InteractiveImagePredictor(tracker)
# Load and set image
image = np.array(Image.open("image.jpg"))
predictor.set_image(image)
# Predict with single point
point_coords = np.array([[500, 375]]) # x, y
point_labels = np.array([1]) # foreground
masks, scores, _ = predictor.predict(
point_coords=point_coords,
point_labels=point_labels,
multimask_output=True
)
print(f"Generated {len(masks)} masks")
print(f"Best mask IoU: {scores.max():.3f}")
Box Prompt
# Set image
predictor.set_image(image)
# Predict with box
box = np.array([100, 100, 500, 400]) # x0, y0, x1, y1
masks, scores, _ = predictor.predict(
box=box,
multimask_output=False # Single mask for unambiguous box
)
Combining Points and Boxes
# Both box and points can be used together
box = np.array([100, 100, 500, 400])
points = np.array([[250, 250], [350, 300]]) # Additional points
labels = np.array([1, 1]) # Both foreground
masks, scores, _ = predictor.predict(
point_coords=points,
point_labels=labels,
box=box,
multimask_output=True
)
Iterative Refinement
# First prediction
masks, scores, low_res_masks = predictor.predict(
point_coords=np.array([[250, 250]]),
point_labels=np.array([1]),
multimask_output=True
)
# Select best mask
best_idx = np.argmax(scores)
best_mask_input = low_res_masks[best_idx:best_idx+1]
# Refine with additional point
masks_refined, scores_refined, _ = predictor.predict(
point_coords=np.array([[250, 250], [300, 300]]),
point_labels=np.array([1, 1]),
mask_input=best_mask_input,
multimask_output=False
)
Batch Processing
# Load multiple images
images = [np.array(Image.open(f"image_{i}.jpg")) for i in range(3)]
predictor.set_image_batch(images)
# Define prompts for each image
point_coords_batch = [
np.array([[250, 250]]),
np.array([[300, 300]]),
np.array([[400, 400]])
]
point_labels_batch = [
np.array([1]),
np.array([1]),
np.array([1])
]
# Predict on all images
masks_list, scores_list, _ = predictor.predict_batch(
point_coords_batch=point_coords_batch,
point_labels_batch=point_labels_batch,
multimask_output=True
)
# Process results
for i, (masks, scores) in enumerate(zip(masks_list, scores_list)):
print(f"Image {i}: {len(masks)} masks, best IoU: {scores.max():.3f}")
Negative Points
# Use negative points to exclude regions
points = np.array([
[250, 250], # Include this region
[450, 450] # Exclude this region
])
labels = np.array([1, 0]) # 1=foreground, 0=background
masks, scores, _ = predictor.predict(
point_coords=points,
point_labels=labels,
multimask_output=True
)
Notes
- Call
set_image() or set_image_batch() before prediction
- Images must be in RGB format with pixel values in [0, 255]
- Point coordinates are in (X, Y) format (width, height)
- Box format is XYXY: [x_min, y_min, x_max, y_max]
- For ambiguous prompts (single point), use
multimask_output=True
- For clear prompts (box, multiple points),
multimask_output=False often works better
- Low-resolution masks can be reused for iterative refinement