Process multiple images efficiently with batch processing in SAM 3
Batch processing allows you to run inference on multiple images simultaneously, improving throughput and efficiency. This guide shows you how to create batched datapoints with multiple queries per image.
import torch# Turn on tfloat32 for Ampere GPUstorch.backends.cuda.matmul.allow_tf32 = Truetorch.backends.cudnn.allow_tf32 = True# Use bfloat16 for the entire notebooktorch.autocast("cuda", dtype=torch.bfloat16).__enter__()# Inference mode for the whole notebooktorch.inference_mode().__enter__()
from sam3.train.data.sam3_image_dataset import ( InferenceMetadata, FindQueryLoaded, Image as SAMImage, Datapoint)from typing import ListGLOBAL_COUNTER = 1def create_empty_datapoint(): """A datapoint is a single image on which we can apply several queries.""" return Datapoint(find_queries=[], images=[])def set_image(datapoint, pil_image): """Add the image to be processed to the datapoint.""" w, h = pil_image.size datapoint.images = [SAMImage(data=pil_image, objects=[], size=[h,w])]def add_text_prompt(datapoint, text_query): """Add a text query to the datapoint.""" global GLOBAL_COUNTER assert len(datapoint.images) == 1, "please set the image first" w, h = datapoint.images[0].size datapoint.find_queries.append( FindQueryLoaded( query_text=text_query, image_id=0, object_ids_output=[], is_exhaustive=True, query_processing_order=0, inference_metadata=InferenceMetadata( coco_image_id=GLOBAL_COUNTER, original_image_id=GLOBAL_COUNTER, original_category_id=1, original_size=[w, h], object_id=0, frame_index=0, ) ) ) GLOBAL_COUNTER += 1 return GLOBAL_COUNTER - 1
def add_visual_prompt(datapoint, boxes: List[List[float]], labels: List[bool], text_prompt="visual"): """Add a visual query to the datapoint. Args: boxes: Bounding boxes in XYXY format (top left and bottom right corners) labels: True for positive boxes, False for negative boxes text_prompt: Optional text hint """ global GLOBAL_COUNTER assert len(datapoint.images) == 1, "please set the image first" assert len(boxes) > 0, "please provide at least one box" assert len(boxes) == len(labels), "Expecting one label per box" labels = torch.tensor(labels, dtype=torch.bool).view(-1) w, h = datapoint.images[0].size datapoint.find_queries.append( FindQueryLoaded( query_text=text_prompt, image_id=0, object_ids_output=[], is_exhaustive=True, query_processing_order=0, input_bbox=torch.tensor(boxes, dtype=torch.float).view(-1,4), input_bbox_label=labels, inference_metadata=InferenceMetadata( coco_image_id=GLOBAL_COUNTER, original_image_id=GLOBAL_COUNTER, original_category_id=1, original_size=[w, h], object_id=0, frame_index=0, ) ) ) GLOBAL_COUNTER += 1 return GLOBAL_COUNTER - 1# Image 2 with mixed promptsimg2 = Image.open(BytesIO(requests.get( "http://images.cocodataset.org/val2017/000000136466.jpg").content))datapoint2 = create_empty_datapoint()set_image(datapoint2, img2)id3 = add_text_prompt(datapoint2, "pot")id4 = add_visual_prompt(datapoint2, boxes=[[59, 144, 76, 163]], labels=[True])datapoint2 = transform(datapoint2)
Boxes must be in XYXY format (top-left and bottom-right corners). Each box needs a corresponding label indicating whether it’s a positive (include) or negative (exclude) prompt.
Use the query IDs returned when adding prompts to access results:
from sam3.visualization_utils import plot_results# Results for "cat" query on image 1plot_results(img1, processed_results[id1])# Results for "laptop" query on image 1plot_results(img1, processed_results[id2])# Results for "pot" query on image 2plot_results(img2, processed_results[id3])# Results for visual prompt on image 2plot_results(img2, processed_results[id4])
# First, try with just text prompt "handle"id6 = add_text_prompt(datapoint2, "handle")# The model finds both pot handles AND oven handles# Exclude the oven handles with a negative box:id7 = add_visual_prompt( datapoint2, boxes=[[40, 183, 318, 204]], # Oven handle region labels=[False], # Negative prompt text_prompt="handle")# Re-run inferencedatapoint2 = transform(datapoint2)batch = collate([datapoint2], dict_key="dummy")["dummy"]batch = copy_data_to_device(batch, torch.device("cuda"), non_blocking=True)output = model(batch)processed_results = postprocessor.process_results(output, batch.find_metadatas)# Now only pot handles are segmentedplot_results(img2, processed_results[id7])