Skip to main content

Overview

The BasePredictor class is the foundation for building Cog models. Subclass it to define how your model loads and runs predictions.
from cog import BasePredictor, Input, Path
import torch

class Predictor(BasePredictor):
    def setup(self):
        """Load the model into memory to make running multiple predictions efficient"""
        self.model = torch.load("weights.pth")

    def predict(
        self,
        image: Path = Input(description="Image to enlarge"),
        scale: float = Input(description="Factor to scale image by", default=1.5)
    ) -> Path:
        """Run a single prediction on the model"""
        output = self.model(image)
        return output

Class Reference

BasePredictor

class BasePredictor:
    def setup(self, weights: Optional[Union[Path, str]] = None) -> None:
        ...

    def predict(self, **kwargs: Any) -> Any:
        ...

    def record_metric(self, key: str, value: Any, mode: str = "replace") -> None:
        ...

    @property
    def scope(self) -> Any:
        ...
Base class for Cog predictors. Override the setup and predict methods to define your model’s behavior.

Methods

setup()

def setup(self, weights: Optional[Union[Path, str]] = None) -> None:
Prepare the model for predictions. This method is called once when the predictor is initialized.
weights
Optional[Union[Path, str]]
Optional path to model weights. Can be a local path or URL. Typically provided via the COG_WEIGHTS environment variable during training.
This method is optional. Use it for expensive one-off operations like:
  • Loading trained models
  • Downloading weights (e.g., using pget)
  • Instantiating data transformations
  • Allocating GPU memory
Example:
import torch
from cog import BasePredictor

