Skip to main content
The CVAT SDK provides PyTorch dataset adapters that allow you to use CVAT tasks directly for training deep learning models.

Installation

Install the SDK with PyTorch support:
pip install "cvat-sdk[pytorch]"
This installs:
  • torch - PyTorch framework
  • torchvision - Computer vision utilities
  • scikit-image - Image processing
  • numpy - Array operations

TaskVisionDataset

The TaskVisionDataset class wraps a CVAT task as a PyTorch VisionDataset.

Basic Usage

from cvat_sdk import Client
from cvat_sdk.pytorch import TaskVisionDataset
import torchvision.transforms as transforms

# Connect to CVAT
client = Client(url="cvat.example.com")
client.login(("username", "password"))

# Create dataset from task
dataset = TaskVisionDataset(
    client=client,
    task_id=123
)

print(f"Dataset size: {len(dataset)}")

# Get a sample
image, target = dataset[0]
print(f"Image type: {type(image)}")  # PIL.Image.Image
print(f"Target type: {type(target)}")  # Target

Constructor Parameters

dataset = TaskVisionDataset(
    client=client,
    task_id=123,
    transforms=None,  # Joint image+target transforms
    transform=None,   # Image-only transforms
    target_transform=None,  # Target-only transforms
    label_name_to_index=None,  # Custom label mapping
    update_policy=UpdatePolicy.IF_MISSING_OR_STALE  # Cache update policy
)
client
Client
required
Connected CVAT client instance
task_id
int
required
ID of the task to load
transforms
Callable
default:"None"
Transform function applied to both image and target. Takes (image, target) and returns transformed (image, target).
transform
Callable
default:"None"
Transform function applied only to images
target_transform
Callable
default:"None"
Transform function applied only to targets
label_name_to_index
Mapping[str, int]
default:"None"
Custom mapping from label names to indices. If not provided, labels are mapped to indices automatically.
update_policy
UpdatePolicy
default:"IF_MISSING_OR_STALE"
When to update the local cache. Options: IF_MISSING_OR_STALE, NEVER, ALWAYS.

Working with Samples

Each sample is a tuple of (image, target):
# Get first sample
image, target = dataset[0]

# Image is a PIL Image
from PIL import Image
assert isinstance(image, Image.Image)
print(f"Image size: {image.size}")  # (width, height)
print(f"Image mode: {image.mode}")  # e.g., 'RGB'

# Target contains annotations
print(f"Annotations: {len(target.annotations)}")
print(f"Label mapping: {target.label_id_to_index}")

Target Structure

The Target object contains frame annotations:
from cvat_sdk.pytorch.common import Target

image, target = dataset[0]

# Access annotations
for ann in target.annotations:
    # Common attributes
    print(f"Label ID: {ann.label_id}")
    print(f"Type: {ann.type}")  # 'rectangle', 'polygon', 'points', etc.
    
    # For shapes (not tags)
    if hasattr(ann, 'points'):
        print(f"Points: {ann.points}")
    
    # For shapes with attributes
    if hasattr(ann, 'attributes'):
        for attr in ann.attributes:
            print(f"  Attribute {attr.spec_id}: {attr.value}")

# Label ID to index mapping
label_to_idx = target.label_id_to_index
print(f"Label mapping: {label_to_idx}")

Custom Label Mapping

Control how label IDs map to indices:
# Define custom label mapping
label_mapping = {
    "background": 0,
    "car": 1,
    "person": 2,
    "bicycle": 3
}

dataset = TaskVisionDataset(
    client=client,
    task_id=123,
    label_name_to_index=label_mapping
)

# Now target.label_id_to_index maps to your indices
image, target = dataset[0]
for ann in target.annotations:
    label_idx = target.label_id_to_index[ann.label_id]
    print(f"Label index: {label_idx}")  # 0, 1, 2, or 3

Transforms

Apply torchvision transforms:
import torchvision.transforms as T

# Image-only transforms
image_transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = TaskVisionDataset(
    client=client,
    task_id=123,
    transform=image_transform
)

image, target = dataset[0]
print(f"Image type: {type(image)}")  # torch.Tensor
print(f"Image shape: {image.shape}")  # (3, 224, 224)

Joint Transforms

For transforms that need to modify both image and annotations:
from cvat_sdk.pytorch.transforms import (
    ResizeTransform,
    RandomHorizontalFlip
)

# Use SDK transforms that handle both image and target
joint_transform = T.Compose([
    ResizeTransform((512, 512)),
    RandomHorizontalFlip(p=0.5)
])

dataset = TaskVisionDataset(
    client=client,
    task_id=123,
    transforms=joint_transform
)

Caching Behavior

The dataset caches task data locally:
from cvat_sdk.datasets.caching import UpdatePolicy

