Skip to main content

HardwareConfig

Hardware-specific training configuration. Supports local (RTX 3060) and high-end GPU (A100/H100) environments with auto-detection and preset configurations.

Parameters

device
str
default:"auto"
Device specifier: “auto”, “cuda”, “cuda:0”, or “cpu”. When set to “auto”, automatically selects CUDA if available, otherwise CPU.
num_gpus
int
default:"1"
Number of GPUs for distributed training. Must be >= 1.
gpu_memory_gb
int
default:"12"
GPU memory in GB, used for auto-tuning batch sizes. Must be >= 1.
mixed_precision
Literal['bf16', 'fp16', 'fp32']
default:"bf16"
Mixed precision dtype for automatic mixed precision (AMP).
gradient_checkpointing
bool
default:"True"
Trade compute for memory savings by recomputing activations during backward pass.
is_distributed
bool
default:"False"
Whether running multi-GPU training via torchrun. Automatically set by from_env().
world_size
int
default:"1"
Total number of processes in distributed training. Set by environment variable.
local_rank
int
default:"0"
Local process rank, set by torchrun via LOCAL_RANK environment variable.

Methods

from_env

Create config from environment variables set by torchrun or SLURM.
@classmethod
def from_env(cls) -> HardwareConfig
Automatically detects:
  • LOCAL_RANK and WORLD_SIZE from environment
  • Number of available GPUs
  • GPU memory size
  • Whether running distributed training
Example:
from modern_llm.config import HardwareConfig

# Auto-detect hardware from environment
config = HardwareConfig.from_env()
print(f"Device: {config.device}")
print(f"GPUs: {config.num_gpus}")
print(f"Memory: {config.gpu_memory_gb}GB")

get_torch_device

Return torch.device for model/tensor placement.
def get_torch_device() -> torch.device
Example:
config = HardwareConfig(device="cuda:0")
device = config.get_torch_device()
model = model.to(device)

Example

from modern_llm.config import HardwareConfig

# Local RTX 3060 config
local_config = HardwareConfig(
    device="cuda",
    num_gpus=1,
    gpu_memory_gb=12,
    mixed_precision="bf16",
    gradient_checkpointing=True,
)

# A100 config
a100_config = HardwareConfig(
    device="cuda",
    num_gpus=1,
    gpu_memory_gb=80,
    mixed_precision="bf16",
    gradient_checkpointing=True,
)

# H100 config (no gradient checkpointing needed)
h100_config = HardwareConfig(
    device="cuda",
    num_gpus=1,
    gpu_memory_gb=80,
    mixed_precision="bf16",
    gradient_checkpointing=False,
)

# Auto-detect from environment
auto_config = HardwareConfig.from_env()

Validation rules

  • num_gpus must be >= 1
  • gpu_memory_gb must be >= 1
  • mixed_precision must be “bf16”, “fp16”, or “fp32”

DataConfig

Data loading and corpus configuration for training.

Parameters

datasets
list[str]
default:"['wikitext-2-raw-v1']"
List of dataset names or paths to mix during training.
tokens_target
int
default:"50000000"
Target number of tokens for pretraining (50M default). Must be >= 1000.
max_epochs
int
default:"10"
Maximum number of epochs to iterate over the data. Must be >= 1.
shuffle_buffer
int
default:"10000"
Number of examples to buffer for shuffling during data loading.
num_workers
int
default:"4"
Number of DataLoader worker processes for parallel data loading.
prefetch_factor
int
default:"2"
Number of batches to prefetch per worker for better GPU utilization.

Example

from modern_llm.config import DataConfig

# Small dataset for testing
small_config = DataConfig(
    datasets=["wikitext-2-raw-v1"],
    tokens_target=10_000_000,
    max_epochs=3,
)

# Large dataset for production
large_config = DataConfig(
    datasets=[
        "wikitext-103-raw-v1",
        "openwebtext",
        "wikipedia",
    ],
    tokens_target=1_000_000_000,
    max_epochs=1,
    shuffle_buffer=50_000,
    num_workers=8,
)

Validation rules

  • datasets list cannot be empty
  • tokens_target must be >= 1000
  • max_epochs must be >= 1

Hardware presets

LOCAL_RTX3060

Preset for local RTX 3060 (12GB VRAM).
from modern_llm.config import LOCAL_RTX3060

config = LOCAL_RTX3060
Configuration:
  • device: “cuda”
  • num_gpus: 1
  • gpu_memory_gb: 12
  • mixed_precision: “bf16”
  • gradient_checkpointing: True

GPU_A100

Preset for A100 (80GB VRAM).
from modern_llm.config import GPU_A100

config = GPU_A100
Configuration:
  • device: “cuda”
  • num_gpus: 1
  • gpu_memory_gb: 80
  • mixed_precision: “bf16”
  • gradient_checkpointing: True

GPU_H100

Preset for H100 (80GB VRAM).
from modern_llm.config import GPU_H100

config = GPU_H100
Configuration:
  • device: “cuda”
  • num_gpus: 1
  • gpu_memory_gb: 80
  • mixed_precision: “bf16”
  • gradient_checkpointing: False (H100 has enough memory)

get_hardware_preset

Get a hardware preset by name.
def get_hardware_preset(name: str) -> HardwareConfig
Parameters:
  • name: One of “local”, “rtx3060”, “a100”, “h100”, or “auto”
Example:
from modern_llm.config import get_hardware_preset

# Get preset by name
config = get_hardware_preset("a100")

# Auto-detect from environment
auto_config = get_hardware_preset("auto")

Data presets

get_data_preset

Get a data scale preset by name.
def get_data_preset(name: str) -> DataConfig
Parameters:
  • name: One of “small”, “medium”, “large”, or “xl”

Small preset

small = get_data_preset("small")
Configuration:
  • datasets: [“wikitext-2-raw-v1”]
  • tokens_target: 10M
  • max_epochs: 3

Medium preset

medium = get_data_preset("medium")
Configuration:
  • datasets: [“wikitext-2-raw-v1”, “roneneldan/TinyStories”]
  • tokens_target: 100M
  • max_epochs: 5

Large preset

large = get_data_preset("large")
Configuration:
  • datasets: [“wikitext-2-raw-v1”, “roneneldan/TinyStories”, “openwebtext”]
  • tokens_target: 1B
  • max_epochs: 1

XL preset

xl = get_data_preset("xl")
Configuration:
  • datasets: [“wikitext-2-raw-v1”, “roneneldan/TinyStories”, “openwebtext”, “bookcorpus”]
  • tokens_target: 5B
  • max_epochs: 1

Complete example

from modern_llm.config import (
    HardwareConfig,
    DataConfig,
    get_hardware_preset,
    get_data_preset,
)

# Manual configuration
hardware = HardwareConfig(
    device="cuda",
    num_gpus=1,
    gpu_memory_gb=80,
    mixed_precision="bf16",
    gradient_checkpointing=False,
)

data = DataConfig(
    datasets=["wikitext-103-raw-v1", "openwebtext"],
    tokens_target=500_000_000,
    max_epochs=2,
    num_workers=8,
)

# Using presets
hardware = get_hardware_preset("a100")
data = get_data_preset("large")

# Auto-detect from environment
hardware = get_hardware_preset("auto")

# Use in training
device = hardware.get_torch_device()
model = model.to(device)

if hardware.is_distributed:
    import torch.distributed as dist
    dist.init_process_group(backend="nccl")
    model = torch.nn.parallel.DistributedDataParallel(model)

Build docs developers (and LLMs) love