class Predictor(BasePredictor):
    def setup(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = torch.load("weights.pth", map_location=self.device)
        self.model.eval()

Weights Storage Strategy

You have two options for managing model weights:
  1. Download in setup(): Smaller images, faster builds, but slower startup
  2. Include in image: Larger images, slower builds, but faster startup and better reproducibility
When including weights in the image, use cog build --separate-weights to store them in a separate layer for better caching.

predict()

def predict(self, **kwargs: Any) -> Any:
Run a single prediction. This method is required.
**kwargs
Any
Prediction inputs. Each parameter must have a type annotation and can optionally use Input() for metadata.
return
Any
The prediction output. Can be a primitive type (str, int, float, bool), a Path object, a list, or a custom Output object.
Example with type annotations:
from cog import BasePredictor, Input, Path

class Predictor(BasePredictor):
    def predict(
        self,
        prompt: str = Input(description="Input prompt"),
        temperature: float = Input(default=0.7, ge=0.0, le=2.0),
        max_tokens: int = Input(default=100, ge=1)
    ) -> str:
        return self.model.generate(prompt, temperature, max_tokens)
Example returning multiple outputs:
from cog import BaseModel, BasePredictor, Path

class Output(BaseModel):
    image: Path
    seed: int
    prompt: str

class Predictor(BasePredictor):
    def predict(self, prompt: str) -> Output:
        seed = generate_seed()
        image = self.model.generate(prompt, seed)
        return Output(image=image, seed=seed, prompt=prompt)

record_metric()

def record_metric(self, key: str, value: Any, mode: str = "replace") -> None:
Record a prediction metric. Metrics are included in the prediction response.
key
str
required
Metric name. Use dot-separated keys (e.g., "timing.inference") to create nested objects.
value
bool | int | float | str | list | dict
required
Metric value. Setting to None deletes the metric.
mode
str
default:"replace"
Accumulation mode:
  • "replace" - Overwrite any previous value (default)
  • "incr" - Add to existing numeric value
  • "append" - Append to an array
Example:
class Predictor(BasePredictor):
    def predict(self, prompt: str) -> str:
        self.record_metric("temperature", 0.7)
        self.record_metric("token_count", 1, mode="incr")
        self.record_metric("timing.preprocess", 0.12)
        
        result = self.model.generate(prompt)
        return result
See the Metrics API documentation for more details.

scope

@property
def scope(self) -> Any:
The current prediction scope. Provides access to the full scope API for advanced metric operations. Example:
class Predictor(BasePredictor):
    def predict(self, prompt: str) -> str:
        # Direct dict-style access
        self.scope.metrics["token_count"] = 42
        
        # Delete a metric
        del self.scope.metrics["token_count"]
        
        # Record with mode
        self.scope.record_metric("steps", "preprocessing", mode="append")
        
        return self.model.generate(prompt)
Outside an active prediction, self.scope returns a no-op object that silently ignores all operations.

Async Predictors

You can define async predictors using async def for both setup() and predict():
from cog import BasePredictor

class Predictor(BasePredictor):
    async def setup(self) -> None:
        self.model = await load_model_async()

    async def predict(self, prompt: str) -> str:
        result = await self.model.generate_async(prompt)
        return result
Models with async predict() can run predictions concurrently, up to the limit specified by concurrency.max in cog.yaml.

Return Types

Simple Types

def predict(self, x: int) -> str:
    return f"Result: {x}"

def predict(self, x: int) -> int:
    return x * 2

def predict(self, x: int) -> float:
    return x * 1.5

def predict(self, x: int) -> bool:
    return x > 0

Path Objects

import tempfile
from cog import BasePredictor, Path

class Predictor(BasePredictor):
    def predict(self, image: Path) -> Path:
        output = process_image(image)
        
        # Create temporary file (automatically deleted after return)
        output_path = Path(tempfile.mkdtemp()) / "output.png"
        output.save(output_path)
        return output_path

Lists

from cog import BasePredictor, Path

class Predictor(BasePredictor):
    def predict(self) -> list[Path]:
        predictions = ["foo", "bar", "baz"]
        output = []
        for i, prediction in enumerate(predictions):
            out_path = Path(f"/tmp/out-{i}.txt")
            out_path.write_text(prediction)
            output.append(out_path)
        return output

Custom Output Objects

from cog import BaseModel, BasePredictor, Path
from typing import Optional

class Output(BaseModel):
    image: Path
    seed: int
    nsfw_detected: Optional[bool]

class Predictor(BasePredictor):
    def predict(self, prompt: str) -> Output:
        seed = generate_seed()
        image = self.model.generate(prompt, seed)
        return Output(image=image, seed=seed, nsfw_detected=None)
The output class must be named Output, not any other name.

Streaming Output

Stream output as it’s generated using iterators:

Sync Streaming

from cog import BasePredictor, ConcatenateIterator

class Predictor(BasePredictor):
    def predict(self, prompt: str) -> ConcatenateIterator[str]:
        for token in self.model.generate_tokens(prompt):
            yield token

Async Streaming

from cog import AsyncConcatenateIterator, BasePredictor

class Predictor(BasePredictor):
    async def predict(self, prompt: str) -> AsyncConcatenateIterator[str]:
        async for token in self.model.generate_tokens_async(prompt):
            yield token
ConcatenateIterator hints that output should be concatenated into a single string for display (useful on Replicate). For non-text streaming, use Iterator[T] or AsyncIterator[T].

Cancellation

When a prediction is canceled, Cog raises an exception to interrupt the predict() method:
Predictor typeException raised
Sync (def predict)CancelationException
Async (async def predict)asyncio.CancelledError
from cog import BasePredictor, CancelationException, Path

class Predictor(BasePredictor):
    def predict(self, image: Path) -> Path:
        try:
            return self.process(image)
        except CancelationException:
            self.cleanup()
            raise  # Always re-raise!
You must re-raise CancelationException after cleanup. Swallowing it will prevent the runtime from marking the prediction as canceled and may result in container termination.
CancelationException is a BaseException subclass (not Exception), so bare except Exception blocks won’t catch it. This matches the behavior of KeyboardInterrupt and asyncio.CancelledError.

Build docs developers (and LLMs) love