Skip to main content

Overview

LiteLLM allows you to add custom LLM providers by implementing the CustomLLM base class. This enables you to integrate proprietary models, internal APIs, or specialized LLM services that aren’t natively supported.

When to Use Custom Providers

  • Integrate internal/proprietary LLM APIs
  • Add support for new LLM providers before native support
  • Wrap existing APIs with custom logic (rate limiting, caching, etc.)
  • Implement specialized model endpoints

Quick Start

1

Create Custom Provider Class

Inherit from CustomLLM and implement required methods:
from litellm.llms.custom_llm import CustomLLM, CustomLLMError
from litellm.utils import ModelResponse
from typing import Optional, Union
import httpx

class MyCustomLLM(CustomLLM):
    def __init__(self):
        super().__init__()
    
    def completion(
        self,
        model: str,
        messages: list,
        api_base: str,
        custom_prompt_dict: dict,
        model_response: ModelResponse,
        print_verbose: callable,
        encoding,
        api_key,
        logging_obj,
        optional_params: dict,
        acompletion=None,
        litellm_params=None,
        logger_fn=None,
        headers={},
        timeout: Optional[Union[float, httpx.Timeout]] = None,
        client: Optional[any] = None,
    ):
        # Implement your completion logic
        response_text = self._call_your_api(
            model=model,
            messages=messages,
            api_key=api_key,
            **optional_params
        )
        
        # Format as ModelResponse
        model_response.choices[0].message.content = response_text
        return model_response
    
    def _call_your_api(self, model, messages, api_key, **kwargs):
        # Your API call logic here
        return "Response from custom API"
2

Register Your Provider

Register the custom provider with LiteLLM:
import litellm

# Register your custom provider
litellm.custom_provider_map = [
    {
        "provider": "my-custom-llm",
        "custom_handler": MyCustomLLM()
    }
]
3

Use Your Provider

Call your custom provider like any other:
response = litellm.completion(
    model="my-custom-llm/my-model",
    messages=[{"role": "user", "content": "Hello!"}],
    api_key="your-api-key",
    api_base="https://your-api.com/v1"
)

print(response.choices[0].message.content)

Implementation Guide

Required Methods

def completion(
    self,
    model: str,
    messages: list,
    api_base: str,
    custom_prompt_dict: dict,
    model_response: ModelResponse,
    print_verbose: callable,
    encoding,
    api_key,
    logging_obj,
    optional_params: dict,
    acompletion=None,
    litellm_params=None,
    logger_fn=None,
    headers={},
    timeout: Optional[Union[float, httpx.Timeout]] = None,
    client: Optional[any] = None,
) -> ModelResponse:
    """
    Handle synchronous completion requests.
    
    Returns:
        ModelResponse with populated choices
    """
    # Transform messages to your API format
    api_messages = self._transform_messages(messages)
    
    # Call your API
    response = self._make_request(
        endpoint=f"{api_base}/completions",
        payload={
            "model": model,
            "messages": api_messages,
            **optional_params
        },
        headers={"Authorization": f"Bearer {api_key}"},
        timeout=timeout
    )
    
    # Populate model_response
    model_response.choices[0].message.content = response["text"]
    model_response.model = model
    
    # Set usage if available
    if "usage" in response:
        model_response.usage.prompt_tokens = response["usage"]["prompt_tokens"]
        model_response.usage.completion_tokens = response["usage"]["completion_tokens"]
        model_response.usage.total_tokens = response["usage"]["total_tokens"]
    
    return model_response
async def acompletion(
    self,
    model: str,
    messages: list,
    api_base: str,
    custom_prompt_dict: dict,
    model_response: ModelResponse,
    print_verbose: callable,
    encoding,
    api_key,
    logging_obj,
    optional_params: dict,
    acompletion=None,
    litellm_params=None,
    logger_fn=None,
    headers={},
    timeout: Optional[Union[float, httpx.Timeout]] = None,
    client: Optional[any] = None,
) -> ModelResponse:
    """
    Handle asynchronous completion requests.
    """
    # Use async HTTP client
    response = await self._make_async_request(
        endpoint=f"{api_base}/completions",
        payload={
            "model": model,
            "messages": messages,
            **optional_params
        },
        headers={"Authorization": f"Bearer {api_key}"},
        timeout=timeout
    )
    
    model_response.choices[0].message.content = response["text"]
    return model_response
