The training API is experimental and subject to change.
Overview
Cog’s training API allows you to define a fine-tuning interface for models, enabling users to bring their own training data to create derivative fine-tuned models.
Configuration
Specify the training entry point in cog.yaml:
build:
python_version: "3.10"
train: "train.py:train"
Function-Based Training
The simplest way to define training is with a function:
from cog import Input, Path
import io
def train(
train_data: Path = Input(description="HTTPS URL of training data"),
learning_rate: float = Input(default=1e-4, ge=0),
seed: int = Input(default=42)
) -> Path:
# Train your model
weights = fine_tune_model(train_data, learning_rate, seed)
# Return weights file
return Path(weights)
Run training:
Class-Based Training
For multiple training runs with shared setup, use a class:
build:
python_version: "3.10"
train: "train.py:Trainer"
from cog import Input, Path
import torch
class Trainer:
def setup(self) -> None:
"""Load base model (called once)"""
self.base_model = torch.load("base_weights.pth")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.base_model.to(self.device)
def train(
self,
train_data: Path = Input(description="Training data archive"),
epochs: int = Input(default=10, ge=1, le=100),
learning_rate: float = Input(default=1e-4, ge=0)
) -> Path:
"""Fine-tune the base model"""
# Extract training data
dataset = load_dataset(train_data)
# Fine-tune
model = self.base_model.clone()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for epoch in range(epochs):
for batch in dataset:
loss = train_step(model, batch, optimizer)
print(f"Epoch {epoch}, Loss: {loss}")
# Save weights
weights_path = Path("/tmp/fine_tuned.pth")
torch.save(model.state_dict(), weights_path)
return weights_path
Use Input() to define training parameters with the same API as predictions:
from cog import Input, Path
def train(
# Required parameter (no default)
train_data: Path = Input(description="Training dataset"),
# Optional with default
learning_rate: float = Input(
description="Learning rate for training",
default=1e-4,
ge=0.0,
le=1.0
),
# Integer with constraints
epochs: int = Input(
description="Number of training epochs",
default=10,
ge=1,
le=100
),
# String with choices
optimizer: str = Input(
description="Optimizer to use",
default="adam",
choices=["adam", "sgd", "adamw"]
),
# Optional parameter
seed: int = Input(
description="Random seed",
default=None
)
) -> Path:
...
Human-readable description of the parameter
Default value. If not provided, the parameter is required. If explicitly None, the parameter is optional.
Minimum value (greater than or equal) for numeric types
Maximum value (less than or equal) for numeric types
Regular expression pattern for string validation
Training Output
Simple Return
Return a weights file directly:
from cog import Path
def train(train_data: Path) -> Path:
weights = fine_tune(train_data)
return Path(weights)
Structured Output
Return multiple outputs using a TrainingOutput object:
from cog import BaseModel, Input, Path
class TrainingOutput(BaseModel):
weights: Path
metrics: dict
def train(
train_data: Path = Input(description="Training data"),
epochs: int = Input(default=10, ge=1)
) -> TrainingOutput:
weights, metrics = fine_tune(train_data, epochs)
return TrainingOutput(
weights=Path(weights),
metrics={"final_loss": metrics["loss"], "accuracy": metrics["acc"]}
)
The output class must be named TrainingOutput, not any other name.
Training functions support the same input types as predictions:
str - String
int - Integer
float - Floating point number
bool - Boolean
Path - File path (local or URL)
Secret - Sensitive string
list[T] - List of supported types
See Input/Output Types for details.
Supported Output Types
Training functions can return:
Path - Weights file (most common)
str, int, float, bool - Primitive types
- Custom
TrainingOutput object with multiple fields
Testing Fine-Tuned Models
Test that your model works with fine-tuned weights using the COG_WEIGHTS environment variable:
cog predict -e COG_WEIGHTS=https://example.com/weights.tar -i prompt="a photo of TOK"
This simulates loading custom weights without running the full training process.
Implementing Weights Loading
Modify your predictor’s setup() to accept weights:
from cog import BasePredictor, Path
from typing import Optional, Union
import torch
class Predictor(BasePredictor):
def setup(self, weights: Optional[Union[Path, str]] = None) -> None:
if weights:
# Load fine-tuned weights
self.model = torch.load(weights)
else:
# Load base weights
self.model = torch.load("base_weights.pth")
self.model.eval()
def predict(self, prompt: str) -> str:
return self.model.generate(prompt)
Example: Fine-Tuning a Text Model
from cog import BaseModel, Input, Path
import torch
import json
class TrainingOutput(BaseModel):
weights: Path
final_loss: float
steps: int
class Trainer:
def setup(self) -> None:
"""Load the base model"""
self.base_model = torch.load("base_model.pth")
self.tokenizer = load_tokenizer()
def train(
self,
train_data: Path = Input(description="JSONL file with training examples"),
learning_rate: float = Input(default=5e-5, ge=0, le=1e-3),
num_epochs: int = Input(default=3, ge=1, le=10),
batch_size: int = Input(default=8, ge=1, le=32),
max_steps: int = Input(default=None),
) -> TrainingOutput:
# Load training data
with open(train_data) as f:
examples = [json.loads(line) for line in f]
# Prepare model
model = self.base_model.clone()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
# Training loop
step = 0
final_loss = 0.0
for epoch in range(num_epochs):
for i in range(0, len(examples), batch_size):
batch = examples[i:i+batch_size]
loss = train_batch(model, batch, optimizer, self.tokenizer)
final_loss = loss
step += 1
print(f"Step {step}, Loss: {loss:.4f}")
if max_steps and step >= max_steps:
break
if max_steps and step >= max_steps:
break
# Save fine-tuned weights
weights_path = Path("/tmp/fine_tuned.pth")
torch.save(model.state_dict(), weights_path)
return TrainingOutput(
weights=weights_path,
final_loss=final_loss,
steps=step
)
Run the training:
cog train \
-i [email protected] \
-i learning_rate=0.00005 \
-i num_epochs=5 \
-i batch_size=16
Real-World Examples