Skip to main content
The training API is experimental and subject to change.

Overview

Cog’s training API allows you to define a fine-tuning interface for models, enabling users to bring their own training data to create derivative fine-tuned models.

Configuration

Specify the training entry point in cog.yaml:
cog.yaml
build:
  python_version: "3.10"
train: "train.py:train"

Function-Based Training

The simplest way to define training is with a function:
train.py
from cog import Input, Path
import io

def train(
    train_data: Path = Input(description="HTTPS URL of training data"),
    learning_rate: float = Input(default=1e-4, ge=0),
    seed: int = Input(default=42)
) -> Path:
    # Train your model
    weights = fine_tune_model(train_data, learning_rate, seed)
    
    # Return weights file
    return Path(weights)
Run training:
cog train -i [email protected] -i learning_rate=0.0001

Class-Based Training

For multiple training runs with shared setup, use a class:
cog.yaml
build:
  python_version: "3.10"
train: "train.py:Trainer"
train.py
from cog import Input, Path
import torch

class Trainer:
    def setup(self) -> None:
        """Load base model (called once)"""
        self.base_model = torch.load("base_weights.pth")
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.base_model.to(self.device)

    def train(
        self,
        train_data: Path = Input(description="Training data archive"),
        epochs: int = Input(default=10, ge=1, le=100),
        learning_rate: float = Input(default=1e-4, ge=0)
    ) -> Path:
        """Fine-tune the base model"""
        # Extract training data
        dataset = load_dataset(train_data)
        
        # Fine-tune
        model = self.base_model.clone()
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        
        for epoch in range(epochs):
            for batch in dataset:
                loss = train_step(model, batch, optimizer)
                print(f"Epoch {epoch}, Loss: {loss}")
        
        # Save weights
        weights_path = Path("/tmp/fine_tuned.pth")
        torch.save(model.state_dict(), weights_path)
        return weights_path

Input Parameters

Use Input() to define training parameters with the same API as predictions:
from cog import Input, Path

def train(
    # Required parameter (no default)
    train_data: Path = Input(description="Training dataset"),
    
    # Optional with default
    learning_rate: float = Input(
        description="Learning rate for training",
        default=1e-4,
        ge=0.0,
        le=1.0
    ),
    
    # Integer with constraints
    epochs: int = Input(
        description="Number of training epochs",
        default=10,
        ge=1,
        le=100
    ),
    
    # String with choices
    optimizer: str = Input(
        description="Optimizer to use",
        default="adam",
        choices=["adam", "sgd", "adamw"]
    ),
    
    # Optional parameter
    seed: int = Input(
        description="Random seed",
        default=None
    )
) -> Path:
    ...

Input() Parameters

description
str
Human-readable description of the parameter
default
Any
Default value. If not provided, the parameter is required. If explicitly None, the parameter is optional.
ge
int | float
Minimum value (greater than or equal) for numeric types
le
int | float
Maximum value (less than or equal) for numeric types
min_length
int
Minimum string length
max_length
int
Maximum string length
regex
str
Regular expression pattern for string validation
choices
list[str | int]
List of allowed values

Training Output

Simple Return

Return a weights file directly:
from cog import Path

def train(train_data: Path) -> Path:
    weights = fine_tune(train_data)
    return Path(weights)

Structured Output

Return multiple outputs using a TrainingOutput object:
from cog import BaseModel, Input, Path

class TrainingOutput(BaseModel):
    weights: Path
    metrics: dict
    
def train(
    train_data: Path = Input(description="Training data"),
    epochs: int = Input(default=10, ge=1)
) -> TrainingOutput:
    weights, metrics = fine_tune(train_data, epochs)
    
    return TrainingOutput(
        weights=Path(weights),
        metrics={"final_loss": metrics["loss"], "accuracy": metrics["acc"]}
    )
The output class must be named TrainingOutput, not any other name.

Supported Input Types

Training functions support the same input types as predictions:
  • str - String
  • int - Integer
  • float - Floating point number
  • bool - Boolean
  • Path - File path (local or URL)
  • Secret - Sensitive string
  • list[T] - List of supported types
See Input/Output Types for details.

Supported Output Types

Training functions can return:
  • Path - Weights file (most common)
  • str, int, float, bool - Primitive types
  • Custom TrainingOutput object with multiple fields

Testing Fine-Tuned Models

Test that your model works with fine-tuned weights using the COG_WEIGHTS environment variable:
cog predict -e COG_WEIGHTS=https://example.com/weights.tar -i prompt="a photo of TOK"
This simulates loading custom weights without running the full training process.

Implementing Weights Loading

Modify your predictor’s setup() to accept weights:
from cog import BasePredictor, Path
from typing import Optional, Union
import torch

class Predictor(BasePredictor):
    def setup(self, weights: Optional[Union[Path, str]] = None) -> None:
        if weights:
            # Load fine-tuned weights
            self.model = torch.load(weights)
        else:
            # Load base weights
            self.model = torch.load("base_weights.pth")
        
        self.model.eval()

    def predict(self, prompt: str) -> str:
        return self.model.generate(prompt)

Example: Fine-Tuning a Text Model

train.py
from cog import BaseModel, Input, Path
import torch
import json

class TrainingOutput(BaseModel):
    weights: Path
    final_loss: float
    steps: int

class Trainer:
    def setup(self) -> None:
        """Load the base model"""
        self.base_model = torch.load("base_model.pth")
        self.tokenizer = load_tokenizer()
    
    def train(
        self,
        train_data: Path = Input(description="JSONL file with training examples"),
        learning_rate: float = Input(default=5e-5, ge=0, le=1e-3),
        num_epochs: int = Input(default=3, ge=1, le=10),
        batch_size: int = Input(default=8, ge=1, le=32),
        max_steps: int = Input(default=None),
    ) -> TrainingOutput:
        # Load training data
        with open(train_data) as f:
            examples = [json.loads(line) for line in f]
        
        # Prepare model
        model = self.base_model.clone()
        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
        
        # Training loop
        step = 0
        final_loss = 0.0
        
        for epoch in range(num_epochs):
            for i in range(0, len(examples), batch_size):
                batch = examples[i:i+batch_size]
                loss = train_batch(model, batch, optimizer, self.tokenizer)
                final_loss = loss
                step += 1
                
                print(f"Step {step}, Loss: {loss:.4f}")
                
                if max_steps and step >= max_steps:
                    break
            
            if max_steps and step >= max_steps:
                break
        
        # Save fine-tuned weights
        weights_path = Path("/tmp/fine_tuned.pth")
        torch.save(model.state_dict(), weights_path)
        
        return TrainingOutput(
            weights=weights_path,
            final_loss=final_loss,
            steps=step
        )
Run the training:
cog train \
  -i [email protected] \
  -i learning_rate=0.00005 \
  -i num_epochs=5 \
  -i batch_size=16

Real-World Examples

Build docs developers (and LLMs) love