Skip to main content
Batch processing allows you to run inference on multiple images simultaneously, improving throughput and efficiency. This guide shows you how to create batched datapoints with multiple queries per image.

Setup

1

Import dependencies

from PIL import Image
import requests
from io import BytesIO
import sam3
from sam3.train.data.collator import collate_fn_api as collate
from sam3.model.utils.misc import copy_data_to_device
import os

sam3_root = os.path.join(os.path.dirname(sam3.__file__), "..")
2

Configure PyTorch

import torch

# Turn on tfloat32 for Ampere GPUs
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# Use bfloat16 for the entire notebook
torch.autocast("cuda", dtype=torch.bfloat16).__enter__()

# Inference mode for the whole notebook
torch.inference_mode().__enter__()
3

Load model and transforms

from sam3 import build_sam3_image_model
from sam3.train.transforms.basic_for_api import (
    ComposeAPI, RandomResizeAPI, ToTensorAPI, NormalizeAPI
)

bpe_path = f"{sam3_root}/assets/bpe_simple_vocab_16e6.txt.gz"
model = build_sam3_image_model(bpe_path=bpe_path)

transform = ComposeAPI(
    transforms=[
        RandomResizeAPI(sizes=1008, max_size=1008, square=True, consistent_transform=False),
        ToTensorAPI(),
        NormalizeAPI(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ]
)
4

Setup postprocessor

from sam3.eval.postprocessors import PostProcessImage

postprocessor = PostProcessImage(
    max_dets_per_img=-1,
    iou_type="segm",
    use_original_sizes_box=True,
    use_original_sizes_mask=True,
    convert_mask_to_rle=False,
    detection_threshold=0.5,
    to_cpu=False,
)

Creating Datapoints

Datapoints are the fundamental unit for batched inference. Each datapoint represents a single image with one or more queries.

Utility Functions

from sam3.train.data.sam3_image_dataset import (
    InferenceMetadata, FindQueryLoaded, Image as SAMImage, Datapoint
)
from typing import List

GLOBAL_COUNTER = 1

def create_empty_datapoint():
    """A datapoint is a single image on which we can apply several queries."""
    return Datapoint(find_queries=[], images=[])

def set_image(datapoint, pil_image):
    """Add the image to be processed to the datapoint."""
    w, h = pil_image.size
    datapoint.images = [SAMImage(data=pil_image, objects=[], size=[h,w])]

def add_text_prompt(datapoint, text_query):
    """Add a text query to the datapoint."""
    global GLOBAL_COUNTER
    assert len(datapoint.images) == 1, "please set the image first"
    
    w, h = datapoint.images[0].size
    datapoint.find_queries.append(
        FindQueryLoaded(
            query_text=text_query,
            image_id=0,
            object_ids_output=[],
            is_exhaustive=True,
            query_processing_order=0,
            inference_metadata=InferenceMetadata(
                coco_image_id=GLOBAL_COUNTER,
                original_image_id=GLOBAL_COUNTER,
                original_category_id=1,
                original_size=[w, h],
                object_id=0,
                frame_index=0,
            )
        )
    )
    GLOBAL_COUNTER += 1
    return GLOBAL_COUNTER - 1

Text Prompts in Batches

Create multiple queries for a single image:
# Image 1 with two text prompts
img1 = Image.open(BytesIO(requests.get(
    "http://images.cocodataset.org/val2017/000000077595.jpg"
).content))

datapoint1 = create_empty_datapoint()
set_image(datapoint1, img1)
id1 = add_text_prompt(datapoint1, "cat")
id2 = add_text_prompt(datapoint1, "laptop")

datapoint1 = transform(datapoint1)

Visual Prompts in Batches

Combine text and visual prompts:
def add_visual_prompt(datapoint, boxes: List[List[float]], labels: List[bool], text_prompt="visual"):
    """Add a visual query to the datapoint.
    
    Args:
        boxes: Bounding boxes in XYXY format (top left and bottom right corners)
        labels: True for positive boxes, False for negative boxes
        text_prompt: Optional text hint
    """
    global GLOBAL_COUNTER
    assert len(datapoint.images) == 1, "please set the image first"
    assert len(boxes) > 0, "please provide at least one box"
    assert len(boxes) == len(labels), "Expecting one label per box"
    
    labels = torch.tensor(labels, dtype=torch.bool).view(-1)
    w, h = datapoint.images[0].size
    
    datapoint.find_queries.append(
        FindQueryLoaded(
            query_text=text_prompt,
            image_id=0,
            object_ids_output=[],
            is_exhaustive=True,
            query_processing_order=0,
            input_bbox=torch.tensor(boxes, dtype=torch.float).view(-1,4),
            input_bbox_label=labels,
            inference_metadata=InferenceMetadata(
                coco_image_id=GLOBAL_COUNTER,
                original_image_id=GLOBAL_COUNTER,
                original_category_id=1,
                original_size=[w, h],
                object_id=0,
                frame_index=0,
            )
        )
    )
    GLOBAL_COUNTER += 1
    return GLOBAL_COUNTER - 1

# Image 2 with mixed prompts
img2 = Image.open(BytesIO(requests.get(
    "http://images.cocodataset.org/val2017/000000136466.jpg"
).content))

datapoint2 = create_empty_datapoint()
set_image(datapoint2, img2)
id3 = add_text_prompt(datapoint2, "pot")
id4 = add_visual_prompt(datapoint2, boxes=[[59, 144, 76, 163]], labels=[True])

datapoint2 = transform(datapoint2)
Boxes must be in XYXY format (top-left and bottom-right corners). Each box needs a corresponding label indicating whether it’s a positive (include) or negative (exclude) prompt.

Running Batch Inference

1

Collate datapoints into a batch

batch = collate([datapoint1, datapoint2], dict_key="dummy")["dummy"]
batch = copy_data_to_device(batch, torch.device("cuda"), non_blocking=True)
2

Forward pass through model

# Note: first forward will be slow due to compilation
output = model(batch)
3

Process results

processed_results = postprocessor.process_results(output, batch.find_metadatas)

Retrieving Individual Results

Use the query IDs returned when adding prompts to access results:
from sam3.visualization_utils import plot_results

# Results for "cat" query on image 1
plot_results(img1, processed_results[id1])

# Results for "laptop" query on image 1
plot_results(img1, processed_results[id2])

# Results for "pot" query on image 2
plot_results(img2, processed_results[id3])

# Results for visual prompt on image 2
plot_results(img2, processed_results[id4])

Negative Prompts Example

Refine results by excluding unwanted objects:
# First, try with just text prompt "handle"
id6 = add_text_prompt(datapoint2, "handle")

# The model finds both pot handles AND oven handles
# Exclude the oven handles with a negative box:
id7 = add_visual_prompt(
    datapoint2, 
    boxes=[[40, 183, 318, 204]],  # Oven handle region
    labels=[False],  # Negative prompt
    text_prompt="handle"
)

# Re-run inference
datapoint2 = transform(datapoint2)
batch = collate([datapoint2], dict_key="dummy")["dummy"]
batch = copy_data_to_device(batch, torch.device("cuda"), non_blocking=True)
output = model(batch)
processed_results = postprocessor.process_results(output, batch.find_metadatas)

# Now only pot handles are segmented
plot_results(img2, processed_results[id7])

Best Practices

  • Batch similar-sized images together
  • Use transform to normalize all images to the same size (1008x1008)
  • First forward pass will be slow due to compilation - subsequent batches are much faster
  • Use bfloat16 precision for speed without significant quality loss

Next Steps

Image Inference

Learn single image segmentation basics

Video Inference

Apply batching concepts to video processing

Build docs developers (and LLMs) love