from litellm.types.utils import GenericStreamingChunk
from typing import Iterator

def streaming(
    self,
    model: str,
    messages: list,
    api_base: str,
    custom_prompt_dict: dict,
    model_response: ModelResponse,
    print_verbose: callable,
    encoding,
    api_key,
    logging_obj,
    optional_params: dict,
    acompletion=None,
    litellm_params=None,
    logger_fn=None,
    headers={},
    timeout: Optional[Union[float, httpx.Timeout]] = None,
    client: Optional[any] = None,
) -> Iterator[GenericStreamingChunk]:
    """
    Handle synchronous streaming requests.
    
    Yields:
        GenericStreamingChunk objects
    """
    stream = self._make_streaming_request(
        endpoint=f"{api_base}/completions",
        payload={
            "model": model,
            "messages": messages,
            "stream": True,
            **optional_params
        },
        headers={"Authorization": f"Bearer {api_key}"},
        timeout=timeout
    )
    
    for chunk in stream:
        if chunk.get("text"):
            yield GenericStreamingChunk(
                text=chunk["text"],
                is_finished=chunk.get("is_finished", False),
                finish_reason=chunk.get("finish_reason")
            )
from typing import AsyncIterator

async def astreaming(
    self,
    model: str,
    messages: list,
    api_base: str,
    custom_prompt_dict: dict,
    model_response: ModelResponse,
    print_verbose: callable,
    encoding,
    api_key,
    logging_obj,
    optional_params: dict,
    acompletion=None,
    litellm_params=None,
    logger_fn=None,
    headers={},
    timeout: Optional[Union[float, httpx.Timeout]] = None,
    client: Optional[any] = None,
) -> AsyncIterator[GenericStreamingChunk]:
    """
    Handle asynchronous streaming requests.
    """
    stream = self._make_async_streaming_request(
        endpoint=f"{api_base}/completions",
        payload={
            "model": model,
            "messages": messages,
            "stream": True,
            **optional_params
        },
        headers={"Authorization": f"Bearer {api_key}"},
        timeout=timeout
    )
    
    async for chunk in stream:
        if chunk.get("text"):
            yield GenericStreamingChunk(
                text=chunk["text"],
                is_finished=chunk.get("is_finished", False),
                finish_reason=chunk.get("finish_reason")
            )

Optional Methods

def embedding(
    self,
    model: str,
    input: list,
    model_response: EmbeddingResponse,
    print_verbose: callable,
    logging_obj: any,
    optional_params: dict,
    api_key: Optional[str] = None,
    api_base: Optional[str] = None,
    timeout: Optional[Union[float, httpx.Timeout]] = None,
    litellm_params=None,
) -> EmbeddingResponse:
    """
    Handle embedding requests.
    """
    response = self._make_request(
        endpoint=f"{api_base}/embeddings",
        payload={
            "model": model,
            "input": input,
            **optional_params
        },
        api_key=api_key
    )
    
    model_response.data = response["embeddings"]
    return model_response

Complete Example

from litellm.llms.custom_llm import CustomLLM, CustomLLMError
from litellm.utils import ModelResponse
from litellm.types.utils import GenericStreamingChunk
from typing import Optional, Union, Iterator, AsyncIterator
import httpx
import json

