Skip to main content

Overview

The BaseModel class provides a common interface for all model implementations in the UC Intel Final platform. It defines the core methods that every model must implement and provides utility methods for model inspection.

Class Definition

from abc import ABC, abstractmethod
from typing import Any, Dict, Tuple
import torch.nn as nn

class BaseModel(ABC):
    """Abstract base class for model implementations"""
Location: app/models/base.py:11

Constructor

__init__(config: Dict[str, Any])

Initialize model with configuration dictionary.
config
Dict[str, Any]
required
Model configuration dictionary containing model-specific parameters. At minimum, should include:
  • num_classes: Number of output classes
  • model_type: Type of model (e.g., “CNN”, “Transfer”, “Transformer”)
  • architecture: Architecture name or description
Example:
config = {
    "num_classes": 10,
    "model_type": "CNN",
    "architecture": "Custom"
}
model = CustomModel(config)

Abstract Methods

These methods must be implemented by all subclasses.

build() -> nn.Module

Build and return the PyTorch model.
model
nn.Module
PyTorch neural network module ready for training or inference
Implementation example:
def build(self) -> nn.Module:
    if not self.validate_config():
        raise ValueError("Invalid model configuration")
    
    model = nn.Sequential(
        nn.Conv2d(3, 64, 3),
        nn.ReLU(),
        nn.Flatten(),
        nn.Linear(64 * 222 * 222, self.config["num_classes"])
    )
    
    self.model = model
    return model

get_parameters_count() -> Tuple[int, int]

Get total and trainable parameter counts.
total_params
int
Total number of parameters in the model
trainable_params
int
Number of trainable parameters (where requires_grad=True)
Implementation example:
def get_parameters_count(self) -> Tuple[int, int]:
    if self.model is None:
        self.model = self.build()
    
    total_params = sum(p.numel() for p in self.model.parameters())
    trainable_params = sum(
        p.numel() for p in self.model.parameters() if p.requires_grad
    )
    
    return total_params, trainable_params

Instance Methods

get_model_summary() -> Dict[str, Any]

Get comprehensive model summary statistics.
summary
Dict[str, Any]
Dictionary containing:
  • total_parameters: Total parameter count
  • trainable_parameters: Trainable parameter count
  • model_type: Model type from config
  • architecture: Architecture name from config
  • num_classes: Number of output classes
Example:
model = CustomModel(config)
summary = model.get_model_summary()

print(f"Total parameters: {summary['total_parameters']:,}")
print(f"Trainable parameters: {summary['trainable_parameters']:,}")
print(f"Model type: {summary['model_type']}")
Output:
{
    "total_parameters": 25557032,
    "trainable_parameters": 25557032,
    "model_type": "CNN",
    "architecture": "Custom",
    "num_classes": 10
}

validate_config() -> bool

Validate model configuration before building.
is_valid
bool
Returns True if configuration is valid, False otherwise
Validation criteria:
  • num_classes must be present in config
  • num_classes must be greater than 0
Example:
config = {"num_classes": 10}
model = CustomModel(config)

if model.validate_config():
    built_model = model.build()
else:
    raise ValueError("Invalid configuration")

Attributes

config
Dict[str, Any]
Model configuration dictionary passed during initialization
model
nn.Module | None
Built PyTorch model instance. Initially None until build() is called

Implementation Guide

When creating a new model class, inherit from BaseModel and implement the required abstract methods:
from models.base import BaseModel
import torch.nn as nn

class MyCustomModel(BaseModel):
    def __init__(self, config: dict):
        super().__init__(config)
        # Add custom initialization
    
    def build(self) -> nn.Module:
        # Implement model building logic
        if not self.validate_config():
            raise ValueError("Invalid configuration")
        
        # Build your model
        model = nn.Sequential(...)
        self.model = model
        return model
    
    def get_parameters_count(self) -> tuple[int, int]:
        if self.model is None:
            self.model = self.build()
        
        total = sum(p.numel() for p in self.model.parameters())
        trainable = sum(
            p.numel() for p in self.model.parameters() 
            if p.requires_grad
        )
        return total, trainable
    
    def validate_config(self) -> bool:
        # Add custom validation on top of base validation
        if not super().validate_config():
            return False
        
        # Add your custom checks
        return True

See Also

Build docs developers (and LLMs) love