Skip to main content

Overview

The AIModelHandler class manages AI model operations including initialization, text generation, and summarization. It supports multiple models (Llama 3.1-Nemotron and BART) with automatic GPU detection and optimization.

Class: AIModelHandler

class AIModelHandler:
    """Handles AI model operations including initialization and text generation."""

Attributes

  • models: Dict[str, Any] - Dictionary storing loaded model instances
  • tokenizers: Dict[str, Any] - Dictionary storing tokenizer instances
  • model_configs: Dict[str, Dict] - Configuration for each supported model
  • default_model: str - Default model key (from DEFAULT_MODEL env var, defaults to ‘llama’)
  • enable_learning: bool - Continuous learning flag (from ENABLE_CONTINUOUS_LEARNING env var)

Methods

__init__()

Initialize the AI model handler with model configurations.
def __init__(self)
Behavior: Sets up the model handler with configurations for supported models: Llama Configuration:
'llama': {
    'name': 'nvidia/Llama-3.1-Nemotron-70B-Instruct-HF',
    'type': 'causal',
    'task': 'text-generation'
}
BART Configuration:
'bart': {
    'name': 'facebook/bart-large',
    'type': 'conditional',
    'task': 'summarization'
}
Environment Variables:
  • DEFAULT_MODEL: Model key to use by default (default: ‘llama’)
  • ENABLE_CONTINUOUS_LEARNING: Enable learning features (default: ‘true’)
Location: ai_handler.py:27-47

initialize_models()

Initialize AI models asynchronously.
async def initialize_models(self) -> bool
return
bool
True if initialization successful, False otherwise
Behavior:
  1. GPU Detection
    • Checks CUDA availability
    • Logs GPU device name and available memory
  2. BART Model Initialization
    • Loads facebook/bart-large tokenizer
    • Loads BART model with automatic device mapping
    • Uses float16 precision on GPU, float32 on CPU
  3. Llama Model Initialization
    • Loads nvidia/Llama-3.1-Nemotron-70B-Instruct-HF tokenizer
    • Loads Llama model with automatic device mapping
    • Uses float16 precision on GPU, float32 on CPU
Model Loading Parameters:
  • device_map='auto': Automatic device placement when GPU is available
  • torch_dtype: float16 for GPU (memory efficient), float32 for CPU
  • local_files_only=False: Downloads models from Hugging Face if not cached
Error Handling:
  • Returns False if any model fails to load
  • Logs detailed error messages for debugging
  • Continues initialization if one model fails (when possible)
Location: ai_handler.py:49-107

generate_response()

Generate a response using the specified model.
async def generate_response(
    self,
    text: str,
    model_key: Optional[str] = None,
    max_length: int = 300,
    temperature: float = 0.2,
    top_p: float = 0.4,
    max_attempts: int = 5
) -> str
text
str
required
Input text to generate response from
model_key
Optional[str]
default:"None"
Key of the model to use. If None, uses default_model (‘llama’)
max_length
int
default:"300"
Maximum length of generated text in tokens
temperature
float
default:"0.2"
Sampling temperature. Higher values (e.g., 1.0) make output more creative/random, lower values (e.g., 0.2) make it more focused/deterministic
top_p
float
default:"0.4"
Nucleus sampling parameter. Only tokens with cumulative probability up to top_p are considered
max_attempts
int
default:"5"
Maximum number of attempts to generate a valid response
return
str
Generated response text. Returns error message if generation fails
Behavior:
  1. Model Selection
    • Uses specified model_key or falls back to default_model
    • Validates that the model exists
  2. Input Processing
    • Tokenizes input text
    • Truncates to 512 tokens maximum
    • Adds padding for batch processing
    • Moves tensors to model’s device (CPU/GPU)
  3. Generation Loop
    • Attempts generation up to max_attempts times
    • Uses nucleus sampling with specified temperature and top_p
    • Validates response quality (>3 words, different from input)
    • Breaks on first valid response
  4. Response Validation
    • Checks response length (must be >3 words)
    • Ensures response differs from input
    • Returns fallback message if all attempts fail
Generation Parameters:
  • do_sample=True: Enables sampling for diverse outputs
  • num_return_sequences=1: Generates one response
  • Uses model’s pad_token_id and eos_token_id for proper sequence handling
Error Handling:
  • Raises ValueError if model_key is invalid
  • Returns user-friendly error message on exceptions
  • Logs all errors for debugging
Location: ai_handler.py:109-182

summarize_text()

Summarize text using BART model.
async def summarize_text(
    self,
    text: str,
    max_length: int = 130,
    min_length: int = 30
) -> str
text
str
required
Text to summarize
max_length
int
default:"130"
Maximum length of summary in tokens
min_length
int
default:"30"
Minimum length of summary in tokens
return
str
Summarized text. Returns error message if summarization fails
Behavior:
  1. Input Processing
    • Tokenizes input text using BART tokenizer
    • Truncates to 1024 tokens maximum
    • Adds padding for batch processing
    • Moves tensors to BART model’s device
  2. Summary Generation
    • Uses beam search with 4 beams for better quality
    • Applies length penalty of 2.0 to encourage conciseness
    • Early stopping when all beams reach EOS token
    • Decodes output tokens to text
  3. Post-processing
    • Removes special tokens from output
    • Strips whitespace from final summary
Summarization Parameters:
  • num_beams=4: Beam search width for quality
  • length_penalty=2.0: Encourages shorter summaries
  • early_stopping=True: Stops when EOS is reached
Error Handling:
  • Returns user-friendly error message on exceptions
  • Logs all errors for debugging
Location: ai_handler.py:184-227

get_available_models()

Get list of available model keys.
def get_available_models(self) -> List[str]
return
List[str]
List of model keys: [‘llama’, ‘bart’]
Behavior:
  • Returns the keys from model_configs dictionary
  • Useful for validating model selection
Location: ai_handler.py:229-231

get_model_info()

Get information about a specific model.
def get_model_info(self, model_key: str) -> Optional[Dict]
model_key
str
required
Key of the model to get information about (‘llama’ or ‘bart’)
return
Optional[Dict]
Dictionary containing model configuration (name, type, task) or None if model_key is invalid
Example Return Value:
{
    'name': 'nvidia/Llama-3.1-Nemotron-70B-Instruct-HF',
    'type': 'causal',
    'task': 'text-generation'
}
Location: ai_handler.py:233-235

Singleton Instance

The module exports a singleton instance:
ai_handler = AIModelHandler()
This ensures only one instance manages all models, preventing duplicate model loading.

Dependencies

import logging
import os
from typing import Dict, Optional, List, Any
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BartForConditionalGeneration,
    BartTokenizer,
    pipeline
)

Environment Variables

  • DEFAULT_MODEL (optional): Model key to use by default (default: ‘llama’)
  • ENABLE_CONTINUOUS_LEARNING (optional): Enable continuous learning features (default: ‘true’)

GPU Support

The handler automatically detects and uses GPU when available:
  • Uses CUDA if torch.cuda.is_available() returns True
  • Automatically maps models to available devices
  • Uses float16 precision on GPU for memory efficiency
  • Falls back to CPU with float32 precision when GPU is unavailable

Build docs developers (and LLMs) love