model.train(). They map to fields on TrainConfig in rfdetr.config.
Basic example
Core training
Path to your dataset directory. RF-DETR auto-detects whether it’s in COCO or YOLO format. See Dataset Formats.
Directory where training artifacts (checkpoints, logs) are saved.
Number of full passes over the training dataset.
Path to a saved checkpoint to continue training. Restores model weights, optimizer state, and scheduler state.
Global random seed for reproducibility.
None means no fixed seed is set.Batch and memory
Number of samples processed per iteration per GPU. Higher values require more GPU memory. Pass
"auto" to let RF-DETR probe for the largest safe batch size.Accumulate gradients over this many mini-batches before an optimizer step. Use with
batch_size to achieve a larger effective batch size without increasing memory.gradient_checkpointing is a model constructor parameter, not a training parameter. Pass it when instantiating the model: RFDETRMedium(gradient_checkpointing=True). See Model Variants for all constructor options.Understanding batch size
The effective batch size is:| GPU | VRAM | batch_size | grad_accum_steps |
|---|---|---|---|
| A100 | 40–80 GB | 16 | 1 |
| RTX 4090 | 24 GB | 8 | 2 |
| RTX 3090 | 24 GB | 8 | 2 |
| T4 | 16 GB | 4 | 4 |
| RTX 3070 | 8 GB | 2 | 8 |
Learning rate
Learning rate for most parts of the model (excluding the backbone encoder).
Learning rate specifically for the backbone encoder. Set lower than
lr to fine-tune the encoder more conservatively.Learning rate scheduler type. Options:
"step" (step decay at lr_drop) or "cosine" (cosine annealing).Floor for the cosine scheduler, expressed as a fraction of the initial LR. Ignored when using
"step".Number of epochs for linear learning rate warmup at the start of training.
Resolution
Input image resolution. Higher values can improve accuracy but require more memory. Must be divisible by 14. Defaults to the model-specific value.
| Resolution | Memory usage | Use case |
|---|---|---|
| 560 | Low | Small objects, limited GPU memory |
| 672 | Medium | Balanced (default for many models) |
| 784 | High | High accuracy requirements |
| 896 | Very high | Maximum quality (requires large GPU) |
Regularization
L2 regularization coefficient. Helps prevent overfitting by penalizing large weights.
Stochastic depth drop-path rate applied to the backbone. Higher values add more regularization.
EMA (exponential moving average)
Enables Exponential Moving Average of weights. Produces a smoothed checkpoint that often improves final performance and generalization.
EMA maintains a moving average of the model weights throughout training. This smoothed version often generalizes better than the raw weights and is the default for
checkpoint_best_total.pth.Checkpoints
Frequency (in epochs) at which periodic model checkpoints are saved. More frequent saves provide better coverage but consume more storage.
| File | Description |
|---|---|
checkpoint.pth | Most recent checkpoint (for resuming) |
checkpoint_<N>.pth | Periodic checkpoint at epoch N |
checkpoint_best_ema.pth | Best validation performance (EMA weights) |
checkpoint_best_regular.pth | Best validation performance (raw weights) |
checkpoint_best_total.pth | Final best model for inference |
Early stopping
Enable early stopping based on validation mAP.
Number of epochs without improvement before stopping training.
Minimum change in mAP to qualify as an improvement.
Whether to track improvements using EMA model metrics.
0.005 for 15 consecutive epochs.
Logging and evaluation
Enable TensorBoard logging. Requires
pip install "rfdetr[loggers]". If the package is not installed, training continues with a UserWarning and TensorBoard output is silently suppressed.Enable Weights & Biases logging. Requires
pip install "rfdetr[loggers]".Enable MLflow logging. Requires
pip install "rfdetr[loggers]".Project name for W&B or MLflow logging.
Run name for W&B or MLflow logging. If not specified, an auto-generated name is used.
Maximum number of detections per image considered during COCO evaluation. Lower values speed up evaluation.
Run COCO evaluation every N epochs. Set to a higher value to reduce evaluation overhead during long training runs.
Log per-class AP metrics to the console and loggers. Disable to reduce log verbosity when there are many classes.
Enable a progress bar during training. Accepts
"tqdm", "rich", or null to disable. Also accepts legacy boolean values (true maps to "tqdm").Data loading
Number of DataLoader worker processes for parallel data loading.
Pin host memory in the DataLoader for faster GPU transfers.
None defers to PyTorch Lightning’s default.Keep DataLoader worker processes alive between epochs.
None defers to PyTorch Lightning’s default.Number of batches to prefetch per DataLoader worker.
None uses PyTorch’s built-in default.Hardware and runtime
PyTorch Lightning accelerator selection.
"auto" picks GPU if available, then MPS, then CPU.Run evaluation passes in FP16 precision. Reduces memory usage but may lower numerical precision.
Compute and log the detection loss on the validation set each epoch.
Compute and log the detection loss during the final test run.
Auto-batch configuration
These parameters control the automatic batch size detection whenbatch_size="auto":
Auto-batch parameters
Auto-batch parameters
| Parameter | Default | Description |
|---|---|---|
auto_batch_target_effective | 16 | Per-device effective batch size target before scaling by devices × num_nodes. |
auto_batch_max_targets_per_image | 100 | Worst-case number of annotations per image used when probing for a safe batch size. |
auto_batch_ema_headroom | 0.7 | Scale the safe batch size by this factor when use_ema=True, since EMA uses extra memory. Must be in (0, 1]. |
Deprecated fields
The following fields exist on
TrainConfig but are deprecated and will be removed in v1.9. Set them on ModelConfig instead.group_detr— query group count is an architecture decision; set onModelConfigia_bce_loss— loss type is tied to the architecture family; set onModelConfigsegmentation_head— architecture flag; set onModelConfignum_select— postprocessor count is an architecture decision; set onModelConfig
Complete reference table
| Parameter | Type | Default | Description |
|---|---|---|---|
dataset_dir | str | Required | Path to COCO or YOLO formatted dataset. |
output_dir | str | "output" | Directory for checkpoints, logs, and artifacts. |
epochs | int | 100 | Number of full passes over the dataset. |
batch_size | int | "auto" | 4 | Samples per iteration. Balance with grad_accum_steps. |
grad_accum_steps | int | 4 | Gradient accumulation steps for effective larger batch sizes. |
lr | float | 1e-4 | Learning rate for the model (excluding encoder). |
lr_encoder | float | 1.5e-4 | Learning rate for the backbone encoder. |
resolution | int | Model-specific | Input image size (must be divisible by 14). |
weight_decay | float | 1e-4 | L2 regularization coefficient. |
use_ema | bool | True | Enable Exponential Moving Average of weights. |
gradient_checkpointing | bool | False | Model constructor param — pass to RFDETRMedium(gradient_checkpointing=True). |
checkpoint_interval | int | 10 | Save checkpoint every N epochs. |
resume | str | None | Path to checkpoint for resuming training. |
tensorboard | bool | True | Enable TensorBoard logging. |
wandb | bool | False | Enable Weights & Biases logging. |
mlflow | bool | False | Enable MLflow logging. |
project | str | None | W&B or MLflow project name. |
run | str | None | W&B or MLflow run name. |
early_stopping | bool | False | Enable early stopping. |
early_stopping_patience | int | 10 | Epochs without improvement before stopping. |
early_stopping_min_delta | float | 0.001 | Minimum mAP change to qualify as improvement. |
early_stopping_use_ema | bool | False | Use EMA model for early stopping metrics. |
eval_max_dets | int | 500 | Maximum detections per image for COCO evaluation. |
eval_interval | int | 1 | Run COCO evaluation every N epochs. |
log_per_class_metrics | bool | True | Log per-class AP metrics. |
progress_bar | str | None | None | Progress bar style: "tqdm", "rich", or None. |
num_workers | int | 2 | DataLoader worker processes. |
accelerator | str | "auto" | PyTorch Lightning accelerator. |
seed | int | None | Random seed for reproducibility. |
lr_scheduler | str | "step" | LR scheduler type: "step" or "cosine". |
lr_min_factor | float | 0.0 | Cosine scheduler LR floor as a fraction of initial LR. |
warmup_epochs | float | 0.0 | Linear warmup epochs at start of training. |
drop_path | float | 0.0 | Stochastic depth drop-path rate for the backbone. |
compute_val_loss | bool | True | Compute and log loss during validation. |
compute_test_loss | bool | True | Compute and log loss during the test run. |
fp16_eval | bool | False | Run evaluation in FP16 precision. |
pin_memory | bool | None | Pin DataLoader memory. |
persistent_workers | bool | None | Keep DataLoader workers alive between epochs. |
prefetch_factor | int | None | Batches prefetched per worker. |
aug_config | dict | None | Custom augmentation config. See Custom Augmentations. |