class MyInternalLLM(CustomLLM):
    def __init__(self):
        super().__init__()
        self.base_url = "https://internal-llm.company.com/api/v1"
    
    def _make_request(self, endpoint: str, payload: dict, api_key: str, timeout=None):
        """Helper for HTTP requests"""
        try:
            response = httpx.post(
                endpoint,
                json=payload,
                headers={
                    "Authorization": f"Bearer {api_key}",
                    "Content-Type": "application/json"
                },
                timeout=timeout or 60.0
            )
            response.raise_for_status()
            return response.json()
        except httpx.HTTPStatusError as e:
            raise CustomLLMError(
                status_code=e.response.status_code,
                message=f"API request failed: {e.response.text}"
            )
        except Exception as e:
            raise CustomLLMError(
                status_code=500,
                message=f"Request error: {str(e)}"
            )
    
    def completion(
        self,
        model: str,
        messages: list,
        api_base: str,
        custom_prompt_dict: dict,
        model_response: ModelResponse,
        print_verbose,
        encoding,
        api_key,
        logging_obj,
        optional_params: dict,
        **kwargs
    ) -> ModelResponse:
        """Synchronous completion"""
        print_verbose(f"MyInternalLLM: completion call for {model}")
        
        # Make API request
        response = self._make_request(
            endpoint=f"{api_base or self.base_url}/chat/completions",
            payload={
                "model": model,
                "messages": messages,
                "temperature": optional_params.get("temperature", 0.7),
                "max_tokens": optional_params.get("max_tokens", 1000),
            },
            api_key=api_key,
            timeout=kwargs.get("timeout")
        )
        
        # Populate response
        model_response.choices[0].message.content = response["choices"][0]["message"]["content"]
        model_response.model = model
        
        # Set usage
        if "usage" in response:
            model_response.usage.prompt_tokens = response["usage"].get("prompt_tokens", 0)
            model_response.usage.completion_tokens = response["usage"].get("completion_tokens", 0)
            model_response.usage.total_tokens = response["usage"].get("total_tokens", 0)
        
        return model_response
    
    async def acompletion(self, **kwargs) -> ModelResponse:
        """Async completion"""
        # For simplicity, wrap sync version
        # In production, use async HTTP client
        return self.completion(**kwargs)
    
    def streaming(
        self,
        model: str,
        messages: list,
        api_base: str,
        api_key: str,
        **kwargs
    ) -> Iterator[GenericStreamingChunk]:
        """Streaming completion"""
        with httpx.stream(
            "POST",
            f"{api_base or self.base_url}/chat/completions",
            json={
                "model": model,
                "messages": messages,
                "stream": True,
            },
            headers={"Authorization": f"Bearer {api_key}"},
        ) as response:
            for line in response.iter_lines():
                if line.startswith("data: "):
                    data = json.loads(line[6:])
                    if data.get("choices"):
                        delta = data["choices"][0].get("delta", {})
                        if "content" in delta:
                            yield GenericStreamingChunk(
                                text=delta["content"],
                                is_finished=data["choices"][0].get("finish_reason") is not None
                            )
    
    async def astreaming(self, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
        """Async streaming"""
        # Wrap sync streaming for simplicity
        for chunk in self.streaming(**kwargs):
            yield chunk

# Register provider
import litellm

litellm.custom_provider_map = [
    {
        "provider": "my-internal-llm",
        "custom_handler": MyInternalLLM()
    }
]

# Use it
response = litellm.completion(
    model="my-internal-llm/gpt-custom",
    messages=[{"role": "user", "content": "Hello!"}],
    api_key="internal-key-123"
)

print(response.choices[0].message.content)

Error Handling

Always use CustomLLMError for exceptions:
from litellm.llms.custom_llm import CustomLLMError

class MyCustomLLM(CustomLLM):
    def completion(self, **kwargs):
        try:
            response = self._make_api_call(**kwargs)
            return self._format_response(response)
        except httpx.HTTPStatusError as e:
            # Map HTTP errors
            raise CustomLLMError(
                status_code=e.response.status_code,
                message=f"API error: {e.response.text}"
            )
        except httpx.TimeoutException:
            # Handle timeouts
            raise CustomLLMError(
                status_code=408,
                message="Request timeout"
            )
        except Exception as e:
            # Catch-all for unexpected errors
            raise CustomLLMError(
                status_code=500,
                message=f"Internal error: {str(e)}"
            )

Using with LiteLLM Proxy

Deploy custom providers via the proxy:
config.yaml
model_list:
  - model_name: my-custom-model
    litellm_params:
      model: my-custom-llm/custom-model-v1
      api_key: os.environ/CUSTOM_API_KEY
      api_base: https://internal-api.company.com/v1
Load custom provider in proxy startup:
# In your proxy startup script
from my_custom_provider import MyCustomLLM
import litellm

litellm.custom_provider_map = [
    {
        "provider": "my-custom-llm",
        "custom_handler": MyCustomLLM()
    }
]

Best Practices

Leverage LiteLLM’s HTTP handlers for connection pooling:
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler

class MyCustomLLM(CustomLLM):
    def __init__(self):
        super().__init__()
        self.http_handler = HTTPHandler()
        self.async_handler = AsyncHTTPHandler()
    
    def completion(self, client=None, **kwargs):
        # Use provided client or create one
        http_client = client or self.http_handler
        response = http_client.post(url, json=payload)
        return response.json()
Provide accurate usage statistics:
def completion(self, model, messages, **kwargs) -> ModelResponse:
    # Use encoding for token counting
    encoding = kwargs.get("encoding")
    
    # Count prompt tokens
    prompt_text = " ".join([m["content"] for m in messages])
    prompt_tokens = len(encoding.encode(prompt_text))
    
    # Make request
    response = self._make_request(**kwargs)
    
    # Count completion tokens
    completion_text = response["text"]
    completion_tokens = len(encoding.encode(completion_text))
    
    # Set usage
    model_response.usage.prompt_tokens = prompt_tokens
    model_response.usage.completion_tokens = completion_tokens
    model_response.usage.total_tokens = prompt_tokens + completion_tokens
    
    return model_response
Implement exponential backoff:
import time

def _make_request_with_retry(self, endpoint, payload, api_key, max_retries=3):
    for attempt in range(max_retries):
        try:
            return self._make_request(endpoint, payload, api_key)
        except CustomLLMError as e:
            if e.status_code == 429 and attempt < max_retries - 1:
                # Exponential backoff
                wait_time = 2 ** attempt
                time.sleep(wait_time)
                continue
            raise
Transform tool calls appropriately:
def completion(self, messages, optional_params, **kwargs):
    # Extract tools from optional_params
    tools = optional_params.get("tools", [])
    
    # Transform to your API format
    api_tools = self._transform_tools(tools)
    
    # Include in request
    response = self._make_request(
        payload={
            "messages": messages,
            "tools": api_tools
        }
    )
    
    # Transform tool calls back to OpenAI format
    if response.get("tool_calls"):
        model_response.choices[0].message.tool_calls = \
            self._transform_tool_calls(response["tool_calls"])
    
    return model_response

Testing Your Provider

import pytest
from litellm import completion

def test_custom_provider_completion():
    response = completion(
        model="my-custom-llm/test-model",
        messages=[{"role": "user", "content": "Hello!"}],
        api_key="test-key"
    )
    
    assert response.choices[0].message.content
    assert response.model == "my-custom-llm/test-model"
    assert response.usage.total_tokens > 0

@pytest.mark.asyncio
async def test_custom_provider_async():
    response = await completion(
        model="my-custom-llm/test-model",
        messages=[{"role": "user", "content": "Hello!"}],
        api_key="test-key"
    )
    
    assert response.choices[0].message.content

def test_custom_provider_streaming():
    stream = completion(
        model="my-custom-llm/test-model",
        messages=[{"role": "user", "content": "Hello!"}],
        stream=True,
        api_key="test-key"
    )
    
    chunks = list(stream)
    assert len(chunks) > 0
    assert any(chunk.choices[0].delta.content for chunk in chunks)

Reference

Source Files

  • Base class: litellm/llms/custom_llm.py:47
  • HTTP handlers: litellm/llms/custom_httpx/http_handler.py
  • Error handling: litellm/llms/custom_llm.py:34

Build docs developers (and LLMs) love