The CVAT SDK provides PyTorch dataset adapters that allow you to use CVAT tasks directly for training deep learning models.
Installation
Install the SDK with PyTorch support:
pip install "cvat-sdk[pytorch]"
This installs:
torch - PyTorch framework
torchvision - Computer vision utilities
scikit-image - Image processing
numpy - Array operations
TaskVisionDataset
The TaskVisionDataset class wraps a CVAT task as a PyTorch VisionDataset.
Basic Usage
from cvat_sdk import Client
from cvat_sdk.pytorch import TaskVisionDataset
import torchvision.transforms as transforms
# Connect to CVAT
client = Client(url="cvat.example.com")
client.login(("username", "password"))
# Create dataset from task
dataset = TaskVisionDataset(
client=client,
task_id=123
)
print(f"Dataset size: {len(dataset)}")
# Get a sample
image, target = dataset[0]
print(f"Image type: {type(image)}") # PIL.Image.Image
print(f"Target type: {type(target)}") # Target
Constructor Parameters
dataset = TaskVisionDataset(
client=client,
task_id=123,
transforms=None, # Joint image+target transforms
transform=None, # Image-only transforms
target_transform=None, # Target-only transforms
label_name_to_index=None, # Custom label mapping
update_policy=UpdatePolicy.IF_MISSING_OR_STALE # Cache update policy
)
Connected CVAT client instance
Transform function applied to both image and target. Takes (image, target) and returns transformed (image, target).
Transform function applied only to images
Transform function applied only to targets
label_name_to_index
Mapping[str, int]
default:"None"
Custom mapping from label names to indices. If not provided, labels are mapped to indices automatically.
update_policy
UpdatePolicy
default:"IF_MISSING_OR_STALE"
When to update the local cache. Options: IF_MISSING_OR_STALE, NEVER, ALWAYS.
Working with Samples
Each sample is a tuple of (image, target):
# Get first sample
image, target = dataset[0]
# Image is a PIL Image
from PIL import Image
assert isinstance(image, Image.Image)
print(f"Image size: {image.size}") # (width, height)
print(f"Image mode: {image.mode}") # e.g., 'RGB'
# Target contains annotations
print(f"Annotations: {len(target.annotations)}")
print(f"Label mapping: {target.label_id_to_index}")
Target Structure
The Target object contains frame annotations:
from cvat_sdk.pytorch.common import Target
image, target = dataset[0]
# Access annotations
for ann in target.annotations:
# Common attributes
print(f"Label ID: {ann.label_id}")
print(f"Type: {ann.type}") # 'rectangle', 'polygon', 'points', etc.
# For shapes (not tags)
if hasattr(ann, 'points'):
print(f"Points: {ann.points}")
# For shapes with attributes
if hasattr(ann, 'attributes'):
for attr in ann.attributes:
print(f" Attribute {attr.spec_id}: {attr.value}")
# Label ID to index mapping
label_to_idx = target.label_id_to_index
print(f"Label mapping: {label_to_idx}")
Custom Label Mapping
Control how label IDs map to indices:
# Define custom label mapping
label_mapping = {
"background": 0,
"car": 1,
"person": 2,
"bicycle": 3
}
dataset = TaskVisionDataset(
client=client,
task_id=123,
label_name_to_index=label_mapping
)
# Now target.label_id_to_index maps to your indices
image, target = dataset[0]
for ann in target.annotations:
label_idx = target.label_id_to_index[ann.label_id]
print(f"Label index: {label_idx}") # 0, 1, 2, or 3
Apply torchvision transforms:
import torchvision.transforms as T
# Image-only transforms
image_transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
dataset = TaskVisionDataset(
client=client,
task_id=123,
transform=image_transform
)
image, target = dataset[0]
print(f"Image type: {type(image)}") # torch.Tensor
print(f"Image shape: {image.shape}") # (3, 224, 224)
For transforms that need to modify both image and annotations:
from cvat_sdk.pytorch.transforms import (
ResizeTransform,
RandomHorizontalFlip
)
# Use SDK transforms that handle both image and target
joint_transform = T.Compose([
ResizeTransform((512, 512)),
RandomHorizontalFlip(p=0.5)
])
dataset = TaskVisionDataset(
client=client,
task_id=123,
transforms=joint_transform
)
Caching Behavior
The dataset caches task data locally:
from cvat_sdk.datasets.caching import UpdatePolicy
# Always use cache (never update from server)
dataset = TaskVisionDataset(
client=client,
task_id=123,
update_policy=UpdatePolicy.NEVER
)
# Always fetch fresh data from server
dataset = TaskVisionDataset(
client=client,
task_id=123,
update_policy=UpdatePolicy.ALWAYS
)
# Default: update if cache is missing or stale
dataset = TaskVisionDataset(
client=client,
task_id=123,
update_policy=UpdatePolicy.IF_MISSING_OR_STALE
)
Cache location is controlled by the client config:
from cvat_sdk import Config
from pathlib import Path
config = Config(cache_dir=Path("/custom/cache/dir"))
client = Client(url="cvat.example.com", config=config)
dataset = TaskVisionDataset(client, task_id=123)
# Cache stored in /custom/cache/dir/
Training Example: Object Detection
import torch
import torchvision.transforms as T
from torch.utils.data import DataLoader
from cvat_sdk import Client
from cvat_sdk.pytorch import TaskVisionDataset
def collate_fn(batch):
"""Custom collate for object detection."""
images = [item[0] for item in batch]
targets = [item[1] for item in batch]
return images, targets
def train():
# Connect to CVAT
client = Client(url="cvat.example.com")
client.login(("username", "password"))
# Create dataset
transform = T.Compose([
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
dataset = TaskVisionDataset(
client=client,
task_id=123,
transform=transform,
label_name_to_index={"background": 0, "object": 1}
)
# Create DataLoader
dataloader = DataLoader(
dataset,
batch_size=4,
shuffle=True,
num_workers=4,
collate_fn=collate_fn
)
# Training loop
model = torch.hub.load('pytorch/vision', 'fasterrcnn_resnet50_fpn', pretrained=True)
model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
for epoch in range(10):
for images, targets in dataloader:
# Convert targets to model format
model_targets = []
for target in targets:
boxes = []
labels = []
for ann in target.annotations:
if ann.type.value == 'rectangle':
# Convert [x1, y1, x2, y2] to tensor
boxes.append(ann.points)
labels.append(target.label_id_to_index[ann.label_id])
if boxes:
model_targets.append({
'boxes': torch.tensor(boxes, dtype=torch.float32),
'labels': torch.tensor(labels, dtype=torch.int64)
})
# Forward pass
if model_targets: # Only if we have annotations
loss_dict = model(images, model_targets)
losses = sum(loss for loss in loss_dict.values())
# Backward pass
optimizer.zero_grad()
losses.backward()
optimizer.step()
print(f"Epoch {epoch}, Loss: {losses.item():.4f}")
if __name__ == "__main__":
train()
Training Example: Classification
import torch
import torch.nn as nn
import torchvision.transforms as T
from torch.utils.data import DataLoader
from cvat_sdk import Client
from cvat_sdk.pytorch import TaskVisionDataset
def extract_label(target):
"""Extract single label for classification."""
if target.annotations:
# Get first tag or shape label
ann = target.annotations[0]
return target.label_id_to_index[ann.label_id]
return 0 # Default background
def train_classifier():
client = Client(url="cvat.example.com")
client.login(("username", "password"))
# Dataset with transforms
transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
dataset = TaskVisionDataset(
client=client,
task_id=456,
transform=transform,
target_transform=extract_label, # Convert target to single label
label_name_to_index={"cat": 0, "dog": 1, "bird": 2}
)
dataloader = DataLoader(
dataset,
batch_size=32,
shuffle=True,
num_workers=4
)
# Model
model = torch.hub.load('pytorch/vision', 'resnet50', pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 3) # 3 classes
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Training loop
model.train()
for epoch in range(10):
running_loss = 0.0
for images, labels in dataloader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch + 1}, Loss: {running_loss / len(dataloader):.4f}")
if __name__ == "__main__":
train_classifier()
Limitations
The current PyTorch adapter has some limitations:
- Only tasks with image data are supported (not video)
- Track annotations are not accessible (only shapes, tags)
- Deleted frames are automatically omitted
ProjectVisionDataset
For working with entire projects:
from cvat_sdk.pytorch import ProjectVisionDataset
# Load all tasks in a project
dataset = ProjectVisionDataset(
client=client,
project_id=789
)
print(f"Total samples across all project tasks: {len(dataset)}")
Next Steps