# Always use cache (never update from server)
dataset = TaskVisionDataset(
    client=client,
    task_id=123,
    update_policy=UpdatePolicy.NEVER
)

# Always fetch fresh data from server
dataset = TaskVisionDataset(
    client=client,
    task_id=123,
    update_policy=UpdatePolicy.ALWAYS
)

# Default: update if cache is missing or stale
dataset = TaskVisionDataset(
    client=client,
    task_id=123,
    update_policy=UpdatePolicy.IF_MISSING_OR_STALE
)
Cache location is controlled by the client config:
from cvat_sdk import Config
from pathlib import Path

config = Config(cache_dir=Path("/custom/cache/dir"))
client = Client(url="cvat.example.com", config=config)

dataset = TaskVisionDataset(client, task_id=123)
# Cache stored in /custom/cache/dir/

Training Example: Object Detection

import torch
import torchvision.transforms as T
from torch.utils.data import DataLoader
from cvat_sdk import Client
from cvat_sdk.pytorch import TaskVisionDataset

def collate_fn(batch):
    """Custom collate for object detection."""
    images = [item[0] for item in batch]
    targets = [item[1] for item in batch]
    return images, targets

def train():
    # Connect to CVAT
    client = Client(url="cvat.example.com")
    client.login(("username", "password"))
    
    # Create dataset
    transform = T.Compose([
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    dataset = TaskVisionDataset(
        client=client,
        task_id=123,
        transform=transform,
        label_name_to_index={"background": 0, "object": 1}
    )
    
    # Create DataLoader
    dataloader = DataLoader(
        dataset,
        batch_size=4,
        shuffle=True,
        num_workers=4,
        collate_fn=collate_fn
    )
    
    # Training loop
    model = torch.hub.load('pytorch/vision', 'fasterrcnn_resnet50_fpn', pretrained=True)
    model.train()
    
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    
    for epoch in range(10):
        for images, targets in dataloader:
            # Convert targets to model format
            model_targets = []
            for target in targets:
                boxes = []
                labels = []
                for ann in target.annotations:
                    if ann.type.value == 'rectangle':
                        # Convert [x1, y1, x2, y2] to tensor
                        boxes.append(ann.points)
                        labels.append(target.label_id_to_index[ann.label_id])
                
                if boxes:
                    model_targets.append({
                        'boxes': torch.tensor(boxes, dtype=torch.float32),
                        'labels': torch.tensor(labels, dtype=torch.int64)
                    })
            
            # Forward pass
            if model_targets:  # Only if we have annotations
                loss_dict = model(images, model_targets)
                losses = sum(loss for loss in loss_dict.values())
                
                # Backward pass
                optimizer.zero_grad()
                losses.backward()
                optimizer.step()
                
                print(f"Epoch {epoch}, Loss: {losses.item():.4f}")

if __name__ == "__main__":
    train()

Training Example: Classification

import torch
import torch.nn as nn
import torchvision.transforms as T
from torch.utils.data import DataLoader
from cvat_sdk import Client
from cvat_sdk.pytorch import TaskVisionDataset

def extract_label(target):
    """Extract single label for classification."""
    if target.annotations:
        # Get first tag or shape label
        ann = target.annotations[0]
        return target.label_id_to_index[ann.label_id]
    return 0  # Default background

def train_classifier():
    client = Client(url="cvat.example.com")
    client.login(("username", "password"))
    
    # Dataset with transforms
    transform = T.Compose([
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    dataset = TaskVisionDataset(
        client=client,
        task_id=456,
        transform=transform,
        target_transform=extract_label,  # Convert target to single label
        label_name_to_index={"cat": 0, "dog": 1, "bird": 2}
    )
    
    dataloader = DataLoader(
        dataset,
        batch_size=32,
        shuffle=True,
        num_workers=4
    )
    
    # Model
    model = torch.hub.load('pytorch/vision', 'resnet50', pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, 3)  # 3 classes
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    # Training loop
    model.train()
    for epoch in range(10):
        running_loss = 0.0
        for images, labels in dataloader:
            optimizer.zero_grad()
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        print(f"Epoch {epoch + 1}, Loss: {running_loss / len(dataloader):.4f}")

if __name__ == "__main__":
    train_classifier()

Limitations

The current PyTorch adapter has some limitations:
  • Only tasks with image data are supported (not video)
  • Track annotations are not accessible (only shapes, tags)
  • Deleted frames are automatically omitted

ProjectVisionDataset

For working with entire projects:
from cvat_sdk.pytorch import ProjectVisionDataset

# Load all tasks in a project
dataset = ProjectVisionDataset(
    client=client,
    project_id=789
)

print(f"Total samples across all project tasks: {len(dataset)}")

Next Steps

Build docs developers (and LLMs) love