Overview
The BasePredictor class is the foundation for building Cog models. Subclass it to define how your model loads and runs predictions.
from cog import BasePredictor, Input, Path
import torch
class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
self.model = torch.load("weights.pth")
def predict(
self,
image: Path = Input(description="Image to enlarge"),
scale: float = Input(description="Factor to scale image by", default=1.5)
) -> Path:
"""Run a single prediction on the model"""
output = self.model(image)
return output
Class Reference
BasePredictor
class BasePredictor:
def setup(self, weights: Optional[Union[Path, str]] = None) -> None:
...
def predict(self, **kwargs: Any) -> Any:
...
def record_metric(self, key: str, value: Any, mode: str = "replace") -> None:
...
@property
def scope(self) -> Any:
...
Base class for Cog predictors. Override the setup and predict methods to define your model’s behavior.
Methods
setup()
def setup(self, weights: Optional[Union[Path, str]] = None) -> None:
Prepare the model for predictions. This method is called once when the predictor is initialized.
weights
Optional[Union[Path, str]]
Optional path to model weights. Can be a local path or URL. Typically provided via the COG_WEIGHTS environment variable during training.
This method is optional. Use it for expensive one-off operations like:
- Loading trained models
- Downloading weights (e.g., using
pget)
- Instantiating data transformations
- Allocating GPU memory
Example:
import torch
from cog import BasePredictor
class Predictor(BasePredictor):
def setup(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = torch.load("weights.pth", map_location=self.device)
self.model.eval()
Weights Storage Strategy
You have two options for managing model weights:
- Download in
setup(): Smaller images, faster builds, but slower startup
- Include in image: Larger images, slower builds, but faster startup and better reproducibility
When including weights in the image, use cog build --separate-weights to store them in a separate layer for better caching.
predict()
def predict(self, **kwargs: Any) -> Any:
Run a single prediction. This method is required.
Prediction inputs. Each parameter must have a type annotation and can optionally use Input() for metadata.
The prediction output. Can be a primitive type (str, int, float, bool), a Path object, a list, or a custom Output object.
Example with type annotations:
from cog import BasePredictor, Input, Path
class Predictor(BasePredictor):
def predict(
self,
prompt: str = Input(description="Input prompt"),
temperature: float = Input(default=0.7, ge=0.0, le=2.0),
max_tokens: int = Input(default=100, ge=1)
) -> str:
return self.model.generate(prompt, temperature, max_tokens)
Example returning multiple outputs:
from cog import BaseModel, BasePredictor, Path
class Output(BaseModel):
image: Path
seed: int
prompt: str
class Predictor(BasePredictor):
def predict(self, prompt: str) -> Output:
seed = generate_seed()
image = self.model.generate(prompt, seed)
return Output(image=image, seed=seed, prompt=prompt)
record_metric()
def record_metric(self, key: str, value: Any, mode: str = "replace") -> None:
Record a prediction metric. Metrics are included in the prediction response.
Metric name. Use dot-separated keys (e.g., "timing.inference") to create nested objects.
value
bool | int | float | str | list | dict
required
Metric value. Setting to None deletes the metric.
Accumulation mode:
"replace" - Overwrite any previous value (default)
"incr" - Add to existing numeric value
"append" - Append to an array
Example:
class Predictor(BasePredictor):
def predict(self, prompt: str) -> str:
self.record_metric("temperature", 0.7)
self.record_metric("token_count", 1, mode="incr")
self.record_metric("timing.preprocess", 0.12)
result = self.model.generate(prompt)
return result
See the Metrics API documentation for more details.
scope
@property
def scope(self) -> Any:
The current prediction scope. Provides access to the full scope API for advanced metric operations.
Example:
class Predictor(BasePredictor):
def predict(self, prompt: str) -> str:
# Direct dict-style access
self.scope.metrics["token_count"] = 42
# Delete a metric
del self.scope.metrics["token_count"]
# Record with mode
self.scope.record_metric("steps", "preprocessing", mode="append")
return self.model.generate(prompt)
Outside an active prediction, self.scope returns a no-op object that silently ignores all operations.
Async Predictors
You can define async predictors using async def for both setup() and predict():
from cog import BasePredictor
class Predictor(BasePredictor):
async def setup(self) -> None:
self.model = await load_model_async()
async def predict(self, prompt: str) -> str:
result = await self.model.generate_async(prompt)
return result
Models with async predict() can run predictions concurrently, up to the limit specified by concurrency.max in cog.yaml.
Return Types
Simple Types
def predict(self, x: int) -> str:
return f"Result: {x}"
def predict(self, x: int) -> int:
return x * 2
def predict(self, x: int) -> float:
return x * 1.5
def predict(self, x: int) -> bool:
return x > 0
Path Objects
import tempfile
from cog import BasePredictor, Path
class Predictor(BasePredictor):
def predict(self, image: Path) -> Path:
output = process_image(image)
# Create temporary file (automatically deleted after return)
output_path = Path(tempfile.mkdtemp()) / "output.png"
output.save(output_path)
return output_path
Lists
from cog import BasePredictor, Path
class Predictor(BasePredictor):
def predict(self) -> list[Path]:
predictions = ["foo", "bar", "baz"]
output = []
for i, prediction in enumerate(predictions):
out_path = Path(f"/tmp/out-{i}.txt")
out_path.write_text(prediction)
output.append(out_path)
return output
Custom Output Objects
from cog import BaseModel, BasePredictor, Path
from typing import Optional
class Output(BaseModel):
image: Path
seed: int
nsfw_detected: Optional[bool]
class Predictor(BasePredictor):
def predict(self, prompt: str) -> Output:
seed = generate_seed()
image = self.model.generate(prompt, seed)
return Output(image=image, seed=seed, nsfw_detected=None)
The output class must be named Output, not any other name.
Streaming Output
Stream output as it’s generated using iterators:
Sync Streaming
from cog import BasePredictor, ConcatenateIterator
class Predictor(BasePredictor):
def predict(self, prompt: str) -> ConcatenateIterator[str]:
for token in self.model.generate_tokens(prompt):
yield token
Async Streaming
from cog import AsyncConcatenateIterator, BasePredictor
class Predictor(BasePredictor):
async def predict(self, prompt: str) -> AsyncConcatenateIterator[str]:
async for token in self.model.generate_tokens_async(prompt):
yield token
ConcatenateIterator hints that output should be concatenated into a single string for display (useful on Replicate). For non-text streaming, use Iterator[T] or AsyncIterator[T].
Cancellation
When a prediction is canceled, Cog raises an exception to interrupt the predict() method:
| Predictor type | Exception raised |
|---|
Sync (def predict) | CancelationException |
Async (async def predict) | asyncio.CancelledError |
from cog import BasePredictor, CancelationException, Path
class Predictor(BasePredictor):
def predict(self, image: Path) -> Path:
try:
return self.process(image)
except CancelationException:
self.cleanup()
raise # Always re-raise!
You must re-raise CancelationException after cleanup. Swallowing it will prevent the runtime from marking the prediction as canceled and may result in container termination.
CancelationException is a BaseException subclass (not Exception), so bare except Exception blocks won’t catch it. This matches the behavior of KeyboardInterrupt and asyncio.CancelledError.